Coverage for /builds/ase/ase/ase/calculators/checkpoint.py: 86.58%

149 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3"""Checkpointing and restart functionality for scripts using ASE Atoms objects. 

4 

5Initialize checkpoint object: 

6 

7CP = Checkpoint('checkpoints.db') 

8 

9Checkpointed code block in try ... except notation: 

10 

11try: 

12 a, C, C_err = CP.load() 

13except NoCheckpoint: 

14 C, C_err = fit_elastic_constants(a) 

15 CP.save(a, C, C_err) 

16 

17Checkpoint code block, shorthand notation: 

18 

19C, C_err = CP(fit_elastic_constants)(a) 

20 

21Example for checkpointing within an iterative loop, e.g. for searching crack 

22tip position: 

23 

24try: 

25 a, converged, tip_x, tip_y = CP.load() 

26except NoCheckpoint: 

27 converged = False 

28 tip_x = tip_x0 

29 tip_y = tip_y0 

30while not converged: 

31 ... do something to find better crack tip position ... 

32 converged = ... 

33 CP.flush(a, converged, tip_x, tip_y) 

34 

35The simplest way to use checkpointing is through the CheckpointCalculator. It 

36wraps any calculator object and does a checkpoint whenever a calculation 

37is performed: 

38 

39 calc = ... 

40 cp_calc = CheckpointCalculator(calc) 

41 atoms.calc = cp_calc 

42 e = atoms.get_potential_energy() # 1st time, does calc, writes to checkfile 

43 # subsequent runs, reads from checkpoint 

44""" 

45 

46from typing import Any, Dict 

47 

48import numpy as np 

49 

50import ase 

51from ase.calculators.calculator import Calculator 

52from ase.db import connect 

53 

54 

55class NoCheckpoint(Exception): 

56 pass 

57 

58 

59class DevNull: 

60 def write(str, *args): 

61 pass 

62 

63 

64class Checkpoint: 

65 _value_prefix = '_values_' 

66 

67 def __init__(self, db='checkpoints.db', logfile=None): 

68 self.db = db 

69 if logfile is None: 

70 logfile = DevNull() 

71 self.logfile = logfile 

72 

73 self.checkpoint_id = [0] 

74 self.in_checkpointed_region = False 

75 

76 def __call__(self, func, *args, **kwargs): 

77 checkpoint_func_name = str(func) 

78 

79 def decorated_func(*args, **kwargs): 

80 # Get the first ase.Atoms object. 

81 atoms = None 

82 for a in args: 

83 if atoms is None and isinstance(a, ase.Atoms): 

84 atoms = a 

85 

86 try: 

87 retvals = self.load(atoms=atoms) 

88 except NoCheckpoint: 

89 retvals = func(*args, **kwargs) 

90 if isinstance(retvals, tuple): 

91 self.save(*retvals, atoms=atoms, 

92 checkpoint_func_name=checkpoint_func_name) 

93 else: 

94 self.save(retvals, atoms=atoms, 

95 checkpoint_func_name=checkpoint_func_name) 

96 return retvals 

97 return decorated_func 

98 

99 def _increase_checkpoint_id(self): 

100 if self.in_checkpointed_region: 

101 self.checkpoint_id += [1] 

102 else: 

103 self.checkpoint_id[-1] += 1 

104 self.logfile.write('Entered checkpoint region ' 

105 '{}.\n'.format(self.checkpoint_id)) 

106 

107 self.in_checkpointed_region = True 

108 

109 def _decrease_checkpoint_id(self): 

110 self.logfile.write('Leaving checkpoint region ' 

111 '{}.\n'.format(self.checkpoint_id)) 

112 if not self.in_checkpointed_region: 

113 self.checkpoint_id = self.checkpoint_id[:-1] 

114 assert len(self.checkpoint_id) >= 1 

115 self.in_checkpointed_region = False 

116 assert self.checkpoint_id[-1] >= 1 

117 

118 def _mangled_checkpoint_id(self): 

119 """ 

120 Returns a mangled checkpoint id string: 

121 check_c_1:c_2:c_3:... 

122 E.g. if checkpoint is nested and id is [3,2,6] it returns: 

123 'check3:2:6' 

124 """ 

125 return 'check' + ':'.join(str(id) for id in self.checkpoint_id) 

126 

127 def load(self, atoms=None): 

128 """ 

129 Retrieve checkpoint data from file. If atoms object is specified, then 

130 the calculator connected to that object is copied to all returning 

131 atoms object. 

132 

133 Returns tuple of values as passed to flush or save during checkpoint 

134 write. 

135 """ 

136 self._increase_checkpoint_id() 

137 

138 retvals = [] 

139 with connect(self.db) as db: 

140 try: 

141 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id()) 

142 except KeyError: 

143 raise NoCheckpoint 

144 

145 data = dbentry.data 

146 atomsi = data['checkpoint_atoms_args_index'] 

147 i = 0 

148 while (i == atomsi or 

149 f'{self._value_prefix}{i}' in data): 

150 if i == atomsi: 

151 newatoms = dbentry.toatoms() 

152 if atoms is not None: 

153 # Assign calculator 

154 newatoms.calc = atoms.calc 

155 retvals += [newatoms] 

156 else: 

157 retvals += [data[f'{self._value_prefix}{i}']] 

158 i += 1 

159 

160 self.logfile.write('Successfully restored checkpoint ' 

161 '{}.\n'.format(self.checkpoint_id)) 

162 self._decrease_checkpoint_id() 

163 if len(retvals) == 1: 

164 return retvals[0] 

165 else: 

166 return tuple(retvals) 

167 

168 def _flush(self, *args, **kwargs): 

169 data = {f'{self._value_prefix}{i}': v 

170 for i, v in enumerate(args)} 

171 

172 try: 

173 atomsi = [isinstance(v, ase.Atoms) for v in args].index(True) 

174 atoms = args[atomsi] 

175 del data[f'{self._value_prefix}{atomsi}'] 

176 except ValueError: 

177 atomsi = -1 

178 try: 

179 atoms = kwargs['atoms'] 

180 except KeyError: 

181 raise RuntimeError('No atoms object provided in arguments.') 

182 

183 try: 

184 del kwargs['atoms'] 

185 except KeyError: 

186 pass 

187 

188 data['checkpoint_atoms_args_index'] = atomsi 

189 data.update(kwargs) 

190 

191 with connect(self.db) as db: 

192 try: 

193 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id()) 

194 del db[dbentry.id] 

195 except KeyError: 

196 pass 

197 db.write(atoms, checkpoint_id=self._mangled_checkpoint_id(), 

198 data=data) 

199 

200 self.logfile.write('Successfully stored checkpoint ' 

201 '{}.\n'.format(self.checkpoint_id)) 

202 

203 def flush(self, *args, **kwargs): 

204 """ 

205 Store data to a checkpoint without increasing the checkpoint id. This 

206 is useful to continuously update the checkpoint state in an iterative 

207 loop. 

208 """ 

209 # If we are flushing from a successfully restored checkpoint, then 

210 # in_checkpointed_region will be set to False. We need to reset to True 

211 # because a call to flush indicates that this checkpoint is still 

212 # active. 

213 self.in_checkpointed_region = False 

214 self._flush(*args, **kwargs) 

215 

216 def save(self, *args, **kwargs): 

217 """ 

218 Store data to a checkpoint and increase the checkpoint id. This closes 

219 the checkpoint. 

220 """ 

221 self._decrease_checkpoint_id() 

222 self._flush(*args, **kwargs) 

223 

224 

225def atoms_almost_equal(a, b, tol=1e-9): 

226 return (np.abs(a.positions - b.positions).max() < tol and 

227 (a.numbers == b.numbers).all() and 

228 np.abs(a.cell - b.cell).max() < tol and 

229 (a.pbc == b.pbc).all()) 

230 

231 

232class CheckpointCalculator(Calculator): 

233 """ 

234 This wraps any calculator object to checkpoint whenever a calculation 

235 is performed. 

236 

237 This is particularly useful for expensive calculators, e.g. DFT and 

238 allows usage of complex workflows. 

239 

240 Example usage: 

241 

242 calc = ... 

243 cp_calc = CheckpointCalculator(calc) 

244 atoms.calc = cp_calc 

245 e = atoms.get_potential_energy() 

246 # 1st time, does calc, writes to checkfile 

247 # subsequent runs, reads from checkpoint file 

248 """ 

249 implemented_properties = ase.calculators.calculator.all_properties 

250 default_parameters: Dict[str, Any] = {} 

251 name = 'CheckpointCalculator' 

252 

253 property_to_method_name = { 

254 'energy': 'get_potential_energy', 

255 'energies': 'get_potential_energies', 

256 'forces': 'get_forces', 

257 'stress': 'get_stress', 

258 'stresses': 'get_stresses'} 

259 

260 def __init__(self, calculator, db='checkpoints.db', logfile=None): 

261 Calculator.__init__(self) 

262 self.calculator = calculator 

263 if logfile is None: 

264 logfile = DevNull() 

265 self.checkpoint = Checkpoint(db, logfile) 

266 self.logfile = logfile 

267 

268 def calculate(self, atoms, properties, system_changes): 

269 Calculator.calculate(self, atoms, properties, system_changes) 

270 try: 

271 results = self.checkpoint.load(atoms) 

272 prev_atoms, results = results[0], results[1:] 

273 if not atoms_almost_equal(atoms, prev_atoms): 

274 raise RuntimeError('mismatch between current atoms and ' 

275 'those read from checkpoint file') 

276 self.logfile.write('retrieved results for {} from checkpoint\n' 

277 .format(properties)) 

278 # save results in calculator for next time 

279 if isinstance(self.calculator, Calculator): 

280 if not hasattr(self.calculator, 'results'): 

281 self.calculator.results = {} 

282 self.calculator.results.update(dict(zip(properties, results))) 

283 except NoCheckpoint: 

284 if isinstance(self.calculator, Calculator): 

285 self.logfile.write('doing calculation of {} with new-style ' 

286 'calculator interface\n'.format(properties)) 

287 self.calculator.calculate(atoms, properties, system_changes) 

288 results = [self.calculator.results[prop] 

289 for prop in properties] 

290 else: 

291 self.logfile.write('doing calculation of {} with old-style ' 

292 'calculator interface\n'.format(properties)) 

293 results = [] 

294 for prop in properties: 

295 method_name = self.property_to_method_name[prop] 

296 method = getattr(self.calculator, method_name) 

297 results.append(method(atoms)) 

298 _calculator = atoms.calc 

299 try: 

300 atoms.calc = self.calculator 

301 self.checkpoint.save(atoms, *results) 

302 finally: 

303 atoms.calc = _calculator 

304 

305 self.results = dict(zip(properties, results))