Coverage for /builds/ase/ase/ase/optimize/fire.py: 92.94%

85 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3from typing import IO, Any, Callable, Dict, List, Optional, Union 

4 

5import numpy as np 

6 

7from ase import Atoms 

8from ase.optimize.optimize import Optimizer 

9from ase.utils import deprecated 

10 

11 

12def _forbid_maxmove(args: List, kwargs: Dict[str, Any]) -> bool: 

13 """Set maxstep with maxmove if not set.""" 

14 maxstep_index = 6 

15 maxmove_index = 7 

16 

17 def _pop_arg(name: str) -> Any: 

18 to_pop = None 

19 if len(args) > maxmove_index: 

20 to_pop = args[maxmove_index] 

21 args[maxmove_index] = None 

22 

23 elif name in kwargs: 

24 to_pop = kwargs[name] 

25 del kwargs[name] 

26 return to_pop 

27 

28 if len(args) > maxstep_index and args[maxstep_index] is None: 

29 value = args[maxstep_index] = _pop_arg("maxmove") 

30 elif kwargs.get("maxstep", None) is None: 

31 value = kwargs["maxstep"] = _pop_arg("maxmove") 

32 else: 

33 return False 

34 

35 return value is not None 

36 

37 

38class FIRE(Optimizer): 

39 @deprecated( 

40 "Use of `maxmove` is deprecated. Use `maxstep` instead.", 

41 category=FutureWarning, 

42 callback=_forbid_maxmove, 

43 ) 

44 def __init__( 

45 self, 

46 atoms: Atoms, 

47 restart: Optional[str] = None, 

48 logfile: Union[IO, str] = '-', 

49 trajectory: Optional[str] = None, 

50 dt: float = 0.1, 

51 maxstep: Optional[float] = None, 

52 maxmove: Optional[float] = None, 

53 dtmax: float = 1.0, 

54 Nmin: int = 5, 

55 finc: float = 1.1, 

56 fdec: float = 0.5, 

57 astart: float = 0.1, 

58 fa: float = 0.99, 

59 a: float = 0.1, 

60 downhill_check: bool = False, 

61 position_reset_callback: Optional[Callable] = None, 

62 **kwargs, 

63 ): 

64 """ 

65 

66 Parameters 

67 ---------- 

68 atoms: :class:`~ase.Atoms` 

69 The Atoms object to relax. 

70 

71 restart: str 

72 JSON file used to store hessian matrix. If set, file with 

73 such a name will be searched and hessian matrix stored will 

74 be used, if the file exists. 

75 

76 logfile: file object or str 

77 If *logfile* is a string, a file with that name will be opened. 

78 Use '-' for stdout. 

79 

80 trajectory: str 

81 Trajectory file used to store optimisation path. 

82 

83 dt: float 

84 Initial time step. Defualt value is 0.1 

85 

86 maxstep: float 

87 Used to set the maximum distance an atom can move per 

88 iteration (default value is 0.2). 

89 

90 dtmax: float 

91 Maximum time step. Default value is 1.0 

92 

93 Nmin: int 

94 Number of steps to wait after the last time the dot product of 

95 the velocity and force is negative (P in The FIRE article) before 

96 increasing the time step. Default value is 5. 

97 

98 finc: float 

99 Factor to increase the time step. Default value is 1.1 

100 

101 fdec: float 

102 Factor to decrease the time step. Default value is 0.5 

103 

104 astart: float 

105 Initial value of the parameter a. a is the Coefficient for 

106 mixing the velocity and the force. Called alpha in the FIRE article. 

107 Default value 0.1. 

108 

109 fa: float 

110 Factor to decrease the parameter alpha. Default value is 0.99 

111 

112 a: float 

113 Coefficient for mixing the velocity and the force. Called 

114 alpha in the FIRE article. Default value 0.1. 

115 

116 downhill_check: bool 

117 Downhill check directly compares potential energies of subsequent 

118 steps of the FIRE algorithm rather than relying on the current 

119 product v*f that is positive if the FIRE dynamics moves downhill. 

120 This can detect numerical issues where at large time steps the step 

121 is uphill in energy even though locally v*f is positive, i.e. the 

122 algorithm jumps over a valley because of a too large time step. 

123 

124 position_reset_callback: function(atoms, r, e, e_last) 

125 Function that takes current *atoms* object, an array of position 

126 *r* that the optimizer will revert to, current energy *e* and 

127 energy of last step *e_last*. This is only called if e > e_last. 

128 

129 kwargs : dict, optional 

130 Extra arguments passed to 

131 :class:`~ase.optimize.optimize.Optimizer`. 

132 

133 .. deprecated:: 3.19.3 

134 Use of ``maxmove`` is deprecated; please use ``maxstep``. 

135 

136 """ 

137 Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs) 

138 

139 self.dt = dt 

140 

141 self.Nsteps = 0 

142 

143 if maxstep is not None: 

144 self.maxstep = maxstep 

145 else: 

146 self.maxstep = self.defaults["maxstep"] 

147 

148 self.dtmax = dtmax 

149 self.Nmin = Nmin 

150 self.finc = finc 

151 self.fdec = fdec 

152 self.astart = astart 

153 self.fa = fa 

154 self.a = a 

155 self.downhill_check = downhill_check 

156 self.position_reset_callback = position_reset_callback 

157 

158 def initialize(self): 

159 self.v = None 

160 

161 def read(self): 

162 self.v, self.dt = self.load() 

163 

164 def step(self, f=None): 

165 optimizable = self.optimizable 

166 

167 if f is None: 

168 f = optimizable.get_gradient().reshape(-1, 3) 

169 

170 if self.v is None: 

171 self.v = np.zeros(optimizable.ndofs()).reshape(-1, 3) 

172 if self.downhill_check: 

173 self.e_last = optimizable.get_value() 

174 self.r_last = optimizable.get_x().reshape(-1, 3).copy() 

175 self.v_last = self.v.copy() 

176 else: 

177 is_uphill = False 

178 if self.downhill_check: 

179 e = optimizable.get_value() 

180 # Check if the energy actually decreased 

181 if e > self.e_last: 

182 # If not, reset to old positions... 

183 if self.position_reset_callback is not None: 

184 self.position_reset_callback( 

185 optimizable, self.r_last, e, 

186 self.e_last) 

187 optimizable.set_x(self.r_last.ravel()) 

188 is_uphill = True 

189 self.e_last = optimizable.get_value() 

190 self.r_last = optimizable.get_x().reshape(-1, 3).copy() 

191 self.v_last = self.v.copy() 

192 

193 vf = np.vdot(f, self.v) 

194 if vf > 0.0 and not is_uphill: 

195 self.v = (1.0 - self.a) * self.v + self.a * f / np.sqrt( 

196 np.vdot(f, f)) * np.sqrt(np.vdot(self.v, self.v)) 

197 if self.Nsteps > self.Nmin: 

198 self.dt = min(self.dt * self.finc, self.dtmax) 

199 self.a *= self.fa 

200 self.Nsteps += 1 

201 else: 

202 self.v[:] *= 0.0 

203 self.a = self.astart 

204 self.dt *= self.fdec 

205 self.Nsteps = 0 

206 

207 self.v += self.dt * f 

208 dr = self.dt * self.v 

209 normdr = np.sqrt(np.vdot(dr, dr)) 

210 if normdr > self.maxstep: 

211 dr = self.maxstep * dr / normdr 

212 r = optimizable.get_x().reshape(-1, 3) 

213 optimizable.set_x((r + dr).ravel()) 

214 self.dump((self.v, self.dt))