Coverage for /builds/ase/ase/ase/calculators/singlepoint.py: 79.61%

206 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3from functools import cached_property 

4 

5import numpy as np 

6 

7from ase.calculators.calculator import ( 

8 Calculator, 

9 PropertyNotImplementedError, 

10 PropertyNotPresent, 

11 all_properties, 

12) 

13from ase.outputs import Properties 

14 

15 

16class SinglePointCalculator(Calculator): 

17 """Special calculator for a single configuration. 

18 

19 Used to remember the energy, force and stress for a given 

20 configuration. If the positions, atomic numbers, unit cell, or 

21 boundary conditions are changed, then asking for 

22 energy/forces/stress will raise an exception.""" 

23 

24 name = 'unknown' 

25 

26 def __init__(self, atoms, **results): 

27 """Save energy, forces, stress, ... for the current configuration.""" 

28 Calculator.__init__(self) 

29 self.results = {} 

30 for property, value in results.items(): 

31 assert property in all_properties, property 

32 if value is None: 

33 continue 

34 if property in ['energy', 'magmom', 'free_energy']: 

35 self.results[property] = value 

36 else: 

37 self.results[property] = np.array(value, float) 

38 self.atoms = atoms.copy() 

39 

40 def __str__(self): 

41 tokens = [] 

42 for key, val in sorted(self.results.items()): 

43 if np.isscalar(val): 

44 txt = f'{key}={val}' 

45 else: 

46 txt = f'{key}=...' 

47 tokens.append(txt) 

48 return '{}({})'.format(self.__class__.__name__, ', '.join(tokens)) 

49 

50 def get_property(self, name, atoms=None, allow_calculation=True): 

51 if atoms is None: 

52 atoms = self.atoms 

53 if name not in self.results or self.check_state(atoms): 

54 if allow_calculation: 

55 raise PropertyNotImplementedError( 

56 f'The property "{name}" is not available.') 

57 return None 

58 

59 result = self.results[name] 

60 if isinstance(result, np.ndarray): 

61 result = result.copy() 

62 return result 

63 

64 

65class SinglePointKPoint: 

66 def __init__(self, weight, s, k, eps_n=None, f_n=None): 

67 self.weight = weight 

68 self.s = s # spin index 

69 self.k = k # k-point index 

70 if eps_n is None: 

71 eps_n = [] 

72 self.eps_n = eps_n 

73 if f_n is None: 

74 f_n = [] 

75 self.f_n = f_n 

76 

77 

78def arrays_to_kpoints(eigenvalues, occupations, weights): 

79 """Helper function for building SinglePointKPoints. 

80 

81 Convert eigenvalue, occupation, and weight arrays to list of 

82 SinglePointKPoint objects.""" 

83 nspins, nkpts, _nbands = eigenvalues.shape 

84 assert eigenvalues.shape == occupations.shape 

85 assert len(weights) == nkpts 

86 kpts = [] 

87 for s in range(nspins): 

88 for k in range(nkpts): 

89 kpt = SinglePointKPoint( 

90 weight=weights[k], s=s, k=k, 

91 eps_n=eigenvalues[s, k], f_n=occupations[s, k]) 

92 kpts.append(kpt) 

93 return kpts 

94 

95 

96class SinglePointDFTCalculator(SinglePointCalculator): 

97 def __init__(self, atoms, 

98 efermi=None, bzkpts=None, ibzkpts=None, bz2ibz=None, 

99 kpts=None, 

100 **results): 

101 self.bz_kpts = bzkpts 

102 self.ibz_kpts = ibzkpts 

103 self.bz2ibz = bz2ibz 

104 self.eFermi = efermi 

105 

106 SinglePointCalculator.__init__(self, atoms, **results) 

107 self.kpts = kpts 

108 

109 def get_fermi_level(self): 

110 """Return the Fermi-level(s).""" 

111 return self.eFermi 

112 

113 def get_bz_to_ibz_map(self): 

114 return self.bz2ibz 

115 

116 def get_bz_k_points(self): 

117 """Return the k-points.""" 

118 return self.bz_kpts 

119 

120 def get_number_of_spins(self): 

121 """Return the number of spins in the calculation. 

122 

123 Spin-paired calculations: 1, spin-polarized calculation: 2.""" 

124 if self.kpts is not None: 

125 nspin = set() 

126 for kpt in self.kpts: 

127 nspin.add(kpt.s) 

128 return len(nspin) 

129 return None 

130 

131 def get_number_of_bands(self): 

132 values = {len(kpt.eps_n) for kpt in self.kpts} 

133 if not values: 

134 return None 

135 elif len(values) == 1: 

136 return values.pop() 

137 else: 

138 raise RuntimeError('Multiple array sizes') 

139 

140 def get_spin_polarized(self): 

141 """Is it a spin-polarized calculation?""" 

142 nos = self.get_number_of_spins() 

143 if nos is not None: 

144 return nos == 2 

145 return None 

146 

147 def get_ibz_k_points(self): 

148 """Return k-points in the irreducible part of the Brillouin zone.""" 

149 return self.ibz_kpts 

150 

151 def get_kpt(self, kpt=0, spin=0): 

152 if self.kpts is not None: 

153 counter = 0 

154 for kpoint in self.kpts: 

155 if kpoint.s == spin: 

156 if kpt == counter: 

157 return kpoint 

158 counter += 1 

159 return None 

160 

161 def get_k_point_weights(self): 

162 """ Retunrs the weights of the k points """ 

163 if self.kpts is not None: 

164 weights = [] 

165 for kpoint in self.kpts: 

166 if kpoint.s == 0: 

167 weights.append(kpoint.weight) 

168 return np.array(weights) 

169 return None 

170 

171 def get_occupation_numbers(self, kpt=0, spin=0): 

172 """Return occupation number array.""" 

173 kpoint = self.get_kpt(kpt, spin) 

174 if kpoint is not None: 

175 if len(kpoint.f_n): 

176 return kpoint.f_n 

177 return None 

178 

179 def get_eigenvalues(self, kpt=0, spin=0): 

180 """Return eigenvalue array.""" 

181 kpoint = self.get_kpt(kpt, spin) 

182 if kpoint is not None: 

183 return kpoint.eps_n 

184 return None 

185 

186 def get_homo_lumo(self): 

187 """Return HOMO and LUMO energies.""" 

188 if self.kpts is None: 

189 raise RuntimeError('No kpts') 

190 eH = -np.inf 

191 eL = np.inf 

192 for spin in range(self.get_number_of_spins()): 

193 homo, lumo = self.get_homo_lumo_by_spin(spin) 

194 eH = max(eH, homo) 

195 eL = min(eL, lumo) 

196 return eH, eL 

197 

198 def get_homo_lumo_by_spin(self, spin=0): 

199 """Return HOMO and LUMO energies for a given spin.""" 

200 if self.kpts is None: 

201 raise RuntimeError('No kpts') 

202 for kpt in self.kpts: 

203 if kpt.s == spin: 

204 break 

205 else: 

206 raise RuntimeError(f'No k-point with spin {spin}') 

207 if self.eFermi is None: 

208 raise RuntimeError('Fermi level is not available') 

209 eH = -1.e32 

210 eL = 1.e32 

211 for kpt in self.kpts: 

212 if kpt.s == spin: 

213 for e in kpt.eps_n: 

214 if e <= self.eFermi: 

215 eH = max(eH, e) 

216 else: 

217 eL = min(eL, e) 

218 return eH, eL 

219 

220 def properties(self) -> Properties: 

221 return OutputPropertyWrapper(self).properties() 

222 

223 

224def propertygetter(func): 

225 from functools import wraps 

226 

227 @wraps(func) 

228 def getter(self): 

229 value = func(self) 

230 if value is None: 

231 raise PropertyNotPresent(func.__name__) 

232 return value 

233 return cached_property(getter) 

234 

235 

236class OutputPropertyWrapper: 

237 def __init__(self, calc): 

238 self.calc = calc 

239 

240 @propertygetter 

241 def nspins(self): 

242 return self.calc.get_number_of_spins() 

243 

244 @propertygetter 

245 def nbands(self): 

246 return self.calc.get_number_of_bands() 

247 

248 @propertygetter 

249 def nkpts(self): 

250 return len(self.calc.kpts) // self.nspins 

251 

252 def _build_eig_occ_array(self, getter): 

253 arr = np.empty((self.nspins, self.nkpts, self.nbands)) 

254 for s in range(self.nspins): 

255 for k in range(self.nkpts): 

256 value = getter(spin=s, kpt=k) 

257 if value is None: 

258 return None 

259 arr[s, k, :] = value 

260 return arr 

261 

262 @propertygetter 

263 def eigenvalues(self): 

264 return self._build_eig_occ_array(self.calc.get_eigenvalues) 

265 

266 @propertygetter 

267 def occupations(self): 

268 return self._build_eig_occ_array(self.calc.get_occupation_numbers) 

269 

270 @propertygetter 

271 def fermi_level(self): 

272 return self.calc.get_fermi_level() 

273 

274 @propertygetter 

275 def kpoint_weights(self): 

276 return self.calc.get_k_point_weights() 

277 

278 @propertygetter 

279 def ibz_kpoints(self): 

280 return self.calc.get_ibz_k_points() 

281 

282 def properties(self) -> Properties: 

283 dct = {} 

284 for name in ['eigenvalues', 'occupations', 'fermi_level', 

285 'kpoint_weights', 'ibz_kpoints']: 

286 try: 

287 value = getattr(self, name) 

288 except PropertyNotPresent: 

289 pass 

290 else: 

291 dct[name] = value 

292 

293 for name, value in self.calc.results.items(): 

294 dct[name] = value 

295 

296 return Properties(dct)