Coverage for ase / spectrum / dosdata.py: 100.00%
153 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3# Refactor of DOS-like data objects
4# towards replacing ase.dft.dos and ase.dft.pdos
5import warnings
6from abc import ABCMeta, abstractmethod
7from collections.abc import Sequence
8from typing import Any, TypeVar
10import numpy as np
11from matplotlib.axes import Axes
13from ase.utils.plotting import SimplePlottingAxes
15# For now we will be strict about Info and say it has to be str->str. Perhaps
16# later we will allow other types that have reliable comparison operations.
17Info = dict[str, str]
19# Still no good solution to type checking with arrays.
20Floats = Sequence[float] | np.ndarray
23class DOSData(metaclass=ABCMeta):
24 """Abstract base class for a single series of DOS-like data
26 Only the 'info' is a mutable attribute; DOS data is set at init"""
28 def __init__(self,
29 info: Info | None = None) -> None:
30 if info is None:
31 self.info = {}
32 elif isinstance(info, dict):
33 self.info = info
34 else:
35 raise TypeError("Info must be a dict or None")
37 @abstractmethod
38 def get_energies(self) -> Floats:
39 """Get energy data stored in this object"""
41 @abstractmethod
42 def get_weights(self) -> Floats:
43 """Get DOS weights stored in this object"""
45 @abstractmethod
46 def copy(self) -> 'DOSData':
47 """Returns a copy in which info dict can be safely mutated"""
49 def _sample(self,
50 energies: Floats,
51 width: float = 0.1,
52 smearing: str = 'Gauss') -> np.ndarray:
53 """Sample the DOS data at chosen points, with broadening
55 Note that no correction is made here for the sampling bin width; total
56 intensity will vary with sampling density.
58 Args:
59 energies: energy values for sampling
60 width: Width of broadening kernel
61 smearing: selection of broadening kernel (only "Gauss" is currently
62 supported)
64 Returns
65 -------
66 Weights sampled from a broadened DOS at values corresponding to x
67 """
69 self._check_positive_width(width)
70 weights_grid = np.zeros(len(energies), float)
71 weights = self.get_weights()
72 energies = np.asarray(energies, float)
74 for i, raw_energy in enumerate(self.get_energies()):
75 delta = self._delta(energies, raw_energy, width, smearing=smearing)
76 weights_grid += weights[i] * delta
77 return weights_grid
79 def _almost_equals(self, other: Any) -> bool:
80 """Compare with another DOSData for testing purposes"""
81 if not isinstance(other, type(self)):
82 return False
83 if self.info != other.info:
84 return False
85 if not np.allclose(self.get_weights(), other.get_weights()):
86 return False
87 return np.allclose(self.get_energies(), other.get_energies())
89 @staticmethod
90 def _delta(x: np.ndarray,
91 x0: float,
92 width: float,
93 smearing: str = 'Gauss') -> np.ndarray:
94 """Return a delta-function centered at 'x0'.
96 This function is used with numpy broadcasting; if x is a row and x0 is
97 a column vector, the returned data will be a 2D array with each row
98 corresponding to a different delta center.
99 """
100 if smearing.lower() == 'gauss':
101 x1 = -0.5 * ((x - x0) / width)**2
102 return np.exp(x1) / (np.sqrt(2 * np.pi) * width)
103 else:
104 msg = 'Requested smearing type not recognized. Got {}'.format(
105 smearing)
106 raise ValueError(msg)
108 @staticmethod
109 def _check_positive_width(width):
110 if width <= 0.0:
111 msg = 'Cannot add 0 or negative width smearing'
112 raise ValueError(msg)
114 def sample_grid(self,
115 npts: int,
116 xmin: float | None = None,
117 xmax: float | None = None,
118 padding: float = 3,
119 width: float = 0.1,
120 smearing: str = 'Gauss',
121 ) -> 'GridDOSData':
122 """Sample the DOS data on an evenly-spaced energy grid
124 Args:
125 npts: Number of sampled points
126 xmin: Minimum sampled x value; if unspecified, a default is chosen
127 xmax: Maximum sampled x value; if unspecified, a default is chosen
128 padding: If xmin/xmax is unspecified, default value will be padded
129 by padding * width to avoid cutting off peaks.
130 width: Width of broadening kernel
131 smearing: selection of broadening kernel (only 'Gauss' is
132 implemented)
134 Returns
135 -------
136 (energy values, sampled DOS)
137 """
139 if xmin is None:
140 xmin = min(self.get_energies()) - (padding * width)
141 if xmax is None:
142 xmax = max(self.get_energies()) + (padding * width)
143 energies_grid = np.linspace(xmin, xmax, npts)
144 weights_grid = self._sample(energies_grid, width=width,
145 smearing=smearing)
147 return GridDOSData(energies_grid, weights_grid, info=self.info.copy())
149 def plot(self,
150 npts: int = 1000,
151 xmin: float | None = None,
152 xmax: float | None = None,
153 width: float = 0.1,
154 smearing: str = 'Gauss',
155 ax: Axes | None = None,
156 show: bool = False,
157 filename: str | None = None,
158 mplargs: dict | None = None) -> Axes:
159 """Simple 1-D plot of DOS data, resampled onto a grid
161 If the special key 'label' is present in self.info, this will be set
162 as the label for the plotted line (unless overruled in mplargs). The
163 label is only seen if a legend is added to the plot (i.e. by calling
164 ``ax.legend()``).
166 Args:
167 npts, xmin, xmax: output data range, as passed to self.sample_grid
168 width: Width of broadening kernel for self.sample_grid()
169 smearing: selection of broadening kernel for self.sample_grid()
170 ax: existing Matplotlib axes object. If not provided, a new figure
171 with one set of axes will be created using Pyplot
172 show: show the figure on-screen
173 filename: if a path is given, save the figure to this file
174 mplargs: additional arguments to pass to matplotlib plot command
175 (e.g. {'linewidth': 2} for a thicker line).
178 Returns
179 -------
180 Plotting axes. If "ax" was set, this is the same object.
181 """
183 if mplargs is None:
184 mplargs = {}
185 if 'label' not in mplargs:
186 mplargs.update({'label': self.label_from_info(self.info)})
188 return self.sample_grid(npts, xmin=xmin, xmax=xmax,
189 width=width,
190 smearing=smearing
191 ).plot(ax=ax, xmin=xmin, xmax=xmax,
192 show=show, filename=filename,
193 mplargs=mplargs)
195 @staticmethod
196 def label_from_info(info: dict[str, str]):
197 """Generate an automatic legend label from info dict"""
198 if 'label' in info:
199 return info['label']
200 else:
201 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}',
202 info.items()))
205class GeneralDOSData(DOSData):
206 """Base class for a single series of DOS-like data
208 Only the 'info' is a mutable attribute; DOS data is set at init
210 This is the base class for DOSData objects that accept/set seperate
211 "energies" and "weights" sequences of equal length at init.
213 """
215 def __init__(self,
216 energies: Floats,
217 weights: Floats,
218 info: Info | None = None) -> None:
219 super().__init__(info=info)
221 n_entries = len(energies)
222 if len(weights) != n_entries:
223 raise ValueError("Energies and weights must be the same length")
225 # Internally store the data as a np array with two rows; energy, weight
226 self._data = np.empty((2, n_entries), dtype=float, order='C')
227 self._data[0, :] = energies
228 self._data[1, :] = weights
230 def get_energies(self) -> np.ndarray:
231 return self._data[0, :].copy()
233 def get_weights(self) -> np.ndarray:
234 return self._data[1, :].copy()
236 D = TypeVar('D', bound='GeneralDOSData')
238 def copy(self: D) -> D: # noqa F821
239 return type(self)(self.get_energies(), self.get_weights(),
240 info=self.info.copy())
243class RawDOSData(GeneralDOSData):
244 """A collection of weighted delta functions which sum to form a DOS
246 This is an appropriate data container for density-of-states (DOS) or
247 spectral data where the energy data values not form a known regular
248 grid. The data may be plotted or resampled for further analysis using the
249 sample_grid() and plot() methods. Multiple weights at the same
250 energy value will *only* be combined in output data, and data stored in
251 RawDOSData is never resampled. A plot_deltas() function is also provided
252 which plots the raw data.
254 Metadata may be stored in the info dict, in which keys and values must be
255 strings. This data is used for selecting and combining multiple DOSData
256 objects in a DOSCollection object.
258 When RawDOSData objects are combined with the addition operator::
260 big_dos = raw_dos_1 + raw_dos_2
262 the energy and weights data is *concatenated* (i.e. combined without
263 sorting or replacement) and the new info dictionary consists of the
264 *intersection* of the inputs: only key-value pairs that were common to both
265 of the input objects will be retained in the new combined object. For
266 example::
268 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'})
269 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'}))
271 will yield the equivalent of::
273 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'})
275 """
277 def __add__(self, other: 'RawDOSData') -> 'RawDOSData':
278 if not isinstance(other, RawDOSData):
279 raise TypeError("RawDOSData can only be combined with other "
280 "RawDOSData objects")
282 # Take intersection of metadata (i.e. only common entries are retained)
283 new_info = dict(set(self.info.items()) & set(other.info.items()))
285 # Concatenate the energy/weight data
286 new_data = np.concatenate((self._data, other._data), axis=1)
288 new_object = RawDOSData([], [], info=new_info)
289 new_object._data = new_data
291 return new_object
293 def plot_deltas(self,
294 ax: Axes | None = None,
295 show: bool = False,
296 filename: str | None = None,
297 mplargs: dict | None = None) -> Axes:
298 """Simple plot of sparse DOS data as a set of delta functions
300 Items at the same x-value can overlap and will not be summed together
302 Args:
303 ax: existing Matplotlib axes object. If not provided, a new figure
304 with one set of axes will be created using Pyplot
305 show: show the figure on-screen
306 filename: if a path is given, save the figure to this file
307 mplargs: additional arguments to pass to matplotlib Axes.vlines
308 command (e.g. {'linewidth': 2} for a thicker line).
310 Returns
311 -------
312 Plotting axes. If "ax" was set, this is the same object.
313 """
315 if mplargs is None:
316 mplargs = {}
318 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
319 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs)
321 return ax
324class GridDOSData(GeneralDOSData):
325 """A collection of regularly-sampled data which represents a DOS
327 This is an appropriate data container for density-of-states (DOS) or
328 spectral data where the intensity values form a regular grid. This
329 is generally the result of sampling or integrating into discrete
330 bins, rather than a collection of unique states. The data may be
331 plotted or resampled for further analysis using the sample_grid()
332 and plot() methods.
334 Metadata may be stored in the info dict, in which keys and values must be
335 strings. This data is used for selecting and combining multiple DOSData
336 objects in a DOSCollection object.
338 When RawDOSData objects are combined with the addition operator::
340 big_dos = raw_dos_1 + raw_dos_2
342 the weights data is *summed* (requiring a consistent energy grid) and the
343 new info dictionary consists of the *intersection* of the inputs: only
344 key-value pairs that were common to both of the input objects will be
345 retained in the new combined object. For example::
347 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3],
348 info={'symbol': 'O', 'index': '1'})
349 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6],
350 info={'symbol': 'O', 'index': '2'}))
352 will yield the equivalent of::
354 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'})
356 """
358 def __init__(self,
359 energies: Floats,
360 weights: Floats,
361 info: Info | None = None) -> None:
362 n_entries = len(energies)
363 if not np.allclose(energies,
364 np.linspace(energies[0], energies[-1], n_entries)):
365 raise ValueError("Energies must be an evenly-spaced 1-D grid")
367 if len(weights) != n_entries:
368 raise ValueError("Energies and weights must be the same length")
370 super().__init__(energies, weights, info=info)
371 self.sigma_cutoff = 3
373 def _check_spacing(self, width) -> float:
374 current_spacing = self._data[0, 1] - self._data[0, 0]
375 if width < (2 * current_spacing):
376 warnings.warn(
377 "The broadening width is small compared to the original "
378 "sampling density. The results are unlikely to be smooth.")
379 return current_spacing
381 def _sample(self,
382 energies: Floats,
383 width: float = 0.1,
384 smearing: str = 'Gauss') -> np.ndarray:
385 current_spacing = self._check_spacing(width)
386 return super()._sample(energies=energies,
387 width=width, smearing=smearing
388 ) * current_spacing
390 def __add__(self, other: 'GridDOSData') -> 'GridDOSData':
391 # This method uses direct access to the mutable energy and weights data
392 # (self._data) to avoid redundant copying operations. The __init__
393 # method of GridDOSData will write this to a new array, so on this
394 # occasion it is safe to pass references to the mutable data.
396 if not isinstance(other, GridDOSData):
397 raise TypeError("GridDOSData can only be combined with other "
398 "GridDOSData objects")
399 if len(self._data[0, :]) != len(other.get_energies()):
400 raise ValueError("Cannot add GridDOSData objects with different-"
401 "length energy grids.")
403 if not np.allclose(self._data[0, :], other.get_energies()):
404 raise ValueError("Cannot add GridDOSData objects with different "
405 "energy grids.")
407 # Take intersection of metadata (i.e. only common entries are retained)
408 new_info = dict(set(self.info.items()) & set(other.info.items()))
410 # Sum the energy/weight data
411 new_weights = self._data[1, :] + other.get_weights()
413 new_object = GridDOSData(self._data[0, :], new_weights,
414 info=new_info)
415 return new_object
417 @staticmethod
418 def _interpret_smearing_args(npts: int,
419 width: float | None = None,
420 default_npts: int = 1000,
421 default_width: float = 0.1
422 ) -> tuple[int, float | None]:
423 """Figure out what the user intended: resample if width provided"""
424 if width is not None:
425 if npts:
426 return (npts, float(width))
427 else:
428 return (default_npts, float(width))
429 else:
430 if npts:
431 return (npts, default_width)
432 else:
433 return (0, None)
435 def plot(self,
436 npts: int = 0,
437 xmin: float | None = None,
438 xmax: float | None = None,
439 width: float | None = None,
440 smearing: str = 'Gauss',
441 ax: Axes | None = None,
442 show: bool = False,
443 filename: str | None = None,
444 mplargs: dict | None = None) -> Axes:
445 """Simple 1-D plot of DOS data
447 Data will be resampled onto a grid with `npts` points unless `npts` is
448 set to zero, in which case:
450 - no resampling takes place
451 - `width` and `smearing` are ignored
452 - `xmin` and `xmax` affect the axis limits of the plot, not the
453 underlying data.
455 If the special key 'label' is present in self.info, this will be set
456 as the label for the plotted line (unless overruled in mplargs). The
457 label is only seen if a legend is added to the plot (i.e. by calling
458 ``ax.legend()``).
460 Args:
461 npts, xmin, xmax: output data range, as passed to self.sample_grid
462 width: Width of broadening kernel, passed to self.sample_grid().
463 If no npts was set but width is set, npts will be set to 1000.
464 smearing: selection of broadening kernel for self.sample_grid()
465 ax: existing Matplotlib axes object. If not provided, a new figure
466 with one set of axes will be created using Pyplot
467 show: show the figure on-screen
468 filename: if a path is given, save the figure to this file
469 mplargs: additional arguments to pass to matplotlib plot command
470 (e.g. {'linewidth': 2} for a thicker line).
472 Returns
473 -------
474 Plotting axes. If "ax" was set, this is the same object.
475 """
477 npts, width = self._interpret_smearing_args(npts, width)
479 if mplargs is None:
480 mplargs = {}
481 if 'label' not in mplargs:
482 mplargs.update({'label': self.label_from_info(self.info)})
484 if npts:
485 assert isinstance(width, float)
486 dos = self.sample_grid(npts, xmin=xmin,
487 xmax=xmax, width=width,
488 smearing=smearing)
489 else:
490 dos = self
492 energies, intensity = dos.get_energies(), dos.get_weights()
494 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
495 ax.plot(energies, intensity, **mplargs)
496 ax.set_xlim(left=xmin, right=xmax)
498 return ax