Coverage for ase / spectrum / doscollection.py: 100.00%

178 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import collections 

4from collections.abc import Iterable, Sequence 

5from functools import reduce, singledispatch 

6from typing import ( 

7 Any, 

8 TypeVar, 

9 overload, 

10) 

11 

12import numpy as np 

13from matplotlib.axes import Axes 

14 

15from ase.spectrum.dosdata import DOSData, Floats, GridDOSData, Info, RawDOSData 

16from ase.utils.plotting import SimplePlottingAxes 

17 

18 

19class DOSCollection(collections.abc.Sequence): 

20 """Base class for a collection of DOSData objects""" 

21 

22 def __init__(self, dos_series: Iterable[DOSData]) -> None: 

23 self._data = list(dos_series) 

24 

25 def _sample(self, 

26 energies: Floats, 

27 width: float = 0.1, 

28 smearing: str = 'Gauss') -> np.ndarray: 

29 """Sample the DOS data at chosen points, with broadening 

30 

31 This samples the underlying DOS data in the same way as the .sample() 

32 method of those DOSData items, returning a 2-D array with columns 

33 corresponding to x and rows corresponding to the collected data series. 

34 

35 Args: 

36 energies: energy values for sampling 

37 width: Width of broadening kernel 

38 smearing: selection of broadening kernel (only "Gauss" is currently 

39 supported) 

40 

41 Returns 

42 ------- 

43 Weights sampled from a broadened DOS at values corresponding to x, 

44 in rows corresponding to DOSData entries contained in this object 

45 """ 

46 

47 if len(self) == 0: 

48 raise IndexError("No data to sample") 

49 

50 return np.asarray( 

51 [data._sample(energies, width=width, smearing=smearing) 

52 for data in self]) 

53 

54 def plot(self, 

55 npts: int = 1000, 

56 xmin: float | None = None, 

57 xmax: float | None = None, 

58 width: float = 0.1, 

59 smearing: str = 'Gauss', 

60 ax: Axes | None = None, 

61 show: bool = False, 

62 filename: str | None = None, 

63 mplargs: dict | None = None) -> Axes: 

64 """Simple plot of collected DOS data, resampled onto a grid 

65 

66 If the special key 'label' is present in self.info, this will be set 

67 as the label for the plotted line (unless overruled in mplargs). The 

68 label is only seen if a legend is added to the plot (i.e. by calling 

69 `ax.legend()`). 

70 

71 Args: 

72 npts, xmin, xmax: output data range, as passed to self.sample_grid 

73 width: Width of broadening kernel, passed to self.sample_grid() 

74 smearing: selection of broadening kernel for self.sample_grid() 

75 ax: existing Matplotlib axes object. If not provided, a new figure 

76 with one set of axes will be created using Pyplot 

77 show: show the figure on-screen 

78 filename: if a path is given, save the figure to this file 

79 mplargs: additional arguments to pass to matplotlib plot command 

80 (e.g. {'linewidth': 2} for a thicker line). 

81 

82 Returns 

83 ------- 

84 Plotting axes. If "ax" was set, this is the same object. 

85 """ 

86 return self.sample_grid(npts, 

87 xmin=xmin, xmax=xmax, 

88 width=width, smearing=smearing 

89 ).plot(npts=npts, 

90 xmin=xmin, xmax=xmax, 

91 width=width, smearing=smearing, 

92 ax=ax, show=show, filename=filename, 

93 mplargs=mplargs) 

94 

95 def sample_grid(self, 

96 npts: int, 

97 xmin: float | None = None, 

98 xmax: float | None = None, 

99 padding: float = 3, 

100 width: float = 0.1, 

101 smearing: str = 'Gauss', 

102 ) -> 'GridDOSCollection': 

103 """Sample the DOS data on an evenly-spaced energy grid 

104 

105 Args: 

106 npts: Number of sampled points 

107 xmin: Minimum sampled energy value; if unspecified, a default is 

108 chosen 

109 xmax: Maximum sampled energy value; if unspecified, a default is 

110 chosen 

111 padding: If xmin/xmax is unspecified, default value will be padded 

112 by padding * width to avoid cutting off peaks. 

113 width: Width of broadening kernel, passed to self.sample_grid() 

114 smearing: selection of broadening kernel, for self.sample_grid() 

115 

116 Returns 

117 ------- 

118 (energy values, sampled DOS) 

119 """ 

120 if len(self) == 0: 

121 raise IndexError("No data to sample") 

122 

123 if xmin is None: 

124 xmin = (min(min(data.get_energies()) for data in self) 

125 - (padding * width)) 

126 if xmax is None: 

127 xmax = (max(max(data.get_energies()) for data in self) 

128 + (padding * width)) 

129 

130 return GridDOSCollection( 

131 [data.sample_grid(npts, xmin=xmin, xmax=xmax, width=width, 

132 smearing=smearing) 

133 for data in self]) 

134 

135 @classmethod 

136 def from_data(cls, 

137 energies: Floats, 

138 weights: Sequence[Floats], 

139 info: Sequence[Info] | None = None) -> 'DOSCollection': 

140 """Create a DOSCollection from data sharing a common set of energies 

141 

142 This is a convenience method to be used when all the DOS data in the 

143 collection has a common energy axis. There is no performance advantage 

144 in using this method for the generic DOSCollection, but for 

145 GridDOSCollection it is more efficient. 

146 

147 Args: 

148 energy: common set of energy values for input data 

149 weights: array of DOS weights with rows corresponding to different 

150 datasets 

151 info: sequence of info dicts corresponding to weights rows. 

152 

153 Returns 

154 ------- 

155 Collection of DOS data (in RawDOSData format) 

156 """ 

157 

158 info = cls._check_weights_and_info(weights, info) 

159 

160 return cls(RawDOSData(energies, row_weights, row_info) 

161 for row_weights, row_info in zip(weights, info)) 

162 

163 @staticmethod 

164 def _check_weights_and_info(weights: Sequence[Floats], 

165 info: Sequence[Info] | None, 

166 ) -> Sequence[Info]: 

167 if info is None: 

168 info = [{} for _ in range(len(weights))] 

169 else: 

170 if len(info) != len(weights): 

171 raise ValueError("Length of info must match number of rows in " 

172 "weights") 

173 return info 

174 

175 @overload 

176 def __getitem__(self, item: int) -> DOSData: 

177 ... 

178 

179 @overload # noqa F811 

180 def __getitem__(self, item: slice) -> 'DOSCollection': # noqa F811 

181 ... 

182 

183 def __getitem__(self, item): # noqa F811 

184 if isinstance(item, int): 

185 return self._data[item] 

186 elif isinstance(item, slice): 

187 return type(self)(self._data[item]) 

188 else: 

189 raise TypeError("index in DOSCollection must be an integer or " 

190 "slice") 

191 

192 def __len__(self) -> int: 

193 return len(self._data) 

194 

195 def _almost_equals(self, other: Any) -> bool: 

196 """Compare with another DOSCollection for testing purposes""" 

197 if not isinstance(other, type(self)): 

198 return False 

199 elif len(self) != len(other): 

200 return False 

201 else: 

202 return all(a._almost_equals(b) for a, b in zip(self, other)) 

203 

204 def total(self) -> DOSData: 

205 """Sum all the DOSData in this Collection and label it as 'Total'""" 

206 data = self.sum_all() 

207 data.info.update({'label': 'Total'}) 

208 return data 

209 

210 def sum_all(self) -> DOSData: 

211 """Sum all the DOSData contained in this Collection""" 

212 if len(self) == 0: 

213 raise IndexError("No data to sum") 

214 elif len(self) == 1: 

215 data = self[0].copy() 

216 else: 

217 data = reduce(lambda x, y: x + y, self) 

218 return data 

219 

220 D = TypeVar('D', bound=DOSData) 

221 

222 @staticmethod 

223 def _select_to_list(dos_collection: Sequence[D], # Bug in flakes 

224 info_selection: dict[str, str], # misses 'D' def 

225 negative: bool = False) -> list[D]: # noqa: F821 

226 query = set(info_selection.items()) 

227 

228 if negative: 

229 return [data for data in dos_collection 

230 if not query.issubset(set(data.info.items()))] 

231 else: 

232 return [data for data in dos_collection 

233 if query.issubset(set(data.info.items()))] 

234 

235 def select(self, **info_selection: str) -> 'DOSCollection': 

236 """Narrow DOSCollection to items with specified info 

237 

238 For example, if :: 

239 

240 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

241 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

242 

243 then :: 

244 

245 dc.select(b='1') 

246 

247 will return an identical object to dc, while :: 

248 

249 dc.select(a='1') 

250 

251 will return a DOSCollection with only the first item and :: 

252 

253 dc.select(a='2', b='1') 

254 

255 will return a DOSCollection with only the second item. 

256 

257 """ 

258 

259 matches = self._select_to_list(self, info_selection) 

260 return type(self)(matches) 

261 

262 def select_not(self, **info_selection: str) -> 'DOSCollection': 

263 """Narrow DOSCollection to items without specified info 

264 

265 For example, if :: 

266 

267 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

268 DOSData(x2, y2, info={'a': '2', 'b': '1'})]) 

269 

270 then :: 

271 

272 dc.select_not(b='2') 

273 

274 will return an identical object to dc, while :: 

275 

276 dc.select_not(a='2') 

277 

278 will return a DOSCollection with only the first item and :: 

279 

280 dc.select_not(a='1', b='1') 

281 

282 will return a DOSCollection with only the second item. 

283 

284 """ 

285 matches = self._select_to_list(self, info_selection, negative=True) 

286 return type(self)(matches) 

287 

288 # Use typehint *info_keys: str from python3.11+ 

289 def sum_by(self, *info_keys) -> 'DOSCollection': 

290 """Return a DOSCollection with some data summed by common attributes 

291 

292 For example, if :: 

293 

294 dc = DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}), 

295 DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

296 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

297 

298 then :: 

299 

300 dc.sum_by('b') 

301 

302 will return a collection equivalent to :: 

303 

304 DOSCollection([DOSData(x1, y1, info={'a': '1', 'b': '1'}) 

305 + DOSData(x2, y2, info={'a': '2', 'b': '1'}), 

306 DOSData(x3, y3, info={'a': '2', 'b': '2'})]) 

307 

308 where the resulting contained DOSData have info attributes of 

309 {'b': '1'} and {'b': '2'} respectively. 

310 

311 dc.sum_by('a', 'b') on the other hand would return the full three-entry 

312 collection, as none of the entries have common 'a' *and* 'b' info. 

313 

314 """ 

315 

316 def _matching_info_tuples(data: DOSData): 

317 """Get relevent dict entries in tuple form 

318 

319 e.g. if data.info = {'a': 1, 'b': 2, 'c': 3} 

320 and info_keys = ('a', 'c') 

321 

322 then return (('a', 1), ('c': 3)) 

323 """ 

324 matched_keys = set(info_keys) & set(data.info) 

325 return tuple(sorted([(key, data.info[key]) 

326 for key in matched_keys])) 

327 

328 # Sorting inside info matching helps set() to remove redundant matches; 

329 # combos are then sorted() to ensure consistent output across sessions. 

330 all_combos = map(_matching_info_tuples, self) 

331 unique_combos = sorted(set(all_combos)) 

332 

333 # For each key/value combination, perform a select() to obtain all 

334 # the matching entries and sum them together. 

335 collection_data = [self.select(**dict(combo)).sum_all() 

336 for combo in unique_combos] 

337 return type(self)(collection_data) 

338 

339 def __add__(self, other: 'DOSCollection | DOSData' 

340 ) -> 'DOSCollection': 

341 """Join entries between two DOSCollection objects of the same type 

342 

343 It is also possible to add a single DOSData object without wrapping it 

344 in a new collection: i.e. :: 

345 

346 DOSCollection([dosdata1]) + DOSCollection([dosdata2]) 

347 

348 or :: 

349 

350 DOSCollection([dosdata1]) + dosdata2 

351 

352 will return :: 

353 

354 DOSCollection([dosdata1, dosdata2]) 

355 

356 """ 

357 return _add_to_collection(other, self) 

358 

359 

360@singledispatch 

361def _add_to_collection(other: DOSData | DOSCollection, 

362 collection: DOSCollection) -> DOSCollection: 

363 if isinstance(other, type(collection)): 

364 return type(collection)(list(collection) + list(other)) 

365 elif isinstance(other, DOSCollection): 

366 raise TypeError("Only DOSCollection objects of the same type may " 

367 "be joined with '+'.") 

368 else: 

369 raise TypeError("DOSCollection may only be joined to DOSData or " 

370 "DOSCollection objects with '+'.") 

371 

372 

373@_add_to_collection.register(DOSData) 

374def _add_data(other: DOSData, collection: DOSCollection) -> DOSCollection: 

375 """Return a new DOSCollection with an additional DOSData item""" 

376 return type(collection)(list(collection) + [other]) 

377 

378 

379class RawDOSCollection(DOSCollection): 

380 def __init__(self, dos_series: Iterable[RawDOSData]) -> None: 

381 super().__init__(dos_series) 

382 for dos_data in self: 

383 if not isinstance(dos_data, RawDOSData): 

384 raise TypeError("RawDOSCollection can only store " 

385 "RawDOSData objects.") 

386 

387 

388class GridDOSCollection(DOSCollection): 

389 def __init__(self, dos_series: Iterable[GridDOSData], 

390 energies: Floats | None = None) -> None: 

391 dos_list = list(dos_series) 

392 if energies is None: 

393 if len(dos_list) == 0: 

394 raise ValueError("Must provide energies to create a " 

395 "GridDOSCollection without any DOS data.") 

396 self._energies = dos_list[0].get_energies() 

397 else: 

398 self._energies = np.asarray(energies) 

399 

400 self._weights: np.ndarray = np.empty( 

401 (len(dos_list), len(self._energies)), float, 

402 ) 

403 self._info = [] 

404 

405 for i, dos_data in enumerate(dos_list): 

406 if not isinstance(dos_data, GridDOSData): 

407 raise TypeError("GridDOSCollection can only store " 

408 "GridDOSData objects.") 

409 if (dos_data.get_energies().shape != self._energies.shape 

410 or not np.allclose(dos_data.get_energies(), 

411 self._energies)): 

412 raise ValueError("All GridDOSData objects in GridDOSCollection" 

413 " must have the same energy axis.") 

414 self._weights[i, :] = dos_data.get_weights() 

415 self._info.append(dos_data.info) 

416 

417 def get_energies(self) -> Floats: 

418 return self._energies.copy() 

419 

420 def get_all_weights(self) -> Sequence[Floats] | np.ndarray: 

421 return self._weights.copy() 

422 

423 def __len__(self) -> int: 

424 return self._weights.shape[0] 

425 

426 @overload # noqa F811 

427 def __getitem__(self, item: int) -> DOSData: 

428 ... 

429 

430 @overload # noqa F811 

431 def __getitem__(self, item: slice) -> 'GridDOSCollection': # noqa F811 

432 ... 

433 

434 def __getitem__(self, item): # noqa F811 

435 if isinstance(item, int): 

436 return GridDOSData(self._energies, self._weights[item, :], 

437 info=self._info[item]) 

438 elif isinstance(item, slice): 

439 return type(self)([self[i] for i in range(len(self))[item]]) 

440 else: 

441 raise TypeError("index in DOSCollection must be an integer or " 

442 "slice") 

443 

444 @classmethod 

445 def from_data(cls, 

446 energies: Floats, 

447 weights: Sequence[Floats], 

448 info: Sequence[Info] | None = None) -> 'GridDOSCollection': 

449 """Create a GridDOSCollection from data with a common set of energies 

450 

451 This convenience method may also be more efficient as it limits 

452 redundant copying/checking of the data. 

453 

454 Args: 

455 energies: common set of energy values for input data 

456 weights: array of DOS weights with rows corresponding to different 

457 datasets 

458 info: sequence of info dicts corresponding to weights rows. 

459 

460 Returns 

461 ------- 

462 Collection of DOS data (in RawDOSData format) 

463 """ 

464 

465 weights_array = np.asarray(weights, dtype=float) 

466 if len(weights_array.shape) != 2: 

467 raise IndexError("Weights must be a 2-D array or nested sequence") 

468 if weights_array.shape[0] < 1: 

469 raise IndexError("Weights cannot be empty") 

470 if weights_array.shape[1] != len(energies): 

471 raise IndexError("Length of weights rows must equal size of x") 

472 

473 info = cls._check_weights_and_info(weights, info) 

474 

475 dos_collection = cls([GridDOSData(energies, weights_array[0])]) 

476 dos_collection._weights = weights_array 

477 dos_collection._info = list(info) 

478 

479 return dos_collection 

480 

481 def select(self, **info_selection: str) -> 'DOSCollection': 

482 """Narrow GridDOSCollection to items with specified info 

483 

484 For example, if :: 

485 

486 dc = GridDOSCollection([GridDOSData(x, y1, 

487 info={'a': '1', 'b': '1'}), 

488 GridDOSData(x, y2, 

489 info={'a': '2', 'b': '1'})]) 

490 

491 then :: 

492 

493 dc.select(b='1') 

494 

495 will return an identical object to dc, while :: 

496 

497 dc.select(a='1') 

498 

499 will return a DOSCollection with only the first item and :: 

500 

501 dc.select(a='2', b='1') 

502 

503 will return a DOSCollection with only the second item. 

504 

505 """ 

506 

507 matches = self._select_to_list(self, info_selection) 

508 if len(matches) == 0: 

509 return type(self)([], energies=self._energies) 

510 else: 

511 return type(self)(matches) 

512 

513 def select_not(self, **info_selection: str) -> 'DOSCollection': 

514 """Narrow GridDOSCollection to items without specified info 

515 

516 For example, if :: 

517 

518 dc = GridDOSCollection([GridDOSData(x, y1, 

519 info={'a': '1', 'b': '1'}), 

520 GridDOSData(x, y2, 

521 info={'a': '2', 'b': '1'})]) 

522 

523 then :: 

524 

525 dc.select_not(b='2') 

526 

527 will return an identical object to dc, while :: 

528 

529 dc.select_not(a='2') 

530 

531 will return a DOSCollection with only the first item and :: 

532 

533 dc.select_not(a='1', b='1') 

534 

535 will return a DOSCollection with only the second item. 

536 

537 """ 

538 matches = self._select_to_list(self, info_selection, negative=True) 

539 if len(matches) == 0: 

540 return type(self)([], energies=self._energies) 

541 else: 

542 return type(self)(matches) 

543 

544 def plot(self, 

545 npts: int = 0, 

546 xmin: float | None = None, 

547 xmax: float | None = None, 

548 width: float | None = None, 

549 smearing: str = 'Gauss', 

550 ax: Axes | None = None, 

551 show: bool = False, 

552 filename: str | None = None, 

553 mplargs: dict | None = None) -> Axes: 

554 """Simple plot of collected DOS data, resampled onto a grid 

555 

556 If the special key 'label' is present in self.info, this will be set 

557 as the label for the plotted line (unless overruled in mplargs). The 

558 label is only seen if a legend is added to the plot (i.e. by calling 

559 `ax.legend()`). 

560 

561 Args: 

562 npts: 

563 Number of points in resampled x-axis. If set to zero (default), 

564 no resampling is performed and the stored data is plotted 

565 directly. 

566 xmin, xmax: 

567 output data range; this limits the resampling range as well as 

568 the plotting output 

569 width: Width of broadening kernel, passed to self.sample() 

570 smearing: selection of broadening kernel, passed to self.sample() 

571 ax: existing Matplotlib axes object. If not provided, a new figure 

572 with one set of axes will be created using Pyplot 

573 show: show the figure on-screen 

574 filename: if a path is given, save the figure to this file 

575 mplargs: additional arguments to pass to matplotlib plot command 

576 (e.g. {'linewidth': 2} for a thicker line). 

577 

578 Returns 

579 ------- 

580 Plotting axes. If "ax" was set, this is the same object. 

581 """ 

582 

583 # Apply defaults if necessary 

584 npts, width = GridDOSData._interpret_smearing_args(npts, width) 

585 

586 if npts: 

587 assert isinstance(width, float) 

588 dos = self.sample_grid(npts, 

589 xmin=xmin, xmax=xmax, 

590 width=width, smearing=smearing) 

591 else: 

592 dos = self 

593 

594 energies, all_y = dos._energies, dos._weights 

595 

596 all_labels = [DOSData.label_from_info(data.info) for data in self] 

597 

598 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

599 self._plot_broadened(ax, energies, all_y, all_labels, mplargs) 

600 

601 return ax 

602 

603 @staticmethod 

604 def _plot_broadened(ax: Axes, 

605 energies: Floats, 

606 all_y: np.ndarray, 

607 all_labels: Sequence[str], 

608 mplargs: dict | None): 

609 """Plot DOS data with labels to axes 

610 

611 This is separated into another function so that subclasses can 

612 manipulate broadening, labels etc in their plot() method.""" 

613 if mplargs is None: 

614 mplargs = {} 

615 

616 all_lines = ax.plot(energies, all_y.T, **mplargs) 

617 for line, label in zip(all_lines, all_labels): 

618 line.set_label(label) 

619 ax.legend() 

620 

621 ax.set_xlim(left=min(energies), right=max(energies)) 

622 ax.set_ylim(bottom=0)