Coverage for ase / spectrum / dosdata.py: 100.00%

153 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3# Refactor of DOS-like data objects 

4# towards replacing ase.dft.dos and ase.dft.pdos 

5import warnings 

6from abc import ABCMeta, abstractmethod 

7from collections.abc import Sequence 

8from typing import Any, TypeVar 

9 

10import numpy as np 

11from matplotlib.axes import Axes 

12 

13from ase.utils.plotting import SimplePlottingAxes 

14 

15# For now we will be strict about Info and say it has to be str->str. Perhaps 

16# later we will allow other types that have reliable comparison operations. 

17Info = dict[str, str] 

18 

19# Still no good solution to type checking with arrays. 

20Floats = Sequence[float] | np.ndarray 

21 

22 

23class DOSData(metaclass=ABCMeta): 

24 """Abstract base class for a single series of DOS-like data 

25 

26 Only the 'info' is a mutable attribute; DOS data is set at init""" 

27 

28 def __init__(self, 

29 info: Info | None = None) -> None: 

30 if info is None: 

31 self.info = {} 

32 elif isinstance(info, dict): 

33 self.info = info 

34 else: 

35 raise TypeError("Info must be a dict or None") 

36 

37 @abstractmethod 

38 def get_energies(self) -> Floats: 

39 """Get energy data stored in this object""" 

40 

41 @abstractmethod 

42 def get_weights(self) -> Floats: 

43 """Get DOS weights stored in this object""" 

44 

45 @abstractmethod 

46 def copy(self) -> 'DOSData': 

47 """Returns a copy in which info dict can be safely mutated""" 

48 

49 def _sample(self, 

50 energies: Floats, 

51 width: float = 0.1, 

52 smearing: str = 'Gauss') -> np.ndarray: 

53 """Sample the DOS data at chosen points, with broadening 

54 

55 Note that no correction is made here for the sampling bin width; total 

56 intensity will vary with sampling density. 

57 

58 Args: 

59 energies: energy values for sampling 

60 width: Width of broadening kernel 

61 smearing: selection of broadening kernel (only "Gauss" is currently 

62 supported) 

63 

64 Returns 

65 ------- 

66 Weights sampled from a broadened DOS at values corresponding to x 

67 """ 

68 

69 self._check_positive_width(width) 

70 weights_grid = np.zeros(len(energies), float) 

71 weights = self.get_weights() 

72 energies = np.asarray(energies, float) 

73 

74 for i, raw_energy in enumerate(self.get_energies()): 

75 delta = self._delta(energies, raw_energy, width, smearing=smearing) 

76 weights_grid += weights[i] * delta 

77 return weights_grid 

78 

79 def _almost_equals(self, other: Any) -> bool: 

80 """Compare with another DOSData for testing purposes""" 

81 if not isinstance(other, type(self)): 

82 return False 

83 if self.info != other.info: 

84 return False 

85 if not np.allclose(self.get_weights(), other.get_weights()): 

86 return False 

87 return np.allclose(self.get_energies(), other.get_energies()) 

88 

89 @staticmethod 

90 def _delta(x: np.ndarray, 

91 x0: float, 

92 width: float, 

93 smearing: str = 'Gauss') -> np.ndarray: 

94 """Return a delta-function centered at 'x0'. 

95 

96 This function is used with numpy broadcasting; if x is a row and x0 is 

97 a column vector, the returned data will be a 2D array with each row 

98 corresponding to a different delta center. 

99 """ 

100 if smearing.lower() == 'gauss': 

101 x1 = -0.5 * ((x - x0) / width)**2 

102 return np.exp(x1) / (np.sqrt(2 * np.pi) * width) 

103 else: 

104 msg = 'Requested smearing type not recognized. Got {}'.format( 

105 smearing) 

106 raise ValueError(msg) 

107 

108 @staticmethod 

109 def _check_positive_width(width): 

110 if width <= 0.0: 

111 msg = 'Cannot add 0 or negative width smearing' 

112 raise ValueError(msg) 

113 

114 def sample_grid(self, 

115 npts: int, 

116 xmin: float | None = None, 

117 xmax: float | None = None, 

118 padding: float = 3, 

119 width: float = 0.1, 

120 smearing: str = 'Gauss', 

121 ) -> 'GridDOSData': 

122 """Sample the DOS data on an evenly-spaced energy grid 

123 

124 Args: 

125 npts: Number of sampled points 

126 xmin: Minimum sampled x value; if unspecified, a default is chosen 

127 xmax: Maximum sampled x value; if unspecified, a default is chosen 

128 padding: If xmin/xmax is unspecified, default value will be padded 

129 by padding * width to avoid cutting off peaks. 

130 width: Width of broadening kernel 

131 smearing: selection of broadening kernel (only 'Gauss' is 

132 implemented) 

133 

134 Returns 

135 ------- 

136 (energy values, sampled DOS) 

137 """ 

138 

139 if xmin is None: 

140 xmin = min(self.get_energies()) - (padding * width) 

141 if xmax is None: 

142 xmax = max(self.get_energies()) + (padding * width) 

143 energies_grid = np.linspace(xmin, xmax, npts) 

144 weights_grid = self._sample(energies_grid, width=width, 

145 smearing=smearing) 

146 

147 return GridDOSData(energies_grid, weights_grid, info=self.info.copy()) 

148 

149 def plot(self, 

150 npts: int = 1000, 

151 xmin: float | None = None, 

152 xmax: float | None = None, 

153 width: float = 0.1, 

154 smearing: str = 'Gauss', 

155 ax: Axes | None = None, 

156 show: bool = False, 

157 filename: str | None = None, 

158 mplargs: dict | None = None) -> Axes: 

159 """Simple 1-D plot of DOS data, resampled onto a grid 

160 

161 If the special key 'label' is present in self.info, this will be set 

162 as the label for the plotted line (unless overruled in mplargs). The 

163 label is only seen if a legend is added to the plot (i.e. by calling 

164 ``ax.legend()``). 

165 

166 Args: 

167 npts, xmin, xmax: output data range, as passed to self.sample_grid 

168 width: Width of broadening kernel for self.sample_grid() 

169 smearing: selection of broadening kernel for self.sample_grid() 

170 ax: existing Matplotlib axes object. If not provided, a new figure 

171 with one set of axes will be created using Pyplot 

172 show: show the figure on-screen 

173 filename: if a path is given, save the figure to this file 

174 mplargs: additional arguments to pass to matplotlib plot command 

175 (e.g. {'linewidth': 2} for a thicker line). 

176 

177 

178 Returns 

179 ------- 

180 Plotting axes. If "ax" was set, this is the same object. 

181 """ 

182 

183 if mplargs is None: 

184 mplargs = {} 

185 if 'label' not in mplargs: 

186 mplargs.update({'label': self.label_from_info(self.info)}) 

187 

188 return self.sample_grid(npts, xmin=xmin, xmax=xmax, 

189 width=width, 

190 smearing=smearing 

191 ).plot(ax=ax, xmin=xmin, xmax=xmax, 

192 show=show, filename=filename, 

193 mplargs=mplargs) 

194 

195 @staticmethod 

196 def label_from_info(info: dict[str, str]): 

197 """Generate an automatic legend label from info dict""" 

198 if 'label' in info: 

199 return info['label'] 

200 else: 

201 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}', 

202 info.items())) 

203 

204 

205class GeneralDOSData(DOSData): 

206 """Base class for a single series of DOS-like data 

207 

208 Only the 'info' is a mutable attribute; DOS data is set at init 

209 

210 This is the base class for DOSData objects that accept/set seperate 

211 "energies" and "weights" sequences of equal length at init. 

212 

213 """ 

214 

215 def __init__(self, 

216 energies: Floats, 

217 weights: Floats, 

218 info: Info | None = None) -> None: 

219 super().__init__(info=info) 

220 

221 n_entries = len(energies) 

222 if len(weights) != n_entries: 

223 raise ValueError("Energies and weights must be the same length") 

224 

225 # Internally store the data as a np array with two rows; energy, weight 

226 self._data = np.empty((2, n_entries), dtype=float, order='C') 

227 self._data[0, :] = energies 

228 self._data[1, :] = weights 

229 

230 def get_energies(self) -> np.ndarray: 

231 return self._data[0, :].copy() 

232 

233 def get_weights(self) -> np.ndarray: 

234 return self._data[1, :].copy() 

235 

236 D = TypeVar('D', bound='GeneralDOSData') 

237 

238 def copy(self: D) -> D: # noqa F821 

239 return type(self)(self.get_energies(), self.get_weights(), 

240 info=self.info.copy()) 

241 

242 

243class RawDOSData(GeneralDOSData): 

244 """A collection of weighted delta functions which sum to form a DOS 

245 

246 This is an appropriate data container for density-of-states (DOS) or 

247 spectral data where the energy data values not form a known regular 

248 grid. The data may be plotted or resampled for further analysis using the 

249 sample_grid() and plot() methods. Multiple weights at the same 

250 energy value will *only* be combined in output data, and data stored in 

251 RawDOSData is never resampled. A plot_deltas() function is also provided 

252 which plots the raw data. 

253 

254 Metadata may be stored in the info dict, in which keys and values must be 

255 strings. This data is used for selecting and combining multiple DOSData 

256 objects in a DOSCollection object. 

257 

258 When RawDOSData objects are combined with the addition operator:: 

259 

260 big_dos = raw_dos_1 + raw_dos_2 

261 

262 the energy and weights data is *concatenated* (i.e. combined without 

263 sorting or replacement) and the new info dictionary consists of the 

264 *intersection* of the inputs: only key-value pairs that were common to both 

265 of the input objects will be retained in the new combined object. For 

266 example:: 

267 

268 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'}) 

269 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'})) 

270 

271 will yield the equivalent of:: 

272 

273 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'}) 

274 

275 """ 

276 

277 def __add__(self, other: 'RawDOSData') -> 'RawDOSData': 

278 if not isinstance(other, RawDOSData): 

279 raise TypeError("RawDOSData can only be combined with other " 

280 "RawDOSData objects") 

281 

282 # Take intersection of metadata (i.e. only common entries are retained) 

283 new_info = dict(set(self.info.items()) & set(other.info.items())) 

284 

285 # Concatenate the energy/weight data 

286 new_data = np.concatenate((self._data, other._data), axis=1) 

287 

288 new_object = RawDOSData([], [], info=new_info) 

289 new_object._data = new_data 

290 

291 return new_object 

292 

293 def plot_deltas(self, 

294 ax: Axes | None = None, 

295 show: bool = False, 

296 filename: str | None = None, 

297 mplargs: dict | None = None) -> Axes: 

298 """Simple plot of sparse DOS data as a set of delta functions 

299 

300 Items at the same x-value can overlap and will not be summed together 

301 

302 Args: 

303 ax: existing Matplotlib axes object. If not provided, a new figure 

304 with one set of axes will be created using Pyplot 

305 show: show the figure on-screen 

306 filename: if a path is given, save the figure to this file 

307 mplargs: additional arguments to pass to matplotlib Axes.vlines 

308 command (e.g. {'linewidth': 2} for a thicker line). 

309 

310 Returns 

311 ------- 

312 Plotting axes. If "ax" was set, this is the same object. 

313 """ 

314 

315 if mplargs is None: 

316 mplargs = {} 

317 

318 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

319 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs) 

320 

321 return ax 

322 

323 

324class GridDOSData(GeneralDOSData): 

325 """A collection of regularly-sampled data which represents a DOS 

326 

327 This is an appropriate data container for density-of-states (DOS) or 

328 spectral data where the intensity values form a regular grid. This 

329 is generally the result of sampling or integrating into discrete 

330 bins, rather than a collection of unique states. The data may be 

331 plotted or resampled for further analysis using the sample_grid() 

332 and plot() methods. 

333 

334 Metadata may be stored in the info dict, in which keys and values must be 

335 strings. This data is used for selecting and combining multiple DOSData 

336 objects in a DOSCollection object. 

337 

338 When RawDOSData objects are combined with the addition operator:: 

339 

340 big_dos = raw_dos_1 + raw_dos_2 

341 

342 the weights data is *summed* (requiring a consistent energy grid) and the 

343 new info dictionary consists of the *intersection* of the inputs: only 

344 key-value pairs that were common to both of the input objects will be 

345 retained in the new combined object. For example:: 

346 

347 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3], 

348 info={'symbol': 'O', 'index': '1'}) 

349 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6], 

350 info={'symbol': 'O', 'index': '2'})) 

351 

352 will yield the equivalent of:: 

353 

354 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'}) 

355 

356 """ 

357 

358 def __init__(self, 

359 energies: Floats, 

360 weights: Floats, 

361 info: Info | None = None) -> None: 

362 n_entries = len(energies) 

363 if not np.allclose(energies, 

364 np.linspace(energies[0], energies[-1], n_entries)): 

365 raise ValueError("Energies must be an evenly-spaced 1-D grid") 

366 

367 if len(weights) != n_entries: 

368 raise ValueError("Energies and weights must be the same length") 

369 

370 super().__init__(energies, weights, info=info) 

371 self.sigma_cutoff = 3 

372 

373 def _check_spacing(self, width) -> float: 

374 current_spacing = self._data[0, 1] - self._data[0, 0] 

375 if width < (2 * current_spacing): 

376 warnings.warn( 

377 "The broadening width is small compared to the original " 

378 "sampling density. The results are unlikely to be smooth.") 

379 return current_spacing 

380 

381 def _sample(self, 

382 energies: Floats, 

383 width: float = 0.1, 

384 smearing: str = 'Gauss') -> np.ndarray: 

385 current_spacing = self._check_spacing(width) 

386 return super()._sample(energies=energies, 

387 width=width, smearing=smearing 

388 ) * current_spacing 

389 

390 def __add__(self, other: 'GridDOSData') -> 'GridDOSData': 

391 # This method uses direct access to the mutable energy and weights data 

392 # (self._data) to avoid redundant copying operations. The __init__ 

393 # method of GridDOSData will write this to a new array, so on this 

394 # occasion it is safe to pass references to the mutable data. 

395 

396 if not isinstance(other, GridDOSData): 

397 raise TypeError("GridDOSData can only be combined with other " 

398 "GridDOSData objects") 

399 if len(self._data[0, :]) != len(other.get_energies()): 

400 raise ValueError("Cannot add GridDOSData objects with different-" 

401 "length energy grids.") 

402 

403 if not np.allclose(self._data[0, :], other.get_energies()): 

404 raise ValueError("Cannot add GridDOSData objects with different " 

405 "energy grids.") 

406 

407 # Take intersection of metadata (i.e. only common entries are retained) 

408 new_info = dict(set(self.info.items()) & set(other.info.items())) 

409 

410 # Sum the energy/weight data 

411 new_weights = self._data[1, :] + other.get_weights() 

412 

413 new_object = GridDOSData(self._data[0, :], new_weights, 

414 info=new_info) 

415 return new_object 

416 

417 @staticmethod 

418 def _interpret_smearing_args(npts: int, 

419 width: float | None = None, 

420 default_npts: int = 1000, 

421 default_width: float = 0.1 

422 ) -> tuple[int, float | None]: 

423 """Figure out what the user intended: resample if width provided""" 

424 if width is not None: 

425 if npts: 

426 return (npts, float(width)) 

427 else: 

428 return (default_npts, float(width)) 

429 else: 

430 if npts: 

431 return (npts, default_width) 

432 else: 

433 return (0, None) 

434 

435 def plot(self, 

436 npts: int = 0, 

437 xmin: float | None = None, 

438 xmax: float | None = None, 

439 width: float | None = None, 

440 smearing: str = 'Gauss', 

441 ax: Axes | None = None, 

442 show: bool = False, 

443 filename: str | None = None, 

444 mplargs: dict | None = None) -> Axes: 

445 """Simple 1-D plot of DOS data 

446 

447 Data will be resampled onto a grid with `npts` points unless `npts` is 

448 set to zero, in which case: 

449 

450 - no resampling takes place 

451 - `width` and `smearing` are ignored 

452 - `xmin` and `xmax` affect the axis limits of the plot, not the 

453 underlying data. 

454 

455 If the special key 'label' is present in self.info, this will be set 

456 as the label for the plotted line (unless overruled in mplargs). The 

457 label is only seen if a legend is added to the plot (i.e. by calling 

458 ``ax.legend()``). 

459 

460 Args: 

461 npts, xmin, xmax: output data range, as passed to self.sample_grid 

462 width: Width of broadening kernel, passed to self.sample_grid(). 

463 If no npts was set but width is set, npts will be set to 1000. 

464 smearing: selection of broadening kernel for self.sample_grid() 

465 ax: existing Matplotlib axes object. If not provided, a new figure 

466 with one set of axes will be created using Pyplot 

467 show: show the figure on-screen 

468 filename: if a path is given, save the figure to this file 

469 mplargs: additional arguments to pass to matplotlib plot command 

470 (e.g. {'linewidth': 2} for a thicker line). 

471 

472 Returns 

473 ------- 

474 Plotting axes. If "ax" was set, this is the same object. 

475 """ 

476 

477 npts, width = self._interpret_smearing_args(npts, width) 

478 

479 if mplargs is None: 

480 mplargs = {} 

481 if 'label' not in mplargs: 

482 mplargs.update({'label': self.label_from_info(self.info)}) 

483 

484 if npts: 

485 assert isinstance(width, float) 

486 dos = self.sample_grid(npts, xmin=xmin, 

487 xmax=xmax, width=width, 

488 smearing=smearing) 

489 else: 

490 dos = self 

491 

492 energies, intensity = dos.get_energies(), dos.get_weights() 

493 

494 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

495 ax.plot(energies, intensity, **mplargs) 

496 ax.set_xlim(left=xmin, right=xmax) 

497 

498 return ax