Coverage for ase / _4 / calculators / calculator.py: 39.01%

141 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 15:52 +0000

1import copy 

2import os 

3import warnings 

4from abc import ABC, abstractmethod 

5from pathlib import Path 

6from typing import Any 

7 

8from ase import Atoms as V3Atoms 

9from ase._4.calculators.results import CalculationResults 

10from ase.calculators.calculator import BaseCalculator as V3BaseCalculator 

11from ase.calculators.calculator import ( 

12 Parameters, 

13 all_properties, # noqa: F401 

14 equal, 

15) 

16 

17special = { 

18 'emt': 'EMT', 

19} 

20 

21 

22class BaseCalculator(ABC): 

23 implemented_properties: list[str] = [] 

24 'Properties calculator can handle (energy, forces, ...)' 

25 

26 # Placeholder object for deprecated arguments. Let deprecated keywords 

27 # default to _deprecated and then issue a warning if the user passed 

28 # any other object (such as None). 

29 _deprecated = object() 

30 

31 def __init__(self, parameters=None): 

32 if parameters is None: 

33 parameters = {} 

34 

35 self.parameters = dict(parameters) 

36 

37 @abstractmethod 

38 def evaluate(self, atoms, properties): ... 

39 

40 def _get_name(self) -> str: # child class can override this 

41 return self.__class__.__name__.lower() 

42 

43 @property 

44 def name(self) -> str: 

45 return self._get_name() 

46 

47 def todict(self) -> dict[str, Any]: 

48 """Obtain a dictionary of parameter information""" 

49 return {} 

50 

51 

52class Calculator(BaseCalculator): 

53 """Base-class for all ASE calculators. 

54 

55 A calculator must raise PropertyNotImplementedError if asked for a 

56 property that it can't calculate. So, if calculation of the 

57 stress tensor has not been implemented, 

58 evaluate(atoms, properties=["stress"]) should raise 

59 PropertyNotImplementedError. This can be achieved simply by not 

60 including the string 'stress' in the list implemented_properties 

61 which is a class member. These are the names of the standard 

62 properties: 'energy', 'forces', 'stress', 'dipole', 'charges', 

63 'magmom' and 'magmoms'. 

64 """ 

65 

66 default_parameters: dict[str, Any] = {} 

67 'Default parameters' 

68 

69 def __init__( 

70 self, 

71 restart=None, 

72 label=None, 

73 directory='.', 

74 **kwargs, 

75 ): 

76 """Basic calculator implementation. 

77 

78 restart: str 

79 Prefix for restart file. May contain a directory. Default 

80 is None: don't restart. 

81 directory: str or PurePath 

82 Working directory in which to read and write files and 

83 perform calculations. 

84 label: str 

85 Name used for all files. Not supported by all calculators. 

86 May contain a directory, but please use the directory parameter 

87 for that instead. 

88 """ 

89 self.parameters = None # calculational parameters 

90 self._directory = None # Initialize 

91 

92 if restart is not None: 

93 # duplicated in transition implementation 

94 self.read(restart) # read parameters, atoms and results 

95 

96 self.directory = directory 

97 self.prefix = None 

98 if label is not None: 

99 if self.directory == '.' and '/' in label: 

100 # We specified directory in label, and nothing in the directory 

101 # key 

102 self.label = label 

103 elif '/' not in label: 

104 # We specified our directory in the directory keyword 

105 # or not at all 

106 self.label = '/'.join((self.directory, label)) 

107 else: 

108 raise ValueError( 

109 'Directory redundantly specified though ' 

110 'directory="{}" and label="{}". ' 

111 'Please omit "/" in label.'.format(self.directory, label) 

112 ) 

113 

114 if self.parameters is None: 

115 # Use default parameters if they were not read from file: 

116 self.parameters = self.get_default_parameters() 

117 

118 self.set_check_parameter_changes(**kwargs) 

119 

120 if not hasattr(self, 'get_spin_polarized'): 

121 self.get_spin_polarized = self._deprecated_get_spin_polarized 

122 # XXX We are very naughty and do not call super constructor! 

123 

124 @property 

125 def directory(self) -> str: 

126 return self._directory 

127 

128 @directory.setter 

129 def directory(self, directory: str | os.PathLike): 

130 self._directory = str(Path(directory)) # Normalize path. 

131 

132 @property 

133 def label(self): 

134 if self.directory == '.': 

135 return self.prefix 

136 

137 # Generally, label ~ directory/prefix 

138 # 

139 # We use '/' rather than os.pathsep because 

140 # 1) directory/prefix does not represent any actual path 

141 # 2) We want the same string to work the same on all platforms 

142 if self.prefix is None: 

143 return self.directory + '/' 

144 

145 return f'{self.directory}/{self.prefix}' 

146 

147 @label.setter 

148 def label(self, label): 

149 if label is None: 

150 self.directory = '.' 

151 self.prefix = None 

152 return 

153 

154 tokens = label.rsplit('/', 1) 

155 if len(tokens) == 2: 

156 directory, prefix = tokens 

157 else: 

158 assert len(tokens) == 1 

159 directory = '.' 

160 prefix = tokens[0] 

161 if prefix == '': 

162 prefix = None 

163 self.directory = directory 

164 self.prefix = prefix 

165 

166 def set_label(self, label): 

167 """Set label and convert label to directory and prefix. 

168 

169 Examples 

170 -------- 

171 

172 * label='abc': (directory='.', prefix='abc') 

173 * label='dir1/abc': (directory='dir1', prefix='abc') 

174 * label=None: (directory='.', prefix=None) 

175 """ 

176 self.label = label 

177 

178 def get_default_parameters(self): 

179 return Parameters(copy.deepcopy(self.default_parameters)) 

180 

181 def todict(self, skip_default=True): 

182 defaults = self.get_default_parameters() 

183 dct = {} 

184 for key, value in self.parameters.items(): 

185 if hasattr(value, 'todict'): 

186 value = value.todict() 

187 if skip_default: 

188 default = defaults.get(key, '_no_default_') 

189 if default != '_no_default_' and equal(value, default): 

190 continue 

191 dct[key] = value 

192 return dct 

193 

194 # EG: How should restarts work in v4? 

195 # It might be thought of as a type of caching, in which case it 

196 # shouldn't be part of v4 calculators and be its own function, 

197 # for example. Or an input to .evaluate . 

198 # Sticking with the original for now. 

199 def read(self, label): 

200 """To be updated or deprecated. 

201 

202 Read atoms, parameters and calculated properties from output file. 

203 

204 Read result from self.label file. Raise ReadError if the file 

205 is not there. If the file is corrupted or contains an error 

206 message from the calculation, a ReadError should also be 

207 raised. In case of success, these attributes must set: 

208 

209 atoms: Atoms object 

210 The state of the atoms from last calculation. 

211 parameters: Parameters object 

212 The parameter dictionary. 

213 results: CalculationResults 

214 Calculated properties like energy and forces. 

215 

216 The FileIOCalculator.read() method will typically read atoms 

217 and parameters and get the results dict by calling the 

218 read_results() method.""" 

219 

220 self.set_label(label) 

221 

222 # EG: not yet sure how to handle this. 

223 # The ase4 calculator needs a parameter setter function 

224 # if we only want to split out Atoms, but preserve the rest 

225 # of the behavirour unchanged. But ase3 calculator needs 

226 # to check whether the parameters have changed before updating 

227 # so that the calculator may be reset. How should that be 

228 # handled in the ASEv4 + Version3Adaptor during the transition? 

229 def set_check_parameter_changes(self, **kwargs): 

230 """Set parameters like set(key1=value1, key2=value2, ...). 

231 

232 A dictionary containing the parameters that have been changed 

233 is returned. 

234 

235 The special keyword 'parameters' can be used to read 

236 parameters from a file.""" 

237 

238 if 'parameters' in kwargs: 

239 filename = kwargs.pop('parameters') 

240 parameters = Parameters.read(filename) 

241 parameters.update(kwargs) 

242 kwargs = parameters 

243 

244 changed_parameters = {} 

245 

246 for key, value in kwargs.items(): 

247 oldvalue = self.parameters.get(key) 

248 if key not in self.parameters or not equal(value, oldvalue): 

249 changed_parameters[key] = value 

250 # set only here in v4 base class 

251 self.parameters[key] = value 

252 # also returned by the transition class 

253 return changed_parameters 

254 

255 def evaluate(self, atoms, properties=None): 

256 """Use the calculator to evaluate the structure and obtain properties. 

257 

258 atoms: Atoms 

259 Structure to be evaluated. 

260 properties: list of str 

261 List of what needs to be calculated. Can be any combination 

262 of 'energy', 'forces', 'stress', 'dipole', 'charges', 'magmom' 

263 and 'magmoms'. 

264 

265 Subclasses need to implement this, but can ignore properties 

266 if they want. Calculated properties should 

267 be returned as a CalculationResults object. 

268 

269 The subclass implementation should first call this 

270 implementation to create any missing directories. 

271 """ 

272 if properties is None: 

273 properties = ['energy'] 

274 

275 if not os.path.isdir(self._directory): 

276 try: 

277 os.makedirs(self._directory) 

278 except FileExistsError as e: 

279 # We can only end up here in case of a race condition if 

280 # multiple Calculators are running concurrently *and* use the 

281 # same _directory, which cannot be expected to work anyway. 

282 msg = ( 

283 'Concurrent use of directory ' 

284 + self._directory 

285 + 'by multiple Calculator instances detected. Please ' 

286 'use one directory per instance.' 

287 ) 

288 raise RuntimeError(msg) from e 

289 

290 def _deprecated_get_spin_polarized(self): 

291 msg = ( 

292 'This calculator does not implement get_spin_polarized(). ' 

293 'In the future, calc.get_spin_polarized() will work only on ' 

294 'calculator classes that explicitly implement this method or ' 

295 'inherit the method via specialized subclasses.' 

296 ) 

297 warnings.warn(msg, FutureWarning) 

298 return False 

299 

300 def band_structure(self): 

301 """Create band-structure object for plotting.""" 

302 from ase.spectrum.band_structure import get_band_structure 

303 

304 # XXX This calculator is supposed to just have done a band structure 

305 # calculation, but the calculator may not have the correct Fermi level 

306 # if it updated the Fermi level after changing k-points. 

307 # This will be a problem with some calculators (currently GPAW), and 

308 # the user would have to override this by providing the Fermi level 

309 # from the selfconsistent calculation. 

310 return get_band_structure(calc=self) 

311 

312 

313class Version4Adaptor(BaseCalculator): 

314 """A generic wrapper to make ASEv3 calculators work 

315 with ASE 4.x interface. 

316 """ 

317 

318 wrapped_class: type[V3BaseCalculator] 

319 

320 def __init__(self, *args, **kwargs): 

321 self._v3_calculator = self.wrapped_class(*args, **kwargs) 

322 

323 @property 

324 def parameters(self): 

325 return self._v3_calculator.parameters 

326 

327 def evaluate( 

328 self, atoms: V3Atoms, properties: list[str] | None = None 

329 ) -> CalculationResults: 

330 

331 if properties is None: 

332 properties = ['energy'] 

333 

334 # enforce no modification of the input atoms 

335 atoms = atoms.copy() 

336 self._v3_calculator.calculate(atoms=atoms, properties=properties) 

337 

338 valid = {} 

339 for prop, val in self._v3_calculator.results.items(): 

340 if prop in CalculationResults.recognised_properties: 

341 valid[prop] = val 

342 else: 

343 warnings.warn( 

344 f'Property {prop} was found in calculation results ' 

345 f'but is not one of the standard properties of ' 

346 f'ase.outputs.all_outputs, skipping.' 

347 ) 

348 

349 return CalculationResults(properties=valid)