Coverage for /builds/ase/ase/ase/optimize/optimize.py: 94.80%
173 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""Structure optimization. """
4import time
5import warnings
6from collections.abc import Callable
7from functools import cached_property
8from os.path import isfile
9from pathlib import Path
10from typing import IO, Any, Dict, List, Optional, Tuple, Union
12from ase import Atoms
13from ase.calculators.calculator import PropertyNotImplementedError
14from ase.filters import UnitCellFilter
15from ase.parallel import world
16from ase.utils import IOContext
17from ase.utils.abc import Optimizable
19DEFAULT_MAX_STEPS = 100_000_000
22class RestartError(RuntimeError):
23 pass
26class OptimizableAtoms(Optimizable):
27 def __init__(self, atoms):
28 self.atoms = atoms
30 def get_x(self):
31 return self.atoms.get_positions().ravel()
33 def set_x(self, x):
34 self.atoms.set_positions(x.reshape(-1, 3))
36 def get_gradient(self):
37 return self.atoms.get_forces().ravel()
39 @cached_property
40 def _use_force_consistent_energy(self):
41 # This boolean is in principle invalidated if the
42 # calculator changes. This can lead to weird things
43 # in multi-step optimizations.
44 try:
45 self.atoms.get_potential_energy(force_consistent=True)
46 except PropertyNotImplementedError:
47 # warnings.warn(
48 # 'Could not get force consistent energy (\'free_energy\'). '
49 # 'Please make sure calculator provides \'free_energy\', even '
50 # 'if equal to the ordinary energy. '
51 # 'This will raise an error in future versions of ASE.',
52 # FutureWarning)
53 return False
54 else:
55 return True
57 def get_value(self):
58 force_consistent = self._use_force_consistent_energy
59 return self.atoms.get_potential_energy(
60 force_consistent=force_consistent)
62 def iterimages(self):
63 # XXX document purpose of iterimages
64 return self.atoms.iterimages()
66 def ndofs(self):
67 return 3 * len(self.atoms)
70class Dynamics(IOContext):
71 """Base-class for all MD and structure optimization classes."""
73 def __init__(
74 self,
75 atoms: Atoms,
76 logfile: Optional[Union[IO, Path, str]] = None,
77 trajectory: Optional[Union[str, Path]] = None,
78 append_trajectory: bool = False,
79 master: Optional[bool] = None,
80 comm=world,
81 *,
82 loginterval: int = 1,
83 ):
84 """Dynamics object.
86 Parameters
87 ----------
88 atoms : Atoms object
89 The Atoms object to operate on.
91 logfile : file object, Path, or str
92 If *logfile* is a string, a file with that name will be opened.
93 Use '-' for stdout.
95 trajectory : Trajectory object, str, or Path
96 Attach a trajectory object. If *trajectory* is a string/Path, a
97 Trajectory will be constructed. Use *None* for no trajectory.
99 append_trajectory : bool
100 Defaults to False, which causes the trajectory file to be
101 overwriten each time the dynamics is restarted from scratch.
102 If True, the new structures are appended to the trajectory
103 file instead.
105 master : bool
106 Defaults to None, which causes only rank 0 to save files. If set to
107 true, this rank will save files.
109 comm : Communicator object
110 Communicator to handle parallel file reading and writing.
112 loginterval : int, default: 1
113 Only write a log line for every *loginterval* time steps.
114 """
115 self.atoms = atoms
116 self.optimizable = atoms.__ase_optimizable__()
117 self.logfile = self.openfile(file=logfile, comm=comm, mode='a')
118 self.observers: List[Tuple[Callable, int, Tuple, Dict[str, Any]]] = []
119 self.nsteps = 0
120 self.max_steps = 0 # to be updated in run or irun
121 self.comm = comm
123 if trajectory is not None:
124 if isinstance(trajectory, str) or isinstance(trajectory, Path):
125 from ase.io.trajectory import Trajectory
126 mode = "a" if append_trajectory else "w"
127 trajectory = self.closelater(Trajectory(
128 trajectory, mode=mode, master=master, comm=comm
129 ))
130 self.attach(
131 trajectory,
132 interval=loginterval,
133 atoms=self.optimizable,
134 )
136 self.trajectory = trajectory
138 def todict(self) -> Dict[str, Any]:
139 raise NotImplementedError
141 def get_number_of_steps(self):
142 return self.nsteps
144 def insert_observer(
145 self, function, position=0, interval=1, *args, **kwargs
146 ):
147 """Insert an observer.
149 This can be used for pre-processing before logging and dumping.
151 Examples
152 --------
153 >>> from ase.build import bulk
154 >>> from ase.calculators.emt import EMT
155 >>> from ase.optimize import BFGS
156 ...
157 ...
158 >>> def update_info(atoms, opt):
159 ... atoms.info["nsteps"] = opt.nsteps
160 ...
161 ...
162 >>> atoms = bulk("Cu", cubic=True) * 2
163 >>> atoms.rattle()
164 >>> atoms.calc = EMT()
165 >>> with BFGS(atoms, logfile=None, trajectory="opt.traj") as opt:
166 ... opt.insert_observer(update_info, atoms=atoms, opt=opt)
167 ... opt.run(fmax=0.05, steps=10)
168 True
169 """
170 if not isinstance(function, Callable):
171 function = function.write
172 self.observers.insert(position, (function, interval, args, kwargs))
174 def attach(self, function, interval=1, *args, **kwargs):
175 """Attach callback function.
177 If *interval > 0*, at every *interval* steps, call *function* with
178 arguments *args* and keyword arguments *kwargs*.
180 If *interval <= 0*, after step *interval*, call *function* with
181 arguments *args* and keyword arguments *kwargs*. This is
182 currently zero indexed."""
184 if hasattr(function, "set_description"):
185 d = self.todict()
186 d.update(interval=interval)
187 function.set_description(d)
188 if not isinstance(function, Callable):
189 function = function.write
190 self.observers.append((function, interval, args, kwargs))
192 def call_observers(self):
193 for function, interval, args, kwargs in self.observers:
194 call = False
195 # Call every interval iterations
196 if interval > 0:
197 if (self.nsteps % interval) == 0:
198 call = True
199 # Call only on iteration interval
200 elif interval <= 0:
201 if self.nsteps == abs(interval):
202 call = True
203 if call:
204 function(*args, **kwargs)
206 def irun(self, steps=DEFAULT_MAX_STEPS):
207 """Run dynamics algorithm as generator.
209 Parameters
210 ----------
211 steps : int, default=DEFAULT_MAX_STEPS
212 Number of dynamics steps to be run.
214 Yields
215 ------
216 converged : bool
217 True if the forces on atoms are converged.
219 Examples
220 --------
221 This method allows, e.g., to run two optimizers or MD thermostats at
222 the same time.
223 >>> opt1 = BFGS(atoms)
224 >>> opt2 = BFGS(StrainFilter(atoms)).irun()
225 >>> for _ in opt2:
226 ... opt1.run()
227 """
229 # update the maximum number of steps
230 self.max_steps = self.nsteps + steps
232 # compute the initial step
233 gradient = self.optimizable.get_gradient()
235 # log the initial step
236 if self.nsteps == 0:
237 self.log(gradient)
239 # we write a trajectory file if it is None
240 if self.trajectory is None:
241 self.call_observers()
242 # We do not write on restart w/ an existing trajectory file
243 # present. This duplicates the same entry twice
244 elif len(self.trajectory) == 0:
245 self.call_observers()
247 # check convergence
248 gradient = self.optimizable.get_gradient()
249 is_converged = self.converged(gradient)
250 yield is_converged
252 # run the algorithm until converged or max_steps reached
253 while not is_converged and self.nsteps < self.max_steps:
254 # compute the next step
255 self.step()
256 self.nsteps += 1
258 # log the step
259 gradient = self.optimizable.get_gradient()
260 self.log(gradient)
261 self.call_observers()
263 # check convergence
264 gradient = self.optimizable.get_gradient()
265 is_converged = self.converged(gradient)
266 yield is_converged
268 def run(self, steps=DEFAULT_MAX_STEPS):
269 """Run dynamics algorithm.
271 This method will return when the forces on all individual
272 atoms are less than *fmax* or when the number of steps exceeds
273 *steps*.
275 Parameters
276 ----------
277 steps : int, default=DEFAULT_MAX_STEPS
278 Number of dynamics steps to be run.
280 Returns
281 -------
282 converged : bool
283 True if the forces on atoms are converged.
284 """
286 for converged in Dynamics.irun(self, steps=steps):
287 pass
288 return converged
290 def converged(self, gradient):
291 """" a dummy function as placeholder for a real criterion, e.g. in
292 Optimizer """
293 return False
295 def log(self, *args, **kwargs):
296 """ a dummy function as placeholder for a real logger, e.g. in
297 Optimizer """
298 return True
300 def step(self):
301 """this needs to be implemented by subclasses"""
302 raise RuntimeError("step not implemented.")
305class Optimizer(Dynamics):
306 """Base-class for all structure optimization classes."""
308 # default maxstep for all optimizers
309 defaults = {'maxstep': 0.2}
310 _deprecated = object()
312 def __init__(
313 self,
314 atoms: Atoms,
315 restart: Optional[str] = None,
316 logfile: Optional[Union[IO, str, Path]] = None,
317 trajectory: Optional[Union[str, Path]] = None,
318 append_trajectory: bool = False,
319 **kwargs,
320 ):
321 """
323 Parameters
324 ----------
325 atoms: :class:`~ase.Atoms`
326 The Atoms object to relax.
328 restart: str
329 Filename for restart file. Default value is *None*.
331 logfile: file object, Path, or str
332 If *logfile* is a string, a file with that name will be opened.
333 Use '-' for stdout.
335 trajectory: Trajectory object, Path, or str
336 Attach trajectory object. If *trajectory* is a string a
337 Trajectory will be constructed. Use *None* for no
338 trajectory.
340 append_trajectory: bool
341 Appended to the trajectory file instead of overwriting it.
343 kwargs : dict, optional
344 Extra arguments passed to :class:`~ase.optimize.optimize.Dynamics`.
346 """
347 super().__init__(
348 atoms=atoms,
349 logfile=logfile,
350 trajectory=trajectory,
351 append_trajectory=append_trajectory,
352 **kwargs,
353 )
355 self.restart = restart
357 self.fmax = None
359 if restart is None or not isfile(restart):
360 self.initialize()
361 else:
362 self.read()
363 self.comm.barrier()
365 def read(self):
366 raise NotImplementedError
368 def todict(self):
369 description = {
370 "type": "optimization",
371 "optimizer": self.__class__.__name__,
372 }
373 # add custom attributes from subclasses
374 for attr in ('maxstep', 'alpha', 'max_steps', 'restart',
375 'fmax'):
376 if hasattr(self, attr):
377 description.update({attr: getattr(self, attr)})
378 return description
380 def initialize(self):
381 pass
383 def irun(self, fmax=0.05, steps=DEFAULT_MAX_STEPS):
384 """Run optimizer as generator.
386 Parameters
387 ----------
388 fmax : float
389 Convergence criterion of the forces on atoms.
390 steps : int, default=DEFAULT_MAX_STEPS
391 Number of optimizer steps to be run.
393 Yields
394 ------
395 converged : bool
396 True if the forces on atoms are converged.
397 """
398 self.fmax = fmax
399 return Dynamics.irun(self, steps=steps)
401 def run(self, fmax=0.05, steps=DEFAULT_MAX_STEPS):
402 """Run optimizer.
404 Parameters
405 ----------
406 fmax : float
407 Convergence criterion of the forces on atoms.
408 steps : int, default=DEFAULT_MAX_STEPS
409 Number of optimizer steps to be run.
411 Returns
412 -------
413 converged : bool
414 True if the forces on atoms are converged.
415 """
416 self.fmax = fmax
417 return Dynamics.run(self, steps=steps)
419 def converged(self, gradient):
420 """Did the optimization converge?"""
421 assert gradient.ndim == 1
422 return self.optimizable.converged(gradient, self.fmax)
424 def log(self, gradient):
425 fmax = self.optimizable.gradient_norm(gradient)
426 e = self.optimizable.get_value()
427 T = time.localtime()
428 if self.logfile is not None:
429 name = self.__class__.__name__
430 if self.nsteps == 0:
431 args = (" " * len(name), "Step", "Time", "Energy", "fmax")
432 msg = "%s %4s %8s %15s %12s\n" % args
433 self.logfile.write(msg)
435 args = (name, self.nsteps, T[3], T[4], T[5], e, fmax)
436 msg = "%s: %3d %02d:%02d:%02d %15.6f %15.6f\n" % args
437 self.logfile.write(msg)
438 self.logfile.flush()
440 def dump(self, data):
441 from ase.io.jsonio import write_json
442 if self.comm.rank == 0 and self.restart is not None:
443 with open(self.restart, 'w') as fd:
444 write_json(fd, data)
446 def load(self):
447 from ase.io.jsonio import read_json
448 with open(self.restart) as fd:
449 try:
450 from ase.optimize import BFGS
451 if not isinstance(self, BFGS) and isinstance(
452 self.atoms, UnitCellFilter
453 ):
454 warnings.warn(
455 "WARNING: restart function is untested and may result "
456 "in unintended behavior. Namely orig_cell is not "
457 "loaded in the UnitCellFilter. Please test on your own"
458 " to ensure consistent results."
459 )
460 return read_json(fd, always_array=False)
461 except Exception as ex:
462 msg = ('Could not decode restart file as JSON. '
463 'You may need to delete the restart file '
464 f'{self.restart}')
465 raise RestartError(msg) from ex