Coverage for /builds/ase/ase/ase/utils/linesearch.py: 83.68%

239 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3# flake8: noqa 

4import numpy as np 

5 

6pymin = min 

7pymax = max 

8 

9 

10class LineSearch: 

11 def __init__(self, xtol=1e-14): 

12 

13 self.xtol = xtol 

14 self.task = 'START' 

15 self.isave = np.zeros((2,), np.intc) 

16 self.dsave = np.zeros((13,), float) 

17 self.fc = 0 

18 self.gc = 0 

19 self.case = 0 

20 self.old_stp = 0 

21 

22 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, 

23 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., 

24 stpmax=50., stpmin=1e-8, args=()): 

25 self.stpmin = stpmin 

26 self.pk = pk 

27 # ??? p_size = np.sqrt((pk **2).sum()) 

28 self.stpmax = stpmax 

29 self.xtrapl = xtrapl 

30 self.xtrapu = xtrapu 

31 self.maxstep = maxstep 

32 phi0 = old_fval 

33 derphi0 = np.dot(gfk, pk) 

34 self.dim = len(pk) 

35 self.gms = np.sqrt(self.dim) * maxstep 

36 # alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0) 

37 alpha1 = 1. 

38 self.no_update = False 

39 

40 if isinstance(myfprime, type(())): 

41 # eps = myfprime[1] 

42 fprime = myfprime[0] 

43 # ??? newargs = (f,eps) + args 

44 gradient = False 

45 else: 

46 fprime = myfprime 

47 newargs = args 

48 gradient = True 

49 

50 fval = old_fval 

51 gval = gfk 

52 self.steps = [] 

53 

54 while True: 

55 stp = self.step(alpha1, phi0, derphi0, c1, c2, 

56 self.xtol, 

57 self.isave, self.dsave) 

58 

59 if self.task[:2] == 'FG': 

60 alpha1 = stp 

61 fval = func(xk + stp * pk, *args) 

62 self.fc += 1 

63 gval = fprime(xk + stp * pk, *newargs) 

64 if gradient: 

65 self.gc += 1 

66 else: 

67 self.fc += len(xk) + 1 

68 phi0 = fval 

69 derphi0 = np.dot(gval, pk) 

70 self.old_stp = alpha1 

71 if self.no_update == True: 

72 break 

73 else: 

74 break 

75 

76 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': 

77 stp = None # failed 

78 return stp, fval, old_fval, self.no_update 

79 

80 def step(self, stp, f, g, c1, c2, xtol, isave, dsave): 

81 if self.task[:5] == 'START': 

82 # Check the input arguments for errors. 

83 if stp < self.stpmin: 

84 self.task = 'ERROR: STP .LT. minstep' 

85 if stp > self.stpmax: 

86 self.task = 'ERROR: STP .GT. maxstep' 

87 if g >= 0: 

88 self.task = 'ERROR: INITIAL G >= 0' 

89 if c1 < 0: 

90 self.task = 'ERROR: c1 .LT. 0' 

91 if c2 < 0: 

92 self.task = 'ERROR: c2 .LT. 0' 

93 if xtol < 0: 

94 self.task = 'ERROR: XTOL .LT. 0' 

95 if self.stpmin < 0: 

96 self.task = 'ERROR: minstep .LT. 0' 

97 if self.stpmax < self.stpmin: 

98 self.task = 'ERROR: maxstep .LT. minstep' 

99 if self.task[:5] == 'ERROR': 

100 return stp 

101 

102 # Initialize local variables. 

103 self.bracket = False 

104 stage = 1 

105 finit = f 

106 ginit = g 

107 gtest = c1 * ginit 

108 width = self.stpmax - self.stpmin 

109 width1 = width / .5 

110# The variables stx, fx, gx contain the values of the step, 

111# function, and derivative at the best step. 

112# The variables sty, fy, gy contain the values of the step, 

113# function, and derivative at sty. 

114# The variables stp, f, g contain the values of the step, 

115# function, and derivative at stp. 

116 stx = 0 

117 fx = finit 

118 gx = ginit 

119 sty = 0 

120 fy = finit 

121 gy = ginit 

122 stmin = 0 

123 stmax = stp + self.xtrapu * stp 

124 self.task = 'FG' 

125 self.save((stage, ginit, gtest, gx, 

126 gy, finit, fx, fy, stx, sty, 

127 stmin, stmax, width, width1)) 

128 stp = self.determine_step(stp) 

129 # return stp, f, g 

130 return stp 

131 else: 

132 if self.isave[0] == 1: 

133 self.bracket = True 

134 else: 

135 self.bracket = False 

136 stage = self.isave[1] 

137 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, 

138 width, width1) = self.dsave 

139 

140# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the 

141# algorithm enters the second stage. 

142 ftest = finit + stp * gtest 

143 if stage == 1 and f < ftest and g >= 0.: 

144 stage = 2 

145 

146# Test for warnings. 

147 if self.bracket and (stp <= stmin or stp >= stmax): 

148 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' 

149 if self.bracket and stmax - stmin <= self.xtol * stmax: 

150 self.task = 'WARNING: XTOL TEST SATISFIED' 

151 if stp == self.stpmax and f <= ftest and g <= gtest: 

152 self.task = 'WARNING: STP = maxstep' 

153 if stp == self.stpmin and (f > ftest or g >= gtest): 

154 self.task = 'WARNING: STP = minstep' 

155 

156# Test for convergence. 

157 if f <= ftest and abs(g) <= c2 * (- ginit): 

158 self.task = 'CONVERGENCE' 

159 

160# Test for termination. 

161 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': 

162 self.save((stage, ginit, gtest, gx, 

163 gy, finit, fx, fy, stx, sty, 

164 stmin, stmax, width, width1)) 

165 # return stp, f, g 

166 return stp 

167 

168# A modified function is used to predict the step during the 

169# first stage if a lower function value has been obtained but 

170# the decrease is not sufficient. 

171 # if stage == 1 and f <= fx and f > ftest: 

172# # Define the modified function and derivative values. 

173 # fm =f - stp * gtest 

174 # fxm = fx - stx * gtest 

175 # fym = fy - sty * gtest 

176 # gm = g - gtest 

177 # gxm = gx - gtest 

178 # gym = gy - gtest 

179 

180# Call step to update stx, sty, and to compute the new step. 

181 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty, 

182 # fym, gym, stp, fm, gm, 

183 # stmin, stmax) 

184 

185# # Reset the function and derivative values for f. 

186 

187 # fx = fxm + stx * gtest 

188 # fy = fym + sty * gtest 

189 # gx = gxm + gtest 

190 # gy = gym + gtest 

191 

192 # else: 

193# Call step to update stx, sty, and to compute the new step. 

194 

195 stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty, 

196 fy, gy, stp, f, g, 

197 stmin, stmax) 

198 

199 

200# Decide if a bisection step is needed. 

201 

202 if self.bracket: 

203 if abs(sty - stx) >= .66 * width1: 

204 stp = stx + .5 * (sty - stx) 

205 width1 = width 

206 width = abs(sty - stx) 

207 

208# Set the minimum and maximum steps allowed for stp. 

209 

210 if self.bracket: 

211 stmin = min(stx, sty) 

212 stmax = max(stx, sty) 

213 else: 

214 stmin = stp + self.xtrapl * (stp - stx) 

215 stmax = stp + self.xtrapu * (stp - stx) 

216 

217# Force the step to be within the bounds maxstep and minstep. 

218 

219 stp = max(stp, self.stpmin) 

220 stp = min(stp, self.stpmax) 

221 

222 if (stx == stp and stp == self.stpmax and stmin > self.stpmax): 

223 self.no_update = True 

224# If further progress is not possible, let stp be the best 

225# point obtained during the search. 

226 

227 if (self.bracket and stp < stmin or stp >= stmax) \ 

228 or (self.bracket and stmax - stmin < self.xtol * stmax): 

229 stp = stx 

230 

231# Obtain another function and derivative. 

232 

233 self.task = 'FG' 

234 self.save((stage, ginit, gtest, gx, 

235 gy, finit, fx, fy, stx, sty, 

236 stmin, stmax, width, width1)) 

237 return stp 

238 

239 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, 

240 stpmin, stpmax): 

241 sign = gp * (gx / abs(gx)) 

242 

243# First case: A higher function value. The minimum is bracketed. 

244# If the cubic step is closer to stx than the quadratic step, the 

245# cubic step is taken, otherwise the average of the cubic and 

246# quadratic steps is taken. 

247 if fp > fx: # case1 

248 self.case = 1 

249 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

250 s = max(abs(theta), abs(gx), abs(gp)) 

251 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) 

252 if stp < stx: 

253 gamma = -gamma 

254 p = (gamma - gx) + theta 

255 q = ((gamma - gx) + gamma) + gp 

256 r = p / q 

257 stpc = stx + r * (stp - stx) 

258 stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \ 

259 * (stp - stx) 

260 if (abs(stpc - stx) < abs(stpq - stx)): 

261 stpf = stpc 

262 else: 

263 stpf = stpc + (stpq - stpc) / 2. 

264 

265 self.bracket = True 

266 

267# Second case: A lower function value and derivatives of opposite 

268# sign. The minimum is bracketed. If the cubic step is farther from 

269# stp than the secant step, the cubic step is taken, otherwise the 

270# secant step is taken. 

271 

272 elif sign < 0: # case2 

273 self.case = 2 

274 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

275 s = max(abs(theta), abs(gx), abs(gp)) 

276 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) 

277 if stp > stx: 

278 gamma = -gamma 

279 p = (gamma - gp) + theta 

280 q = ((gamma - gp) + gamma) + gx 

281 r = p / q 

282 stpc = stp + r * (stx - stp) 

283 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

284 if (abs(stpc - stp) > abs(stpq - stp)): 

285 stpf = stpc 

286 else: 

287 stpf = stpq 

288 self.bracket = True 

289 

290# Third case: A lower function value, derivatives of the same sign, 

291# and the magnitude of the derivative decreases. 

292 

293 elif abs(gp) < abs(gx): # case3 

294 self.case = 3 

295# The cubic step is computed only if the cubic tends to infinity 

296# in the direction of the step or if the minimum of the cubic 

297# is beyond stp. Otherwise the cubic step is defined to be the 

298# secant step. 

299 

300 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

301 s = max(abs(theta), abs(gx), abs(gp)) 

302 

303# The case gamma = 0 only arises if the cubic does not tend 

304# to infinity in the direction of the step. 

305 

306 gamma = s * np.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s))) 

307 if stp > stx: 

308 gamma = -gamma 

309 p = (gamma - gp) + theta 

310 q = (gamma + (gx - gp)) + gamma 

311 r = p / q 

312 if r < 0. and gamma != 0: 

313 stpc = stp + r * (stx - stp) 

314 elif stp > stx: 

315 stpc = stpmax 

316 else: 

317 stpc = stpmin 

318 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

319 

320 if self.bracket: 

321 

322 # A minimizer has been bracketed. If the cubic step is 

323 # closer to stp than the secant step, the cubic step is 

324 # taken, otherwise the secant step is taken. 

325 

326 if abs(stpc - stp) < abs(stpq - stp): 

327 stpf = stpc 

328 else: 

329 stpf = stpq 

330 if stp > stx: 

331 stpf = min(stp + .66 * (sty - stp), stpf) 

332 else: 

333 stpf = max(stp + .66 * (sty - stp), stpf) 

334 else: 

335 

336 # A minimizer has not been bracketed. If the cubic step is 

337 # farther from stp than the secant step, the cubic step is 

338 # taken, otherwise the secant step is taken. 

339 

340 if abs(stpc - stp) > abs(stpq - stp): 

341 stpf = stpc 

342 else: 

343 stpf = stpq 

344 stpf = min(stpmax, stpf) 

345 stpf = max(stpmin, stpf) 

346 

347# Fourth case: A lower function value, derivatives of the same sign, 

348# and the magnitude of the derivative does not decrease. If the 

349# minimum is not bracketed, the step is either minstep or maxstep, 

350# otherwise the cubic step is taken. 

351 

352 else: # case4 

353 self.case = 4 

354 if self.bracket: 

355 theta = 3. * (fp - fy) / (sty - stp) + gy + gp 

356 s = max(abs(theta), abs(gy), abs(gp)) 

357 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) 

358 if stp > sty: 

359 gamma = -gamma 

360 p = (gamma - gp) + theta 

361 q = ((gamma - gp) + gamma) + gy 

362 r = p / q 

363 stpc = stp + r * (sty - stp) 

364 stpf = stpc 

365 elif stp > stx: 

366 stpf = stpmax 

367 else: 

368 stpf = stpmin 

369 

370# Update the interval which contains a minimizer. 

371 

372 if fp > fx: 

373 sty = stp 

374 fy = fp 

375 gy = gp 

376 else: 

377 if sign < 0: 

378 sty = stx 

379 fy = fx 

380 gy = gx 

381 stx = stp 

382 fx = fp 

383 gx = gp 

384# Compute the new step. 

385 

386 stp = self.determine_step(stpf) 

387 

388 return stx, sty, stp, gx, fx, gy, fy 

389 

390 def determine_step(self, stp): 

391 dr = stp - self.old_stp 

392 x = np.reshape(self.pk, (-1, 3)) 

393 steplengths = ((dr * x)**2).sum(1)**0.5 

394 maxsteplength = pymax(steplengths) 

395 if maxsteplength >= self.maxstep: 

396 dr *= self.maxstep / maxsteplength 

397 stp = self.old_stp + dr 

398 return stp 

399 

400 def save(self, data): 

401 if self.bracket: 

402 self.isave[0] = 1 

403 else: 

404 self.isave[0] = 0 

405 self.isave[1] = data[0] 

406 self.dsave = data[1:]