Coverage for ase / spectrum / doscollection.py: 100.00%
178 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3import collections
4from collections.abc import Iterable, Sequence
5from functools import reduce, singledispatch
6from typing import (
7 Any,
8 TypeVar,
9 overload,
10)
12import numpy as np
13from matplotlib.axes import Axes
15from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData
16from ase.utils.plotting import SimplePlottingAxes
19class DOSCollection(collections.abc.Sequence):
20 """Base class for a collection of DOSData objects"""
22 def __init__(self, dos_series: Iterable[DOSData]) -> None:
23 self._data = list(dos_series)
25 def _sample(self,
26 energies: Floats,
27 width: float = 0.1,
28 smearing: str = 'Gauss') -> np.ndarray:
29 """Sample the DOS data at chosen points, with broadening
31 This samples the underlying DOS data in the same way as the .sample()
32 method of those DOSData items, returning a 2-D array with columns
33 corresponding to x and rows corresponding to the collected data series.
35 Args:
36 energies: energy values for sampling
37 width: Width of broadening kernel
38 smearing: selection of broadening kernel (only "Gauss" is currently
39 supported)
41 Returns
42 -------
43 Weights sampled from a broadened DOS at values corresponding to x,
44 in rows corresponding to DOSData entries contained in this object
45 """
47 if len(self) == 0:
48 raise IndexError("No data to sample")
50 return np.asarray(
51 [data._sample(energies, width=width, smearing=smearing)
52 for data in self])
54 def plot(self,
55 npts: int = 1000,
56 xmin: float | None = None,
57 xmax: float | None = None,
58 width: float = 0.1,
59 smearing: str = 'Gauss',
60 ax: Axes | None = None,
61 show: bool = False,
62 filename: str | None = None,
63 mplargs: dict | None = None) -> Axes:
64 """Simple plot of collected DOS data, resampled onto a grid
66 If the special key 'label' is present in self.info, this will be set
67 as the label for the plotted line (unless overruled in mplargs). The
68 label is only seen if a legend is added to the plot (i.e. by calling
69 `ax.legend()`).
71 Args:
72 npts, xmin, xmax: output data range, as passed to self.sample_grid
73 width: Width of broadening kernel, passed to self.sample_grid()
74 smearing: selection of broadening kernel for self.sample_grid()
75 ax: existing Matplotlib axes object. If not provided, a new figure
76 with one set of axes will be created using Pyplot
77 show: show the figure on-screen
78 filename: if a path is given, save the figure to this file
79 mplargs: additional arguments to pass to matplotlib plot command
80 (e.g. {'linewidth': 2} for a thicker line).
82 Returns
83 -------
84 Plotting axes. If "ax" was set, this is the same object.
85 """
86 return self.sample_grid(npts,
87 xmin=xmin, xmax=xmax,
88 width=width, smearing=smearing
89 ).plot(npts=npts,
90 xmin=xmin, xmax=xmax,
91 width=width, smearing=smearing,
92 ax=ax, show=show, filename=filename,
93 mplargs=mplargs)
95 def sample_grid(self,
96 npts: int,
97 xmin: float | None = None,
98 xmax: float | None = None,
99 padding: float = 3,
100 width: float = 0.1,
101 smearing: str = 'Gauss',
102 ) -> 'GridDOSCollection':
103 """Sample the DOS data on an evenly-spaced energy grid
105 Args:
106 npts: Number of sampled points
107 xmin: Minimum sampled energy value; if unspecified, a default is
108 chosen
109 xmax: Maximum sampled energy value; if unspecified, a default is
110 chosen
111 padding: If xmin/xmax is unspecified, default value will be padded
112 by padding * width to avoid cutting off peaks.
113 width: Width of broadening kernel, passed to self.sample_grid()
114 smearing: selection of broadening kernel, for self.sample_grid()
116 Returns
117 -------
118 (energy values, sampled DOS)
119 """
120 if len(self) == 0:
121 raise IndexError("No data to sample")
123 if xmin is None:
124 xmin = (min(min(data.get_energies()) for data in self)
125 - (padding * width))
126 if xmax is None:
127 xmax = (max(max(data.get_energies()) for data in self)
128 + (padding * width))
130 return GridDOSCollection(
131 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width,
132 smearing=smearing)
133 for data in self])
135 @classmethod
136 def from_data(cls,
137 energies: Floats,
138 weights: Sequence[Floats],
139 info: Sequence[Info] | None = None) -> 'DOSCollection':
140 """Create a DOSCollection from data sharing a common set of energies
142 This is a convenience method to be used when all the DOS data in the
143 collection has a common energy axis. There is no performance advantage
144 in using this method for the generic DOSCollection, but for
145 GridDOSCollection it is more efficient.
147 Args:
148 energy: common set of energy values for input data
149 weights: array of DOS weights with rows corresponding to different
150 datasets
151 info: sequence of info dicts corresponding to weights rows.
153 Returns
154 -------
155 Collection of DOS data (in RawDOSData format)
156 """
158 info = cls._check_weights_and_info(weights, info)
160 return cls(RawDOSData(energies, row_weights, row_info)
161 for row_weights, row_info in zip(weights, info))
163 @staticmethod
164 def _check_weights_and_info(weights: Sequence[Floats],
165 info: Sequence[Info] | None,
166 ) -> Sequence[Info]:
167 if info is None:
168 info = [{} for _ in range(len(weights))]
169 else:
170 if len(info) != len(weights):
171 raise ValueError("Length of info must match number of rows in "
172 "weights")
173 return info
175 @overload
176 def __getitem__(self, item: int) -> DOSData:
177 ...
179 @overload # noqa F811
180 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811
181 ...
183 def __getitem__(self, item): # noqa F811
184 if isinstance(item, int):
185 return self._data[item]
186 elif isinstance(item, slice):
187 return type(self)(self._data[item])
188 else:
189 raise TypeError("index in DOSCollection must be an integer or "
190 "slice")
192 def __len__(self) -> int:
193 return len(self._data)
195 def _almost_equals(self, other: Any) -> bool:
196 """Compare with another DOSCollection for testing purposes"""
197 if not isinstance(other, type(self)):
198 return False
199 elif len(self) != len(other):
200 return False
201 else:
202 return all(a._almost_equals(b) for a, b in zip(self, other))
204 def total(self) -> DOSData:
205 """Sum all the DOSData in this Collection and label it as 'Total'"""
206 data = self.sum_all()
207 data.info.update({'label': 'Total'})
208 return data
210 def sum_all(self) -> DOSData:
211 """Sum all the DOSData contained in this Collection"""
212 if len(self) == 0:
213 raise IndexError("No data to sum")
214 elif len(self) == 1:
215 data = self[0].copy()
216 else:
217 data = reduce(lambda x, y: x + y, self)
218 return data
220 D = TypeVar('D', bound=DOSData)
222 @staticmethod
223 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes
224 info_selection: dict[str, str], # misses 'D' def
225 negative: bool = False) -> list[D]: # noqa: F821
226 query = set(info_selection.items())
228 if negative:
229 return [data for data in dos_collection
230 if not query.issubset(set(data.info.items()))]
231 else:
232 return [data for data in dos_collection
233 if query.issubset(set(data.info.items()))]
235 def select(self, **info_selection: str) -> 'DOSCollection':
236 """Narrow DOSCollection to items with specified info
238 For example, if ::
240 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
241 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
243 then ::
245 dc.select(b='1')
247 will return an identical object to dc, while ::
249 dc.select(a='1')
251 will return a DOSCollection with only the first item and ::
253 dc.select(a='2', b='1')
255 will return a DOSCollection with only the second item.
257 """
259 matches = self._select_to_list(self, info_selection)
260 return type(self)(matches)
262 def select_not(self, **info_selection: str) -> 'DOSCollection':
263 """Narrow DOSCollection to items without specified info
265 For example, if ::
267 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
268 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
270 then ::
272 dc.select_not(b='2')
274 will return an identical object to dc, while ::
276 dc.select_not(a='2')
278 will return a DOSCollection with only the first item and ::
280 dc.select_not(a='1', b='1')
282 will return a DOSCollection with only the second item.
284 """
285 matches = self._select_to_list(self, info_selection, negative=True)
286 return type(self)(matches)
288 # Use typehint *info_keys: str from python3.11+
289 def sum_by(self, *info_keys) -> 'DOSCollection':
290 """Return a DOSCollection with some data summed by common attributes
292 For example, if ::
294 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
295 DOSData(x2, y2, info={'a': '2', 'b': '1'}),
296 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
298 then ::
300 dc.sum_by('b')
302 will return a collection equivalent to ::
304 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'})
305 + DOSData(x2, y2, info={'a': '2', 'b': '1'}),
306 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
308 where the resulting contained DOSData have info attributes of
309 {'b': '1'} and {'b': '2'} respectively.
311 dc.sum_by('a', 'b') on the other hand would return the full three-entry
312 collection, as none of the entries have common 'a' *and* 'b' info.
314 """
316 def _matching_info_tuples(data: DOSData):
317 """Get relevent dict entries in tuple form
319 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3}
320 and info_keys = ('a', 'c')
322 then return (('a', 1), ('c': 3))
323 """
324 matched_keys = set(info_keys) & set(data.info)
325 return tuple(sorted([(key, data.info[key])
326 for key in matched_keys]))
328 # Sorting inside info matching helps set() to remove redundant matches;
329 # combos are then sorted() to ensure consistent output across sessions.
330 all_combos = map(_matching_info_tuples, self)
331 unique_combos = sorted(set(all_combos))
333 # For each key/value combination, perform a select() to obtain all
334 # the matching entries and sum them together.
335 collection_data = [self.select(**dict(combo)).sum_all()
336 for combo in unique_combos]
337 return type(self)(collection_data)
339 def __add__(self, other: 'DOSCollection | DOSData'
340 ) -> 'DOSCollection':
341 """Join entries between two DOSCollection objects of the same type
343 It is also possible to add a single DOSData object without wrapping it
344 in a new collection: i.e. ::
346 DOSCollection([dosdata1]) + DOSCollection([dosdata2])
348 or ::
350 DOSCollection([dosdata1]) + dosdata2
352 will return ::
354 DOSCollection([dosdata1, dosdata2])
356 """
357 return _add_to_collection(other, self)
360@singledispatch
361def _add_to_collection(other: DOSData | DOSCollection,
362 collection: DOSCollection) -> DOSCollection:
363 if isinstance(other, type(collection)):
364 return type(collection)(list(collection) + list(other))
365 elif isinstance(other, DOSCollection):
366 raise TypeError("Only DOSCollection objects of the same type may "
367 "be joined with '+'.")
368 else:
369 raise TypeError("DOSCollection may only be joined to DOSData or "
370 "DOSCollection objects with '+'.")
373@_add_to_collection.register(DOSData)
374def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection:
375 """Return a new DOSCollection with an additional DOSData item"""
376 return type(collection)(list(collection) + [other])
379class RawDOSCollection(DOSCollection):
380 def __init__(self, dos_series: Iterable[RawDOSData]) -> None:
381 super().__init__(dos_series)
382 for dos_data in self:
383 if not isinstance(dos_data, RawDOSData):
384 raise TypeError("RawDOSCollection can only store "
385 "RawDOSData objects.")
388class GridDOSCollection(DOSCollection):
389 def __init__(self, dos_series: Iterable[GridDOSData],
390 energies: Floats | None = None) -> None:
391 dos_list = list(dos_series)
392 if energies is None:
393 if len(dos_list) == 0:
394 raise ValueError("Must provide energies to create a "
395 "GridDOSCollection without any DOS data.")
396 self._energies = dos_list[0].get_energies()
397 else:
398 self._energies = np.asarray(energies)
400 self._weights: np.ndarray = np.empty(
401 (len(dos_list), len(self._energies)), float,
402 )
403 self._info = []
405 for i, dos_data in enumerate(dos_list):
406 if not isinstance(dos_data, GridDOSData):
407 raise TypeError("GridDOSCollection can only store "
408 "GridDOSData objects.")
409 if (dos_data.get_energies().shape != self._energies.shape
410 or not np.allclose(dos_data.get_energies(),
411 self._energies)):
412 raise ValueError("All GridDOSData objects in GridDOSCollection"
413 " must have the same energy axis.")
414 self._weights[i, :] = dos_data.get_weights()
415 self._info.append(dos_data.info)
417 def get_energies(self) -> Floats:
418 return self._energies.copy()
420 def get_all_weights(self) -> Sequence[Floats] | np.ndarray:
421 return self._weights.copy()
423 def __len__(self) -> int:
424 return self._weights.shape[0]
426 @overload # noqa F811
427 def __getitem__(self, item: int) -> DOSData:
428 ...
430 @overload # noqa F811
431 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811
432 ...
434 def __getitem__(self, item): # noqa F811
435 if isinstance(item, int):
436 return GridDOSData(self._energies, self._weights[item, :],
437 info=self._info[item])
438 elif isinstance(item, slice):
439 return type(self)([self[i] for i in range(len(self))[item]])
440 else:
441 raise TypeError("index in DOSCollection must be an integer or "
442 "slice")
444 @classmethod
445 def from_data(cls,
446 energies: Floats,
447 weights: Sequence[Floats],
448 info: Sequence[Info] | None = None) -> 'GridDOSCollection':
449 """Create a GridDOSCollection from data with a common set of energies
451 This convenience method may also be more efficient as it limits
452 redundant copying/checking of the data.
454 Args:
455 energies: common set of energy values for input data
456 weights: array of DOS weights with rows corresponding to different
457 datasets
458 info: sequence of info dicts corresponding to weights rows.
460 Returns
461 -------
462 Collection of DOS data (in RawDOSData format)
463 """
465 weights_array = np.asarray(weights, dtype=float)
466 if len(weights_array.shape) != 2:
467 raise IndexError("Weights must be a 2-D array or nested sequence")
468 if weights_array.shape[0] < 1:
469 raise IndexError("Weights cannot be empty")
470 if weights_array.shape[1] != len(energies):
471 raise IndexError("Length of weights rows must equal size of x")
473 info = cls._check_weights_and_info(weights, info)
475 dos_collection = cls([GridDOSData(energies, weights_array[0])])
476 dos_collection._weights = weights_array
477 dos_collection._info = list(info)
479 return dos_collection
481 def select(self, **info_selection: str) -> 'DOSCollection':
482 """Narrow GridDOSCollection to items with specified info
484 For example, if ::
486 dc = GridDOSCollection([GridDOSData(x, y1,
487 info={'a': '1', 'b': '1'}),
488 GridDOSData(x, y2,
489 info={'a': '2', 'b': '1'})])
491 then ::
493 dc.select(b='1')
495 will return an identical object to dc, while ::
497 dc.select(a='1')
499 will return a DOSCollection with only the first item and ::
501 dc.select(a='2', b='1')
503 will return a DOSCollection with only the second item.
505 """
507 matches = self._select_to_list(self, info_selection)
508 if len(matches) == 0:
509 return type(self)([], energies=self._energies)
510 else:
511 return type(self)(matches)
513 def select_not(self, **info_selection: str) -> 'DOSCollection':
514 """Narrow GridDOSCollection to items without specified info
516 For example, if ::
518 dc = GridDOSCollection([GridDOSData(x, y1,
519 info={'a': '1', 'b': '1'}),
520 GridDOSData(x, y2,
521 info={'a': '2', 'b': '1'})])
523 then ::
525 dc.select_not(b='2')
527 will return an identical object to dc, while ::
529 dc.select_not(a='2')
531 will return a DOSCollection with only the first item and ::
533 dc.select_not(a='1', b='1')
535 will return a DOSCollection with only the second item.
537 """
538 matches = self._select_to_list(self, info_selection, negative=True)
539 if len(matches) == 0:
540 return type(self)([], energies=self._energies)
541 else:
542 return type(self)(matches)
544 def plot(self,
545 npts: int = 0,
546 xmin: float | None = None,
547 xmax: float | None = None,
548 width: float | None = None,
549 smearing: str = 'Gauss',
550 ax: Axes | None = None,
551 show: bool = False,
552 filename: str | None = None,
553 mplargs: dict | None = None) -> Axes:
554 """Simple plot of collected DOS data, resampled onto a grid
556 If the special key 'label' is present in self.info, this will be set
557 as the label for the plotted line (unless overruled in mplargs). The
558 label is only seen if a legend is added to the plot (i.e. by calling
559 `ax.legend()`).
561 Args:
562 npts:
563 Number of points in resampled x-axis. If set to zero (default),
564 no resampling is performed and the stored data is plotted
565 directly.
566 xmin, xmax:
567 output data range; this limits the resampling range as well as
568 the plotting output
569 width: Width of broadening kernel, passed to self.sample()
570 smearing: selection of broadening kernel, passed to self.sample()
571 ax: existing Matplotlib axes object. If not provided, a new figure
572 with one set of axes will be created using Pyplot
573 show: show the figure on-screen
574 filename: if a path is given, save the figure to this file
575 mplargs: additional arguments to pass to matplotlib plot command
576 (e.g. {'linewidth': 2} for a thicker line).
578 Returns
579 -------
580 Plotting axes. If "ax" was set, this is the same object.
581 """
583 # Apply defaults if necessary
584 npts, width = GridDOSData._interpret_smearing_args(npts, width)
586 if npts:
587 assert isinstance(width, float)
588 dos = self.sample_grid(npts,
589 xmin=xmin, xmax=xmax,
590 width=width, smearing=smearing)
591 else:
592 dos = self
594 energies, all_y = dos._energies, dos._weights
596 all_labels = [DOSData.label_from_info(data.info) for data in self]
598 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
599 self._plot_broadened(ax, energies, all_y, all_labels, mplargs)
601 return ax
603 @staticmethod
604 def _plot_broadened(ax: Axes,
605 energies: Floats,
606 all_y: np.ndarray,
607 all_labels: Sequence[str],
608 mplargs: dict | None):
609 """Plot DOS data with labels to axes
611 This is separated into another function so that subclasses can
612 manipulate broadening, labels etc in their plot() method."""
613 if mplargs is None:
614 mplargs = {}
616 all_lines = ax.plot(energies, all_y.T, **mplargs)
617 for line, label in zip(all_lines, all_labels):
618 line.set_label(label)
619 ax.legend()
621 ax.set_xlim(left=min(energies), right=max(energies))
622 ax.set_ylim(bottom=0)