Coverage for /builds/ase/ase/ase/utils/linesearch.py: 83.68%
239 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3# flake8: noqa
4import numpy as np
6pymin = min
7pymax = max
10class LineSearch:
11 def __init__(self, xtol=1e-14):
13 self.xtol = xtol
14 self.task = 'START'
15 self.isave = np.zeros((2,), np.intc)
16 self.dsave = np.zeros((13,), float)
17 self.fc = 0
18 self.gc = 0
19 self.case = 0
20 self.old_stp = 0
22 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
23 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
24 stpmax=50., stpmin=1e-8, args=()):
25 self.stpmin = stpmin
26 self.pk = pk
27 # ??? p_size = np.sqrt((pk **2).sum())
28 self.stpmax = stpmax
29 self.xtrapl = xtrapl
30 self.xtrapu = xtrapu
31 self.maxstep = maxstep
32 phi0 = old_fval
33 derphi0 = np.dot(gfk, pk)
34 self.dim = len(pk)
35 self.gms = np.sqrt(self.dim) * maxstep
36 # alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0)
37 alpha1 = 1.
38 self.no_update = False
40 if isinstance(myfprime, type(())):
41 # eps = myfprime[1]
42 fprime = myfprime[0]
43 # ??? newargs = (f,eps) + args
44 gradient = False
45 else:
46 fprime = myfprime
47 newargs = args
48 gradient = True
50 fval = old_fval
51 gval = gfk
52 self.steps = []
54 while True:
55 stp = self.step(alpha1, phi0, derphi0, c1, c2,
56 self.xtol,
57 self.isave, self.dsave)
59 if self.task[:2] == 'FG':
60 alpha1 = stp
61 fval = func(xk + stp * pk, *args)
62 self.fc += 1
63 gval = fprime(xk + stp * pk, *newargs)
64 if gradient:
65 self.gc += 1
66 else:
67 self.fc += len(xk) + 1
68 phi0 = fval
69 derphi0 = np.dot(gval, pk)
70 self.old_stp = alpha1
71 if self.no_update == True:
72 break
73 else:
74 break
76 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
77 stp = None # failed
78 return stp, fval, old_fval, self.no_update
80 def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
81 if self.task[:5] == 'START':
82 # Check the input arguments for errors.
83 if stp < self.stpmin:
84 self.task = 'ERROR: STP .LT. minstep'
85 if stp > self.stpmax:
86 self.task = 'ERROR: STP .GT. maxstep'
87 if g >= 0:
88 self.task = 'ERROR: INITIAL G >= 0'
89 if c1 < 0:
90 self.task = 'ERROR: c1 .LT. 0'
91 if c2 < 0:
92 self.task = 'ERROR: c2 .LT. 0'
93 if xtol < 0:
94 self.task = 'ERROR: XTOL .LT. 0'
95 if self.stpmin < 0:
96 self.task = 'ERROR: minstep .LT. 0'
97 if self.stpmax < self.stpmin:
98 self.task = 'ERROR: maxstep .LT. minstep'
99 if self.task[:5] == 'ERROR':
100 return stp
102 # Initialize local variables.
103 self.bracket = False
104 stage = 1
105 finit = f
106 ginit = g
107 gtest = c1 * ginit
108 width = self.stpmax - self.stpmin
109 width1 = width / .5
110# The variables stx, fx, gx contain the values of the step,
111# function, and derivative at the best step.
112# The variables sty, fy, gy contain the values of the step,
113# function, and derivative at sty.
114# The variables stp, f, g contain the values of the step,
115# function, and derivative at stp.
116 stx = 0
117 fx = finit
118 gx = ginit
119 sty = 0
120 fy = finit
121 gy = ginit
122 stmin = 0
123 stmax = stp + self.xtrapu * stp
124 self.task = 'FG'
125 self.save((stage, ginit, gtest, gx,
126 gy, finit, fx, fy, stx, sty,
127 stmin, stmax, width, width1))
128 stp = self.determine_step(stp)
129 # return stp, f, g
130 return stp
131 else:
132 if self.isave[0] == 1:
133 self.bracket = True
134 else:
135 self.bracket = False
136 stage = self.isave[1]
137 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax,
138 width, width1) = self.dsave
140# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
141# algorithm enters the second stage.
142 ftest = finit + stp * gtest
143 if stage == 1 and f < ftest and g >= 0.:
144 stage = 2
146# Test for warnings.
147 if self.bracket and (stp <= stmin or stp >= stmax):
148 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS'
149 if self.bracket and stmax - stmin <= self.xtol * stmax:
150 self.task = 'WARNING: XTOL TEST SATISFIED'
151 if stp == self.stpmax and f <= ftest and g <= gtest:
152 self.task = 'WARNING: STP = maxstep'
153 if stp == self.stpmin and (f > ftest or g >= gtest):
154 self.task = 'WARNING: STP = minstep'
156# Test for convergence.
157 if f <= ftest and abs(g) <= c2 * (- ginit):
158 self.task = 'CONVERGENCE'
160# Test for termination.
161 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
162 self.save((stage, ginit, gtest, gx,
163 gy, finit, fx, fy, stx, sty,
164 stmin, stmax, width, width1))
165 # return stp, f, g
166 return stp
168# A modified function is used to predict the step during the
169# first stage if a lower function value has been obtained but
170# the decrease is not sufficient.
171 # if stage == 1 and f <= fx and f > ftest:
172# # Define the modified function and derivative values.
173 # fm =f - stp * gtest
174 # fxm = fx - stx * gtest
175 # fym = fy - sty * gtest
176 # gm = g - gtest
177 # gxm = gx - gtest
178 # gym = gy - gtest
180# Call step to update stx, sty, and to compute the new step.
181 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty,
182 # fym, gym, stp, fm, gm,
183 # stmin, stmax)
185# # Reset the function and derivative values for f.
187 # fx = fxm + stx * gtest
188 # fy = fym + sty * gtest
189 # gx = gxm + gtest
190 # gy = gym + gtest
192 # else:
193# Call step to update stx, sty, and to compute the new step.
195 stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty,
196 fy, gy, stp, f, g,
197 stmin, stmax)
200# Decide if a bisection step is needed.
202 if self.bracket:
203 if abs(sty - stx) >= .66 * width1:
204 stp = stx + .5 * (sty - stx)
205 width1 = width
206 width = abs(sty - stx)
208# Set the minimum and maximum steps allowed for stp.
210 if self.bracket:
211 stmin = min(stx, sty)
212 stmax = max(stx, sty)
213 else:
214 stmin = stp + self.xtrapl * (stp - stx)
215 stmax = stp + self.xtrapu * (stp - stx)
217# Force the step to be within the bounds maxstep and minstep.
219 stp = max(stp, self.stpmin)
220 stp = min(stp, self.stpmax)
222 if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
223 self.no_update = True
224# If further progress is not possible, let stp be the best
225# point obtained during the search.
227 if (self.bracket and stp < stmin or stp >= stmax) \
228 or (self.bracket and stmax - stmin < self.xtol * stmax):
229 stp = stx
231# Obtain another function and derivative.
233 self.task = 'FG'
234 self.save((stage, ginit, gtest, gx,
235 gy, finit, fx, fy, stx, sty,
236 stmin, stmax, width, width1))
237 return stp
239 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
240 stpmin, stpmax):
241 sign = gp * (gx / abs(gx))
243# First case: A higher function value. The minimum is bracketed.
244# If the cubic step is closer to stx than the quadratic step, the
245# cubic step is taken, otherwise the average of the cubic and
246# quadratic steps is taken.
247 if fp > fx: # case1
248 self.case = 1
249 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
250 s = max(abs(theta), abs(gx), abs(gp))
251 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
252 if stp < stx:
253 gamma = -gamma
254 p = (gamma - gx) + theta
255 q = ((gamma - gx) + gamma) + gp
256 r = p / q
257 stpc = stx + r * (stp - stx)
258 stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \
259 * (stp - stx)
260 if (abs(stpc - stx) < abs(stpq - stx)):
261 stpf = stpc
262 else:
263 stpf = stpc + (stpq - stpc) / 2.
265 self.bracket = True
267# Second case: A lower function value and derivatives of opposite
268# sign. The minimum is bracketed. If the cubic step is farther from
269# stp than the secant step, the cubic step is taken, otherwise the
270# secant step is taken.
272 elif sign < 0: # case2
273 self.case = 2
274 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
275 s = max(abs(theta), abs(gx), abs(gp))
276 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
277 if stp > stx:
278 gamma = -gamma
279 p = (gamma - gp) + theta
280 q = ((gamma - gp) + gamma) + gx
281 r = p / q
282 stpc = stp + r * (stx - stp)
283 stpq = stp + (gp / (gp - gx)) * (stx - stp)
284 if (abs(stpc - stp) > abs(stpq - stp)):
285 stpf = stpc
286 else:
287 stpf = stpq
288 self.bracket = True
290# Third case: A lower function value, derivatives of the same sign,
291# and the magnitude of the derivative decreases.
293 elif abs(gp) < abs(gx): # case3
294 self.case = 3
295# The cubic step is computed only if the cubic tends to infinity
296# in the direction of the step or if the minimum of the cubic
297# is beyond stp. Otherwise the cubic step is defined to be the
298# secant step.
300 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
301 s = max(abs(theta), abs(gx), abs(gp))
303# The case gamma = 0 only arises if the cubic does not tend
304# to infinity in the direction of the step.
306 gamma = s * np.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s)))
307 if stp > stx:
308 gamma = -gamma
309 p = (gamma - gp) + theta
310 q = (gamma + (gx - gp)) + gamma
311 r = p / q
312 if r < 0. and gamma != 0:
313 stpc = stp + r * (stx - stp)
314 elif stp > stx:
315 stpc = stpmax
316 else:
317 stpc = stpmin
318 stpq = stp + (gp / (gp - gx)) * (stx - stp)
320 if self.bracket:
322 # A minimizer has been bracketed. If the cubic step is
323 # closer to stp than the secant step, the cubic step is
324 # taken, otherwise the secant step is taken.
326 if abs(stpc - stp) < abs(stpq - stp):
327 stpf = stpc
328 else:
329 stpf = stpq
330 if stp > stx:
331 stpf = min(stp + .66 * (sty - stp), stpf)
332 else:
333 stpf = max(stp + .66 * (sty - stp), stpf)
334 else:
336 # A minimizer has not been bracketed. If the cubic step is
337 # farther from stp than the secant step, the cubic step is
338 # taken, otherwise the secant step is taken.
340 if abs(stpc - stp) > abs(stpq - stp):
341 stpf = stpc
342 else:
343 stpf = stpq
344 stpf = min(stpmax, stpf)
345 stpf = max(stpmin, stpf)
347# Fourth case: A lower function value, derivatives of the same sign,
348# and the magnitude of the derivative does not decrease. If the
349# minimum is not bracketed, the step is either minstep or maxstep,
350# otherwise the cubic step is taken.
352 else: # case4
353 self.case = 4
354 if self.bracket:
355 theta = 3. * (fp - fy) / (sty - stp) + gy + gp
356 s = max(abs(theta), abs(gy), abs(gp))
357 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
358 if stp > sty:
359 gamma = -gamma
360 p = (gamma - gp) + theta
361 q = ((gamma - gp) + gamma) + gy
362 r = p / q
363 stpc = stp + r * (sty - stp)
364 stpf = stpc
365 elif stp > stx:
366 stpf = stpmax
367 else:
368 stpf = stpmin
370# Update the interval which contains a minimizer.
372 if fp > fx:
373 sty = stp
374 fy = fp
375 gy = gp
376 else:
377 if sign < 0:
378 sty = stx
379 fy = fx
380 gy = gx
381 stx = stp
382 fx = fp
383 gx = gp
384# Compute the new step.
386 stp = self.determine_step(stpf)
388 return stx, sty, stp, gx, fx, gy, fy
390 def determine_step(self, stp):
391 dr = stp - self.old_stp
392 x = np.reshape(self.pk, (-1, 3))
393 steplengths = ((dr * x)**2).sum(1)**0.5
394 maxsteplength = pymax(steplengths)
395 if maxsteplength >= self.maxstep:
396 dr *= self.maxstep / maxsteplength
397 stp = self.old_stp + dr
398 return stp
400 def save(self, data):
401 if self.bracket:
402 self.isave[0] = 1
403 else:
404 self.isave[0] = 0
405 self.isave[1] = data[0]
406 self.dsave = data[1:]