Coverage for ase / spectrum / band_structure.py: 87.83%

189 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1# fmt: off 

2 

3import numpy as np 

4 

5import ase # Annotations 

6from ase.calculators.calculator import PropertyNotImplementedError 

7from ase.utils import jsonable 

8 

9 

10def calculate_band_structure(atoms, path=None, scf_kwargs=None, 

11 bs_kwargs=None, kpts_tol=1e-6, cell_tol=1e-6): 

12 """Calculate band structure. 

13 

14 The purpose of this function is to abstract a band structure calculation 

15 so the workflow does not depend on the calculator. 

16 

17 First trigger SCF calculation if necessary, then set arguments 

18 on the calculator for band structure calculation, then return 

19 calculated band structure. 

20 

21 The difference from get_band_structure() is that the latter 

22 expects the calculation to already have been done.""" 

23 if path is None: 

24 path = atoms.cell.bandpath() 

25 

26 from ase.lattice import celldiff # Should this be a method on cell? 

27 if any(path.cell.any(1) != atoms.pbc): 

28 raise ValueError('The band path\'s cell, {}, does not match the ' 

29 'periodicity {} of the atoms' 

30 .format(path.cell, atoms.pbc)) 

31 cell_err = celldiff(path.cell, atoms.cell.uncomplete(atoms.pbc)) 

32 if cell_err > cell_tol: 

33 raise ValueError('Atoms and band path have different unit cells. ' 

34 'Please reduce atoms to standard form. ' 

35 'Cell lengths and angles are {} vs {}' 

36 .format(atoms.cell.cellpar(), path.cell.cellpar())) 

37 

38 calc = atoms.calc 

39 if calc is None: 

40 raise ValueError('Atoms have no calculator') 

41 

42 if scf_kwargs is not None: 

43 calc.set(**scf_kwargs) 

44 

45 # Proposed standard mechanism for calculators to advertise that they 

46 # use the bandpath keyword to handle band structures rather than 

47 # a double (SCF + BS) run. 

48 use_bandpath_kw = getattr(calc, 'accepts_bandpath_keyword', False) 

49 if use_bandpath_kw: 

50 calc.set(bandpath=path) 

51 atoms.get_potential_energy() 

52 return calc.band_structure() 

53 

54 atoms.get_potential_energy() 

55 

56 if hasattr(calc, 'get_fermi_level'): 

57 # What is the protocol for a calculator to tell whether 

58 # it has fermi_energy? 

59 eref = calc.get_fermi_level() 

60 else: 

61 eref = 0.0 

62 

63 if bs_kwargs is None: 

64 bs_kwargs = {} 

65 

66 calc.set(kpts=path, **bs_kwargs) 

67 calc.results.clear() # XXX get rid of me 

68 

69 # Calculators are too inconsistent here: 

70 # * atoms.get_potential_energy() will fail when total energy is 

71 # not in results after BS calculation (Espresso) 

72 # * calc.calculate(atoms) doesn't ask for any quantity, so some 

73 # calculators may not calculate anything at all 

74 # * 'bandstructure' is not a recognized property we can ask for 

75 try: 

76 atoms.get_potential_energy() 

77 except PropertyNotImplementedError: 

78 pass 

79 

80 ibzkpts = calc.get_ibz_k_points() 

81 kpts_err = np.abs(path.kpts - ibzkpts).max() 

82 if kpts_err > kpts_tol: 

83 raise RuntimeError('Kpoints of calculator differ from those ' 

84 'of the band path we just used; ' 

85 'err={} > tol={}'.format(kpts_err, kpts_tol)) 

86 

87 bs = get_band_structure(atoms, path=path, reference=eref) 

88 return bs 

89 

90 

91def get_band_structure(atoms=None, calc=None, path=None, reference=None): 

92 """Create band structure object from Atoms or calculator.""" 

93 # path and reference are used internally at the moment, but 

94 # the exact implementation will probably change. WIP. 

95 # 

96 # XXX We throw away info about the bandpath when we create the calculator. 

97 # If we have kept the bandpath, we can provide it as an argument here. 

98 # It would be wise to check that the bandpath kpoints are the same as 

99 # those stored in the calculator. 

100 atoms = atoms if atoms is not None else calc.atoms 

101 calc = calc if calc is not None else atoms.calc 

102 

103 kpts = calc.get_ibz_k_points() 

104 

105 energies = [] 

106 for s in range(calc.get_number_of_spins()): 

107 energies.append([calc.get_eigenvalues(kpt=k, spin=s) 

108 for k in range(len(kpts))]) 

109 energies = np.array(energies) 

110 

111 if path is None: 

112 from ase.dft.kpoints import ( 

113 BandPath, 

114 find_bandpath_kinks, 

115 resolve_custom_points, 

116 ) 

117 standard_path = atoms.cell.bandpath(npoints=0) 

118 # Kpoints are already evaluated, we just need to put them into 

119 # the path (whether they fit our idea of what the path is, or not). 

120 # 

121 # Depending on how the path was established, the kpoints might 

122 # be valid high-symmetry points, but since there are multiple 

123 # high-symmetry points of each type, they may not coincide 

124 # with ours if the bandpath was generated by another code. 

125 # 

126 # Here we hack it so the BandPath has proper points even if they 

127 # come from some weird source. 

128 # 

129 # This operation (manually hacking the bandpath) is liable to break. 

130 # TODO: Make it available as a proper (documented) bandpath method. 

131 kinks = find_bandpath_kinks(atoms.cell, kpts, eps=1e-5) 

132 pathspec, special_points = resolve_custom_points( 

133 kpts[kinks], standard_path.special_points, eps=1e-5) 

134 path = BandPath(standard_path.cell, 

135 kpts=kpts, 

136 path=pathspec, 

137 special_points=special_points) 

138 

139 # XXX If we *did* get the path, now would be a good time to check 

140 # that it matches the cell! Although the path can only be passed 

141 # because we internally want to not re-evaluate the Bravais 

142 # lattice type. (We actually need an eps parameter, too.) 

143 

144 if reference is None: 

145 # Fermi level should come from the GS calculation, not the BS one! 

146 reference = calc.get_fermi_level() 

147 

148 if reference is None: 

149 # Fermi level may not be available, e.g., with non-Fermi smearing. 

150 # XXX Actually get_fermi_level() should raise an error when Fermi 

151 # level wasn't available, so we should fix that. 

152 reference = 0.0 

153 

154 return BandStructure(path=path, 

155 energies=energies, 

156 reference=reference) 

157 

158 

159class BandStructurePlot: 

160 def __init__(self, bs): 

161 self.bs = bs 

162 self.ax = None 

163 self.xcoords = None 

164 

165 def plot(self, ax=None, *, spin=None, emin=-10, emax=5, filename=None, 

166 show=False, ylabel=None, colors=None, point_colors=None, 

167 label=None, loc=None, 

168 cmap=None, cmin=-1.0, cmax=1.0, sortcolors=False, 

169 colorbar=True, clabel='$s_z$', cax=None, 

170 **plotkwargs): 

171 """Plot band-structure. 

172 

173 ax: Axes 

174 MatPlotLib Axes object. Will be created if not supplied. 

175 spin: int or None 

176 If given, only plot the specified spin channel. 

177 If None, plot all spins. 

178 Default: None, i.e., plot all spins. 

179 emin, emax: float 

180 Minimum and maximum energy above reference. 

181 filename: str 

182 If given, write image to a file. 

183 show: bool 

184 Show the image (not needed in notebooks). 

185 ylabel: str 

186 The label along the y-axis. Defaults to 'energies [eV]' 

187 colors: sequence of str 

188 A sequence of one or two color specifications, depending on 

189 whether there is spin. 

190 Default: green if no spin, yellow and blue if spin is present. 

191 point_colors: ndarray 

192 An array of numbers of the shape (nspins, n_kpts, nbands) which 

193 are then mapped onto colors by the colormap (see ``cmap``). 

194 ``colors`` and ``point_colors`` are mutually exclusive 

195 label: str or list of str 

196 Label for the curves on the legend. A string if one spin is 

197 present, a list of two strings if two spins are present. 

198 Default: If no spin is given, no legend is made; if spin is 

199 present default labels 'spin up' and 'spin down' are used, but 

200 can be suppressed by setting ``label=False``. 

201 loc: str 

202 Location of the legend. 

203 

204 If ``point_colors`` is given, the following arguments can be specified. 

205 

206 cmap: 

207 Only used if colors is an array of numbers. A matplotlib 

208 colormap object, or a string naming a standard colormap. 

209 Default: The matplotlib default, typically 'viridis'. 

210 cmin, cmax: float 

211 Minimal and maximal values used for colormap translation. 

212 Default: -1.0 and 1.0 

213 colorbar: bool 

214 Whether to make a colorbar. 

215 clabel: str 

216 Label for the colorbar (default 's_z', set to None to suppress. 

217 cax: Axes 

218 Axes object used for plotting colorbar. Default: split off a 

219 new one. 

220 sortcolors (bool or callable): 

221 Sort points so highest color values are in front. If a callable is 

222 given, then it is called on the color values to determine the sort 

223 order. 

224 

225 Any additional keyword arguments are passed directly to matplotlib's 

226 plot() or scatter() methods, depending on whether point_colors is 

227 given. 

228 """ 

229 import matplotlib.pyplot as plt 

230 

231 if colors is not None and point_colors is not None: 

232 raise ValueError("Don't give both 'color' and 'point_color'") 

233 

234 if self.ax is None: 

235 ax = self.prepare_plot(ax, emin, emax, ylabel) 

236 

237 if spin is None: 

238 e_skn = self.bs.energies 

239 elif spin not in [0, 1]: 

240 raise ValueError(f"spin should be 0 or 1, not {spin}") 

241 else: 

242 # Select only one spin channel. 

243 e_skn = self.bs.energies[spin, np.newaxis] 

244 

245 nspins = len(e_skn) 

246 

247 if point_colors is None: 

248 # Normal band structure plot 

249 if colors is None: 

250 if len(e_skn) == 1: 

251 colors = 'g' 

252 else: 

253 colors = 'yb' 

254 elif (len(colors) != nspins): 

255 raise ValueError( 

256 f"colors should be a sequence of {nspins} colors" 

257 ) 

258 

259 # Default values for label 

260 if label is None and nspins == 2: 

261 label = ['spin up', 'spin down'] 

262 

263 if label: 

264 if nspins == 1 and isinstance(label, str): 

265 label = [label] 

266 elif len(label) != nspins: 

267 raise ValueError( 

268 f'label should be a list of {nspins} strings' 

269 ) 

270 

271 for spin, e_kn in enumerate(e_skn): 

272 kwargs = dict(color=colors[spin]) 

273 kwargs.update(plotkwargs) 

274 lbl = None # Retain lbl=None if label=False 

275 if label: 

276 lbl = label[spin] 

277 ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs) 

278 

279 for e_k in e_kn.T[1:]: 

280 ax.plot(self.xcoords, e_k, **kwargs) 

281 show_legend = label is not None or nspins == 2 

282 

283 else: 

284 # A color per datapoint. 

285 kwargs = dict(vmin=cmin, vmax=cmax, cmap=cmap, s=1) 

286 kwargs.update(plotkwargs) 

287 shape = e_skn.shape 

288 xcoords = np.zeros(shape) 

289 xcoords += self.xcoords[np.newaxis, :, np.newaxis] 

290 if sortcolors: 

291 if callable(sortcolors): 

292 perm = sortcolors(point_colors).argsort(axis=None) 

293 else: 

294 perm = point_colors.argsort(axis=None) 

295 e_skn = e_skn.ravel()[perm].reshape(shape) 

296 point_colors = point_colors.ravel()[perm].reshape(shape) 

297 xcoords = xcoords.ravel()[perm].reshape(shape) 

298 

299 things = ax.scatter(xcoords, e_skn, c=point_colors, **kwargs) 

300 if colorbar: 

301 cbar = plt.colorbar(things, cax=cax) 

302 if clabel: 

303 cbar.set_label(clabel) 

304 show_legend = False 

305 

306 self.finish_plot(filename, show, loc, show_legend) 

307 

308 return ax 

309 

310 def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None): 

311 import matplotlib.pyplot as plt 

312 if ax is None: 

313 ax = plt.figure().add_subplot(111) 

314 

315 def pretty(kpt): 

316 if kpt == 'G': 

317 kpt = r'$\Gamma$' 

318 elif len(kpt) == 2: 

319 kpt = kpt[0] + '$_' + kpt[1] + '$' 

320 return kpt 

321 

322 self.xcoords, label_xcoords, orig_labels = self.bs.get_labels() 

323 label_xcoords = list(label_xcoords) 

324 labels = [pretty(name) for name in orig_labels] 

325 

326 i = 1 

327 while i < len(labels): 

328 if label_xcoords[i - 1] == label_xcoords[i]: 

329 labels[i - 1] = labels[i - 1] + ',' + labels[i] 

330 labels.pop(i) 

331 label_xcoords.pop(i) 

332 else: 

333 i += 1 

334 

335 for x in label_xcoords[1:-1]: 

336 ax.axvline(x, color='0.5') 

337 

338 ylabel = ylabel if ylabel is not None else 'energies [eV]' 

339 

340 ax.set_xticks(label_xcoords) 

341 ax.set_xticklabels(labels) 

342 ax.set_ylabel(ylabel) 

343 ax.axhline(self.bs.reference, color='k', ls=':') 

344 ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax) 

345 self.ax = ax 

346 return ax 

347 

348 def finish_plot(self, filename, show, loc, show_legend=False): 

349 import matplotlib.pyplot as plt 

350 

351 if show_legend: 

352 leg = self.ax.legend(loc=loc) 

353 leg.get_frame().set_alpha(1) 

354 

355 if filename: 

356 self.ax.figure.savefig(filename) 

357 

358 if show: 

359 plt.show() 

360 

361 

362@jsonable('bandstructure') 

363class BandStructure: 

364 """A band structure consists of an array of eigenvalues and a bandpath. 

365 

366 BandStructure objects support JSON I/O. 

367 """ 

368 

369 def __init__(self, path, energies, reference=0.0): 

370 self._path = path 

371 self._energies = np.asarray(energies) 

372 assert self.energies.shape[0] in [1, 2] # spins x kpts x bands 

373 assert self.energies.shape[1] == len(path.kpts) 

374 assert np.isscalar(reference) 

375 self._reference = reference 

376 

377 @property 

378 def energies(self) -> np.ndarray: 

379 """The energies of this band structure. 

380 

381 This is a numpy array of shape (nspins, nkpoints, nbands).""" 

382 return self._energies 

383 

384 @property 

385 def path(self) -> 'ase.dft.kpoints.BandPath': 

386 """The :class:`~ase.dft.kpoints.BandPath` of this band structure.""" 

387 return self._path 

388 

389 @property 

390 def reference(self) -> float: 

391 """The reference energy. 

392 

393 Semantics may vary; typically a Fermi energy or zero, 

394 depending on how the band structure was created.""" 

395 return self._reference 

396 

397 def subtract_reference(self) -> 'BandStructure': 

398 """Return new band structure with reference energy subtracted.""" 

399 return BandStructure(self.path, self.energies - self.reference, 

400 reference=0.0) 

401 

402 def todict(self): 

403 return dict(path=self.path, 

404 energies=self.energies, 

405 reference=self.reference) 

406 

407 def get_labels(self, eps=1e-5): 

408 """"See :func:`ase.dft.kpoints.labels_from_kpts`.""" 

409 return self.path.get_linear_kpoint_axis(eps=eps) 

410 

411 def plot(self, *args, **kwargs): 

412 """Plot this band structure.""" 

413 bsp = BandStructurePlot(self) 

414 return bsp.plot(*args, **kwargs) 

415 

416 def __repr__(self): 

417 return ('{}(path={!r}, energies=[{} values], reference={})' 

418 .format(self.__class__.__name__, self.path, 

419 '{}x{}x{}'.format(*self.energies.shape), 

420 self.reference))