Coverage for ase / optimize / precon / fire.py: 84.82%

112 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1# fmt: off 

2 

3import time 

4 

5import numpy as np 

6 

7from ase.filters import UnitCellFilter 

8from ase.optimize.optimize import Optimizer 

9 

10 

11class PreconFIRE(Optimizer): 

12 

13 def __init__(self, atoms, restart=None, logfile='-', trajectory=None, 

14 dt=0.1, maxmove=0.2, dtmax=1.0, Nmin=5, finc=1.1, fdec=0.5, 

15 astart=0.1, fa=0.99, a=0.1, theta=0.1, 

16 precon=None, use_armijo=True, variable_cell=False, **kwargs): 

17 """ 

18 Preconditioned version of the FIRE optimizer 

19 

20 In time this implementation is expected to replace 

21 :class:`~ase.optimize.fire.FIRE`. 

22 

23 Parameters 

24 ---------- 

25 atoms: :class:`~ase.Atoms` 

26 The Atoms object to relax. 

27 

28 restart: string 

29 JSON file used to store hessian matrix. If set, file with 

30 such a name will be searched and hessian matrix stored will 

31 be used, if the file exists. 

32 

33 trajectory: string 

34 Trajectory file used to store optimisation path. 

35 

36 logfile: file object or str 

37 If *logfile* is a string, a file with that name will be opened. 

38 Use '-' for stdout. 

39 

40 variable_cell: bool 

41 If True, wrap atoms in UnitCellFilter to relax cell and positions. 

42 

43 kwargs : dict, optional 

44 Extra arguments passed to 

45 :class:`~ase.optimize.optimize.Optimizer`. 

46 

47 """ 

48 if variable_cell: 

49 atoms = UnitCellFilter(atoms) 

50 super().__init__(atoms, restart, logfile, trajectory, **kwargs) 

51 

52 self._actual_atoms = atoms 

53 

54 self.dt = dt 

55 self.Nsteps = 0 

56 self.maxmove = maxmove 

57 self.dtmax = dtmax 

58 self.Nmin = Nmin 

59 self.finc = finc 

60 self.fdec = fdec 

61 self.astart = astart 

62 self.fa = fa 

63 self.a = a 

64 self.theta = theta 

65 self.precon = precon 

66 self.use_armijo = use_armijo 

67 

68 def initialize(self): 

69 self.v = None 

70 self.skip_flag = False 

71 self.e1 = None 

72 

73 def read(self): 

74 self.v, self.dt = self.load() 

75 

76 def step(self, f=None): 

77 atoms = self._actual_atoms 

78 if f is None: 

79 f = atoms.get_forces() 

80 r = atoms.get_positions() 

81 

82 if self.precon is not None: 

83 # Can this be moved out of the step method? 

84 self.precon.make_precon(atoms) 

85 invP_f = self.precon.solve(f.reshape(-1)).reshape(len(atoms), -1) 

86 

87 if self.v is None: 

88 self.v = np.zeros((len(self._actual_atoms), 3)) 

89 else: 

90 if self.use_armijo: 

91 

92 if self.precon is None: 

93 v_test = self.v + self.dt * f 

94 else: 

95 v_test = self.v + self.dt * invP_f 

96 

97 r_test = r + self.dt * v_test 

98 

99 self.skip_flag = False 

100 func_val = self.func(r_test) 

101 self.e1 = func_val 

102 if (func_val > self.func(r) - 

103 self.theta * self.dt * np.vdot(v_test, f)): 

104 self.v[:] *= 0.0 

105 self.a = self.astart 

106 self.dt *= self.fdec 

107 self.Nsteps = 0 

108 self.skip_flag = True 

109 

110 if not self.skip_flag: 

111 

112 v_f = np.vdot(self.v, f) 

113 if v_f > 0.0: 

114 if self.precon is None: 

115 self.v = (1.0 - self.a) * self.v + self.a * f / \ 

116 np.sqrt(np.vdot(f, f)) * \ 

117 np.sqrt(np.vdot(self.v, self.v)) 

118 else: 

119 self.v = ( 

120 (1.0 - self.a) * self.v + 

121 self.a * 

122 (np.sqrt(self.precon.dot(self.v.reshape(-1), 

123 self.v.reshape(-1))) / 

124 np.sqrt(np.dot(f.reshape(-1), 

125 invP_f.reshape(-1))) * invP_f)) 

126 if self.Nsteps > self.Nmin: 

127 self.dt = min(self.dt * self.finc, self.dtmax) 

128 self.a *= self.fa 

129 self.Nsteps += 1 

130 else: 

131 self.v[:] *= 0.0 

132 self.a = self.astart 

133 self.dt *= self.fdec 

134 self.Nsteps = 0 

135 

136 if self.precon is None: 

137 self.v += self.dt * f 

138 else: 

139 self.v += self.dt * invP_f 

140 dr = self.dt * self.v 

141 normdr = np.sqrt(np.vdot(dr, dr)) 

142 if normdr > self.maxmove: 

143 dr = self.maxmove * dr / normdr 

144 atoms.set_positions(r + dr) 

145 self.dump((self.v, self.dt)) 

146 

147 def func(self, x): 

148 """Objective function for use of the optimizers""" 

149 self._actual_atoms.set_positions(x.reshape(-1, 3)) 

150 potl = self._actual_atoms.get_potential_energy() 

151 return potl 

152 

153 def run(self, fmax=0.05, steps=100000000, smax=None): 

154 if smax is None: 

155 smax = fmax 

156 self.smax = smax 

157 return super().run(fmax, steps) 

158 

159 def converged(self, gradient): 

160 """Did the optimization converge?""" 

161 # XXX ignoring gradient 

162 forces = self._actual_atoms.get_forces() 

163 if isinstance(self._actual_atoms, UnitCellFilter): 

164 natoms = len(self._actual_atoms.atoms) 

165 forces, stress = forces[:natoms], self._actual_atoms.stress 

166 fmax_sq = (forces**2).sum(axis=1).max() 

167 smax_sq = (stress**2).max() 

168 return (fmax_sq < self.fmax**2 and smax_sq < self.smax**2) 

169 else: 

170 fmax_sq = (forces**2).sum(axis=1).max() 

171 return fmax_sq < self.fmax**2 

172 

173 def log(self, gradient): 

174 forces = self._actual_atoms.get_forces() 

175 if isinstance(self._actual_atoms, UnitCellFilter): 

176 natoms = len(self._actual_atoms.atoms) 

177 forces, stress = forces[:natoms], self._actual_atoms.stress 

178 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

179 smax = np.sqrt((stress**2).max()) 

180 else: 

181 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

182 if self.e1 is not None: 

183 # reuse energy at end of line search to avoid extra call 

184 e = self.e1 

185 else: 

186 e = self._actual_atoms.get_potential_energy() 

187 T = time.localtime() 

188 if self.logfile is not None: 

189 name = self.__class__.__name__ 

190 if isinstance(self._actual_atoms, UnitCellFilter): 

191 self.logfile.write( 

192 '%s: %3d %02d:%02d:%02d %15.6f %12.4f %12.4f\n' % 

193 (name, self.nsteps, T[3], T[4], T[5], e, fmax, smax)) 

194 

195 else: 

196 self.logfile.write( 

197 '%s: %3d %02d:%02d:%02d %15.6f %12.4f\n' % 

198 (name, self.nsteps, T[3], T[4], T[5], e, fmax))