Coverage for /builds/ase/ase/ase/spectrum/doscollection.py: 97.84%
185 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import collections
4from functools import reduce, singledispatch
5from typing import (
6 Any,
7 Dict,
8 Iterable,
9 List,
10 Optional,
11 Sequence,
12 TypeVar,
13 Union,
14 overload,
15)
17import numpy as np
18from matplotlib.axes import Axes
20from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData
21from ase.utils.plotting import SimplePlottingAxes
24class DOSCollection(collections.abc.Sequence):
25 """Base class for a collection of DOSData objects"""
27 def __init__(self, dos_series: Iterable[DOSData]) -> None:
28 self._data = list(dos_series)
30 def _sample(self,
31 energies: Floats,
32 width: float = 0.1,
33 smearing: str = 'Gauss') -> np.ndarray:
34 """Sample the DOS data at chosen points, with broadening
36 This samples the underlying DOS data in the same way as the .sample()
37 method of those DOSData items, returning a 2-D array with columns
38 corresponding to x and rows corresponding to the collected data series.
40 Args:
41 energies: energy values for sampling
42 width: Width of broadening kernel
43 smearing: selection of broadening kernel (only "Gauss" is currently
44 supported)
46 Returns:
47 Weights sampled from a broadened DOS at values corresponding to x,
48 in rows corresponding to DOSData entries contained in this object
49 """
51 if len(self) == 0:
52 raise IndexError("No data to sample")
54 return np.asarray(
55 [data._sample(energies, width=width, smearing=smearing)
56 for data in self])
58 def plot(self,
59 npts: int = 1000,
60 xmin: float = None,
61 xmax: float = None,
62 width: float = 0.1,
63 smearing: str = 'Gauss',
64 ax: Axes = None,
65 show: bool = False,
66 filename: str = None,
67 mplargs: dict = None) -> Axes:
68 """Simple plot of collected DOS data, resampled onto a grid
70 If the special key 'label' is present in self.info, this will be set
71 as the label for the plotted line (unless overruled in mplargs). The
72 label is only seen if a legend is added to the plot (i.e. by calling
73 `ax.legend()`).
75 Args:
76 npts, xmin, xmax: output data range, as passed to self.sample_grid
77 width: Width of broadening kernel, passed to self.sample_grid()
78 smearing: selection of broadening kernel for self.sample_grid()
79 ax: existing Matplotlib axes object. If not provided, a new figure
80 with one set of axes will be created using Pyplot
81 show: show the figure on-screen
82 filename: if a path is given, save the figure to this file
83 mplargs: additional arguments to pass to matplotlib plot command
84 (e.g. {'linewidth': 2} for a thicker line).
86 Returns:
87 Plotting axes. If "ax" was set, this is the same object.
88 """
89 return self.sample_grid(npts,
90 xmin=xmin, xmax=xmax,
91 width=width, smearing=smearing
92 ).plot(npts=npts,
93 xmin=xmin, xmax=xmax,
94 width=width, smearing=smearing,
95 ax=ax, show=show, filename=filename,
96 mplargs=mplargs)
98 def sample_grid(self,
99 npts: int,
100 xmin: float = None,
101 xmax: float = None,
102 padding: float = 3,
103 width: float = 0.1,
104 smearing: str = 'Gauss',
105 ) -> 'GridDOSCollection':
106 """Sample the DOS data on an evenly-spaced energy grid
108 Args:
109 npts: Number of sampled points
110 xmin: Minimum sampled energy value; if unspecified, a default is
111 chosen
112 xmax: Maximum sampled energy value; if unspecified, a default is
113 chosen
114 padding: If xmin/xmax is unspecified, default value will be padded
115 by padding * width to avoid cutting off peaks.
116 width: Width of broadening kernel, passed to self.sample_grid()
117 smearing: selection of broadening kernel, for self.sample_grid()
119 Returns:
120 (energy values, sampled DOS)
121 """
122 if len(self) == 0:
123 raise IndexError("No data to sample")
125 if xmin is None:
126 xmin = (min(min(data.get_energies()) for data in self)
127 - (padding * width))
128 if xmax is None:
129 xmax = (max(max(data.get_energies()) for data in self)
130 + (padding * width))
132 return GridDOSCollection(
133 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width,
134 smearing=smearing)
135 for data in self])
137 @classmethod
138 def from_data(cls,
139 energies: Floats,
140 weights: Sequence[Floats],
141 info: Sequence[Info] = None) -> 'DOSCollection':
142 """Create a DOSCollection from data sharing a common set of energies
144 This is a convenience method to be used when all the DOS data in the
145 collection has a common energy axis. There is no performance advantage
146 in using this method for the generic DOSCollection, but for
147 GridDOSCollection it is more efficient.
149 Args:
150 energy: common set of energy values for input data
151 weights: array of DOS weights with rows corresponding to different
152 datasets
153 info: sequence of info dicts corresponding to weights rows.
155 Returns:
156 Collection of DOS data (in RawDOSData format)
157 """
159 info = cls._check_weights_and_info(weights, info)
161 return cls(RawDOSData(energies, row_weights, row_info)
162 for row_weights, row_info in zip(weights, info))
164 @staticmethod
165 def _check_weights_and_info(weights: Sequence[Floats],
166 info: Optional[Sequence[Info]],
167 ) -> Sequence[Info]:
168 if info is None:
169 info = [{} for _ in range(len(weights))]
170 else:
171 if len(info) != len(weights):
172 raise ValueError("Length of info must match number of rows in "
173 "weights")
174 return info
176 @overload
177 def __getitem__(self, item: int) -> DOSData:
178 ...
180 @overload # noqa F811
181 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811
182 ...
184 def __getitem__(self, item): # noqa F811
185 if isinstance(item, int):
186 return self._data[item]
187 elif isinstance(item, slice):
188 return type(self)(self._data[item])
189 else:
190 raise TypeError("index in DOSCollection must be an integer or "
191 "slice")
193 def __len__(self) -> int:
194 return len(self._data)
196 def _almost_equals(self, other: Any) -> bool:
197 """Compare with another DOSCollection for testing purposes"""
198 if not isinstance(other, type(self)):
199 return False
200 elif len(self) != len(other):
201 return False
202 else:
203 return all(a._almost_equals(b) for a, b in zip(self, other))
205 def total(self) -> DOSData:
206 """Sum all the DOSData in this Collection and label it as 'Total'"""
207 data = self.sum_all()
208 data.info.update({'label': 'Total'})
209 return data
211 def sum_all(self) -> DOSData:
212 """Sum all the DOSData contained in this Collection"""
213 if len(self) == 0:
214 raise IndexError("No data to sum")
215 elif len(self) == 1:
216 data = self[0].copy()
217 else:
218 data = reduce(lambda x, y: x + y, self)
219 return data
221 D = TypeVar('D', bound=DOSData)
223 @staticmethod
224 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes
225 info_selection: Dict[str, str], # misses 'D' def
226 negative: bool = False) -> List[D]: # noqa: F821
227 query = set(info_selection.items())
229 if negative:
230 return [data for data in dos_collection
231 if not query.issubset(set(data.info.items()))]
232 else:
233 return [data for data in dos_collection
234 if query.issubset(set(data.info.items()))]
236 def select(self, **info_selection: str) -> 'DOSCollection':
237 """Narrow DOSCollection to items with specified info
239 For example, if ::
241 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
242 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
244 then ::
246 dc.select(b='1')
248 will return an identical object to dc, while ::
250 dc.select(a='1')
252 will return a DOSCollection with only the first item and ::
254 dc.select(a='2', b='1')
256 will return a DOSCollection with only the second item.
258 """
260 matches = self._select_to_list(self, info_selection)
261 return type(self)(matches)
263 def select_not(self, **info_selection: str) -> 'DOSCollection':
264 """Narrow DOSCollection to items without specified info
266 For example, if ::
268 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
269 DOSData(x2, y2, info={'a': '2', 'b': '1'})])
271 then ::
273 dc.select_not(b='2')
275 will return an identical object to dc, while ::
277 dc.select_not(a='2')
279 will return a DOSCollection with only the first item and ::
281 dc.select_not(a='1', b='1')
283 will return a DOSCollection with only the second item.
285 """
286 matches = self._select_to_list(self, info_selection, negative=True)
287 return type(self)(matches)
289 # Use typehint *info_keys: str from python3.11+
290 def sum_by(self, *info_keys) -> 'DOSCollection':
291 """Return a DOSCollection with some data summed by common attributes
293 For example, if ::
295 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}),
296 DOSData(x2, y2, info={'a': '2', 'b': '1'}),
297 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
299 then ::
301 dc.sum_by('b')
303 will return a collection equivalent to ::
305 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'})
306 + DOSData(x2, y2, info={'a': '2', 'b': '1'}),
307 DOSData(x3, y3, info={'a': '2', 'b': '2'})])
309 where the resulting contained DOSData have info attributes of
310 {'b': '1'} and {'b': '2'} respectively.
312 dc.sum_by('a', 'b') on the other hand would return the full three-entry
313 collection, as none of the entries have common 'a' *and* 'b' info.
315 """
317 def _matching_info_tuples(data: DOSData):
318 """Get relevent dict entries in tuple form
320 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3}
321 and info_keys = ('a', 'c')
323 then return (('a', 1), ('c': 3))
324 """
325 matched_keys = set(info_keys) & set(data.info)
326 return tuple(sorted([(key, data.info[key])
327 for key in matched_keys]))
329 # Sorting inside info matching helps set() to remove redundant matches;
330 # combos are then sorted() to ensure consistent output across sessions.
331 all_combos = map(_matching_info_tuples, self)
332 unique_combos = sorted(set(all_combos))
334 # For each key/value combination, perform a select() to obtain all
335 # the matching entries and sum them together.
336 collection_data = [self.select(**dict(combo)).sum_all()
337 for combo in unique_combos]
338 return type(self)(collection_data)
340 def __add__(self, other: Union['DOSCollection', DOSData]
341 ) -> 'DOSCollection':
342 """Join entries between two DOSCollection objects of the same type
344 It is also possible to add a single DOSData object without wrapping it
345 in a new collection: i.e. ::
347 DOSCollection([dosdata1]) + DOSCollection([dosdata2])
349 or ::
351 DOSCollection([dosdata1]) + dosdata2
353 will return ::
355 DOSCollection([dosdata1, dosdata2])
357 """
358 return _add_to_collection(other, self)
361@singledispatch
362def _add_to_collection(other: Union[DOSData, DOSCollection],
363 collection: DOSCollection) -> DOSCollection:
364 if isinstance(other, type(collection)):
365 return type(collection)(list(collection) + list(other))
366 elif isinstance(other, DOSCollection):
367 raise TypeError("Only DOSCollection objects of the same type may "
368 "be joined with '+'.")
369 else:
370 raise TypeError("DOSCollection may only be joined to DOSData or "
371 "DOSCollection objects with '+'.")
374@_add_to_collection.register(DOSData)
375def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection:
376 """Return a new DOSCollection with an additional DOSData item"""
377 return type(collection)(list(collection) + [other])
380class RawDOSCollection(DOSCollection):
381 def __init__(self, dos_series: Iterable[RawDOSData]) -> None:
382 super().__init__(dos_series)
383 for dos_data in self:
384 if not isinstance(dos_data, RawDOSData):
385 raise TypeError("RawDOSCollection can only store "
386 "RawDOSData objects.")
389class GridDOSCollection(DOSCollection):
390 def __init__(self, dos_series: Iterable[GridDOSData],
391 energies: Optional[Floats] = None) -> None:
392 dos_list = list(dos_series)
393 if energies is None:
394 if len(dos_list) == 0:
395 raise ValueError("Must provide energies to create a "
396 "GridDOSCollection without any DOS data.")
397 self._energies = dos_list[0].get_energies()
398 else:
399 self._energies = np.asarray(energies)
401 self._weights: np.ndarray = np.empty(
402 (len(dos_list), len(self._energies)), float,
403 )
404 self._info = []
406 for i, dos_data in enumerate(dos_list):
407 if not isinstance(dos_data, GridDOSData):
408 raise TypeError("GridDOSCollection can only store "
409 "GridDOSData objects.")
410 if (dos_data.get_energies().shape != self._energies.shape
411 or not np.allclose(dos_data.get_energies(),
412 self._energies)):
413 raise ValueError("All GridDOSData objects in GridDOSCollection"
414 " must have the same energy axis.")
415 self._weights[i, :] = dos_data.get_weights()
416 self._info.append(dos_data.info)
418 def get_energies(self) -> Floats:
419 return self._energies.copy()
421 def get_all_weights(self) -> Union[Sequence[Floats], np.ndarray]:
422 return self._weights.copy()
424 def __len__(self) -> int:
425 return self._weights.shape[0]
427 @overload # noqa F811
428 def __getitem__(self, item: int) -> DOSData:
429 ...
431 @overload # noqa F811
432 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811
433 ...
435 def __getitem__(self, item): # noqa F811
436 if isinstance(item, int):
437 return GridDOSData(self._energies, self._weights[item, :],
438 info=self._info[item])
439 elif isinstance(item, slice):
440 return type(self)([self[i] for i in range(len(self))[item]])
441 else:
442 raise TypeError("index in DOSCollection must be an integer or "
443 "slice")
445 @classmethod
446 def from_data(cls,
447 energies: Floats,
448 weights: Sequence[Floats],
449 info: Sequence[Info] = None) -> 'GridDOSCollection':
450 """Create a GridDOSCollection from data with a common set of energies
452 This convenience method may also be more efficient as it limits
453 redundant copying/checking of the data.
455 Args:
456 energies: common set of energy values for input data
457 weights: array of DOS weights with rows corresponding to different
458 datasets
459 info: sequence of info dicts corresponding to weights rows.
461 Returns:
462 Collection of DOS data (in RawDOSData format)
463 """
465 weights_array = np.asarray(weights, dtype=float)
466 if len(weights_array.shape) != 2:
467 raise IndexError("Weights must be a 2-D array or nested sequence")
468 if weights_array.shape[0] < 1:
469 raise IndexError("Weights cannot be empty")
470 if weights_array.shape[1] != len(energies):
471 raise IndexError("Length of weights rows must equal size of x")
473 info = cls._check_weights_and_info(weights, info)
475 dos_collection = cls([GridDOSData(energies, weights_array[0])])
476 dos_collection._weights = weights_array
477 dos_collection._info = list(info)
479 return dos_collection
481 def select(self, **info_selection: str) -> 'DOSCollection':
482 """Narrow GridDOSCollection to items with specified info
484 For example, if ::
486 dc = GridDOSCollection([GridDOSData(x, y1,
487 info={'a': '1', 'b': '1'}),
488 GridDOSData(x, y2,
489 info={'a': '2', 'b': '1'})])
491 then ::
493 dc.select(b='1')
495 will return an identical object to dc, while ::
497 dc.select(a='1')
499 will return a DOSCollection with only the first item and ::
501 dc.select(a='2', b='1')
503 will return a DOSCollection with only the second item.
505 """
507 matches = self._select_to_list(self, info_selection)
508 if len(matches) == 0:
509 return type(self)([], energies=self._energies)
510 else:
511 return type(self)(matches)
513 def select_not(self, **info_selection: str) -> 'DOSCollection':
514 """Narrow GridDOSCollection to items without specified info
516 For example, if ::
518 dc = GridDOSCollection([GridDOSData(x, y1,
519 info={'a': '1', 'b': '1'}),
520 GridDOSData(x, y2,
521 info={'a': '2', 'b': '1'})])
523 then ::
525 dc.select_not(b='2')
527 will return an identical object to dc, while ::
529 dc.select_not(a='2')
531 will return a DOSCollection with only the first item and ::
533 dc.select_not(a='1', b='1')
535 will return a DOSCollection with only the second item.
537 """
538 matches = self._select_to_list(self, info_selection, negative=True)
539 if len(matches) == 0:
540 return type(self)([], energies=self._energies)
541 else:
542 return type(self)(matches)
544 def plot(self,
545 npts: int = 0,
546 xmin: float = None,
547 xmax: float = None,
548 width: float = None,
549 smearing: str = 'Gauss',
550 ax: Axes = None,
551 show: bool = False,
552 filename: str = None,
553 mplargs: dict = None) -> Axes:
554 """Simple plot of collected DOS data, resampled onto a grid
556 If the special key 'label' is present in self.info, this will be set
557 as the label for the plotted line (unless overruled in mplargs). The
558 label is only seen if a legend is added to the plot (i.e. by calling
559 `ax.legend()`).
561 Args:
562 npts:
563 Number of points in resampled x-axis. If set to zero (default),
564 no resampling is performed and the stored data is plotted
565 directly.
566 xmin, xmax:
567 output data range; this limits the resampling range as well as
568 the plotting output
569 width: Width of broadening kernel, passed to self.sample()
570 smearing: selection of broadening kernel, passed to self.sample()
571 ax: existing Matplotlib axes object. If not provided, a new figure
572 with one set of axes will be created using Pyplot
573 show: show the figure on-screen
574 filename: if a path is given, save the figure to this file
575 mplargs: additional arguments to pass to matplotlib plot command
576 (e.g. {'linewidth': 2} for a thicker line).
578 Returns:
579 Plotting axes. If "ax" was set, this is the same object.
580 """
582 # Apply defaults if necessary
583 npts, width = GridDOSData._interpret_smearing_args(npts, width)
585 if npts:
586 assert isinstance(width, float)
587 dos = self.sample_grid(npts,
588 xmin=xmin, xmax=xmax,
589 width=width, smearing=smearing)
590 else:
591 dos = self
593 energies, all_y = dos._energies, dos._weights
595 all_labels = [DOSData.label_from_info(data.info) for data in self]
597 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax:
598 self._plot_broadened(ax, energies, all_y, all_labels, mplargs)
600 return ax
602 @staticmethod
603 def _plot_broadened(ax: Axes,
604 energies: Floats,
605 all_y: np.ndarray,
606 all_labels: Sequence[str],
607 mplargs: Optional[Dict]):
608 """Plot DOS data with labels to axes
610 This is separated into another function so that subclasses can
611 manipulate broadening, labels etc in their plot() method."""
612 if mplargs is None:
613 mplargs = {}
615 all_lines = ax.plot(energies, all_y.T, **mplargs)
616 for line, label in zip(all_lines, all_labels):
617 line.set_label(label)
618 ax.legend()
620 ax.set_xlim(left=min(energies), right=max(energies))
621 ax.set_ylim(bottom=0)