Coverage for ase / optimize / fire.py: 93.02%
86 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3from collections.abc import Callable
4from typing import IO, Any
6import numpy as np
8from ase import Atoms
9from ase.optimize.optimize import Optimizer
10from ase.utils import deprecated
13def _forbid_maxmove(args: list, kwargs: dict[str, Any]) -> bool:
14 """Set maxstep with maxmove if not set."""
15 maxstep_index = 6
16 maxmove_index = 7
18 def _pop_arg(name: str) -> Any:
19 to_pop = None
20 if len(args) > maxmove_index:
21 to_pop = args[maxmove_index]
22 args[maxmove_index] = None
24 elif name in kwargs:
25 to_pop = kwargs[name]
26 del kwargs[name]
27 return to_pop
29 if len(args) > maxstep_index and args[maxstep_index] is None:
30 value = args[maxstep_index] = _pop_arg("maxmove")
31 elif kwargs.get("maxstep", None) is None:
32 value = kwargs["maxstep"] = _pop_arg("maxmove")
33 else:
34 return False
36 return value is not None
39class FIRE(Optimizer):
40 @deprecated(
41 "Use of `maxmove` is deprecated. Use `maxstep` instead.",
42 category=FutureWarning,
43 callback=_forbid_maxmove,
44 )
45 def __init__(
46 self,
47 atoms: Atoms,
48 restart: str | None = None,
49 logfile: IO | str = '-',
50 trajectory: str | None = None,
51 dt: float = 0.1,
52 maxstep: float | None = None,
53 maxmove: float | None = None,
54 dtmax: float = 1.0,
55 Nmin: int = 5,
56 finc: float = 1.1,
57 fdec: float = 0.5,
58 astart: float = 0.1,
59 fa: float = 0.99,
60 a: float = 0.1,
61 downhill_check: bool = False,
62 position_reset_callback: Callable | None = None,
63 **kwargs,
64 ):
65 """
67 Parameters
68 ----------
69 atoms: :class:`~ase.Atoms`
70 The Atoms object to relax.
72 restart: str
73 JSON file used to store hessian matrix. If set, file with
74 such a name will be searched and hessian matrix stored will
75 be used, if the file exists.
77 logfile: file object or str
78 If *logfile* is a string, a file with that name will be opened.
79 Use '-' for stdout.
81 trajectory: str
82 Trajectory file used to store optimisation path.
84 dt: float
85 Initial time step. Defualt value is 0.1
87 maxstep: float
88 Used to set the maximum distance an atom can move per
89 iteration (default value is 0.2).
91 dtmax: float
92 Maximum time step. Default value is 1.0
94 Nmin: int
95 Number of steps to wait after the last time the dot product of
96 the velocity and force is negative (P in The FIRE article) before
97 increasing the time step. Default value is 5.
99 finc: float
100 Factor to increase the time step. Default value is 1.1
102 fdec: float
103 Factor to decrease the time step. Default value is 0.5
105 astart: float
106 Initial value of the parameter a. a is the Coefficient for
107 mixing the velocity and the force. Called alpha in the FIRE article.
108 Default value 0.1.
110 fa: float
111 Factor to decrease the parameter alpha. Default value is 0.99
113 a: float
114 Coefficient for mixing the velocity and the force. Called
115 alpha in the FIRE article. Default value 0.1.
117 downhill_check: bool
118 Downhill check directly compares potential energies of subsequent
119 steps of the FIRE algorithm rather than relying on the current
120 product v*f that is positive if the FIRE dynamics moves downhill.
121 This can detect numerical issues where at large time steps the step
122 is uphill in energy even though locally v*f is positive, i.e. the
123 algorithm jumps over a valley because of a too large time step.
125 position_reset_callback: function(atoms, r, e, e_last)
126 Function that takes current *atoms* object, an array of position
127 *r* that the optimizer will revert to, current energy *e* and
128 energy of last step *e_last*. This is only called if e > e_last.
130 kwargs : dict, optional
131 Extra arguments passed to
132 :class:`~ase.optimize.optimize.Optimizer`.
134 .. deprecated:: 3.19.3
135 Use of ``maxmove`` is deprecated; please use ``maxstep``.
137 """
138 super().__init__(atoms, restart, logfile, trajectory, **kwargs)
140 self.dt = dt
142 self.Nsteps = 0
144 if maxstep is not None:
145 self.maxstep = maxstep
146 else:
147 self.maxstep = self.defaults["maxstep"]
149 self.dtmax = dtmax
150 self.Nmin = Nmin
151 self.finc = finc
152 self.fdec = fdec
153 self.astart = astart
154 self.fa = fa
155 self.a = a
156 self.downhill_check = downhill_check
157 self.position_reset_callback = position_reset_callback
159 def initialize(self):
160 self.vel = None
162 def read(self):
163 self.vel, self.dt = self.load()
165 def step(self, f=None):
166 gradient = -self._get_gradient(f)
167 # (XXX This is the negative gradient)
168 optimizable = self.optimizable
170 if self.vel is None:
171 self.vel = np.zeros(optimizable.ndofs())
172 if self.downhill_check:
173 self.e_last = optimizable.get_value()
174 self.r_last = optimizable.get_x()
175 self.vel_last = self.vel.copy()
176 else:
177 is_uphill = False
178 if self.downhill_check:
179 e = optimizable.get_value()
180 # Check if the energy actually decreased
181 if e > self.e_last:
182 # If not, reset to old positions...
183 if self.position_reset_callback is not None:
184 self.position_reset_callback(
185 optimizable, self.r_last, e,
186 self.e_last)
187 optimizable.set_x(self.r_last)
188 is_uphill = True
189 self.e_last = optimizable.get_value()
190 self.r_last = optimizable.get_x()
191 self.vel_last = self.vel.copy()
193 vf = np.vdot(gradient, self.vel)
194 grad2 = np.vdot(gradient, gradient)
195 if vf > 0.0 and not is_uphill:
196 self.vel = (
197 (1.0 - self.a) * self.vel + self.a * gradient / np.sqrt(
198 grad2) * np.sqrt(np.vdot(self.vel, self.vel)))
199 if self.Nsteps > self.Nmin:
200 self.dt = min(self.dt * self.finc, self.dtmax)
201 self.a *= self.fa
202 self.Nsteps += 1
203 else:
204 self.vel[:] *= 0.0
205 self.a = self.astart
206 self.dt *= self.fdec
207 self.Nsteps = 0
209 self.vel += self.dt * gradient
210 dr = self.dt * self.vel
211 normdr = np.sqrt(np.vdot(dr, dr))
212 if normdr > self.maxstep:
213 dr = self.maxstep * dr / normdr
214 r = optimizable.get_x()
215 optimizable.set_x(r + dr)
216 self.dump((self.vel, self.dt))