Coverage for /builds/ase/ase/ase/optimize/fire.py: 92.94%
85 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3from typing import IO, Any, Callable, Dict, List, Optional, Union
5import numpy as np
7from ase import Atoms
8from ase.optimize.optimize import Optimizer
9from ase.utils import deprecated
12def _forbid_maxmove(args: List, kwargs: Dict[str, Any]) -> bool:
13 """Set maxstep with maxmove if not set."""
14 maxstep_index = 6
15 maxmove_index = 7
17 def _pop_arg(name: str) -> Any:
18 to_pop = None
19 if len(args) > maxmove_index:
20 to_pop = args[maxmove_index]
21 args[maxmove_index] = None
23 elif name in kwargs:
24 to_pop = kwargs[name]
25 del kwargs[name]
26 return to_pop
28 if len(args) > maxstep_index and args[maxstep_index] is None:
29 value = args[maxstep_index] = _pop_arg("maxmove")
30 elif kwargs.get("maxstep", None) is None:
31 value = kwargs["maxstep"] = _pop_arg("maxmove")
32 else:
33 return False
35 return value is not None
38class FIRE(Optimizer):
39 @deprecated(
40 "Use of `maxmove` is deprecated. Use `maxstep` instead.",
41 category=FutureWarning,
42 callback=_forbid_maxmove,
43 )
44 def __init__(
45 self,
46 atoms: Atoms,
47 restart: Optional[str] = None,
48 logfile: Union[IO, str] = '-',
49 trajectory: Optional[str] = None,
50 dt: float = 0.1,
51 maxstep: Optional[float] = None,
52 maxmove: Optional[float] = None,
53 dtmax: float = 1.0,
54 Nmin: int = 5,
55 finc: float = 1.1,
56 fdec: float = 0.5,
57 astart: float = 0.1,
58 fa: float = 0.99,
59 a: float = 0.1,
60 downhill_check: bool = False,
61 position_reset_callback: Optional[Callable] = None,
62 **kwargs,
63 ):
64 """
66 Parameters
67 ----------
68 atoms: :class:`~ase.Atoms`
69 The Atoms object to relax.
71 restart: str
72 JSON file used to store hessian matrix. If set, file with
73 such a name will be searched and hessian matrix stored will
74 be used, if the file exists.
76 logfile: file object or str
77 If *logfile* is a string, a file with that name will be opened.
78 Use '-' for stdout.
80 trajectory: str
81 Trajectory file used to store optimisation path.
83 dt: float
84 Initial time step. Defualt value is 0.1
86 maxstep: float
87 Used to set the maximum distance an atom can move per
88 iteration (default value is 0.2).
90 dtmax: float
91 Maximum time step. Default value is 1.0
93 Nmin: int
94 Number of steps to wait after the last time the dot product of
95 the velocity and force is negative (P in The FIRE article) before
96 increasing the time step. Default value is 5.
98 finc: float
99 Factor to increase the time step. Default value is 1.1
101 fdec: float
102 Factor to decrease the time step. Default value is 0.5
104 astart: float
105 Initial value of the parameter a. a is the Coefficient for
106 mixing the velocity and the force. Called alpha in the FIRE article.
107 Default value 0.1.
109 fa: float
110 Factor to decrease the parameter alpha. Default value is 0.99
112 a: float
113 Coefficient for mixing the velocity and the force. Called
114 alpha in the FIRE article. Default value 0.1.
116 downhill_check: bool
117 Downhill check directly compares potential energies of subsequent
118 steps of the FIRE algorithm rather than relying on the current
119 product v*f that is positive if the FIRE dynamics moves downhill.
120 This can detect numerical issues where at large time steps the step
121 is uphill in energy even though locally v*f is positive, i.e. the
122 algorithm jumps over a valley because of a too large time step.
124 position_reset_callback: function(atoms, r, e, e_last)
125 Function that takes current *atoms* object, an array of position
126 *r* that the optimizer will revert to, current energy *e* and
127 energy of last step *e_last*. This is only called if e > e_last.
129 kwargs : dict, optional
130 Extra arguments passed to
131 :class:`~ase.optimize.optimize.Optimizer`.
133 .. deprecated:: 3.19.3
134 Use of ``maxmove`` is deprecated; please use ``maxstep``.
136 """
137 Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs)
139 self.dt = dt
141 self.Nsteps = 0
143 if maxstep is not None:
144 self.maxstep = maxstep
145 else:
146 self.maxstep = self.defaults["maxstep"]
148 self.dtmax = dtmax
149 self.Nmin = Nmin
150 self.finc = finc
151 self.fdec = fdec
152 self.astart = astart
153 self.fa = fa
154 self.a = a
155 self.downhill_check = downhill_check
156 self.position_reset_callback = position_reset_callback
158 def initialize(self):
159 self.v = None
161 def read(self):
162 self.v, self.dt = self.load()
164 def step(self, f=None):
165 optimizable = self.optimizable
167 if f is None:
168 f = optimizable.get_gradient().reshape(-1, 3)
170 if self.v is None:
171 self.v = np.zeros(optimizable.ndofs()).reshape(-1, 3)
172 if self.downhill_check:
173 self.e_last = optimizable.get_value()
174 self.r_last = optimizable.get_x().reshape(-1, 3).copy()
175 self.v_last = self.v.copy()
176 else:
177 is_uphill = False
178 if self.downhill_check:
179 e = optimizable.get_value()
180 # Check if the energy actually decreased
181 if e > self.e_last:
182 # If not, reset to old positions...
183 if self.position_reset_callback is not None:
184 self.position_reset_callback(
185 optimizable, self.r_last, e,
186 self.e_last)
187 optimizable.set_x(self.r_last.ravel())
188 is_uphill = True
189 self.e_last = optimizable.get_value()
190 self.r_last = optimizable.get_x().reshape(-1, 3).copy()
191 self.v_last = self.v.copy()
193 vf = np.vdot(f, self.v)
194 if vf > 0.0 and not is_uphill:
195 self.v = (1.0 - self.a) * self.v + self.a * f / np.sqrt(
196 np.vdot(f, f)) * np.sqrt(np.vdot(self.v, self.v))
197 if self.Nsteps > self.Nmin:
198 self.dt = min(self.dt * self.finc, self.dtmax)
199 self.a *= self.fa
200 self.Nsteps += 1
201 else:
202 self.v[:] *= 0.0
203 self.a = self.astart
204 self.dt *= self.fdec
205 self.Nsteps = 0
207 self.v += self.dt * f
208 dr = self.dt * self.v
209 normdr = np.sqrt(np.vdot(dr, dr))
210 if normdr > self.maxstep:
211 dr = self.maxstep * dr / normdr
212 r = optimizable.get_x().reshape(-1, 3)
213 optimizable.set_x((r + dr).ravel())
214 self.dump((self.v, self.dt))