Coverage for ase / utils / linesearch.py: 83.61%
238 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
1# fmt: off
3# flake8: noqa
4import numpy as np
7def standard_gradient_norm(array):
8 return np.linalg.norm(array.reshape(-1, 3), axis=1).max()
11class LineSearch:
12 def __init__(self, xtol=1e-14, get_gradient_norm=standard_gradient_norm):
13 self.xtol = xtol
14 self.task = 'START'
15 self.isave = np.zeros((2,), np.intc)
16 self.dsave = np.zeros((13,), float)
17 self.fc = 0
18 self.gc = 0
19 self.case = 0
20 self.old_stp = 0
21 self.get_gradient_norm = get_gradient_norm
23 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
24 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
25 stpmax=50., stpmin=1e-8, args=()):
26 self.stpmin = stpmin
27 self.pk = pk
28 # ??? p_size = np.sqrt((pk **2).sum())
29 self.stpmax = stpmax
30 self.xtrapl = xtrapl
31 self.xtrapu = xtrapu
32 self.maxstep = maxstep
33 phi0 = old_fval
34 derphi0 = np.dot(gfk, pk)
35 self.dim = len(pk)
36 self.gms = np.sqrt(self.dim) * maxstep
37 # alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0)
38 alpha1 = 1.
39 self.no_update = False
41 if isinstance(myfprime, type(())):
42 # eps = myfprime[1]
43 fprime = myfprime[0]
44 # ??? newargs = (f,eps) + args
45 gradient = False
46 else:
47 fprime = myfprime
48 newargs = args
49 gradient = True
51 fval = old_fval
52 gval = gfk
53 self.steps = []
55 while True:
56 stp = self.step(alpha1, phi0, derphi0, c1, c2,
57 self.xtol,
58 self.isave, self.dsave)
60 if self.task[:2] == 'FG':
61 alpha1 = stp
62 fval = func(xk + stp * pk, *args)
63 self.fc += 1
64 gval = fprime(xk + stp * pk, *newargs)
65 if gradient:
66 self.gc += 1
67 else:
68 self.fc += len(xk) + 1
69 phi0 = fval
70 derphi0 = np.dot(gval, pk)
71 self.old_stp = alpha1
72 if self.no_update == True:
73 break
74 else:
75 break
77 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
78 stp = None # failed
79 return stp, fval, old_fval, self.no_update
81 def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
82 if self.task[:5] == 'START':
83 # Check the input arguments for errors.
84 if stp < self.stpmin:
85 self.task = 'ERROR: STP .LT. minstep'
86 if stp > self.stpmax:
87 self.task = 'ERROR: STP .GT. maxstep'
88 if g >= 0:
89 self.task = 'ERROR: INITIAL G >= 0'
90 if c1 < 0:
91 self.task = 'ERROR: c1 .LT. 0'
92 if c2 < 0:
93 self.task = 'ERROR: c2 .LT. 0'
94 if xtol < 0:
95 self.task = 'ERROR: XTOL .LT. 0'
96 if self.stpmin < 0:
97 self.task = 'ERROR: minstep .LT. 0'
98 if self.stpmax < self.stpmin:
99 self.task = 'ERROR: maxstep .LT. minstep'
100 if self.task[:5] == 'ERROR':
101 return stp
103 # Initialize local variables.
104 self.bracket = False
105 stage = 1
106 finit = f
107 ginit = g
108 gtest = c1 * ginit
109 width = self.stpmax - self.stpmin
110 width1 = width / .5
111# The variables stx, fx, gx contain the values of the step,
112# function, and derivative at the best step.
113# The variables sty, fy, gy contain the values of the step,
114# function, and derivative at sty.
115# The variables stp, f, g contain the values of the step,
116# function, and derivative at stp.
117 stx = 0
118 fx = finit
119 gx = ginit
120 sty = 0
121 fy = finit
122 gy = ginit
123 stmin = 0
124 stmax = stp + self.xtrapu * stp
125 self.task = 'FG'
126 self.save((stage, ginit, gtest, gx,
127 gy, finit, fx, fy, stx, sty,
128 stmin, stmax, width, width1))
129 stp = self.determine_step(stp)
130 # return stp, f, g
131 return stp
132 else:
133 if self.isave[0] == 1:
134 self.bracket = True
135 else:
136 self.bracket = False
137 stage = self.isave[1]
138 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax,
139 width, width1) = self.dsave
141# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
142# algorithm enters the second stage.
143 ftest = finit + stp * gtest
144 if stage == 1 and f < ftest and g >= 0.:
145 stage = 2
147# Test for warnings.
148 if self.bracket and (stp <= stmin or stp >= stmax):
149 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS'
150 if self.bracket and stmax - stmin <= self.xtol * stmax:
151 self.task = 'WARNING: XTOL TEST SATISFIED'
152 if stp == self.stpmax and f <= ftest and g <= gtest:
153 self.task = 'WARNING: STP = maxstep'
154 if stp == self.stpmin and (f > ftest or g >= gtest):
155 self.task = 'WARNING: STP = minstep'
157# Test for convergence.
158 if f <= ftest and abs(g) <= c2 * (- ginit):
159 self.task = 'CONVERGENCE'
161# Test for termination.
162 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
163 self.save((stage, ginit, gtest, gx,
164 gy, finit, fx, fy, stx, sty,
165 stmin, stmax, width, width1))
166 # return stp, f, g
167 return stp
169# A modified function is used to predict the step during the
170# first stage if a lower function value has been obtained but
171# the decrease is not sufficient.
172 # if stage == 1 and f <= fx and f > ftest:
173# # Define the modified function and derivative values.
174 # fm =f - stp * gtest
175 # fxm = fx - stx * gtest
176 # fym = fy - sty * gtest
177 # gm = g - gtest
178 # gxm = gx - gtest
179 # gym = gy - gtest
181# Call step to update stx, sty, and to compute the new step.
182 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty,
183 # fym, gym, stp, fm, gm,
184 # stmin, stmax)
186# # Reset the function and derivative values for f.
188 # fx = fxm + stx * gtest
189 # fy = fym + sty * gtest
190 # gx = gxm + gtest
191 # gy = gym + gtest
193 # else:
194# Call step to update stx, sty, and to compute the new step.
196 stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty,
197 fy, gy, stp, f, g,
198 stmin, stmax)
201# Decide if a bisection step is needed.
203 if self.bracket:
204 if abs(sty - stx) >= .66 * width1:
205 stp = stx + .5 * (sty - stx)
206 width1 = width
207 width = abs(sty - stx)
209# Set the minimum and maximum steps allowed for stp.
211 if self.bracket:
212 stmin = min(stx, sty)
213 stmax = max(stx, sty)
214 else:
215 stmin = stp + self.xtrapl * (stp - stx)
216 stmax = stp + self.xtrapu * (stp - stx)
218# Force the step to be within the bounds maxstep and minstep.
220 stp = max(stp, self.stpmin)
221 stp = min(stp, self.stpmax)
223 if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
224 self.no_update = True
225# If further progress is not possible, let stp be the best
226# point obtained during the search.
228 if (self.bracket and stp < stmin or stp >= stmax) \
229 or (self.bracket and stmax - stmin < self.xtol * stmax):
230 stp = stx
232# Obtain another function and derivative.
234 self.task = 'FG'
235 self.save((stage, ginit, gtest, gx,
236 gy, finit, fx, fy, stx, sty,
237 stmin, stmax, width, width1))
238 return stp
240 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
241 stpmin, stpmax):
242 sign = gp * (gx / abs(gx))
244# First case: A higher function value. The minimum is bracketed.
245# If the cubic step is closer to stx than the quadratic step, the
246# cubic step is taken, otherwise the average of the cubic and
247# quadratic steps is taken.
248 if fp > fx: # case1
249 self.case = 1
250 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
251 s = max(abs(theta), abs(gx), abs(gp))
252 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
253 if stp < stx:
254 gamma = -gamma
255 p = (gamma - gx) + theta
256 q = ((gamma - gx) + gamma) + gp
257 r = p / q
258 stpc = stx + r * (stp - stx)
259 stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \
260 * (stp - stx)
261 if (abs(stpc - stx) < abs(stpq - stx)):
262 stpf = stpc
263 else:
264 stpf = stpc + (stpq - stpc) / 2.
266 self.bracket = True
268# Second case: A lower function value and derivatives of opposite
269# sign. The minimum is bracketed. If the cubic step is farther from
270# stp than the secant step, the cubic step is taken, otherwise the
271# secant step is taken.
273 elif sign < 0: # case2
274 self.case = 2
275 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
276 s = max(abs(theta), abs(gx), abs(gp))
277 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
278 if stp > stx:
279 gamma = -gamma
280 p = (gamma - gp) + theta
281 q = ((gamma - gp) + gamma) + gx
282 r = p / q
283 stpc = stp + r * (stx - stp)
284 stpq = stp + (gp / (gp - gx)) * (stx - stp)
285 if (abs(stpc - stp) > abs(stpq - stp)):
286 stpf = stpc
287 else:
288 stpf = stpq
289 self.bracket = True
291# Third case: A lower function value, derivatives of the same sign,
292# and the magnitude of the derivative decreases.
294 elif abs(gp) < abs(gx): # case3
295 self.case = 3
296# The cubic step is computed only if the cubic tends to infinity
297# in the direction of the step or if the minimum of the cubic
298# is beyond stp. Otherwise the cubic step is defined to be the
299# secant step.
301 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
302 s = max(abs(theta), abs(gx), abs(gp))
304# The case gamma = 0 only arises if the cubic does not tend
305# to infinity in the direction of the step.
307 gamma = s * np.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s)))
308 if stp > stx:
309 gamma = -gamma
310 p = (gamma - gp) + theta
311 q = (gamma + (gx - gp)) + gamma
312 r = p / q
313 if r < 0. and gamma != 0:
314 stpc = stp + r * (stx - stp)
315 elif stp > stx:
316 stpc = stpmax
317 else:
318 stpc = stpmin
319 stpq = stp + (gp / (gp - gx)) * (stx - stp)
321 if self.bracket:
323 # A minimizer has been bracketed. If the cubic step is
324 # closer to stp than the secant step, the cubic step is
325 # taken, otherwise the secant step is taken.
327 if abs(stpc - stp) < abs(stpq - stp):
328 stpf = stpc
329 else:
330 stpf = stpq
331 if stp > stx:
332 stpf = min(stp + .66 * (sty - stp), stpf)
333 else:
334 stpf = max(stp + .66 * (sty - stp), stpf)
335 else:
337 # A minimizer has not been bracketed. If the cubic step is
338 # farther from stp than the secant step, the cubic step is
339 # taken, otherwise the secant step is taken.
341 if abs(stpc - stp) > abs(stpq - stp):
342 stpf = stpc
343 else:
344 stpf = stpq
345 stpf = min(stpmax, stpf)
346 stpf = max(stpmin, stpf)
348# Fourth case: A lower function value, derivatives of the same sign,
349# and the magnitude of the derivative does not decrease. If the
350# minimum is not bracketed, the step is either minstep or maxstep,
351# otherwise the cubic step is taken.
353 else: # case4
354 self.case = 4
355 if self.bracket:
356 theta = 3. * (fp - fy) / (sty - stp) + gy + gp
357 s = max(abs(theta), abs(gy), abs(gp))
358 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
359 if stp > sty:
360 gamma = -gamma
361 p = (gamma - gp) + theta
362 q = ((gamma - gp) + gamma) + gy
363 r = p / q
364 stpc = stp + r * (sty - stp)
365 stpf = stpc
366 elif stp > stx:
367 stpf = stpmax
368 else:
369 stpf = stpmin
371# Update the interval which contains a minimizer.
373 if fp > fx:
374 sty = stp
375 fy = fp
376 gy = gp
377 else:
378 if sign < 0:
379 sty = stx
380 fy = fx
381 gy = gx
382 stx = stp
383 fx = fp
384 gx = gp
385# Compute the new step.
387 stp = self.determine_step(stpf)
389 return stx, sty, stp, gx, fx, gy, fy
391 def determine_step(self, stp):
392 dr = stp - self.old_stp
393 maxsteplength = self.get_gradient_norm(self.pk * dr)
394 if maxsteplength >= self.maxstep:
395 dr *= self.maxstep / maxsteplength
396 stp = self.old_stp + dr
397 return stp
399 def save(self, data):
400 if self.bracket:
401 self.isave[0] = 1
402 else:
403 self.isave[0] = 0
404 self.isave[1] = data[0]
405 self.dsave = data[1:]