Coverage for /builds/ase/ase/ase/calculators/checkpoint.py: 86.58%
149 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""Checkpointing and restart functionality for scripts using ASE Atoms objects.
5Initialize checkpoint object:
7CP = Checkpoint('checkpoints.db')
9Checkpointed code block in try ... except notation:
11try:
12 a, C, C_err = CP.load()
13except NoCheckpoint:
14 C, C_err = fit_elastic_constants(a)
15 CP.save(a, C, C_err)
17Checkpoint code block, shorthand notation:
19C, C_err = CP(fit_elastic_constants)(a)
21Example for checkpointing within an iterative loop, e.g. for searching crack
22tip position:
24try:
25 a, converged, tip_x, tip_y = CP.load()
26except NoCheckpoint:
27 converged = False
28 tip_x = tip_x0
29 tip_y = tip_y0
30while not converged:
31 ... do something to find better crack tip position ...
32 converged = ...
33 CP.flush(a, converged, tip_x, tip_y)
35The simplest way to use checkpointing is through the CheckpointCalculator. It
36wraps any calculator object and does a checkpoint whenever a calculation
37is performed:
39 calc = ...
40 cp_calc = CheckpointCalculator(calc)
41 atoms.calc = cp_calc
42 e = atoms.get_potential_energy() # 1st time, does calc, writes to checkfile
43 # subsequent runs, reads from checkpoint
44"""
46from typing import Any, Dict
48import numpy as np
50import ase
51from ase.calculators.calculator import Calculator
52from ase.db import connect
55class NoCheckpoint(Exception):
56 pass
59class DevNull:
60 def write(str, *args):
61 pass
64class Checkpoint:
65 _value_prefix = '_values_'
67 def __init__(self, db='checkpoints.db', logfile=None):
68 self.db = db
69 if logfile is None:
70 logfile = DevNull()
71 self.logfile = logfile
73 self.checkpoint_id = [0]
74 self.in_checkpointed_region = False
76 def __call__(self, func, *args, **kwargs):
77 checkpoint_func_name = str(func)
79 def decorated_func(*args, **kwargs):
80 # Get the first ase.Atoms object.
81 atoms = None
82 for a in args:
83 if atoms is None and isinstance(a, ase.Atoms):
84 atoms = a
86 try:
87 retvals = self.load(atoms=atoms)
88 except NoCheckpoint:
89 retvals = func(*args, **kwargs)
90 if isinstance(retvals, tuple):
91 self.save(*retvals, atoms=atoms,
92 checkpoint_func_name=checkpoint_func_name)
93 else:
94 self.save(retvals, atoms=atoms,
95 checkpoint_func_name=checkpoint_func_name)
96 return retvals
97 return decorated_func
99 def _increase_checkpoint_id(self):
100 if self.in_checkpointed_region:
101 self.checkpoint_id += [1]
102 else:
103 self.checkpoint_id[-1] += 1
104 self.logfile.write('Entered checkpoint region '
105 '{}.\n'.format(self.checkpoint_id))
107 self.in_checkpointed_region = True
109 def _decrease_checkpoint_id(self):
110 self.logfile.write('Leaving checkpoint region '
111 '{}.\n'.format(self.checkpoint_id))
112 if not self.in_checkpointed_region:
113 self.checkpoint_id = self.checkpoint_id[:-1]
114 assert len(self.checkpoint_id) >= 1
115 self.in_checkpointed_region = False
116 assert self.checkpoint_id[-1] >= 1
118 def _mangled_checkpoint_id(self):
119 """
120 Returns a mangled checkpoint id string:
121 check_c_1:c_2:c_3:...
122 E.g. if checkpoint is nested and id is [3,2,6] it returns:
123 'check3:2:6'
124 """
125 return 'check' + ':'.join(str(id) for id in self.checkpoint_id)
127 def load(self, atoms=None):
128 """
129 Retrieve checkpoint data from file. If atoms object is specified, then
130 the calculator connected to that object is copied to all returning
131 atoms object.
133 Returns tuple of values as passed to flush or save during checkpoint
134 write.
135 """
136 self._increase_checkpoint_id()
138 retvals = []
139 with connect(self.db) as db:
140 try:
141 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id())
142 except KeyError:
143 raise NoCheckpoint
145 data = dbentry.data
146 atomsi = data['checkpoint_atoms_args_index']
147 i = 0
148 while (i == atomsi or
149 f'{self._value_prefix}{i}' in data):
150 if i == atomsi:
151 newatoms = dbentry.toatoms()
152 if atoms is not None:
153 # Assign calculator
154 newatoms.calc = atoms.calc
155 retvals += [newatoms]
156 else:
157 retvals += [data[f'{self._value_prefix}{i}']]
158 i += 1
160 self.logfile.write('Successfully restored checkpoint '
161 '{}.\n'.format(self.checkpoint_id))
162 self._decrease_checkpoint_id()
163 if len(retvals) == 1:
164 return retvals[0]
165 else:
166 return tuple(retvals)
168 def _flush(self, *args, **kwargs):
169 data = {f'{self._value_prefix}{i}': v
170 for i, v in enumerate(args)}
172 try:
173 atomsi = [isinstance(v, ase.Atoms) for v in args].index(True)
174 atoms = args[atomsi]
175 del data[f'{self._value_prefix}{atomsi}']
176 except ValueError:
177 atomsi = -1
178 try:
179 atoms = kwargs['atoms']
180 except KeyError:
181 raise RuntimeError('No atoms object provided in arguments.')
183 try:
184 del kwargs['atoms']
185 except KeyError:
186 pass
188 data['checkpoint_atoms_args_index'] = atomsi
189 data.update(kwargs)
191 with connect(self.db) as db:
192 try:
193 dbentry = db.get(checkpoint_id=self._mangled_checkpoint_id())
194 del db[dbentry.id]
195 except KeyError:
196 pass
197 db.write(atoms, checkpoint_id=self._mangled_checkpoint_id(),
198 data=data)
200 self.logfile.write('Successfully stored checkpoint '
201 '{}.\n'.format(self.checkpoint_id))
203 def flush(self, *args, **kwargs):
204 """
205 Store data to a checkpoint without increasing the checkpoint id. This
206 is useful to continuously update the checkpoint state in an iterative
207 loop.
208 """
209 # If we are flushing from a successfully restored checkpoint, then
210 # in_checkpointed_region will be set to False. We need to reset to True
211 # because a call to flush indicates that this checkpoint is still
212 # active.
213 self.in_checkpointed_region = False
214 self._flush(*args, **kwargs)
216 def save(self, *args, **kwargs):
217 """
218 Store data to a checkpoint and increase the checkpoint id. This closes
219 the checkpoint.
220 """
221 self._decrease_checkpoint_id()
222 self._flush(*args, **kwargs)
225def atoms_almost_equal(a, b, tol=1e-9):
226 return (np.abs(a.positions - b.positions).max() < tol and
227 (a.numbers == b.numbers).all() and
228 np.abs(a.cell - b.cell).max() < tol and
229 (a.pbc == b.pbc).all())
232class CheckpointCalculator(Calculator):
233 """
234 This wraps any calculator object to checkpoint whenever a calculation
235 is performed.
237 This is particularly useful for expensive calculators, e.g. DFT and
238 allows usage of complex workflows.
240 Example usage:
242 calc = ...
243 cp_calc = CheckpointCalculator(calc)
244 atoms.calc = cp_calc
245 e = atoms.get_potential_energy()
246 # 1st time, does calc, writes to checkfile
247 # subsequent runs, reads from checkpoint file
248 """
249 implemented_properties = ase.calculators.calculator.all_properties
250 default_parameters: Dict[str, Any] = {}
251 name = 'CheckpointCalculator'
253 property_to_method_name = {
254 'energy': 'get_potential_energy',
255 'energies': 'get_potential_energies',
256 'forces': 'get_forces',
257 'stress': 'get_stress',
258 'stresses': 'get_stresses'}
260 def __init__(self, calculator, db='checkpoints.db', logfile=None):
261 Calculator.__init__(self)
262 self.calculator = calculator
263 if logfile is None:
264 logfile = DevNull()
265 self.checkpoint = Checkpoint(db, logfile)
266 self.logfile = logfile
268 def calculate(self, atoms, properties, system_changes):
269 Calculator.calculate(self, atoms, properties, system_changes)
270 try:
271 results = self.checkpoint.load(atoms)
272 prev_atoms, results = results[0], results[1:]
273 if not atoms_almost_equal(atoms, prev_atoms):
274 raise RuntimeError('mismatch between current atoms and '
275 'those read from checkpoint file')
276 self.logfile.write('retrieved results for {} from checkpoint\n'
277 .format(properties))
278 # save results in calculator for next time
279 if isinstance(self.calculator, Calculator):
280 if not hasattr(self.calculator, 'results'):
281 self.calculator.results = {}
282 self.calculator.results.update(dict(zip(properties, results)))
283 except NoCheckpoint:
284 if isinstance(self.calculator, Calculator):
285 self.logfile.write('doing calculation of {} with new-style '
286 'calculator interface\n'.format(properties))
287 self.calculator.calculate(atoms, properties, system_changes)
288 results = [self.calculator.results[prop]
289 for prop in properties]
290 else:
291 self.logfile.write('doing calculation of {} with old-style '
292 'calculator interface\n'.format(properties))
293 results = []
294 for prop in properties:
295 method_name = self.property_to_method_name[prop]
296 method = getattr(self.calculator, method_name)
297 results.append(method(atoms))
298 _calculator = atoms.calc
299 try:
300 atoms.calc = self.calculator
301 self.checkpoint.save(atoms, *results)
302 finally:
303 atoms.calc = _calculator
305 self.results = dict(zip(properties, results))