Coverage for /builds/ase/ase/ase/spectrum/doscollection.py: 97.84%

185 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import collections 

4from functools import reduce, singledispatch 

5from typing import ( 

6 Any, 

7 Dict, 

8 Iterable, 

9 List, 

10 Optional, 

11 Sequence, 

12 TypeVar, 

13 Union, 

14 overload, 

15) 

16 

17import numpy as np 

18from matplotlib.axes import Axes 

19 

20from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData 

21from ase.utils.plotting import SimplePlottingAxes 

22 

23 

24class DOSCollection(collections.abc.Sequence): 

25 """Base class for a collection of DOSData objects""" 

26 

27 def __init__(self, dos_series: Iterable[DOSData]) -> None: 

28 self._data = list(dos_series) 

29 

30 def _sample(self, 

31 energies: Floats, 

32 width: float = 0.1, 

33 smearing: str = 'Gauss') -> np.ndarray: 

34 """Sample the DOS data at chosen points, with broadening 

35 

36 This samples the underlying DOS data in the same way as the .sample() 

37 method of those DOSData items, returning a 2-D array with columns 

38 corresponding to x and rows corresponding to the collected data series. 

39 

40 Args: 

41 energies: energy values for sampling 

42 width: Width of broadening kernel 

43 smearing: selection of broadening kernel (only "Gauss" is currently 

44 supported) 

45 

46 Returns: 

47 Weights sampled from a broadened DOS at values corresponding to x, 

48 in rows corresponding to DOSData entries contained in this object 

49 """ 

50 

51 if len(self) == 0: 

52 raise IndexError("No data to sample") 

53 

54 return np.asarray( 

55 [data._sample(energies, width=width, smearing=smearing) 

56 for data in self]) 

57 

58 def plot(self, 

59 npts: int = 1000, 

60 xmin: float = None, 

61 xmax: float = None, 

62 width: float = 0.1, 

63 smearing: str = 'Gauss', 

64 ax: Axes = None, 

65 show: bool = False, 

66 filename: str = None, 

67 mplargs: dict = None) -> Axes: 

68 """Simple plot of collected DOS data, resampled onto a grid 

69 

70 If the special key 'label' is present in self.info, this will be set 

71 as the label for the plotted line (unless overruled in mplargs). The 

72 label is only seen if a legend is added to the plot (i.e. by calling 

73 `ax.legend()`). 

74 

75 Args: 

76 npts, xmin, xmax: output data range, as passed to self.sample_grid 

77 width: Width of broadening kernel, passed to self.sample_grid() 

78 smearing: selection of broadening kernel for self.sample_grid() 

79 ax: existing Matplotlib axes object. If not provided, a new figure 

80 with one set of axes will be created using Pyplot 

81 show: show the figure on-screen 

82 filename: if a path is given, save the figure to this file 

83 mplargs: additional arguments to pass to matplotlib plot command 

84 (e.g. {'linewidth': 2} for a thicker line). 

85 

86 Returns: 

87 Plotting axes. If "ax" was set, this is the same object. 

88 """ 

89 return self.sample_grid(npts, 

90 xmin=xmin, xmax=xmax, 

91 width=width, smearing=smearing 

92 ).plot(npts=npts, 

93 xmin=xmin, xmax=xmax, 

94 width=width, smearing=smearing, 

95 ax=ax, show=show, filename=filename, 

96 mplargs=mplargs) 

97 

98 def sample_grid(self, 

99 npts: int, 

100 xmin: float = None, 

101 xmax: float = None, 

102 padding: float = 3, 

103 width: float = 0.1, 

104 smearing: str = 'Gauss', 

105 ) -> 'GridDOSCollection': 

106 """Sample the DOS data on an evenly-spaced energy grid 

107 

108 Args: 

109 npts: Number of sampled points 

110 xmin: Minimum sampled energy value; if unspecified, a default is 

111 chosen 

112 xmax: Maximum sampled energy value; if unspecified, a default is 

113 chosen 

114 padding: If xmin/xmax is unspecified, default value will be padded 

115 by padding * width to avoid cutting off peaks. 

116 width: Width of broadening kernel, passed to self.sample_grid() 

117 smearing: selection of broadening kernel, for self.sample_grid() 

118 

119 Returns: 

120 (energy values, sampled DOS) 

121 """ 

122 if len(self) == 0: 

123 raise IndexError("No data to sample") 

124 

125 if xmin is None: 

126 xmin = (min(min(data.get_energies()) for data in self) 

127 - (padding * width)) 

128 if xmax is None: 

129 xmax = (max(max(data.get_energies()) for data in self) 

130 + (padding * width)) 

131 

132 return GridDOSCollection( 

133 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width, 

134 smearing=smearing) 

135 for data in self]) 

136 

137 @classmethod 

138 def from_data(cls, 

139 energies: Floats, 

140 weights: Sequence[Floats], 

141 info: Sequence[Info] = None) -> 'DOSCollection': 

142 """Create a DOSCollection from data sharing a common set of energies 

143 

144 This is a convenience method to be used when all the DOS data in the 

145 collection has a common energy axis. There is no performance advantage 

146 in using this method for the generic DOSCollection, but for 

147 GridDOSCollection it is more efficient. 

148 

149 Args: 

150 energy: common set of energy values for input data 

151 weights: array of DOS weights with rows corresponding to different 

152 datasets 

153 info: sequence of info dicts corresponding to weights rows. 

154 

155 Returns: 

156 Collection of DOS data (in RawDOSData format) 

157 """ 

158 

159 info = cls._check_weights_and_info(weights, info) 

160 

161 return cls(RawDOSData(energies, row_weights, row_info) 

162 for row_weights, row_info in zip(weights, info)) 

163 

164 @staticmethod 

165 def _check_weights_and_info(weights: Sequence[Floats], 

166 info: Optional[Sequence[Info]], 

167 ) -> Sequence[Info]: 

168 if info is None: 

169 info = [{} for _ in range(len(weights))] 

170 else: 

171 if len(info) != len(weights): 

172 raise ValueError("Length of info must match number of rows in " 

173 "weights") 

174 return info 

175 

176 @overload 

177 def __getitem__(self, item: int) -> DOSData: 

178 ... 

179 

180 @overload # noqa F811 

181 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811 

182 ... 

183 

184 def __getitem__(self, item): # noqa F811 

185 if isinstance(item, int): 

186 return self._data[item] 

187 elif isinstance(item, slice): 

188 return type(self)(self._data[item]) 

189 else: 

190 raise TypeError("index in DOSCollection must be an integer or " 

191 "slice") 

192 

193 def __len__(self) -> int: 

194 return len(self._data) 

195 

196 def _almost_equals(self, other: Any) -> bool: 

197 """Compare with another DOSCollection for testing purposes""" 

198 if not isinstance(other, type(self)): 

199 return False 

200 elif len(self) != len(other): 

201 return False 

202 else: 

203 return all(a._almost_equals(b) for a, b in zip(self, other)) 

204 

205 def total(self) -> DOSData: 

206 """Sum all the DOSData in this Collection and label it as 'Total'""" 

207 data = self.sum_all() 

208 data.info.update({'label': 'Total'}) 

209 return data 

210 

211 def sum_all(self) -> DOSData: 

212 """Sum all the DOSData contained in this Collection""" 

213 if len(self) == 0: 

214 raise IndexError("No data to sum") 

215 elif len(self) == 1: 

216 data = self[0].copy() 

217 else: 

218 data = reduce(lambda x, y: x + y, self) 

219 return data 

220 

221 D = TypeVar('D', bound=DOSData) 

222 

223 @staticmethod 

224 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes 

225 info_selection: Dict[str, str], # misses 'D' def 

226 negative: bool = False) -> List[D]: # noqa: F821 

227 query = set(info_selection.items()) 

228 

229 if negative: 

230 return [data for data in dos_collection 

231 if not query.issubset(set(data.info.items()))] 

232 else: 

233 return [data for data in dos_collection 

234 if query.issubset(set(data.info.items()))] 

235 

236 def select(self, **info_selection: str) -> 'DOSCollection': 

237 """Narrow DOSCollection to items with specified info 

238 

239 For example, if :: 

240 

241 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

242 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

243 

244 then :: 

245 

246 dc.select(b='1') 

247 

248 will return an identical object to dc, while :: 

249 

250 dc.select(a='1') 

251 

252 will return a DOSCollection with only the first item and :: 

253 

254 dc.select(a='2', b='1') 

255 

256 will return a DOSCollection with only the second item. 

257 

258 """ 

259 

260 matches = self._select_to_list(self, info_selection) 

261 return type(self)(matches) 

262 

263 def select_not(self, **info_selection: str) -> 'DOSCollection': 

264 """Narrow DOSCollection to items without specified info 

265 

266 For example, if :: 

267 

268 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

269 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

270 

271 then :: 

272 

273 dc.select_not(b='2') 

274 

275 will return an identical object to dc, while :: 

276 

277 dc.select_not(a='2') 

278 

279 will return a DOSCollection with only the first item and :: 

280 

281 dc.select_not(a='1', b='1') 

282 

283 will return a DOSCollection with only the second item. 

284 

285 """ 

286 matches = self._select_to_list(self, info_selection, negative=True) 

287 return type(self)(matches) 

288 

289 # Use typehint *info_keys: str from python3.11+ 

290 def sum_by(self, *info_keys) -> 'DOSCollection': 

291 """Return a DOSCollection with some data summed by common attributes 

292 

293 For example, if :: 

294 

295 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

296 DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

297 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

298 

299 then :: 

300 

301 dc.sum_by('b') 

302 

303 will return a collection equivalent to :: 

304 

305 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}) 

306 + DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

307 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

308 

309 where the resulting contained DOSData have info attributes of 

310 {'b': '1'} and {'b': '2'} respectively. 

311 

312 dc.sum_by('a', 'b') on the other hand would return the full three-entry 

313 collection, as none of the entries have common 'a' *and* 'b' info. 

314 

315 """ 

316 

317 def _matching_info_tuples(data: DOSData): 

318 """Get relevent dict entries in tuple form 

319 

320 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3} 

321 and info_keys = ('a', 'c') 

322 

323 then return (('a', 1), ('c': 3)) 

324 """ 

325 matched_keys = set(info_keys) & set(data.info) 

326 return tuple(sorted([(key, data.info[key]) 

327 for key in matched_keys])) 

328 

329 # Sorting inside info matching helps set() to remove redundant matches; 

330 # combos are then sorted() to ensure consistent output across sessions. 

331 all_combos = map(_matching_info_tuples, self) 

332 unique_combos = sorted(set(all_combos)) 

333 

334 # For each key/value combination, perform a select() to obtain all 

335 # the matching entries and sum them together. 

336 collection_data = [self.select(**dict(combo)).sum_all() 

337 for combo in unique_combos] 

338 return type(self)(collection_data) 

339 

340 def __add__(self, other: Union['DOSCollection', DOSData] 

341 ) -> 'DOSCollection': 

342 """Join entries between two DOSCollection objects of the same type 

343 

344 It is also possible to add a single DOSData object without wrapping it 

345 in a new collection: i.e. :: 

346 

347 DOSCollection([dosdata1]) + DOSCollection([dosdata2]) 

348 

349 or :: 

350 

351 DOSCollection([dosdata1]) + dosdata2 

352 

353 will return :: 

354 

355 DOSCollection([dosdata1, dosdata2]) 

356 

357 """ 

358 return _add_to_collection(other, self) 

359 

360 

361@singledispatch 

362def _add_to_collection(other: Union[DOSData, DOSCollection], 

363 collection: DOSCollection) -> DOSCollection: 

364 if isinstance(other, type(collection)): 

365 return type(collection)(list(collection) + list(other)) 

366 elif isinstance(other, DOSCollection): 

367 raise TypeError("Only DOSCollection objects of the same type may " 

368 "be joined with '+'.") 

369 else: 

370 raise TypeError("DOSCollection may only be joined to DOSData or " 

371 "DOSCollection objects with '+'.") 

372 

373 

374@_add_to_collection.register(DOSData) 

375def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection: 

376 """Return a new DOSCollection with an additional DOSData item""" 

377 return type(collection)(list(collection) + [other]) 

378 

379 

380class RawDOSCollection(DOSCollection): 

381 def __init__(self, dos_series: Iterable[RawDOSData]) -> None: 

382 super().__init__(dos_series) 

383 for dos_data in self: 

384 if not isinstance(dos_data, RawDOSData): 

385 raise TypeError("RawDOSCollection can only store " 

386 "RawDOSData objects.") 

387 

388 

389class GridDOSCollection(DOSCollection): 

390 def __init__(self, dos_series: Iterable[GridDOSData], 

391 energies: Optional[Floats] = None) -> None: 

392 dos_list = list(dos_series) 

393 if energies is None: 

394 if len(dos_list) == 0: 

395 raise ValueError("Must provide energies to create a " 

396 "GridDOSCollection without any DOS data.") 

397 self._energies = dos_list[0].get_energies() 

398 else: 

399 self._energies = np.asarray(energies) 

400 

401 self._weights: np.ndarray = np.empty( 

402 (len(dos_list), len(self._energies)), float, 

403 ) 

404 self._info = [] 

405 

406 for i, dos_data in enumerate(dos_list): 

407 if not isinstance(dos_data, GridDOSData): 

408 raise TypeError("GridDOSCollection can only store " 

409 "GridDOSData objects.") 

410 if (dos_data.get_energies().shape != self._energies.shape 

411 or not np.allclose(dos_data.get_energies(), 

412 self._energies)): 

413 raise ValueError("All GridDOSData objects in GridDOSCollection" 

414 " must have the same energy axis.") 

415 self._weights[i, :] = dos_data.get_weights() 

416 self._info.append(dos_data.info) 

417 

418 def get_energies(self) -> Floats: 

419 return self._energies.copy() 

420 

421 def get_all_weights(self) -> Union[Sequence[Floats], np.ndarray]: 

422 return self._weights.copy() 

423 

424 def __len__(self) -> int: 

425 return self._weights.shape[0] 

426 

427 @overload # noqa F811 

428 def __getitem__(self, item: int) -> DOSData: 

429 ... 

430 

431 @overload # noqa F811 

432 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811 

433 ... 

434 

435 def __getitem__(self, item): # noqa F811 

436 if isinstance(item, int): 

437 return GridDOSData(self._energies, self._weights[item, :], 

438 info=self._info[item]) 

439 elif isinstance(item, slice): 

440 return type(self)([self[i] for i in range(len(self))[item]]) 

441 else: 

442 raise TypeError("index in DOSCollection must be an integer or " 

443 "slice") 

444 

445 @classmethod 

446 def from_data(cls, 

447 energies: Floats, 

448 weights: Sequence[Floats], 

449 info: Sequence[Info] = None) -> 'GridDOSCollection': 

450 """Create a GridDOSCollection from data with a common set of energies 

451 

452 This convenience method may also be more efficient as it limits 

453 redundant copying/checking of the data. 

454 

455 Args: 

456 energies: common set of energy values for input data 

457 weights: array of DOS weights with rows corresponding to different 

458 datasets 

459 info: sequence of info dicts corresponding to weights rows. 

460 

461 Returns: 

462 Collection of DOS data (in RawDOSData format) 

463 """ 

464 

465 weights_array = np.asarray(weights, dtype=float) 

466 if len(weights_array.shape) != 2: 

467 raise IndexError("Weights must be a 2-D array or nested sequence") 

468 if weights_array.shape[0] < 1: 

469 raise IndexError("Weights cannot be empty") 

470 if weights_array.shape[1] != len(energies): 

471 raise IndexError("Length of weights rows must equal size of x") 

472 

473 info = cls._check_weights_and_info(weights, info) 

474 

475 dos_collection = cls([GridDOSData(energies, weights_array[0])]) 

476 dos_collection._weights = weights_array 

477 dos_collection._info = list(info) 

478 

479 return dos_collection 

480 

481 def select(self, **info_selection: str) -> 'DOSCollection': 

482 """Narrow GridDOSCollection to items with specified info 

483 

484 For example, if :: 

485 

486 dc = GridDOSCollection([GridDOSData(x, y1, 

487 info={'a': '1', 'b': '1'}), 

488 GridDOSData(x, y2, 

489 info={'a': '2', 'b': '1'})]) 

490 

491 then :: 

492 

493 dc.select(b='1') 

494 

495 will return an identical object to dc, while :: 

496 

497 dc.select(a='1') 

498 

499 will return a DOSCollection with only the first item and :: 

500 

501 dc.select(a='2', b='1') 

502 

503 will return a DOSCollection with only the second item. 

504 

505 """ 

506 

507 matches = self._select_to_list(self, info_selection) 

508 if len(matches) == 0: 

509 return type(self)([], energies=self._energies) 

510 else: 

511 return type(self)(matches) 

512 

513 def select_not(self, **info_selection: str) -> 'DOSCollection': 

514 """Narrow GridDOSCollection to items without specified info 

515 

516 For example, if :: 

517 

518 dc = GridDOSCollection([GridDOSData(x, y1, 

519 info={'a': '1', 'b': '1'}), 

520 GridDOSData(x, y2, 

521 info={'a': '2', 'b': '1'})]) 

522 

523 then :: 

524 

525 dc.select_not(b='2') 

526 

527 will return an identical object to dc, while :: 

528 

529 dc.select_not(a='2') 

530 

531 will return a DOSCollection with only the first item and :: 

532 

533 dc.select_not(a='1', b='1') 

534 

535 will return a DOSCollection with only the second item. 

536 

537 """ 

538 matches = self._select_to_list(self, info_selection, negative=True) 

539 if len(matches) == 0: 

540 return type(self)([], energies=self._energies) 

541 else: 

542 return type(self)(matches) 

543 

544 def plot(self, 

545 npts: int = 0, 

546 xmin: float = None, 

547 xmax: float = None, 

548 width: float = None, 

549 smearing: str = 'Gauss', 

550 ax: Axes = None, 

551 show: bool = False, 

552 filename: str = None, 

553 mplargs: dict = None) -> Axes: 

554 """Simple plot of collected DOS data, resampled onto a grid 

555 

556 If the special key 'label' is present in self.info, this will be set 

557 as the label for the plotted line (unless overruled in mplargs). The 

558 label is only seen if a legend is added to the plot (i.e. by calling 

559 `ax.legend()`). 

560 

561 Args: 

562 npts: 

563 Number of points in resampled x-axis. If set to zero (default), 

564 no resampling is performed and the stored data is plotted 

565 directly. 

566 xmin, xmax: 

567 output data range; this limits the resampling range as well as 

568 the plotting output 

569 width: Width of broadening kernel, passed to self.sample() 

570 smearing: selection of broadening kernel, passed to self.sample() 

571 ax: existing Matplotlib axes object. If not provided, a new figure 

572 with one set of axes will be created using Pyplot 

573 show: show the figure on-screen 

574 filename: if a path is given, save the figure to this file 

575 mplargs: additional arguments to pass to matplotlib plot command 

576 (e.g. {'linewidth': 2} for a thicker line). 

577 

578 Returns: 

579 Plotting axes. If "ax" was set, this is the same object. 

580 """ 

581 

582 # Apply defaults if necessary 

583 npts, width = GridDOSData._interpret_smearing_args(npts, width) 

584 

585 if npts: 

586 assert isinstance(width, float) 

587 dos = self.sample_grid(npts, 

588 xmin=xmin, xmax=xmax, 

589 width=width, smearing=smearing) 

590 else: 

591 dos = self 

592 

593 energies, all_y = dos._energies, dos._weights 

594 

595 all_labels = [DOSData.label_from_info(data.info) for data in self] 

596 

597 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

598 self._plot_broadened(ax, energies, all_y, all_labels, mplargs) 

599 

600 return ax 

601 

602 @staticmethod 

603 def _plot_broadened(ax: Axes, 

604 energies: Floats, 

605 all_y: np.ndarray, 

606 all_labels: Sequence[str], 

607 mplargs: Optional[Dict]): 

608 """Plot DOS data with labels to axes 

609 

610 This is separated into another function so that subclasses can 

611 manipulate broadening, labels etc in their plot() method.""" 

612 if mplargs is None: 

613 mplargs = {} 

614 

615 all_lines = ax.plot(energies, all_y.T, **mplargs) 

616 for line, label in zip(all_lines, all_labels): 

617 line.set_label(label) 

618 ax.legend() 

619 

620 ax.set_xlim(left=min(energies), right=max(energies)) 

621 ax.set_ylim(bottom=0)