Coverage for /builds/ase/ase/ase/spectrum/dosdata.py: 100.00%
152 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3# Refactor of DOS-like data objects
4# towards replacing ase.dft.dos and ase.dft.pdos
5import warnings
6from abc import ABCMeta, abstractmethod
7from typing import Any, Dict, Sequence, Tuple, TypeVar, Union
9import numpy as np
10from matplotlib.axes import Axes
12from ase.utils.plotting import SimplePlottingAxes
14# For now we will be strict about Info and say it has to be str->str. Perhaps
15# later we will allow other types that have reliable comparison operations.
16Info = Dict[str, str]
18# Still no good solution to type checking with arrays.
19Floats = Union[Sequence[float], np.ndarray]
22class DOSData(metaclass=ABCMeta):
23 """Abstract base class for a single series of DOS-like data
25 Only the 'info' is a mutable attribute; DOS data is set at init"""
27 def __init__(self,
28 info: Info = None) -> None:
29 if info is None:
30 self.info = {}
31 elif isinstance(info, dict):
32 self.info = info
33 else:
34 raise TypeError("Info must be a dict or None")
36 @abstractmethod
37 def get_energies(self) -> Floats:
38 """Get energy data stored in this object"""
40 @abstractmethod
41 def get_weights(self) -> Floats:
42 """Get DOS weights stored in this object"""
44 @abstractmethod
45 def copy(self) -> 'DOSData':
46 """Returns a copy in which info dict can be safely mutated"""
48 def _sample(self,
49 energies: Floats,
50 width: float = 0.1,
51 smearing: str = 'Gauss') -> np.ndarray:
52 """Sample the DOS data at chosen points, with broadening
54 Note that no correction is made here for the sampling bin width; total
55 intensity will vary with sampling density.
57 Args:
58 energies: energy values for sampling
59 width: Width of broadening kernel
60 smearing: selection of broadening kernel (only "Gauss" is currently
61 supported)
63 Returns:
64 Weights sampled from a broadened DOS at values corresponding to x
65 """
67 self._check_positive_width(width)
68 weights_grid = np.zeros(len(energies), float)
69 weights = self.get_weights()
70 energies = np.asarray(energies, float)
72 for i, raw_energy in enumerate(self.get_energies()):
73 delta = self._delta(energies, raw_energy, width, smearing=smearing)
74 weights_grid += weights[i] * delta
75 return weights_grid
77 def _almost_equals(self, other: Any) -> bool:
78 """Compare with another DOSData for testing purposes"""
79 if not isinstance(other, type(self)):
80 return False
81 if self.info != other.info:
82 return False
83 if not np.allclose(self.get_weights(), other.get_weights()):
84 return False
85 return np.allclose(self.get_energies(), other.get_energies())
87 @staticmethod
88 def _delta(x: np.ndarray,
89 x0: float,
90 width: float,
91 smearing: str = 'Gauss') -> np.ndarray:
92 """Return a delta-function centered at 'x0'.
94 This function is used with numpy broadcasting; if x is a row and x0 is
95 a column vector, the returned data will be a 2D array with each row
96 corresponding to a different delta center.
97 """
98 if smearing.lower() == 'gauss':
99 x1 = -0.5 * ((x - x0) / width)**2
100 return np.exp(x1) / (np.sqrt(2 * np.pi) * width)
101 else:
102 msg = 'Requested smearing type not recognized. Got {}'.format(
103 smearing)
104 raise ValueError(msg)
106 @staticmethod
107 def _check_positive_width(width):
108 if width <= 0.0:
109 msg = 'Cannot add 0 or negative width smearing'
110 raise ValueError(msg)
112 def sample_grid(self,
113 npts: int,
114 xmin: float = None,
115 xmax: float = None,
116 padding: float = 3,
117 width: float = 0.1,
118 smearing: str = 'Gauss',
119 ) -> 'GridDOSData':
120 """Sample the DOS data on an evenly-spaced energy grid
122 Args:
123 npts: Number of sampled points
124 xmin: Minimum sampled x value; if unspecified, a default is chosen
125 xmax: Maximum sampled x value; if unspecified, a default is chosen
126 padding: If xmin/xmax is unspecified, default value will be padded
127 by padding * width to avoid cutting off peaks.
128 width: Width of broadening kernel
129 smearing: selection of broadening kernel (only 'Gauss' is
130 implemented)
132 Returns:
133 (energy values, sampled DOS)
134 """
136 if xmin is None:
137 xmin = min(self.get_energies()) - (padding * width)
138 if xmax is None:
139 xmax = max(self.get_energies()) + (padding * width)
140 energies_grid = np.linspace(xmin, xmax, npts)
141 weights_grid = self._sample(energies_grid, width=width,
142 smearing=smearing)
144 return GridDOSData(energies_grid, weights_grid, info=self.info.copy())
146 def plot(self,
147 npts: int = 1000,
148 xmin: float = None,
149 xmax: float = None,
150 width: float = 0.1,
151 smearing: str = 'Gauss',
152 ax: Axes = None,
153 show: bool = False,
154 filename: str = None,
155 mplargs: dict = None) -> Axes:
156 """Simple 1-D plot of DOS data, resampled onto a grid
158 If the special key 'label' is present in self.info, this will be set
159 as the label for the plotted line (unless overruled in mplargs). The
160 label is only seen if a legend is added to the plot (i.e. by calling
161 ``ax.legend()``).
163 Args:
164 npts, xmin, xmax: output data range, as passed to self.sample_grid
165 width: Width of broadening kernel for self.sample_grid()
166 smearing: selection of broadening kernel for self.sample_grid()
167 ax: existing Matplotlib axes object. If not provided, a new figure
168 with one set of axes will be created using Pyplot
169 show: show the figure on-screen
170 filename: if a path is given, save the figure to this file
171 mplargs: additional arguments to pass to matplotlib plot command
172 (e.g. {'linewidth': 2} for a thicker line).
175 Returns:
176 Plotting axes. If "ax" was set, this is the same object.
177 """
179 if mplargs is None:
180 mplargs = {}
181 if 'label' not in mplargs:
182 mplargs.update({'label': self.label_from_info(self.info)})
184 return self.sample_grid(npts, xmin=xmin, xmax=xmax,
185 width=width,
186 smearing=smearing
187 ).plot(ax=ax, xmin=xmin, xmax=xmax,
188 show=show, filename=filename,
189 mplargs=mplargs)
191 @staticmethod
192 def label_from_info(info: Dict[str, str]):
193 """Generate an automatic legend label from info dict"""
194 if 'label' in info:
195 return info['label']
196 else:
197 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}',
198 info.items()))
201class GeneralDOSData(DOSData):
202 """Base class for a single series of DOS-like data
204 Only the 'info' is a mutable attribute; DOS data is set at init
206 This is the base class for DOSData objects that accept/set seperate
207 "energies" and "weights" sequences of equal length at init.
209 """
211 def __init__(self,
212 energies: Floats,
213 weights: Floats,
214 info: Info = None) -> None:
215 super().__init__(info=info)
217 n_entries = len(energies)
218 if len(weights) != n_entries:
219 raise ValueError("Energies and weights must be the same length")
221 # Internally store the data as a np array with two rows; energy, weight
222 self._data = np.empty((2, n_entries), dtype=float, order='C')
223 self._data[0, :] = energies
224 self._data[1, :] = weights
226 def get_energies(self) -> np.ndarray:
227 return self._data[0, :].copy()
229 def get_weights(self) -> np.ndarray:
230 return self._data[1, :].copy()
232 D = TypeVar('D', bound='GeneralDOSData')
234 def copy(self: D) -> D: # noqa F821
235 return type(self)(self.get_energies(), self.get_weights(),
236 info=self.info.copy())
239class RawDOSData(GeneralDOSData):
240 """A collection of weighted delta functions which sum to form a DOS
242 This is an appropriate data container for density-of-states (DOS) or
243 spectral data where the energy data values not form a known regular
244 grid. The data may be plotted or resampled for further analysis using the
245 sample_grid() and plot() methods. Multiple weights at the same
246 energy value will *only* be combined in output data, and data stored in
247 RawDOSData is never resampled. A plot_deltas() function is also provided
248 which plots the raw data.
250 Metadata may be stored in the info dict, in which keys and values must be
251 strings. This data is used for selecting and combining multiple DOSData
252 objects in a DOSCollection object.
254 When RawDOSData objects are combined with the addition operator::
256 big_dos = raw_dos_1 + raw_dos_2
258 the energy and weights data is *concatenated* (i.e. combined without
259 sorting or replacement) and the new info dictionary consists of the
260 *intersection* of the inputs: only key-value pairs that were common to both
261 of the input objects will be retained in the new combined object. For
262 example::
264 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'})
265 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'}))
267 will yield the equivalent of::
269 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'})
271 """
273 def __add__(self, other: 'RawDOSData') -> 'RawDOSData':
274 if not isinstance(other, RawDOSData):
275 raise TypeError("RawDOSData can only be combined with other "
276 "RawDOSData objects")
278 # Take intersection of metadata (i.e. only common entries are retained)
279 new_info = dict(set(self.info.items()) & set(other.info.items()))
281 # Concatenate the energy/weight data
282 new_data = np.concatenate((self._data, other._data), axis=1)
284 new_object = RawDOSData([], [], info=new_info)
285 new_object._data = new_data
287 return new_object
289 def plot_deltas(self,
290 ax: Axes = None,
291 show: bool = False,
292 filename: str = None,
293 mplargs: dict = None) -> Axes:
294 """Simple plot of sparse DOS data as a set of delta functions
296 Items at the same x-value can overlap and will not be summed together
298 Args:
299 ax: existing Matplotlib axes object. If not provided, a new figure
300 with one set of axes will be created using Pyplot
301 show: show the figure on-screen
302 filename: if a path is given, save the figure to this file
303 mplargs: additional arguments to pass to matplotlib Axes.vlines
304 command (e.g. {'linewidth': 2} for a thicker line).
306 Returns:
307 Plotting axes. If "ax" was set, this is the same object.
308 """
310 if mplargs is None:
311 mplargs = {}
313 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
314 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs)
316 return ax
319class GridDOSData(GeneralDOSData):
320 """A collection of regularly-sampled data which represents a DOS
322 This is an appropriate data container for density-of-states (DOS) or
323 spectral data where the intensity values form a regular grid. This
324 is generally the result of sampling or integrating into discrete
325 bins, rather than a collection of unique states. The data may be
326 plotted or resampled for further analysis using the sample_grid()
327 and plot() methods.
329 Metadata may be stored in the info dict, in which keys and values must be
330 strings. This data is used for selecting and combining multiple DOSData
331 objects in a DOSCollection object.
333 When RawDOSData objects are combined with the addition operator::
335 big_dos = raw_dos_1 + raw_dos_2
337 the weights data is *summed* (requiring a consistent energy grid) and the
338 new info dictionary consists of the *intersection* of the inputs: only
339 key-value pairs that were common to both of the input objects will be
340 retained in the new combined object. For example::
342 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3],
343 info={'symbol': 'O', 'index': '1'})
344 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6],
345 info={'symbol': 'O', 'index': '2'}))
347 will yield the equivalent of::
349 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'})
351 """
353 def __init__(self,
354 energies: Floats,
355 weights: Floats,
356 info: Info = None) -> None:
357 n_entries = len(energies)
358 if not np.allclose(energies,
359 np.linspace(energies[0], energies[-1], n_entries)):
360 raise ValueError("Energies must be an evenly-spaced 1-D grid")
362 if len(weights) != n_entries:
363 raise ValueError("Energies and weights must be the same length")
365 super().__init__(energies, weights, info=info)
366 self.sigma_cutoff = 3
368 def _check_spacing(self, width) -> float:
369 current_spacing = self._data[0, 1] - self._data[0, 0]
370 if width < (2 * current_spacing):
371 warnings.warn(
372 "The broadening width is small compared to the original "
373 "sampling density. The results are unlikely to be smooth.")
374 return current_spacing
376 def _sample(self,
377 energies: Floats,
378 width: float = 0.1,
379 smearing: str = 'Gauss') -> np.ndarray:
380 current_spacing = self._check_spacing(width)
381 return super()._sample(energies=energies,
382 width=width, smearing=smearing
383 ) * current_spacing
385 def __add__(self, other: 'GridDOSData') -> 'GridDOSData':
386 # This method uses direct access to the mutable energy and weights data
387 # (self._data) to avoid redundant copying operations. The __init__
388 # method of GridDOSData will write this to a new array, so on this
389 # occasion it is safe to pass references to the mutable data.
391 if not isinstance(other, GridDOSData):
392 raise TypeError("GridDOSData can only be combined with other "
393 "GridDOSData objects")
394 if len(self._data[0, :]) != len(other.get_energies()):
395 raise ValueError("Cannot add GridDOSData objects with different-"
396 "length energy grids.")
398 if not np.allclose(self._data[0, :], other.get_energies()):
399 raise ValueError("Cannot add GridDOSData objects with different "
400 "energy grids.")
402 # Take intersection of metadata (i.e. only common entries are retained)
403 new_info = dict(set(self.info.items()) & set(other.info.items()))
405 # Sum the energy/weight data
406 new_weights = self._data[1, :] + other.get_weights()
408 new_object = GridDOSData(self._data[0, :], new_weights,
409 info=new_info)
410 return new_object
412 @staticmethod
413 def _interpret_smearing_args(npts: int,
414 width: float = None,
415 default_npts: int = 1000,
416 default_width: float = 0.1
417 ) -> Tuple[int, Union[float, None]]:
418 """Figure out what the user intended: resample if width provided"""
419 if width is not None:
420 if npts:
421 return (npts, float(width))
422 else:
423 return (default_npts, float(width))
424 else:
425 if npts:
426 return (npts, default_width)
427 else:
428 return (0, None)
430 def plot(self,
431 npts: int = 0,
432 xmin: float = None,
433 xmax: float = None,
434 width: float = None,
435 smearing: str = 'Gauss',
436 ax: Axes = None,
437 show: bool = False,
438 filename: str = None,
439 mplargs: dict = None) -> Axes:
440 """Simple 1-D plot of DOS data
442 Data will be resampled onto a grid with `npts` points unless `npts` is
443 set to zero, in which case:
445 - no resampling takes place
446 - `width` and `smearing` are ignored
447 - `xmin` and `xmax` affect the axis limits of the plot, not the
448 underlying data.
450 If the special key 'label' is present in self.info, this will be set
451 as the label for the plotted line (unless overruled in mplargs). The
452 label is only seen if a legend is added to the plot (i.e. by calling
453 ``ax.legend()``).
455 Args:
456 npts, xmin, xmax: output data range, as passed to self.sample_grid
457 width: Width of broadening kernel, passed to self.sample_grid().
458 If no npts was set but width is set, npts will be set to 1000.
459 smearing: selection of broadening kernel for self.sample_grid()
460 ax: existing Matplotlib axes object. If not provided, a new figure
461 with one set of axes will be created using Pyplot
462 show: show the figure on-screen
463 filename: if a path is given, save the figure to this file
464 mplargs: additional arguments to pass to matplotlib plot command
465 (e.g. {'linewidth': 2} for a thicker line).
467 Returns:
468 Plotting axes. If "ax" was set, this is the same object.
469 """
471 npts, width = self._interpret_smearing_args(npts, width)
473 if mplargs is None:
474 mplargs = {}
475 if 'label' not in mplargs:
476 mplargs.update({'label': self.label_from_info(self.info)})
478 if npts:
479 assert isinstance(width, float)
480 dos = self.sample_grid(npts, xmin=xmin,
481 xmax=xmax, width=width,
482 smearing=smearing)
483 else:
484 dos = self
486 energies, intensity = dos.get_energies(), dos.get_weights()
488 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
489 ax.plot(energies, intensity, **mplargs)
490 ax.set_xlim(left=xmin, right=xmax)
492 return ax