Coverage for /builds/ase/ase/ase/optimize/precon/fire.py: 84.96%

113 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import time 

4 

5import numpy as np 

6 

7from ase.filters import UnitCellFilter 

8from ase.optimize.optimize import Optimizer 

9 

10 

11class PreconFIRE(Optimizer): 

12 

13 def __init__(self, atoms, restart=None, logfile='-', trajectory=None, 

14 dt=0.1, maxmove=0.2, dtmax=1.0, Nmin=5, finc=1.1, fdec=0.5, 

15 astart=0.1, fa=0.99, a=0.1, theta=0.1, 

16 precon=None, use_armijo=True, variable_cell=False, **kwargs): 

17 """ 

18 Preconditioned version of the FIRE optimizer 

19 

20 In time this implementation is expected to replace 

21 :class:`~ase.optimize.fire.FIRE`. 

22 

23 Parameters 

24 ---------- 

25 atoms: :class:`~ase.Atoms` 

26 The Atoms object to relax. 

27 

28 restart: string 

29 JSON file used to store hessian matrix. If set, file with 

30 such a name will be searched and hessian matrix stored will 

31 be used, if the file exists. 

32 

33 trajectory: string 

34 Trajectory file used to store optimisation path. 

35 

36 logfile: file object or str 

37 If *logfile* is a string, a file with that name will be opened. 

38 Use '-' for stdout. 

39 

40 variable_cell: bool 

41 If True, wrap atoms in UnitCellFilter to relax cell and positions. 

42 

43 kwargs : dict, optional 

44 Extra arguments passed to 

45 :class:`~ase.optimize.optimize.Optimizer`. 

46 

47 """ 

48 if variable_cell: 

49 atoms = UnitCellFilter(atoms) 

50 Optimizer.__init__(self, atoms, restart, logfile, trajectory, **kwargs) 

51 

52 self._actual_atoms = atoms 

53 

54 self.dt = dt 

55 self.Nsteps = 0 

56 self.maxmove = maxmove 

57 self.dtmax = dtmax 

58 self.Nmin = Nmin 

59 self.finc = finc 

60 self.fdec = fdec 

61 self.astart = astart 

62 self.fa = fa 

63 self.a = a 

64 self.theta = theta 

65 self.precon = precon 

66 self.use_armijo = use_armijo 

67 

68 def initialize(self): 

69 self.v = None 

70 self.skip_flag = False 

71 self.e1 = None 

72 

73 def read(self): 

74 self.v, self.dt = self.load() 

75 

76 def step(self, f=None): 

77 atoms = self._actual_atoms 

78 

79 if f is None: 

80 f = atoms.get_forces() 

81 

82 r = atoms.get_positions() 

83 

84 if self.precon is not None: 

85 # Can this be moved out of the step method? 

86 self.precon.make_precon(atoms) 

87 invP_f = self.precon.solve(f.reshape(-1)).reshape(len(atoms), -1) 

88 

89 if self.v is None: 

90 self.v = np.zeros((len(self._actual_atoms), 3)) 

91 else: 

92 if self.use_armijo: 

93 

94 if self.precon is None: 

95 v_test = self.v + self.dt * f 

96 else: 

97 v_test = self.v + self.dt * invP_f 

98 

99 r_test = r + self.dt * v_test 

100 

101 self.skip_flag = False 

102 func_val = self.func(r_test) 

103 self.e1 = func_val 

104 if (func_val > self.func(r) - 

105 self.theta * self.dt * np.vdot(v_test, f)): 

106 self.v[:] *= 0.0 

107 self.a = self.astart 

108 self.dt *= self.fdec 

109 self.Nsteps = 0 

110 self.skip_flag = True 

111 

112 if not self.skip_flag: 

113 

114 v_f = np.vdot(self.v, f) 

115 if v_f > 0.0: 

116 if self.precon is None: 

117 self.v = (1.0 - self.a) * self.v + self.a * f / \ 

118 np.sqrt(np.vdot(f, f)) * \ 

119 np.sqrt(np.vdot(self.v, self.v)) 

120 else: 

121 self.v = ( 

122 (1.0 - self.a) * self.v + 

123 self.a * 

124 (np.sqrt(self.precon.dot(self.v.reshape(-1), 

125 self.v.reshape(-1))) / 

126 np.sqrt(np.dot(f.reshape(-1), 

127 invP_f.reshape(-1))) * invP_f)) 

128 if self.Nsteps > self.Nmin: 

129 self.dt = min(self.dt * self.finc, self.dtmax) 

130 self.a *= self.fa 

131 self.Nsteps += 1 

132 else: 

133 self.v[:] *= 0.0 

134 self.a = self.astart 

135 self.dt *= self.fdec 

136 self.Nsteps = 0 

137 

138 if self.precon is None: 

139 self.v += self.dt * f 

140 else: 

141 self.v += self.dt * invP_f 

142 dr = self.dt * self.v 

143 normdr = np.sqrt(np.vdot(dr, dr)) 

144 if normdr > self.maxmove: 

145 dr = self.maxmove * dr / normdr 

146 atoms.set_positions(r + dr) 

147 self.dump((self.v, self.dt)) 

148 

149 def func(self, x): 

150 """Objective function for use of the optimizers""" 

151 self._actual_atoms.set_positions(x.reshape(-1, 3)) 

152 potl = self._actual_atoms.get_potential_energy() 

153 return potl 

154 

155 def run(self, fmax=0.05, steps=100000000, smax=None): 

156 if smax is None: 

157 smax = fmax 

158 self.smax = smax 

159 return Optimizer.run(self, fmax, steps) 

160 

161 def converged(self, gradient): 

162 """Did the optimization converge?""" 

163 # XXX ignoring gradient 

164 forces = self._actual_atoms.get_forces() 

165 if isinstance(self._actual_atoms, UnitCellFilter): 

166 natoms = len(self._actual_atoms.atoms) 

167 forces, stress = forces[:natoms], self._actual_atoms.stress 

168 fmax_sq = (forces**2).sum(axis=1).max() 

169 smax_sq = (stress**2).max() 

170 return (fmax_sq < self.fmax**2 and smax_sq < self.smax**2) 

171 else: 

172 fmax_sq = (forces**2).sum(axis=1).max() 

173 return fmax_sq < self.fmax**2 

174 

175 def log(self, gradient): 

176 forces = self._actual_atoms.get_forces() 

177 if isinstance(self._actual_atoms, UnitCellFilter): 

178 natoms = len(self._actual_atoms.atoms) 

179 forces, stress = forces[:natoms], self._actual_atoms.stress 

180 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

181 smax = np.sqrt((stress**2).max()) 

182 else: 

183 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

184 if self.e1 is not None: 

185 # reuse energy at end of line search to avoid extra call 

186 e = self.e1 

187 else: 

188 e = self._actual_atoms.get_potential_energy() 

189 T = time.localtime() 

190 if self.logfile is not None: 

191 name = self.__class__.__name__ 

192 if isinstance(self._actual_atoms, UnitCellFilter): 

193 self.logfile.write( 

194 '%s: %3d %02d:%02d:%02d %15.6f %12.4f %12.4f\n' % 

195 (name, self.nsteps, T[3], T[4], T[5], e, fmax, smax)) 

196 

197 else: 

198 self.logfile.write( 

199 '%s: %3d %02d:%02d:%02d %15.6f %12.4f\n' % 

200 (name, self.nsteps, T[3], T[4], T[5], e, fmax)) 

201 self.logfile.flush()