Coverage for ase / spectrum / band_structure.py: 87.83%
189 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
1# fmt: off
3import numpy as np
5import ase # Annotations
6from ase.calculators.calculator import PropertyNotImplementedError
7from ase.utils import jsonable
10def calculate_band_structure(atoms, path=None, scf_kwargs=None,
11 bs_kwargs=None, kpts_tol=1e-6, cell_tol=1e-6):
12 """Calculate band structure.
14 The purpose of this function is to abstract a band structure calculation
15 so the workflow does not depend on the calculator.
17 First trigger SCF calculation if necessary, then set arguments
18 on the calculator for band structure calculation, then return
19 calculated band structure.
21 The difference from get_band_structure() is that the latter
22 expects the calculation to already have been done."""
23 if path is None:
24 path = atoms.cell.bandpath()
26 from ase.lattice import celldiff # Should this be a method on cell?
27 if any(path.cell.any(1) != atoms.pbc):
28 raise ValueError('The band path\'s cell, {}, does not match the '
29 'periodicity {} of the atoms'
30 .format(path.cell, atoms.pbc))
31 cell_err = celldiff(path.cell, atoms.cell.uncomplete(atoms.pbc))
32 if cell_err > cell_tol:
33 raise ValueError('Atoms and band path have different unit cells. '
34 'Please reduce atoms to standard form. '
35 'Cell lengths and angles are {} vs {}'
36 .format(atoms.cell.cellpar(), path.cell.cellpar()))
38 calc = atoms.calc
39 if calc is None:
40 raise ValueError('Atoms have no calculator')
42 if scf_kwargs is not None:
43 calc.set(**scf_kwargs)
45 # Proposed standard mechanism for calculators to advertise that they
46 # use the bandpath keyword to handle band structures rather than
47 # a double (SCF + BS) run.
48 use_bandpath_kw = getattr(calc, 'accepts_bandpath_keyword', False)
49 if use_bandpath_kw:
50 calc.set(bandpath=path)
51 atoms.get_potential_energy()
52 return calc.band_structure()
54 atoms.get_potential_energy()
56 if hasattr(calc, 'get_fermi_level'):
57 # What is the protocol for a calculator to tell whether
58 # it has fermi_energy?
59 eref = calc.get_fermi_level()
60 else:
61 eref = 0.0
63 if bs_kwargs is None:
64 bs_kwargs = {}
66 calc.set(kpts=path, **bs_kwargs)
67 calc.results.clear() # XXX get rid of me
69 # Calculators are too inconsistent here:
70 # * atoms.get_potential_energy() will fail when total energy is
71 # not in results after BS calculation (Espresso)
72 # * calc.calculate(atoms) doesn't ask for any quantity, so some
73 # calculators may not calculate anything at all
74 # * 'bandstructure' is not a recognized property we can ask for
75 try:
76 atoms.get_potential_energy()
77 except PropertyNotImplementedError:
78 pass
80 ibzkpts = calc.get_ibz_k_points()
81 kpts_err = np.abs(path.kpts - ibzkpts).max()
82 if kpts_err > kpts_tol:
83 raise RuntimeError('Kpoints of calculator differ from those '
84 'of the band path we just used; '
85 'err={} > tol={}'.format(kpts_err, kpts_tol))
87 bs = get_band_structure(atoms, path=path, reference=eref)
88 return bs
91def get_band_structure(atoms=None, calc=None, path=None, reference=None):
92 """Create band structure object from Atoms or calculator."""
93 # path and reference are used internally at the moment, but
94 # the exact implementation will probably change. WIP.
95 #
96 # XXX We throw away info about the bandpath when we create the calculator.
97 # If we have kept the bandpath, we can provide it as an argument here.
98 # It would be wise to check that the bandpath kpoints are the same as
99 # those stored in the calculator.
100 atoms = atoms if atoms is not None else calc.atoms
101 calc = calc if calc is not None else atoms.calc
103 kpts = calc.get_ibz_k_points()
105 energies = []
106 for s in range(calc.get_number_of_spins()):
107 energies.append([calc.get_eigenvalues(kpt=k, spin=s)
108 for k in range(len(kpts))])
109 energies = np.array(energies)
111 if path is None:
112 from ase.dft.kpoints import (
113 BandPath,
114 find_bandpath_kinks,
115 resolve_custom_points,
116 )
117 standard_path = atoms.cell.bandpath(npoints=0)
118 # Kpoints are already evaluated, we just need to put them into
119 # the path (whether they fit our idea of what the path is, or not).
120 #
121 # Depending on how the path was established, the kpoints might
122 # be valid high-symmetry points, but since there are multiple
123 # high-symmetry points of each type, they may not coincide
124 # with ours if the bandpath was generated by another code.
125 #
126 # Here we hack it so the BandPath has proper points even if they
127 # come from some weird source.
128 #
129 # This operation (manually hacking the bandpath) is liable to break.
130 # TODO: Make it available as a proper (documented) bandpath method.
131 kinks = find_bandpath_kinks(atoms.cell, kpts, eps=1e-5)
132 pathspec, special_points = resolve_custom_points(
133 kpts[kinks], standard_path.special_points, eps=1e-5)
134 path = BandPath(standard_path.cell,
135 kpts=kpts,
136 path=pathspec,
137 special_points=special_points)
139 # XXX If we *did* get the path, now would be a good time to check
140 # that it matches the cell! Although the path can only be passed
141 # because we internally want to not re-evaluate the Bravais
142 # lattice type. (We actually need an eps parameter, too.)
144 if reference is None:
145 # Fermi level should come from the GS calculation, not the BS one!
146 reference = calc.get_fermi_level()
148 if reference is None:
149 # Fermi level may not be available, e.g., with non-Fermi smearing.
150 # XXX Actually get_fermi_level() should raise an error when Fermi
151 # level wasn't available, so we should fix that.
152 reference = 0.0
154 return BandStructure(path=path,
155 energies=energies,
156 reference=reference)
159class BandStructurePlot:
160 def __init__(self, bs):
161 self.bs = bs
162 self.ax = None
163 self.xcoords = None
165 def plot(self, ax=None, *, spin=None, emin=-10, emax=5, filename=None,
166 show=False, ylabel=None, colors=None, point_colors=None,
167 label=None, loc=None,
168 cmap=None, cmin=-1.0, cmax=1.0, sortcolors=False,
169 colorbar=True, clabel='$s_z$', cax=None,
170 **plotkwargs):
171 """Plot band-structure.
173 ax: Axes
174 MatPlotLib Axes object. Will be created if not supplied.
175 spin: int or None
176 If given, only plot the specified spin channel.
177 If None, plot all spins.
178 Default: None, i.e., plot all spins.
179 emin, emax: float
180 Minimum and maximum energy above reference.
181 filename: str
182 If given, write image to a file.
183 show: bool
184 Show the image (not needed in notebooks).
185 ylabel: str
186 The label along the y-axis. Defaults to 'energies [eV]'
187 colors: sequence of str
188 A sequence of one or two color specifications, depending on
189 whether there is spin.
190 Default: green if no spin, yellow and blue if spin is present.
191 point_colors: ndarray
192 An array of numbers of the shape (nspins, n_kpts, nbands) which
193 are then mapped onto colors by the colormap (see ``cmap``).
194 ``colors`` and ``point_colors`` are mutually exclusive
195 label: str or list of str
196 Label for the curves on the legend. A string if one spin is
197 present, a list of two strings if two spins are present.
198 Default: If no spin is given, no legend is made; if spin is
199 present default labels 'spin up' and 'spin down' are used, but
200 can be suppressed by setting ``label=False``.
201 loc: str
202 Location of the legend.
204 If ``point_colors`` is given, the following arguments can be specified.
206 cmap:
207 Only used if colors is an array of numbers. A matplotlib
208 colormap object, or a string naming a standard colormap.
209 Default: The matplotlib default, typically 'viridis'.
210 cmin, cmax: float
211 Minimal and maximal values used for colormap translation.
212 Default: -1.0 and 1.0
213 colorbar: bool
214 Whether to make a colorbar.
215 clabel: str
216 Label for the colorbar (default 's_z', set to None to suppress.
217 cax: Axes
218 Axes object used for plotting colorbar. Default: split off a
219 new one.
220 sortcolors (bool or callable):
221 Sort points so highest color values are in front. If a callable is
222 given, then it is called on the color values to determine the sort
223 order.
225 Any additional keyword arguments are passed directly to matplotlib's
226 plot() or scatter() methods, depending on whether point_colors is
227 given.
228 """
229 import matplotlib.pyplot as plt
231 if colors is not None and point_colors is not None:
232 raise ValueError("Don't give both 'color' and 'point_color'")
234 if self.ax is None:
235 ax = self.prepare_plot(ax, emin, emax, ylabel)
237 if spin is None:
238 e_skn = self.bs.energies
239 elif spin not in [0, 1]:
240 raise ValueError(f"spin should be 0 or 1, not {spin}")
241 else:
242 # Select only one spin channel.
243 e_skn = self.bs.energies[spin, np.newaxis]
245 nspins = len(e_skn)
247 if point_colors is None:
248 # Normal band structure plot
249 if colors is None:
250 if len(e_skn) == 1:
251 colors = 'g'
252 else:
253 colors = 'yb'
254 elif (len(colors) != nspins):
255 raise ValueError(
256 f"colors should be a sequence of {nspins} colors"
257 )
259 # Default values for label
260 if label is None and nspins == 2:
261 label = ['spin up', 'spin down']
263 if label:
264 if nspins == 1 and isinstance(label, str):
265 label = [label]
266 elif len(label) != nspins:
267 raise ValueError(
268 f'label should be a list of {nspins} strings'
269 )
271 for spin, e_kn in enumerate(e_skn):
272 kwargs = dict(color=colors[spin])
273 kwargs.update(plotkwargs)
274 lbl = None # Retain lbl=None if label=False
275 if label:
276 lbl = label[spin]
277 ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs)
279 for e_k in e_kn.T[1:]:
280 ax.plot(self.xcoords, e_k, **kwargs)
281 show_legend = label is not None or nspins == 2
283 else:
284 # A color per datapoint.
285 kwargs = dict(vmin=cmin, vmax=cmax, cmap=cmap, s=1)
286 kwargs.update(plotkwargs)
287 shape = e_skn.shape
288 xcoords = np.zeros(shape)
289 xcoords += self.xcoords[np.newaxis, :, np.newaxis]
290 if sortcolors:
291 if callable(sortcolors):
292 perm = sortcolors(point_colors).argsort(axis=None)
293 else:
294 perm = point_colors.argsort(axis=None)
295 e_skn = e_skn.ravel()[perm].reshape(shape)
296 point_colors = point_colors.ravel()[perm].reshape(shape)
297 xcoords = xcoords.ravel()[perm].reshape(shape)
299 things = ax.scatter(xcoords, e_skn, c=point_colors, **kwargs)
300 if colorbar:
301 cbar = plt.colorbar(things, cax=cax)
302 if clabel:
303 cbar.set_label(clabel)
304 show_legend = False
306 self.finish_plot(filename, show, loc, show_legend)
308 return ax
310 def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None):
311 import matplotlib.pyplot as plt
312 if ax is None:
313 ax = plt.figure().add_subplot(111)
315 def pretty(kpt):
316 if kpt == 'G':
317 kpt = r'$\Gamma$'
318 elif len(kpt) == 2:
319 kpt = kpt[0] + '$_' + kpt[1] + '$'
320 return kpt
322 self.xcoords, label_xcoords, orig_labels = self.bs.get_labels()
323 label_xcoords = list(label_xcoords)
324 labels = [pretty(name) for name in orig_labels]
326 i = 1
327 while i < len(labels):
328 if label_xcoords[i - 1] == label_xcoords[i]:
329 labels[i - 1] = labels[i - 1] + ',' + labels[i]
330 labels.pop(i)
331 label_xcoords.pop(i)
332 else:
333 i += 1
335 for x in label_xcoords[1:-1]:
336 ax.axvline(x, color='0.5')
338 ylabel = ylabel if ylabel is not None else 'energies [eV]'
340 ax.set_xticks(label_xcoords)
341 ax.set_xticklabels(labels)
342 ax.set_ylabel(ylabel)
343 ax.axhline(self.bs.reference, color='k', ls=':')
344 ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax)
345 self.ax = ax
346 return ax
348 def finish_plot(self, filename, show, loc, show_legend=False):
349 import matplotlib.pyplot as plt
351 if show_legend:
352 leg = self.ax.legend(loc=loc)
353 leg.get_frame().set_alpha(1)
355 if filename:
356 self.ax.figure.savefig(filename)
358 if show:
359 plt.show()
362@jsonable('bandstructure')
363class BandStructure:
364 """A band structure consists of an array of eigenvalues and a bandpath.
366 BandStructure objects support JSON I/O.
367 """
369 def __init__(self, path, energies, reference=0.0):
370 self._path = path
371 self._energies = np.asarray(energies)
372 assert self.energies.shape[0] in [1, 2] # spins x kpts x bands
373 assert self.energies.shape[1] == len(path.kpts)
374 assert np.isscalar(reference)
375 self._reference = reference
377 @property
378 def energies(self) -> np.ndarray:
379 """The energies of this band structure.
381 This is a numpy array of shape (nspins, nkpoints, nbands)."""
382 return self._energies
384 @property
385 def path(self) -> 'ase.dft.kpoints.BandPath':
386 """The :class:`~ase.dft.kpoints.BandPath` of this band structure."""
387 return self._path
389 @property
390 def reference(self) -> float:
391 """The reference energy.
393 Semantics may vary; typically a Fermi energy or zero,
394 depending on how the band structure was created."""
395 return self._reference
397 def subtract_reference(self) -> 'BandStructure':
398 """Return new band structure with reference energy subtracted."""
399 return BandStructure(self.path, self.energies - self.reference,
400 reference=0.0)
402 def todict(self):
403 return dict(path=self.path,
404 energies=self.energies,
405 reference=self.reference)
407 def get_labels(self, eps=1e-5):
408 """"See :func:`ase.dft.kpoints.labels_from_kpts`."""
409 return self.path.get_linear_kpoint_axis(eps=eps)
411 def plot(self, *args, **kwargs):
412 """Plot this band structure."""
413 bsp = BandStructurePlot(self)
414 return bsp.plot(*args, **kwargs)
416 def __repr__(self):
417 return ('{}(path={!r}, energies=[{} values], reference={})'
418 .format(self.__class__.__name__, self.path,
419 '{}x{}x{}'.format(*self.energies.shape),
420 self.reference))