Coverage for /builds/ase/ase/ase/spectrum/band_structure.py: 84.32%

185 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import numpy as np 

4 

5import ase # Annotations 

6from ase.calculators.calculator import PropertyNotImplementedError 

7from ase.utils import jsonable 

8 

9 

10def calculate_band_structure(atoms, path=None, scf_kwargs=None, 

11 bs_kwargs=None, kpts_tol=1e-6, cell_tol=1e-6): 

12 """Calculate band structure. 

13 

14 The purpose of this function is to abstract a band structure calculation 

15 so the workflow does not depend on the calculator. 

16 

17 First trigger SCF calculation if necessary, then set arguments 

18 on the calculator for band structure calculation, then return 

19 calculated band structure. 

20 

21 The difference from get_band_structure() is that the latter 

22 expects the calculation to already have been done.""" 

23 if path is None: 

24 path = atoms.cell.bandpath() 

25 

26 from ase.lattice import celldiff # Should this be a method on cell? 

27 if any(path.cell.any(1) != atoms.pbc): 

28 raise ValueError('The band path\'s cell, {}, does not match the ' 

29 'periodicity {} of the atoms' 

30 .format(path.cell, atoms.pbc)) 

31 cell_err = celldiff(path.cell, atoms.cell.uncomplete(atoms.pbc)) 

32 if cell_err > cell_tol: 

33 raise ValueError('Atoms and band path have different unit cells. ' 

34 'Please reduce atoms to standard form. ' 

35 'Cell lengths and angles are {} vs {}' 

36 .format(atoms.cell.cellpar(), path.cell.cellpar())) 

37 

38 calc = atoms.calc 

39 if calc is None: 

40 raise ValueError('Atoms have no calculator') 

41 

42 if scf_kwargs is not None: 

43 calc.set(**scf_kwargs) 

44 

45 # Proposed standard mechanism for calculators to advertise that they 

46 # use the bandpath keyword to handle band structures rather than 

47 # a double (SCF + BS) run. 

48 use_bandpath_kw = getattr(calc, 'accepts_bandpath_keyword', False) 

49 if use_bandpath_kw: 

50 calc.set(bandpath=path) 

51 atoms.get_potential_energy() 

52 return calc.band_structure() 

53 

54 atoms.get_potential_energy() 

55 

56 if hasattr(calc, 'get_fermi_level'): 

57 # What is the protocol for a calculator to tell whether 

58 # it has fermi_energy? 

59 eref = calc.get_fermi_level() 

60 else: 

61 eref = 0.0 

62 

63 if bs_kwargs is None: 

64 bs_kwargs = {} 

65 

66 calc.set(kpts=path, **bs_kwargs) 

67 calc.results.clear() # XXX get rid of me 

68 

69 # Calculators are too inconsistent here: 

70 # * atoms.get_potential_energy() will fail when total energy is 

71 # not in results after BS calculation (Espresso) 

72 # * calc.calculate(atoms) doesn't ask for any quantity, so some 

73 # calculators may not calculate anything at all 

74 # * 'bandstructure' is not a recognized property we can ask for 

75 try: 

76 atoms.get_potential_energy() 

77 except PropertyNotImplementedError: 

78 pass 

79 

80 ibzkpts = calc.get_ibz_k_points() 

81 kpts_err = np.abs(path.kpts - ibzkpts).max() 

82 if kpts_err > kpts_tol: 

83 raise RuntimeError('Kpoints of calculator differ from those ' 

84 'of the band path we just used; ' 

85 'err={} > tol={}'.format(kpts_err, kpts_tol)) 

86 

87 bs = get_band_structure(atoms, path=path, reference=eref) 

88 return bs 

89 

90 

91def get_band_structure(atoms=None, calc=None, path=None, reference=None): 

92 """Create band structure object from Atoms or calculator.""" 

93 # path and reference are used internally at the moment, but 

94 # the exact implementation will probably change. WIP. 

95 # 

96 # XXX We throw away info about the bandpath when we create the calculator. 

97 # If we have kept the bandpath, we can provide it as an argument here. 

98 # It would be wise to check that the bandpath kpoints are the same as 

99 # those stored in the calculator. 

100 atoms = atoms if atoms is not None else calc.atoms 

101 calc = calc if calc is not None else atoms.calc 

102 

103 kpts = calc.get_ibz_k_points() 

104 

105 energies = [] 

106 for s in range(calc.get_number_of_spins()): 

107 energies.append([calc.get_eigenvalues(kpt=k, spin=s) 

108 for k in range(len(kpts))]) 

109 energies = np.array(energies) 

110 

111 if path is None: 

112 from ase.dft.kpoints import ( 

113 BandPath, 

114 find_bandpath_kinks, 

115 resolve_custom_points, 

116 ) 

117 standard_path = atoms.cell.bandpath(npoints=0) 

118 # Kpoints are already evaluated, we just need to put them into 

119 # the path (whether they fit our idea of what the path is, or not). 

120 # 

121 # Depending on how the path was established, the kpoints might 

122 # be valid high-symmetry points, but since there are multiple 

123 # high-symmetry points of each type, they may not coincide 

124 # with ours if the bandpath was generated by another code. 

125 # 

126 # Here we hack it so the BandPath has proper points even if they 

127 # come from some weird source. 

128 # 

129 # This operation (manually hacking the bandpath) is liable to break. 

130 # TODO: Make it available as a proper (documented) bandpath method. 

131 kinks = find_bandpath_kinks(atoms.cell, kpts, eps=1e-5) 

132 pathspec, special_points = resolve_custom_points( 

133 kpts[kinks], standard_path.special_points, eps=1e-5) 

134 path = BandPath(standard_path.cell, 

135 kpts=kpts, 

136 path=pathspec, 

137 special_points=special_points) 

138 

139 # XXX If we *did* get the path, now would be a good time to check 

140 # that it matches the cell! Although the path can only be passed 

141 # because we internally want to not re-evaluate the Bravais 

142 # lattice type. (We actually need an eps parameter, too.) 

143 

144 if reference is None: 

145 # Fermi level should come from the GS calculation, not the BS one! 

146 reference = calc.get_fermi_level() 

147 

148 if reference is None: 

149 # Fermi level may not be available, e.g., with non-Fermi smearing. 

150 # XXX Actually get_fermi_level() should raise an error when Fermi 

151 # level wasn't available, so we should fix that. 

152 reference = 0.0 

153 

154 return BandStructure(path=path, 

155 energies=energies, 

156 reference=reference) 

157 

158 

159class BandStructurePlot: 

160 def __init__(self, bs): 

161 self.bs = bs 

162 self.ax = None 

163 self.xcoords = None 

164 

165 def plot(self, ax=None, emin=-10, emax=5, filename=None, 

166 show=False, ylabel=None, colors=None, point_colors=None, 

167 label=None, loc=None, 

168 cmap=None, cmin=-1.0, cmax=1.0, sortcolors=False, 

169 colorbar=True, clabel='$s_z$', cax=None, 

170 **plotkwargs): 

171 """Plot band-structure. 

172 

173 ax: Axes 

174 MatPlotLib Axes object. Will be created if not supplied. 

175 emin, emax: float 

176 Minimum and maximum energy above reference. 

177 filename: str 

178 If given, write image to a file. 

179 show: bool 

180 Show the image (not needed in notebooks). 

181 ylabel: str 

182 The label along the y-axis. Defaults to 'energies [eV]' 

183 colors: sequence of str 

184 A sequence of one or two color specifications, depending on 

185 whether there is spin. 

186 Default: green if no spin, yellow and blue if spin is present. 

187 point_colors: ndarray 

188 An array of numbers of the shape (nspins, n_kpts, nbands) which 

189 are then mapped onto colors by the colormap (see ``cmap``). 

190 ``colors`` and ``point_colors`` are mutually exclusive 

191 label: str or list of str 

192 Label for the curves on the legend. A string if one spin is 

193 present, a list of two strings if two spins are present. 

194 Default: If no spin is given, no legend is made; if spin is 

195 present default labels 'spin up' and 'spin down' are used, but 

196 can be suppressed by setting ``label=False``. 

197 loc: str 

198 Location of the legend. 

199 

200 If ``point_colors`` is given, the following arguments can be specified. 

201 

202 cmap: 

203 Only used if colors is an array of numbers. A matplotlib 

204 colormap object, or a string naming a standard colormap. 

205 Default: The matplotlib default, typically 'viridis'. 

206 cmin, cmax: float 

207 Minimal and maximal values used for colormap translation. 

208 Default: -1.0 and 1.0 

209 colorbar: bool 

210 Whether to make a colorbar. 

211 clabel: str 

212 Label for the colorbar (default 's_z', set to None to suppress. 

213 cax: Axes 

214 Axes object used for plotting colorbar. Default: split off a 

215 new one. 

216 sortcolors (bool or callable): 

217 Sort points so highest color values are in front. If a callable is 

218 given, then it is called on the color values to determine the sort 

219 order. 

220 

221 Any additional keyword arguments are passed directly to matplotlib's 

222 plot() or scatter() methods, depending on whether point_colors is 

223 given. 

224 """ 

225 import matplotlib.pyplot as plt 

226 

227 if colors is not None and point_colors is not None: 

228 raise ValueError("Don't give both 'color' and 'point_color'") 

229 

230 if self.ax is None: 

231 ax = self.prepare_plot(ax, emin, emax, ylabel) 

232 

233 e_skn = self.bs.energies 

234 nspins = len(e_skn) 

235 

236 if point_colors is None: 

237 # Normal band structure plot 

238 if colors is None: 

239 if len(e_skn) == 1: 

240 colors = 'g' 

241 else: 

242 colors = 'yb' 

243 elif (len(colors) != nspins): 

244 raise ValueError( 

245 "colors should be a sequence of {nspin} colors" 

246 ) 

247 

248 # Default values for label 

249 if label is None and nspins == 2: 

250 label = ['spin up', 'spin down'] 

251 

252 if label: 

253 if nspins == 1 and isinstance(label, str): 

254 label = [label] 

255 elif len(label) != nspins: 

256 raise ValueError( 

257 f'label should be a list of {nspins} strings' 

258 ) 

259 

260 for spin, e_kn in enumerate(e_skn): 

261 kwargs = dict(color=colors[spin]) 

262 kwargs.update(plotkwargs) 

263 lbl = None # Retain lbl=None if label=False 

264 if label: 

265 lbl = label[spin] 

266 ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs) 

267 

268 for e_k in e_kn.T[1:]: 

269 ax.plot(self.xcoords, e_k, **kwargs) 

270 show_legend = label is not None or nspins == 2 

271 

272 else: 

273 # A color per datapoint. 

274 kwargs = dict(vmin=cmin, vmax=cmax, cmap=cmap, s=1) 

275 kwargs.update(plotkwargs) 

276 shape = e_skn.shape 

277 xcoords = np.zeros(shape) 

278 xcoords += self.xcoords[np.newaxis, :, np.newaxis] 

279 if sortcolors: 

280 if callable(sortcolors): 

281 perm = sortcolors(point_colors).argsort(axis=None) 

282 else: 

283 perm = point_colors.argsort(axis=None) 

284 e_skn = e_skn.ravel()[perm].reshape(shape) 

285 point_colors = point_colors.ravel()[perm].reshape(shape) 

286 xcoords = xcoords.ravel()[perm].reshape(shape) 

287 

288 things = ax.scatter(xcoords, e_skn, c=point_colors, **kwargs) 

289 if colorbar: 

290 cbar = plt.colorbar(things, cax=cax) 

291 if clabel: 

292 cbar.set_label(clabel) 

293 show_legend = False 

294 

295 self.finish_plot(filename, show, loc, show_legend) 

296 

297 return ax 

298 

299 def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None): 

300 import matplotlib.pyplot as plt 

301 if ax is None: 

302 ax = plt.figure().add_subplot(111) 

303 

304 def pretty(kpt): 

305 if kpt == 'G': 

306 kpt = r'$\Gamma$' 

307 elif len(kpt) == 2: 

308 kpt = kpt[0] + '$_' + kpt[1] + '$' 

309 return kpt 

310 

311 self.xcoords, label_xcoords, orig_labels = self.bs.get_labels() 

312 label_xcoords = list(label_xcoords) 

313 labels = [pretty(name) for name in orig_labels] 

314 

315 i = 1 

316 while i < len(labels): 

317 if label_xcoords[i - 1] == label_xcoords[i]: 

318 labels[i - 1] = labels[i - 1] + ',' + labels[i] 

319 labels.pop(i) 

320 label_xcoords.pop(i) 

321 else: 

322 i += 1 

323 

324 for x in label_xcoords[1:-1]: 

325 ax.axvline(x, color='0.5') 

326 

327 ylabel = ylabel if ylabel is not None else 'energies [eV]' 

328 

329 ax.set_xticks(label_xcoords) 

330 ax.set_xticklabels(labels) 

331 ax.set_ylabel(ylabel) 

332 ax.axhline(self.bs.reference, color='k', ls=':') 

333 ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax) 

334 self.ax = ax 

335 return ax 

336 

337 def finish_plot(self, filename, show, loc, show_legend=False): 

338 import matplotlib.pyplot as plt 

339 

340 if show_legend: 

341 leg = plt.legend(loc=loc) 

342 leg.get_frame().set_alpha(1) 

343 

344 if filename: 

345 plt.savefig(filename) 

346 

347 if show: 

348 plt.show() 

349 

350 

351@jsonable('bandstructure') 

352class BandStructure: 

353 """A band structure consists of an array of eigenvalues and a bandpath. 

354 

355 BandStructure objects support JSON I/O. 

356 """ 

357 

358 def __init__(self, path, energies, reference=0.0): 

359 self._path = path 

360 self._energies = np.asarray(energies) 

361 assert self.energies.shape[0] in [1, 2] # spins x kpts x bands 

362 assert self.energies.shape[1] == len(path.kpts) 

363 assert np.isscalar(reference) 

364 self._reference = reference 

365 

366 @property 

367 def energies(self) -> np.ndarray: 

368 """The energies of this band structure. 

369 

370 This is a numpy array of shape (nspins, nkpoints, nbands).""" 

371 return self._energies 

372 

373 @property 

374 def path(self) -> 'ase.dft.kpoints.BandPath': 

375 """The :class:`~ase.dft.kpoints.BandPath` of this band structure.""" 

376 return self._path 

377 

378 @property 

379 def reference(self) -> float: 

380 """The reference energy. 

381 

382 Semantics may vary; typically a Fermi energy or zero, 

383 depending on how the band structure was created.""" 

384 return self._reference 

385 

386 def subtract_reference(self) -> 'BandStructure': 

387 """Return new band structure with reference energy subtracted.""" 

388 return BandStructure(self.path, self.energies - self.reference, 

389 reference=0.0) 

390 

391 def todict(self): 

392 return dict(path=self.path, 

393 energies=self.energies, 

394 reference=self.reference) 

395 

396 def get_labels(self, eps=1e-5): 

397 """"See :func:`ase.dft.kpoints.labels_from_kpts`.""" 

398 return self.path.get_linear_kpoint_axis(eps=eps) 

399 

400 def plot(self, *args, **kwargs): 

401 """Plot this band structure.""" 

402 bsp = BandStructurePlot(self) 

403 return bsp.plot(*args, **kwargs) 

404 

405 def __repr__(self): 

406 return ('{}(path={!r}, energies=[{} values], reference={})' 

407 .format(self.__class__.__name__, self.path, 

408 '{}x{}x{}'.format(*self.energies.shape), 

409 self.reference))