Coverage for /builds/ase/ase/ase/spectrum/band_structure.py: 84.32%
185 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import numpy as np
5import ase # Annotations
6from ase.calculators.calculator import PropertyNotImplementedError
7from ase.utils import jsonable
10def calculate_band_structure(atoms, path=None, scf_kwargs=None,
11 bs_kwargs=None, kpts_tol=1e-6, cell_tol=1e-6):
12 """Calculate band structure.
14 The purpose of this function is to abstract a band structure calculation
15 so the workflow does not depend on the calculator.
17 First trigger SCF calculation if necessary, then set arguments
18 on the calculator for band structure calculation, then return
19 calculated band structure.
21 The difference from get_band_structure() is that the latter
22 expects the calculation to already have been done."""
23 if path is None:
24 path = atoms.cell.bandpath()
26 from ase.lattice import celldiff # Should this be a method on cell?
27 if any(path.cell.any(1) != atoms.pbc):
28 raise ValueError('The band path\'s cell, {}, does not match the '
29 'periodicity {} of the atoms'
30 .format(path.cell, atoms.pbc))
31 cell_err = celldiff(path.cell, atoms.cell.uncomplete(atoms.pbc))
32 if cell_err > cell_tol:
33 raise ValueError('Atoms and band path have different unit cells. '
34 'Please reduce atoms to standard form. '
35 'Cell lengths and angles are {} vs {}'
36 .format(atoms.cell.cellpar(), path.cell.cellpar()))
38 calc = atoms.calc
39 if calc is None:
40 raise ValueError('Atoms have no calculator')
42 if scf_kwargs is not None:
43 calc.set(**scf_kwargs)
45 # Proposed standard mechanism for calculators to advertise that they
46 # use the bandpath keyword to handle band structures rather than
47 # a double (SCF + BS) run.
48 use_bandpath_kw = getattr(calc, 'accepts_bandpath_keyword', False)
49 if use_bandpath_kw:
50 calc.set(bandpath=path)
51 atoms.get_potential_energy()
52 return calc.band_structure()
54 atoms.get_potential_energy()
56 if hasattr(calc, 'get_fermi_level'):
57 # What is the protocol for a calculator to tell whether
58 # it has fermi_energy?
59 eref = calc.get_fermi_level()
60 else:
61 eref = 0.0
63 if bs_kwargs is None:
64 bs_kwargs = {}
66 calc.set(kpts=path, **bs_kwargs)
67 calc.results.clear() # XXX get rid of me
69 # Calculators are too inconsistent here:
70 # * atoms.get_potential_energy() will fail when total energy is
71 # not in results after BS calculation (Espresso)
72 # * calc.calculate(atoms) doesn't ask for any quantity, so some
73 # calculators may not calculate anything at all
74 # * 'bandstructure' is not a recognized property we can ask for
75 try:
76 atoms.get_potential_energy()
77 except PropertyNotImplementedError:
78 pass
80 ibzkpts = calc.get_ibz_k_points()
81 kpts_err = np.abs(path.kpts - ibzkpts).max()
82 if kpts_err > kpts_tol:
83 raise RuntimeError('Kpoints of calculator differ from those '
84 'of the band path we just used; '
85 'err={} > tol={}'.format(kpts_err, kpts_tol))
87 bs = get_band_structure(atoms, path=path, reference=eref)
88 return bs
91def get_band_structure(atoms=None, calc=None, path=None, reference=None):
92 """Create band structure object from Atoms or calculator."""
93 # path and reference are used internally at the moment, but
94 # the exact implementation will probably change. WIP.
95 #
96 # XXX We throw away info about the bandpath when we create the calculator.
97 # If we have kept the bandpath, we can provide it as an argument here.
98 # It would be wise to check that the bandpath kpoints are the same as
99 # those stored in the calculator.
100 atoms = atoms if atoms is not None else calc.atoms
101 calc = calc if calc is not None else atoms.calc
103 kpts = calc.get_ibz_k_points()
105 energies = []
106 for s in range(calc.get_number_of_spins()):
107 energies.append([calc.get_eigenvalues(kpt=k, spin=s)
108 for k in range(len(kpts))])
109 energies = np.array(energies)
111 if path is None:
112 from ase.dft.kpoints import (
113 BandPath,
114 find_bandpath_kinks,
115 resolve_custom_points,
116 )
117 standard_path = atoms.cell.bandpath(npoints=0)
118 # Kpoints are already evaluated, we just need to put them into
119 # the path (whether they fit our idea of what the path is, or not).
120 #
121 # Depending on how the path was established, the kpoints might
122 # be valid high-symmetry points, but since there are multiple
123 # high-symmetry points of each type, they may not coincide
124 # with ours if the bandpath was generated by another code.
125 #
126 # Here we hack it so the BandPath has proper points even if they
127 # come from some weird source.
128 #
129 # This operation (manually hacking the bandpath) is liable to break.
130 # TODO: Make it available as a proper (documented) bandpath method.
131 kinks = find_bandpath_kinks(atoms.cell, kpts, eps=1e-5)
132 pathspec, special_points = resolve_custom_points(
133 kpts[kinks], standard_path.special_points, eps=1e-5)
134 path = BandPath(standard_path.cell,
135 kpts=kpts,
136 path=pathspec,
137 special_points=special_points)
139 # XXX If we *did* get the path, now would be a good time to check
140 # that it matches the cell! Although the path can only be passed
141 # because we internally want to not re-evaluate the Bravais
142 # lattice type. (We actually need an eps parameter, too.)
144 if reference is None:
145 # Fermi level should come from the GS calculation, not the BS one!
146 reference = calc.get_fermi_level()
148 if reference is None:
149 # Fermi level may not be available, e.g., with non-Fermi smearing.
150 # XXX Actually get_fermi_level() should raise an error when Fermi
151 # level wasn't available, so we should fix that.
152 reference = 0.0
154 return BandStructure(path=path,
155 energies=energies,
156 reference=reference)
159class BandStructurePlot:
160 def __init__(self, bs):
161 self.bs = bs
162 self.ax = None
163 self.xcoords = None
165 def plot(self, ax=None, emin=-10, emax=5, filename=None,
166 show=False, ylabel=None, colors=None, point_colors=None,
167 label=None, loc=None,
168 cmap=None, cmin=-1.0, cmax=1.0, sortcolors=False,
169 colorbar=True, clabel='$s_z$', cax=None,
170 **plotkwargs):
171 """Plot band-structure.
173 ax: Axes
174 MatPlotLib Axes object. Will be created if not supplied.
175 emin, emax: float
176 Minimum and maximum energy above reference.
177 filename: str
178 If given, write image to a file.
179 show: bool
180 Show the image (not needed in notebooks).
181 ylabel: str
182 The label along the y-axis. Defaults to 'energies [eV]'
183 colors: sequence of str
184 A sequence of one or two color specifications, depending on
185 whether there is spin.
186 Default: green if no spin, yellow and blue if spin is present.
187 point_colors: ndarray
188 An array of numbers of the shape (nspins, n_kpts, nbands) which
189 are then mapped onto colors by the colormap (see ``cmap``).
190 ``colors`` and ``point_colors`` are mutually exclusive
191 label: str or list of str
192 Label for the curves on the legend. A string if one spin is
193 present, a list of two strings if two spins are present.
194 Default: If no spin is given, no legend is made; if spin is
195 present default labels 'spin up' and 'spin down' are used, but
196 can be suppressed by setting ``label=False``.
197 loc: str
198 Location of the legend.
200 If ``point_colors`` is given, the following arguments can be specified.
202 cmap:
203 Only used if colors is an array of numbers. A matplotlib
204 colormap object, or a string naming a standard colormap.
205 Default: The matplotlib default, typically 'viridis'.
206 cmin, cmax: float
207 Minimal and maximal values used for colormap translation.
208 Default: -1.0 and 1.0
209 colorbar: bool
210 Whether to make a colorbar.
211 clabel: str
212 Label for the colorbar (default 's_z', set to None to suppress.
213 cax: Axes
214 Axes object used for plotting colorbar. Default: split off a
215 new one.
216 sortcolors (bool or callable):
217 Sort points so highest color values are in front. If a callable is
218 given, then it is called on the color values to determine the sort
219 order.
221 Any additional keyword arguments are passed directly to matplotlib's
222 plot() or scatter() methods, depending on whether point_colors is
223 given.
224 """
225 import matplotlib.pyplot as plt
227 if colors is not None and point_colors is not None:
228 raise ValueError("Don't give both 'color' and 'point_color'")
230 if self.ax is None:
231 ax = self.prepare_plot(ax, emin, emax, ylabel)
233 e_skn = self.bs.energies
234 nspins = len(e_skn)
236 if point_colors is None:
237 # Normal band structure plot
238 if colors is None:
239 if len(e_skn) == 1:
240 colors = 'g'
241 else:
242 colors = 'yb'
243 elif (len(colors) != nspins):
244 raise ValueError(
245 "colors should be a sequence of {nspin} colors"
246 )
248 # Default values for label
249 if label is None and nspins == 2:
250 label = ['spin up', 'spin down']
252 if label:
253 if nspins == 1 and isinstance(label, str):
254 label = [label]
255 elif len(label) != nspins:
256 raise ValueError(
257 f'label should be a list of {nspins} strings'
258 )
260 for spin, e_kn in enumerate(e_skn):
261 kwargs = dict(color=colors[spin])
262 kwargs.update(plotkwargs)
263 lbl = None # Retain lbl=None if label=False
264 if label:
265 lbl = label[spin]
266 ax.plot(self.xcoords, e_kn[:, 0], label=lbl, **kwargs)
268 for e_k in e_kn.T[1:]:
269 ax.plot(self.xcoords, e_k, **kwargs)
270 show_legend = label is not None or nspins == 2
272 else:
273 # A color per datapoint.
274 kwargs = dict(vmin=cmin, vmax=cmax, cmap=cmap, s=1)
275 kwargs.update(plotkwargs)
276 shape = e_skn.shape
277 xcoords = np.zeros(shape)
278 xcoords += self.xcoords[np.newaxis, :, np.newaxis]
279 if sortcolors:
280 if callable(sortcolors):
281 perm = sortcolors(point_colors).argsort(axis=None)
282 else:
283 perm = point_colors.argsort(axis=None)
284 e_skn = e_skn.ravel()[perm].reshape(shape)
285 point_colors = point_colors.ravel()[perm].reshape(shape)
286 xcoords = xcoords.ravel()[perm].reshape(shape)
288 things = ax.scatter(xcoords, e_skn, c=point_colors, **kwargs)
289 if colorbar:
290 cbar = plt.colorbar(things, cax=cax)
291 if clabel:
292 cbar.set_label(clabel)
293 show_legend = False
295 self.finish_plot(filename, show, loc, show_legend)
297 return ax
299 def prepare_plot(self, ax=None, emin=-10, emax=5, ylabel=None):
300 import matplotlib.pyplot as plt
301 if ax is None:
302 ax = plt.figure().add_subplot(111)
304 def pretty(kpt):
305 if kpt == 'G':
306 kpt = r'$\Gamma$'
307 elif len(kpt) == 2:
308 kpt = kpt[0] + '$_' + kpt[1] + '$'
309 return kpt
311 self.xcoords, label_xcoords, orig_labels = self.bs.get_labels()
312 label_xcoords = list(label_xcoords)
313 labels = [pretty(name) for name in orig_labels]
315 i = 1
316 while i < len(labels):
317 if label_xcoords[i - 1] == label_xcoords[i]:
318 labels[i - 1] = labels[i - 1] + ',' + labels[i]
319 labels.pop(i)
320 label_xcoords.pop(i)
321 else:
322 i += 1
324 for x in label_xcoords[1:-1]:
325 ax.axvline(x, color='0.5')
327 ylabel = ylabel if ylabel is not None else 'energies [eV]'
329 ax.set_xticks(label_xcoords)
330 ax.set_xticklabels(labels)
331 ax.set_ylabel(ylabel)
332 ax.axhline(self.bs.reference, color='k', ls=':')
333 ax.axis(xmin=0, xmax=self.xcoords[-1], ymin=emin, ymax=emax)
334 self.ax = ax
335 return ax
337 def finish_plot(self, filename, show, loc, show_legend=False):
338 import matplotlib.pyplot as plt
340 if show_legend:
341 leg = plt.legend(loc=loc)
342 leg.get_frame().set_alpha(1)
344 if filename:
345 plt.savefig(filename)
347 if show:
348 plt.show()
351@jsonable('bandstructure')
352class BandStructure:
353 """A band structure consists of an array of eigenvalues and a bandpath.
355 BandStructure objects support JSON I/O.
356 """
358 def __init__(self, path, energies, reference=0.0):
359 self._path = path
360 self._energies = np.asarray(energies)
361 assert self.energies.shape[0] in [1, 2] # spins x kpts x bands
362 assert self.energies.shape[1] == len(path.kpts)
363 assert np.isscalar(reference)
364 self._reference = reference
366 @property
367 def energies(self) -> np.ndarray:
368 """The energies of this band structure.
370 This is a numpy array of shape (nspins, nkpoints, nbands)."""
371 return self._energies
373 @property
374 def path(self) -> 'ase.dft.kpoints.BandPath':
375 """The :class:`~ase.dft.kpoints.BandPath` of this band structure."""
376 return self._path
378 @property
379 def reference(self) -> float:
380 """The reference energy.
382 Semantics may vary; typically a Fermi energy or zero,
383 depending on how the band structure was created."""
384 return self._reference
386 def subtract_reference(self) -> 'BandStructure':
387 """Return new band structure with reference energy subtracted."""
388 return BandStructure(self.path, self.energies - self.reference,
389 reference=0.0)
391 def todict(self):
392 return dict(path=self.path,
393 energies=self.energies,
394 reference=self.reference)
396 def get_labels(self, eps=1e-5):
397 """"See :func:`ase.dft.kpoints.labels_from_kpts`."""
398 return self.path.get_linear_kpoint_axis(eps=eps)
400 def plot(self, *args, **kwargs):
401 """Plot this band structure."""
402 bsp = BandStructurePlot(self)
403 return bsp.plot(*args, **kwargs)
405 def __repr__(self):
406 return ('{}(path={!r}, energies=[{} values], reference={})'
407 .format(self.__class__.__name__, self.path,
408 '{}x{}x{}'.format(*self.energies.shape),
409 self.reference))