Coverage for ase / utils / linesearch.py: 83.61%

238 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1# fmt: off 

2 

3# flake8: noqa 

4import numpy as np 

5 

6 

7def standard_gradient_norm(array): 

8 return np.linalg.norm(array.reshape(-1, 3), axis=1).max() 

9 

10 

11class LineSearch: 

12 def __init__(self, xtol=1e-14, get_gradient_norm=standard_gradient_norm): 

13 self.xtol = xtol 

14 self.task = 'START' 

15 self.isave = np.zeros((2,), np.intc) 

16 self.dsave = np.zeros((13,), float) 

17 self.fc = 0 

18 self.gc = 0 

19 self.case = 0 

20 self.old_stp = 0 

21 self.get_gradient_norm = get_gradient_norm 

22 

23 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, 

24 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., 

25 stpmax=50., stpmin=1e-8, args=()): 

26 self.stpmin = stpmin 

27 self.pk = pk 

28 # ??? p_size = np.sqrt((pk **2).sum()) 

29 self.stpmax = stpmax 

30 self.xtrapl = xtrapl 

31 self.xtrapu = xtrapu 

32 self.maxstep = maxstep 

33 phi0 = old_fval 

34 derphi0 = np.dot(gfk, pk) 

35 self.dim = len(pk) 

36 self.gms = np.sqrt(self.dim) * maxstep 

37 # alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0) 

38 alpha1 = 1. 

39 self.no_update = False 

40 

41 if isinstance(myfprime, type(())): 

42 # eps = myfprime[1] 

43 fprime = myfprime[0] 

44 # ??? newargs = (f,eps) + args 

45 gradient = False 

46 else: 

47 fprime = myfprime 

48 newargs = args 

49 gradient = True 

50 

51 fval = old_fval 

52 gval = gfk 

53 self.steps = [] 

54 

55 while True: 

56 stp = self.step(alpha1, phi0, derphi0, c1, c2, 

57 self.xtol, 

58 self.isave, self.dsave) 

59 

60 if self.task[:2] == 'FG': 

61 alpha1 = stp 

62 fval = func(xk + stp * pk, *args) 

63 self.fc += 1 

64 gval = fprime(xk + stp * pk, *newargs) 

65 if gradient: 

66 self.gc += 1 

67 else: 

68 self.fc += len(xk) + 1 

69 phi0 = fval 

70 derphi0 = np.dot(gval, pk) 

71 self.old_stp = alpha1 

72 if self.no_update == True: 

73 break 

74 else: 

75 break 

76 

77 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': 

78 stp = None # failed 

79 return stp, fval, old_fval, self.no_update 

80 

81 def step(self, stp, f, g, c1, c2, xtol, isave, dsave): 

82 if self.task[:5] == 'START': 

83 # Check the input arguments for errors. 

84 if stp < self.stpmin: 

85 self.task = 'ERROR: STP .LT. minstep' 

86 if stp > self.stpmax: 

87 self.task = 'ERROR: STP .GT. maxstep' 

88 if g >= 0: 

89 self.task = 'ERROR: INITIAL G >= 0' 

90 if c1 < 0: 

91 self.task = 'ERROR: c1 .LT. 0' 

92 if c2 < 0: 

93 self.task = 'ERROR: c2 .LT. 0' 

94 if xtol < 0: 

95 self.task = 'ERROR: XTOL .LT. 0' 

96 if self.stpmin < 0: 

97 self.task = 'ERROR: minstep .LT. 0' 

98 if self.stpmax < self.stpmin: 

99 self.task = 'ERROR: maxstep .LT. minstep' 

100 if self.task[:5] == 'ERROR': 

101 return stp 

102 

103 # Initialize local variables. 

104 self.bracket = False 

105 stage = 1 

106 finit = f 

107 ginit = g 

108 gtest = c1 * ginit 

109 width = self.stpmax - self.stpmin 

110 width1 = width / .5 

111# The variables stx, fx, gx contain the values of the step, 

112# function, and derivative at the best step. 

113# The variables sty, fy, gy contain the values of the step, 

114# function, and derivative at sty. 

115# The variables stp, f, g contain the values of the step, 

116# function, and derivative at stp. 

117 stx = 0 

118 fx = finit 

119 gx = ginit 

120 sty = 0 

121 fy = finit 

122 gy = ginit 

123 stmin = 0 

124 stmax = stp + self.xtrapu * stp 

125 self.task = 'FG' 

126 self.save((stage, ginit, gtest, gx, 

127 gy, finit, fx, fy, stx, sty, 

128 stmin, stmax, width, width1)) 

129 stp = self.determine_step(stp) 

130 # return stp, f, g 

131 return stp 

132 else: 

133 if self.isave[0] == 1: 

134 self.bracket = True 

135 else: 

136 self.bracket = False 

137 stage = self.isave[1] 

138 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, 

139 width, width1) = self.dsave 

140 

141# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the 

142# algorithm enters the second stage. 

143 ftest = finit + stp * gtest 

144 if stage == 1 and f < ftest and g >= 0.: 

145 stage = 2 

146 

147# Test for warnings. 

148 if self.bracket and (stp <= stmin or stp >= stmax): 

149 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' 

150 if self.bracket and stmax - stmin <= self.xtol * stmax: 

151 self.task = 'WARNING: XTOL TEST SATISFIED' 

152 if stp == self.stpmax and f <= ftest and g <= gtest: 

153 self.task = 'WARNING: STP = maxstep' 

154 if stp == self.stpmin and (f > ftest or g >= gtest): 

155 self.task = 'WARNING: STP = minstep' 

156 

157# Test for convergence. 

158 if f <= ftest and abs(g) <= c2 * (- ginit): 

159 self.task = 'CONVERGENCE' 

160 

161# Test for termination. 

162 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': 

163 self.save((stage, ginit, gtest, gx, 

164 gy, finit, fx, fy, stx, sty, 

165 stmin, stmax, width, width1)) 

166 # return stp, f, g 

167 return stp 

168 

169# A modified function is used to predict the step during the 

170# first stage if a lower function value has been obtained but 

171# the decrease is not sufficient. 

172 # if stage == 1 and f <= fx and f > ftest: 

173# # Define the modified function and derivative values. 

174 # fm =f - stp * gtest 

175 # fxm = fx - stx * gtest 

176 # fym = fy - sty * gtest 

177 # gm = g - gtest 

178 # gxm = gx - gtest 

179 # gym = gy - gtest 

180 

181# Call step to update stx, sty, and to compute the new step. 

182 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty, 

183 # fym, gym, stp, fm, gm, 

184 # stmin, stmax) 

185 

186# # Reset the function and derivative values for f. 

187 

188 # fx = fxm + stx * gtest 

189 # fy = fym + sty * gtest 

190 # gx = gxm + gtest 

191 # gy = gym + gtest 

192 

193 # else: 

194# Call step to update stx, sty, and to compute the new step. 

195 

196 stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty, 

197 fy, gy, stp, f, g, 

198 stmin, stmax) 

199 

200 

201# Decide if a bisection step is needed. 

202 

203 if self.bracket: 

204 if abs(sty - stx) >= .66 * width1: 

205 stp = stx + .5 * (sty - stx) 

206 width1 = width 

207 width = abs(sty - stx) 

208 

209# Set the minimum and maximum steps allowed for stp. 

210 

211 if self.bracket: 

212 stmin = min(stx, sty) 

213 stmax = max(stx, sty) 

214 else: 

215 stmin = stp + self.xtrapl * (stp - stx) 

216 stmax = stp + self.xtrapu * (stp - stx) 

217 

218# Force the step to be within the bounds maxstep and minstep. 

219 

220 stp = max(stp, self.stpmin) 

221 stp = min(stp, self.stpmax) 

222 

223 if (stx == stp and stp == self.stpmax and stmin > self.stpmax): 

224 self.no_update = True 

225# If further progress is not possible, let stp be the best 

226# point obtained during the search. 

227 

228 if (self.bracket and stp < stmin or stp >= stmax) \ 

229 or (self.bracket and stmax - stmin < self.xtol * stmax): 

230 stp = stx 

231 

232# Obtain another function and derivative. 

233 

234 self.task = 'FG' 

235 self.save((stage, ginit, gtest, gx, 

236 gy, finit, fx, fy, stx, sty, 

237 stmin, stmax, width, width1)) 

238 return stp 

239 

240 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, 

241 stpmin, stpmax): 

242 sign = gp * (gx / abs(gx)) 

243 

244# First case: A higher function value. The minimum is bracketed. 

245# If the cubic step is closer to stx than the quadratic step, the 

246# cubic step is taken, otherwise the average of the cubic and 

247# quadratic steps is taken. 

248 if fp > fx: # case1 

249 self.case = 1 

250 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

251 s = max(abs(theta), abs(gx), abs(gp)) 

252 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) 

253 if stp < stx: 

254 gamma = -gamma 

255 p = (gamma - gx) + theta 

256 q = ((gamma - gx) + gamma) + gp 

257 r = p / q 

258 stpc = stx + r * (stp - stx) 

259 stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \ 

260 * (stp - stx) 

261 if (abs(stpc - stx) < abs(stpq - stx)): 

262 stpf = stpc 

263 else: 

264 stpf = stpc + (stpq - stpc) / 2. 

265 

266 self.bracket = True 

267 

268# Second case: A lower function value and derivatives of opposite 

269# sign. The minimum is bracketed. If the cubic step is farther from 

270# stp than the secant step, the cubic step is taken, otherwise the 

271# secant step is taken. 

272 

273 elif sign < 0: # case2 

274 self.case = 2 

275 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

276 s = max(abs(theta), abs(gx), abs(gp)) 

277 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) 

278 if stp > stx: 

279 gamma = -gamma 

280 p = (gamma - gp) + theta 

281 q = ((gamma - gp) + gamma) + gx 

282 r = p / q 

283 stpc = stp + r * (stx - stp) 

284 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

285 if (abs(stpc - stp) > abs(stpq - stp)): 

286 stpf = stpc 

287 else: 

288 stpf = stpq 

289 self.bracket = True 

290 

291# Third case: A lower function value, derivatives of the same sign, 

292# and the magnitude of the derivative decreases. 

293 

294 elif abs(gp) < abs(gx): # case3 

295 self.case = 3 

296# The cubic step is computed only if the cubic tends to infinity 

297# in the direction of the step or if the minimum of the cubic 

298# is beyond stp. Otherwise the cubic step is defined to be the 

299# secant step. 

300 

301 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

302 s = max(abs(theta), abs(gx), abs(gp)) 

303 

304# The case gamma = 0 only arises if the cubic does not tend 

305# to infinity in the direction of the step. 

306 

307 gamma = s * np.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s))) 

308 if stp > stx: 

309 gamma = -gamma 

310 p = (gamma - gp) + theta 

311 q = (gamma + (gx - gp)) + gamma 

312 r = p / q 

313 if r < 0. and gamma != 0: 

314 stpc = stp + r * (stx - stp) 

315 elif stp > stx: 

316 stpc = stpmax 

317 else: 

318 stpc = stpmin 

319 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

320 

321 if self.bracket: 

322 

323 # A minimizer has been bracketed. If the cubic step is 

324 # closer to stp than the secant step, the cubic step is 

325 # taken, otherwise the secant step is taken. 

326 

327 if abs(stpc - stp) < abs(stpq - stp): 

328 stpf = stpc 

329 else: 

330 stpf = stpq 

331 if stp > stx: 

332 stpf = min(stp + .66 * (sty - stp), stpf) 

333 else: 

334 stpf = max(stp + .66 * (sty - stp), stpf) 

335 else: 

336 

337 # A minimizer has not been bracketed. If the cubic step is 

338 # farther from stp than the secant step, the cubic step is 

339 # taken, otherwise the secant step is taken. 

340 

341 if abs(stpc - stp) > abs(stpq - stp): 

342 stpf = stpc 

343 else: 

344 stpf = stpq 

345 stpf = min(stpmax, stpf) 

346 stpf = max(stpmin, stpf) 

347 

348# Fourth case: A lower function value, derivatives of the same sign, 

349# and the magnitude of the derivative does not decrease. If the 

350# minimum is not bracketed, the step is either minstep or maxstep, 

351# otherwise the cubic step is taken. 

352 

353 else: # case4 

354 self.case = 4 

355 if self.bracket: 

356 theta = 3. * (fp - fy) / (sty - stp) + gy + gp 

357 s = max(abs(theta), abs(gy), abs(gp)) 

358 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) 

359 if stp > sty: 

360 gamma = -gamma 

361 p = (gamma - gp) + theta 

362 q = ((gamma - gp) + gamma) + gy 

363 r = p / q 

364 stpc = stp + r * (sty - stp) 

365 stpf = stpc 

366 elif stp > stx: 

367 stpf = stpmax 

368 else: 

369 stpf = stpmin 

370 

371# Update the interval which contains a minimizer. 

372 

373 if fp > fx: 

374 sty = stp 

375 fy = fp 

376 gy = gp 

377 else: 

378 if sign < 0: 

379 sty = stx 

380 fy = fx 

381 gy = gx 

382 stx = stp 

383 fx = fp 

384 gx = gp 

385# Compute the new step. 

386 

387 stp = self.determine_step(stpf) 

388 

389 return stx, sty, stp, gx, fx, gy, fy 

390 

391 def determine_step(self, stp): 

392 dr = stp - self.old_stp 

393 maxsteplength = self.get_gradient_norm(self.pk * dr) 

394 if maxsteplength >= self.maxstep: 

395 dr *= self.maxstep / maxsteplength 

396 stp = self.old_stp + dr 

397 return stp 

398 

399 def save(self, data): 

400 if self.bracket: 

401 self.isave[0] = 1 

402 else: 

403 self.isave[0] = 0 

404 self.isave[1] = data[0] 

405 self.dsave = data[1:]