Coverage for /builds/ase/ase/ase/spectrum/dosdata.py: 100.00%

152 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3# Refactor of DOS-like data objects 

4# towards replacing ase.dft.dos and ase.dft.pdos 

5import warnings 

6from abc import ABCMeta, abstractmethod 

7from typing import Any, Dict, Sequence, Tuple, TypeVar, Union 

8 

9import numpy as np 

10from matplotlib.axes import Axes 

11 

12from ase.utils.plotting import SimplePlottingAxes 

13 

14# For now we will be strict about Info and say it has to be str->str. Perhaps 

15# later we will allow other types that have reliable comparison operations. 

16Info = Dict[str, str] 

17 

18# Still no good solution to type checking with arrays. 

19Floats = Union[Sequence[float], np.ndarray] 

20 

21 

22class DOSData(metaclass=ABCMeta): 

23 """Abstract base class for a single series of DOS-like data 

24 

25 Only the 'info' is a mutable attribute; DOS data is set at init""" 

26 

27 def __init__(self, 

28 info: Info = None) -> None: 

29 if info is None: 

30 self.info = {} 

31 elif isinstance(info, dict): 

32 self.info = info 

33 else: 

34 raise TypeError("Info must be a dict or None") 

35 

36 @abstractmethod 

37 def get_energies(self) -> Floats: 

38 """Get energy data stored in this object""" 

39 

40 @abstractmethod 

41 def get_weights(self) -> Floats: 

42 """Get DOS weights stored in this object""" 

43 

44 @abstractmethod 

45 def copy(self) -> 'DOSData': 

46 """Returns a copy in which info dict can be safely mutated""" 

47 

48 def _sample(self, 

49 energies: Floats, 

50 width: float = 0.1, 

51 smearing: str = 'Gauss') -> np.ndarray: 

52 """Sample the DOS data at chosen points, with broadening 

53 

54 Note that no correction is made here for the sampling bin width; total 

55 intensity will vary with sampling density. 

56 

57 Args: 

58 energies: energy values for sampling 

59 width: Width of broadening kernel 

60 smearing: selection of broadening kernel (only "Gauss" is currently 

61 supported) 

62 

63 Returns: 

64 Weights sampled from a broadened DOS at values corresponding to x 

65 """ 

66 

67 self._check_positive_width(width) 

68 weights_grid = np.zeros(len(energies), float) 

69 weights = self.get_weights() 

70 energies = np.asarray(energies, float) 

71 

72 for i, raw_energy in enumerate(self.get_energies()): 

73 delta = self._delta(energies, raw_energy, width, smearing=smearing) 

74 weights_grid += weights[i] * delta 

75 return weights_grid 

76 

77 def _almost_equals(self, other: Any) -> bool: 

78 """Compare with another DOSData for testing purposes""" 

79 if not isinstance(other, type(self)): 

80 return False 

81 if self.info != other.info: 

82 return False 

83 if not np.allclose(self.get_weights(), other.get_weights()): 

84 return False 

85 return np.allclose(self.get_energies(), other.get_energies()) 

86 

87 @staticmethod 

88 def _delta(x: np.ndarray, 

89 x0: float, 

90 width: float, 

91 smearing: str = 'Gauss') -> np.ndarray: 

92 """Return a delta-function centered at 'x0'. 

93 

94 This function is used with numpy broadcasting; if x is a row and x0 is 

95 a column vector, the returned data will be a 2D array with each row 

96 corresponding to a different delta center. 

97 """ 

98 if smearing.lower() == 'gauss': 

99 x1 = -0.5 * ((x - x0) / width)**2 

100 return np.exp(x1) / (np.sqrt(2 * np.pi) * width) 

101 else: 

102 msg = 'Requested smearing type not recognized. Got {}'.format( 

103 smearing) 

104 raise ValueError(msg) 

105 

106 @staticmethod 

107 def _check_positive_width(width): 

108 if width <= 0.0: 

109 msg = 'Cannot add 0 or negative width smearing' 

110 raise ValueError(msg) 

111 

112 def sample_grid(self, 

113 npts: int, 

114 xmin: float = None, 

115 xmax: float = None, 

116 padding: float = 3, 

117 width: float = 0.1, 

118 smearing: str = 'Gauss', 

119 ) -> 'GridDOSData': 

120 """Sample the DOS data on an evenly-spaced energy grid 

121 

122 Args: 

123 npts: Number of sampled points 

124 xmin: Minimum sampled x value; if unspecified, a default is chosen 

125 xmax: Maximum sampled x value; if unspecified, a default is chosen 

126 padding: If xmin/xmax is unspecified, default value will be padded 

127 by padding * width to avoid cutting off peaks. 

128 width: Width of broadening kernel 

129 smearing: selection of broadening kernel (only 'Gauss' is 

130 implemented) 

131 

132 Returns: 

133 (energy values, sampled DOS) 

134 """ 

135 

136 if xmin is None: 

137 xmin = min(self.get_energies()) - (padding * width) 

138 if xmax is None: 

139 xmax = max(self.get_energies()) + (padding * width) 

140 energies_grid = np.linspace(xmin, xmax, npts) 

141 weights_grid = self._sample(energies_grid, width=width, 

142 smearing=smearing) 

143 

144 return GridDOSData(energies_grid, weights_grid, info=self.info.copy()) 

145 

146 def plot(self, 

147 npts: int = 1000, 

148 xmin: float = None, 

149 xmax: float = None, 

150 width: float = 0.1, 

151 smearing: str = 'Gauss', 

152 ax: Axes = None, 

153 show: bool = False, 

154 filename: str = None, 

155 mplargs: dict = None) -> Axes: 

156 """Simple 1-D plot of DOS data, resampled onto a grid 

157 

158 If the special key 'label' is present in self.info, this will be set 

159 as the label for the plotted line (unless overruled in mplargs). The 

160 label is only seen if a legend is added to the plot (i.e. by calling 

161 ``ax.legend()``). 

162 

163 Args: 

164 npts, xmin, xmax: output data range, as passed to self.sample_grid 

165 width: Width of broadening kernel for self.sample_grid() 

166 smearing: selection of broadening kernel for self.sample_grid() 

167 ax: existing Matplotlib axes object. If not provided, a new figure 

168 with one set of axes will be created using Pyplot 

169 show: show the figure on-screen 

170 filename: if a path is given, save the figure to this file 

171 mplargs: additional arguments to pass to matplotlib plot command 

172 (e.g. {'linewidth': 2} for a thicker line). 

173 

174 

175 Returns: 

176 Plotting axes. If "ax" was set, this is the same object. 

177 """ 

178 

179 if mplargs is None: 

180 mplargs = {} 

181 if 'label' not in mplargs: 

182 mplargs.update({'label': self.label_from_info(self.info)}) 

183 

184 return self.sample_grid(npts, xmin=xmin, xmax=xmax, 

185 width=width, 

186 smearing=smearing 

187 ).plot(ax=ax, xmin=xmin, xmax=xmax, 

188 show=show, filename=filename, 

189 mplargs=mplargs) 

190 

191 @staticmethod 

192 def label_from_info(info: Dict[str, str]): 

193 """Generate an automatic legend label from info dict""" 

194 if 'label' in info: 

195 return info['label'] 

196 else: 

197 return '; '.join(map(lambda x: f'{x[0]}: {x[1]}', 

198 info.items())) 

199 

200 

201class GeneralDOSData(DOSData): 

202 """Base class for a single series of DOS-like data 

203 

204 Only the 'info' is a mutable attribute; DOS data is set at init 

205 

206 This is the base class for DOSData objects that accept/set seperate 

207 "energies" and "weights" sequences of equal length at init. 

208 

209 """ 

210 

211 def __init__(self, 

212 energies: Floats, 

213 weights: Floats, 

214 info: Info = None) -> None: 

215 super().__init__(info=info) 

216 

217 n_entries = len(energies) 

218 if len(weights) != n_entries: 

219 raise ValueError("Energies and weights must be the same length") 

220 

221 # Internally store the data as a np array with two rows; energy, weight 

222 self._data = np.empty((2, n_entries), dtype=float, order='C') 

223 self._data[0, :] = energies 

224 self._data[1, :] = weights 

225 

226 def get_energies(self) -> np.ndarray: 

227 return self._data[0, :].copy() 

228 

229 def get_weights(self) -> np.ndarray: 

230 return self._data[1, :].copy() 

231 

232 D = TypeVar('D', bound='GeneralDOSData') 

233 

234 def copy(self: D) -> D: # noqa F821 

235 return type(self)(self.get_energies(), self.get_weights(), 

236 info=self.info.copy()) 

237 

238 

239class RawDOSData(GeneralDOSData): 

240 """A collection of weighted delta functions which sum to form a DOS 

241 

242 This is an appropriate data container for density-of-states (DOS) or 

243 spectral data where the energy data values not form a known regular 

244 grid. The data may be plotted or resampled for further analysis using the 

245 sample_grid() and plot() methods. Multiple weights at the same 

246 energy value will *only* be combined in output data, and data stored in 

247 RawDOSData is never resampled. A plot_deltas() function is also provided 

248 which plots the raw data. 

249 

250 Metadata may be stored in the info dict, in which keys and values must be 

251 strings. This data is used for selecting and combining multiple DOSData 

252 objects in a DOSCollection object. 

253 

254 When RawDOSData objects are combined with the addition operator:: 

255 

256 big_dos = raw_dos_1 + raw_dos_2 

257 

258 the energy and weights data is *concatenated* (i.e. combined without 

259 sorting or replacement) and the new info dictionary consists of the 

260 *intersection* of the inputs: only key-value pairs that were common to both 

261 of the input objects will be retained in the new combined object. For 

262 example:: 

263 

264 (RawDOSData([x1], [y1], info={'symbol': 'O', 'index': '1'}) 

265 + RawDOSData([x2], [y2], info={'symbol': 'O', 'index': '2'})) 

266 

267 will yield the equivalent of:: 

268 

269 RawDOSData([x1, x2], [y1, y2], info={'symbol': 'O'}) 

270 

271 """ 

272 

273 def __add__(self, other: 'RawDOSData') -> 'RawDOSData': 

274 if not isinstance(other, RawDOSData): 

275 raise TypeError("RawDOSData can only be combined with other " 

276 "RawDOSData objects") 

277 

278 # Take intersection of metadata (i.e. only common entries are retained) 

279 new_info = dict(set(self.info.items()) & set(other.info.items())) 

280 

281 # Concatenate the energy/weight data 

282 new_data = np.concatenate((self._data, other._data), axis=1) 

283 

284 new_object = RawDOSData([], [], info=new_info) 

285 new_object._data = new_data 

286 

287 return new_object 

288 

289 def plot_deltas(self, 

290 ax: Axes = None, 

291 show: bool = False, 

292 filename: str = None, 

293 mplargs: dict = None) -> Axes: 

294 """Simple plot of sparse DOS data as a set of delta functions 

295 

296 Items at the same x-value can overlap and will not be summed together 

297 

298 Args: 

299 ax: existing Matplotlib axes object. If not provided, a new figure 

300 with one set of axes will be created using Pyplot 

301 show: show the figure on-screen 

302 filename: if a path is given, save the figure to this file 

303 mplargs: additional arguments to pass to matplotlib Axes.vlines 

304 command (e.g. {'linewidth': 2} for a thicker line). 

305 

306 Returns: 

307 Plotting axes. If "ax" was set, this is the same object. 

308 """ 

309 

310 if mplargs is None: 

311 mplargs = {} 

312 

313 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

314 ax.vlines(self.get_energies(), 0, self.get_weights(), **mplargs) 

315 

316 return ax 

317 

318 

319class GridDOSData(GeneralDOSData): 

320 """A collection of regularly-sampled data which represents a DOS 

321 

322 This is an appropriate data container for density-of-states (DOS) or 

323 spectral data where the intensity values form a regular grid. This 

324 is generally the result of sampling or integrating into discrete 

325 bins, rather than a collection of unique states. The data may be 

326 plotted or resampled for further analysis using the sample_grid() 

327 and plot() methods. 

328 

329 Metadata may be stored in the info dict, in which keys and values must be 

330 strings. This data is used for selecting and combining multiple DOSData 

331 objects in a DOSCollection object. 

332 

333 When RawDOSData objects are combined with the addition operator:: 

334 

335 big_dos = raw_dos_1 + raw_dos_2 

336 

337 the weights data is *summed* (requiring a consistent energy grid) and the 

338 new info dictionary consists of the *intersection* of the inputs: only 

339 key-value pairs that were common to both of the input objects will be 

340 retained in the new combined object. For example:: 

341 

342 (GridDOSData([0.1, 0.2, 0.3], [y1, y2, y3], 

343 info={'symbol': 'O', 'index': '1'}) 

344 + GridDOSData([0.1, 0.2, 0.3], [y4, y5, y6], 

345 info={'symbol': 'O', 'index': '2'})) 

346 

347 will yield the equivalent of:: 

348 

349 GridDOSData([0.1, 0.2, 0.3], [y1+y4, y2+y5, y3+y6], info={'symbol': 'O'}) 

350 

351 """ 

352 

353 def __init__(self, 

354 energies: Floats, 

355 weights: Floats, 

356 info: Info = None) -> None: 

357 n_entries = len(energies) 

358 if not np.allclose(energies, 

359 np.linspace(energies[0], energies[-1], n_entries)): 

360 raise ValueError("Energies must be an evenly-spaced 1-D grid") 

361 

362 if len(weights) != n_entries: 

363 raise ValueError("Energies and weights must be the same length") 

364 

365 super().__init__(energies, weights, info=info) 

366 self.sigma_cutoff = 3 

367 

368 def _check_spacing(self, width) -> float: 

369 current_spacing = self._data[0, 1] - self._data[0, 0] 

370 if width < (2 * current_spacing): 

371 warnings.warn( 

372 "The broadening width is small compared to the original " 

373 "sampling density. The results are unlikely to be smooth.") 

374 return current_spacing 

375 

376 def _sample(self, 

377 energies: Floats, 

378 width: float = 0.1, 

379 smearing: str = 'Gauss') -> np.ndarray: 

380 current_spacing = self._check_spacing(width) 

381 return super()._sample(energies=energies, 

382 width=width, smearing=smearing 

383 ) * current_spacing 

384 

385 def __add__(self, other: 'GridDOSData') -> 'GridDOSData': 

386 # This method uses direct access to the mutable energy and weights data 

387 # (self._data) to avoid redundant copying operations. The __init__ 

388 # method of GridDOSData will write this to a new array, so on this 

389 # occasion it is safe to pass references to the mutable data. 

390 

391 if not isinstance(other, GridDOSData): 

392 raise TypeError("GridDOSData can only be combined with other " 

393 "GridDOSData objects") 

394 if len(self._data[0, :]) != len(other.get_energies()): 

395 raise ValueError("Cannot add GridDOSData objects with different-" 

396 "length energy grids.") 

397 

398 if not np.allclose(self._data[0, :], other.get_energies()): 

399 raise ValueError("Cannot add GridDOSData objects with different " 

400 "energy grids.") 

401 

402 # Take intersection of metadata (i.e. only common entries are retained) 

403 new_info = dict(set(self.info.items()) & set(other.info.items())) 

404 

405 # Sum the energy/weight data 

406 new_weights = self._data[1, :] + other.get_weights() 

407 

408 new_object = GridDOSData(self._data[0, :], new_weights, 

409 info=new_info) 

410 return new_object 

411 

412 @staticmethod 

413 def _interpret_smearing_args(npts: int, 

414 width: float = None, 

415 default_npts: int = 1000, 

416 default_width: float = 0.1 

417 ) -> Tuple[int, Union[float, None]]: 

418 """Figure out what the user intended: resample if width provided""" 

419 if width is not None: 

420 if npts: 

421 return (npts, float(width)) 

422 else: 

423 return (default_npts, float(width)) 

424 else: 

425 if npts: 

426 return (npts, default_width) 

427 else: 

428 return (0, None) 

429 

430 def plot(self, 

431 npts: int = 0, 

432 xmin: float = None, 

433 xmax: float = None, 

434 width: float = None, 

435 smearing: str = 'Gauss', 

436 ax: Axes = None, 

437 show: bool = False, 

438 filename: str = None, 

439 mplargs: dict = None) -> Axes: 

440 """Simple 1-D plot of DOS data 

441 

442 Data will be resampled onto a grid with `npts` points unless `npts` is 

443 set to zero, in which case: 

444 

445 - no resampling takes place 

446 - `width` and `smearing` are ignored 

447 - `xmin` and `xmax` affect the axis limits of the plot, not the 

448 underlying data. 

449 

450 If the special key 'label' is present in self.info, this will be set 

451 as the label for the plotted line (unless overruled in mplargs). The 

452 label is only seen if a legend is added to the plot (i.e. by calling 

453 ``ax.legend()``). 

454 

455 Args: 

456 npts, xmin, xmax: output data range, as passed to self.sample_grid 

457 width: Width of broadening kernel, passed to self.sample_grid(). 

458 If no npts was set but width is set, npts will be set to 1000. 

459 smearing: selection of broadening kernel for self.sample_grid() 

460 ax: existing Matplotlib axes object. If not provided, a new figure 

461 with one set of axes will be created using Pyplot 

462 show: show the figure on-screen 

463 filename: if a path is given, save the figure to this file 

464 mplargs: additional arguments to pass to matplotlib plot command 

465 (e.g. {'linewidth': 2} for a thicker line). 

466 

467 Returns: 

468 Plotting axes. If "ax" was set, this is the same object. 

469 """ 

470 

471 npts, width = self._interpret_smearing_args(npts, width) 

472 

473 if mplargs is None: 

474 mplargs = {} 

475 if 'label' not in mplargs: 

476 mplargs.update({'label': self.label_from_info(self.info)}) 

477 

478 if npts: 

479 assert isinstance(width, float) 

480 dos = self.sample_grid(npts, xmin=xmin, 

481 xmax=xmax, width=width, 

482 smearing=smearing) 

483 else: 

484 dos = self 

485 

486 energies, intensity = dos.get_energies(), dos.get_weights() 

487 

488 with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: 

489 ax.plot(energies, intensity, **mplargs) 

490 ax.set_xlim(left=xmin, right=xmax) 

491 

492 return ax