Coverage for /builds/ase/ase/ase/utils/linesearcharmijo.py: 62.50%
152 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1import logging
2import math
4import numpy as np
5import scipy
6import scipy.linalg
8from ase.utils import longsum
10logger = logging.getLogger(__name__)
13class LinearPath:
14 """Describes a linear search path of the form t -> t g"""
16 def __init__(self, dirn):
17 """Initialise LinearPath object
19 Args:
20 dirn : search direction
21 """
22 self.dirn = dirn
24 def step(self, alpha):
25 return alpha * self.dirn
28def nullspace(A, myeps=1e-10):
29 """The RumPath class needs the ability to compute the null-space of
30 a small matrix. This is provided here. But we now also need scipy!
32 This routine was copy-pasted from
33 http://stackoverflow.com/questions/5889142/python-numpy-scipy-finding-the-null-space-of-a-matrix
34 How the h*** does numpy/scipy not have a null-space implemented?
35 """
36 u, s, vh = scipy.linalg.svd(A)
37 padding = max(0, np.shape(A)[1] - np.shape(s)[0])
38 null_mask = np.concatenate(
39 ((s <= myeps), np.ones((padding,), dtype=bool)), axis=0
40 )
41 null_space = scipy.compress(null_mask, vh, axis=0)
42 return scipy.transpose(null_space)
45class RumPath:
46 """Describes a curved search path, taking into account information
47 about (near-) rigid unit motions (RUMs).
49 One can tag sub-molecules of the system, which are collections of
50 particles that form a (near-)rigid unit. Let x1, ... xn be the positions
51 of one such molecule, then we construct a path of the form
52 xi(t) = xi(0) + (exp(K t) - I) yi + t wi + t c
53 where yi = xi - <x>, c = <g> is a rigid translation, K is anti-symmetric
54 so that exp(tK) yi denotes a rotation about the centre of mass, and wi
55 is the remainind stretch of the molecule.
57 The following variables are stored:
58 * rotation_factors : array of acceleration factors
59 * rigid_units : array of molecule indices
60 * stretch : w
61 * K : list of K matrices
62 * y : list of y-vectors
63 """
65 def __init__(self, x_start, dirn, rigid_units, rotation_factors):
66 """Initialise a `RumPath`
68 Args:
69 x_start : vector containing the positions in d x nAt shape
70 dirn : search direction, same shape as x_start vector
71 rigid_units : array of arrays of molecule indices
72 rotation_factors : factor by which the rotation of each molecular
73 is accelerated; array of scalars, same length as
74 rigid_units
75 """
77 # keep some stuff stored
78 self.rotation_factors = rotation_factors
79 self.rigid_units = rigid_units
80 # create storage for more stuff
81 self.K = []
82 self.y = []
83 # We need to reshape x_start and dirn since we want to apply
84 # rotations to individual position vectors!
85 # we will eventually store the stretch in w, X is just a reference
86 # to x_start with different shape
87 w = dirn.copy().reshape([3, len(dirn) / 3])
88 X = x_start.reshape([3, len(dirn) / 3])
90 for I in rigid_units: # I is a list of indices for one molecule
91 # get the positions of the i-th molecule, subtract mean
92 x = X[:, I]
93 y = x - x.mean(0).T # PBC?
94 # same for forces >>> translation component
95 g = w[:, I]
96 f = g - g.mean(0).T
97 # compute the system to solve for K (see accompanying note!)
98 # A = \sum_j Yj Yj'
99 # b = \sum_j Yj' fj
100 A = np.zeros((3, 3))
101 b = np.zeros(3)
102 for j in range(len(I)):
103 Yj = np.array(
104 [
105 [y[1, j], 0.0, -y[2, j]],
106 [-y[0, j], y[2, j], 0.0],
107 [0.0, -y[1, j], y[0, j]],
108 ]
109 )
110 A += np.dot(Yj.T, Yj)
111 b += np.dot(Yj.T, f[:, j])
112 # If the directions y[:,j] span all of R^3 (canonically this is true
113 # when there are at least three atoms in the molecule) but if
114 # not, then A is singular so we cannot solve A k = b. In this case
115 # we solve Ak = b in the space orthogonal to the null-space of A.
116 # TODO:
117 # this can get unstable if A is "near-singular"! We may
118 # need to revisit this idea at some point to get something
119 # more robust
120 N = nullspace(A)
121 b -= np.dot(np.dot(N, N.T), b)
122 A += np.dot(N, N.T)
123 k = scipy.linalg.solve(A, b, sym_pos=True)
124 K = np.array(
125 [[0.0, k[0], -k[2]], [-k[0], 0.0, k[1]], [k[2], -k[1], 0.0]]
126 )
127 # now remove the rotational component from the search direction
128 # ( we actually keep the translational component as part of w,
129 # but this could be changed as well! )
130 w[:, I] -= np.dot(K, y)
131 # store K and y
132 self.K.append(K)
133 self.y.append(y)
135 # store the stretch (no need to copy here, since w is already a copy)
136 self.stretch = w
138 def step(self, alpha):
139 """perform a step in the line-search, given a step-length alpha
141 Args:
142 alpha : step-length
144 Returns:
145 s : update for positions
146 """
147 # translation and stretch
148 s = alpha * self.stretch
149 # loop through rigid_units
150 for I, K, y, rf in zip(
151 self.rigid_units, self.K, self.y, self.rotation_factors
152 ):
153 # with matrix exponentials:
154 # s[:, I] += expm(K * alpha * rf) * p.y - p.y
155 # third-order taylor approximation:
156 # I + t K + 1/2 t^2 K^2 + 1/6 t^3 K^3 - I
157 # = t K (I + 1/2 t K (I + 1/3 t K))
158 aK = alpha * rf * K
159 s[:, I] += np.dot(
160 aK, y + 0.5 * np.dot(aK, y + 1 / 3.0 * np.dot(aK, y))
161 )
163 return s.ravel()
166class LineSearchArmijo:
167 def __init__(self, func, c1=0.1, tol=1e-14):
168 """Initialise the linesearch with set parameters and functions.
170 Args:
171 func: the function we are trying to minimise (energy), which should
172 take an array of positions for its argument
173 c1: parameter for the sufficient decrease condition in (0.0 0.5)
174 tol: tolerance for evaluating equality
176 """
178 self.tol = tol
179 self.func = func
181 if not (0 < c1 < 0.5):
182 logger.error(
183 'c1 outside of allowed interval (0, 0.5). Replacing with '
184 'default value.'
185 )
186 print(
187 'Warning: C1 outside of allowed interval. Replacing with '
188 'default value.'
189 )
190 c1 = 0.1
192 self.c1 = c1
194 # CO : added rigid_units and rotation_factors
196 def run(
197 self,
198 x_start,
199 dirn,
200 a_max=None,
201 a_min=None,
202 a1=None,
203 func_start=None,
204 func_old=None,
205 func_prime_start=None,
206 rigid_units=None,
207 rotation_factors=None,
208 maxstep=None,
209 ):
210 """Perform a backtracking / quadratic-interpolation linesearch
211 to find an appropriate step length with Armijo condition.
212 NOTE THIS LINESEARCH DOES NOT IMPOSE WOLFE CONDITIONS!
214 The idea is to do backtracking via quadratic interpolation, stabilised
215 by putting a lower bound on the decrease at each linesearch step.
216 To ensure BFGS-behaviour, whenever "reasonable" we take 1.0 as the
217 starting step.
219 Since Armijo does not guarantee convergence of BFGS, the outer
220 BFGS algorithm must restart when the current search direction
221 ceases to be a descent direction.
223 Args:
224 x_start: vector containing the position to begin the linesearch
225 from (ie the current location of the optimisation)
226 dirn: vector pointing in the direction to search in (pk in [NW]).
227 Note that this does not have to be a unit vector, but the
228 function will return a value scaled with respect to dirn.
229 a_max: an upper bound on the maximum step length allowed.
230 Default is 2.0.
231 a_min: a lower bound on the minimum step length allowed.
232 Default is 1e-10.
233 A RuntimeError is raised if this bound is violated
234 during the line search.
235 a1: the initial guess for an acceptable step length. If no value is
236 given, this will be set automatically, using quadratic
237 interpolation using func_old, or "rounded" to 1.0 if the
238 initial guess lies near 1.0. (specifically for LBFGS)
239 func_start: the value of func at the start of the linesearch, ie
240 phi(0). Passing this information avoids potentially expensive
241 re-calculations
242 func_prime_start: the value of func_prime at the start of the
243 linesearch (this will be dotted with dirn to find phi_prime(0))
244 func_old: the value of func_start at the previous step taken in
245 the optimisation (this will be used to calculate the initial
246 guess for the step length if it is not provided)
247 rigid_units, rotationfactors : see documentation of RumPath,if it
248 is unclear what these parameters are, then leave them at None
249 maxstep: maximum allowed displacement in Angstrom. Default is 0.2.
251 Returns:
252 A tuple: (step, func_val, no_update)
254 step: the final chosen step length, representing the number of
255 multiples of the direction vector to move
256 func_val: the value of func after taking this step, ie phi(step)
257 no_update: true if the linesearch has not performed any updates of
258 phi or alpha, due to errors or immediate convergence
260 Raises:
261 ValueError for problems with arguments
262 RuntimeError for problems encountered during iteration
263 """
265 a1 = self.handle_args(
266 x_start,
267 dirn,
268 a_max,
269 a_min,
270 a1,
271 func_start,
272 func_old,
273 func_prime_start,
274 maxstep,
275 )
277 # DEBUG
278 logger.debug('a1(auto) = %e', a1)
280 if abs(a1 - 1.0) <= 0.5:
281 a1 = 1.0
283 logger.debug('-----------NEW LINESEARCH STARTED---------')
285 a_final = None
286 phi_a_final = None
287 num_iter = 0
289 # create a search-path
290 if rigid_units is None:
291 # standard linear search-path
292 logger.debug('-----using LinearPath-----')
293 path = LinearPath(dirn)
294 else:
295 logger.debug('-----using RumPath------')
296 # if rigid_units != None, but rotation_factors == None, then
297 # raise an error.
298 if rotation_factors is None:
299 raise RuntimeError(
300 'RumPath cannot be created since rotation_factors == None'
301 )
302 path = RumPath(x_start, dirn, rigid_units, rotation_factors)
304 while True:
305 logger.debug('-----------NEW ITERATION OF LINESEARCH----------')
306 logger.debug('Number of linesearch iterations: %d', num_iter)
307 logger.debug('a1 = %e', a1)
309 # CO replaced: func_a1 = self.func(x_start + a1 * self.dirn)
310 func_a1 = self.func(x_start + path.step(a1))
311 phi_a1 = func_a1
312 # compute sufficient decrease (Armijo) condition
313 suff_dec = (
314 phi_a1 <= self.func_start + self.c1 * a1 * self.phi_prime_start
315 )
317 # DEBUG
318 # print("c1*a1*phi_prime_start = ", self.c1*a1*self.phi_prime_start,
319 # " | phi_a1 - phi_0 = ", phi_a1 - self.func_start)
320 logger.info('a1 = %.3f, suff_dec = %r', a1, suff_dec)
321 if a1 < self.a_min:
322 raise RuntimeError('a1 < a_min, giving up')
323 if self.phi_prime_start > 0.0:
324 raise RuntimeError('self.phi_prime_start > 0.0')
326 # check sufficient decrease (Armijo condition)
327 if suff_dec:
328 a_final = a1
329 phi_a_final = phi_a1
330 logger.debug(
331 'Linesearch returned a = %e, phi_a = %e',
332 a_final,
333 phi_a_final,
334 )
335 logger.debug('-----------LINESEARCH COMPLETE-----------')
336 return a_final, phi_a_final, num_iter == 0
338 # we don't have sufficient decrease, so we need to compute a
339 # new trial step-length
340 at = -(
341 (self.phi_prime_start * a1)
342 / (2 * ((phi_a1 - self.func_start) / a1 - self.phi_prime_start))
343 )
344 logger.debug('quadratic_min: initial at = %e', at)
346 # because a1 does not satisfy Armijo it follows that at must
347 # lie between 0 and a1. In fact, more strongly,
348 # at \leq (2 (1-c1))^{-1} a1, which is a back-tracking condition
349 # therefore, we should now only check that at has not become
350 # too small, in which case it is likely that nonlinearity has
351 # played a big role here, so we take an ultra-conservative
352 # backtracking step
353 a1 = max(at, a1 / 10.0)
354 if a1 > at:
355 logger.debug(
356 'at (%e) < a1/10: revert to backtracking a1/10', at
357 )
359 # (end of while(True) line-search loop)
361 # (end of run())
363 def handle_args(
364 self,
365 x_start,
366 dirn,
367 a_max,
368 a_min,
369 a1,
370 func_start,
371 func_old,
372 func_prime_start,
373 maxstep,
374 ):
375 """Verify passed parameters and set appropriate attributes accordingly.
377 A suitable value for the initial step-length guess will be either
378 verified or calculated, stored in the attribute self.a_start, and
379 returned.
381 Args:
382 The args should be identical to those of self.run().
384 Returns:
385 The suitable initial step-length guess a_start
387 Raises:
388 ValueError for problems with arguments
390 """
392 self.a_max = a_max
393 self.a_min = a_min
394 self.x_start = x_start
395 self.dirn = dirn
396 self.func_old = func_old
397 self.func_start = func_start
398 self.func_prime_start = func_prime_start
400 if a_max is None:
401 a_max = 2.0
403 if a_max < self.tol:
404 logger.warning(
405 'a_max too small relative to tol. Reverting to '
406 'default value a_max = 2.0 (twice the <ideal> step).'
407 )
408 a_max = 2.0 # THIS ASSUMES NEWTON/BFGS TYPE BEHAVIOUR!
410 if self.a_min is None:
411 self.a_min = 1e-10
413 if func_start is None:
414 logger.debug('Setting func_start')
415 self.func_start = self.func(x_start)
417 self.phi_prime_start = longsum(self.func_prime_start * self.dirn)
418 if self.phi_prime_start >= 0:
419 logger.error(
420 'Passed direction which is not downhill. Aborting...: %e',
421 self.phi_prime_start,
422 )
423 raise ValueError('Direction is not downhill.')
424 elif math.isinf(self.phi_prime_start):
425 logger.error(
426 'Passed func_prime_start and dirn which are too big. '
427 'Aborting...'
428 )
429 raise ValueError('func_prime_start and dirn are too big.')
431 if a1 is None:
432 if func_old is not None:
433 # Interpolating a quadratic to func and func_old - see NW
434 # equation 3.60
435 a1 = (
436 2 * (self.func_start - self.func_old) / self.phi_prime_start
437 )
438 logger.debug('Interpolated quadratic, obtained a1 = %e', a1)
439 if a1 is None or a1 > a_max:
440 logger.debug(
441 'a1 greater than a_max. Reverting to default value a1 = 1.0'
442 )
443 a1 = 1.0
444 if a1 is None or a1 < self.tol:
445 logger.debug(
446 'a1 is None or a1 < self.tol. Reverting to default value '
447 'a1 = 1.0'
448 )
449 a1 = 1.0
450 if a1 is None or a1 < self.a_min:
451 logger.debug(
452 'a1 is None or a1 < a_min. Reverting to default value a1 = 1.0'
453 )
454 a1 = 1.0
456 if maxstep is None:
457 maxstep = 0.2
458 logger.debug('maxstep = %e', maxstep)
460 r = np.reshape(dirn, (-1, 3))
461 steplengths = ((a1 * r) ** 2).sum(1) ** 0.5
462 maxsteplength = np.max(steplengths)
463 if maxsteplength >= maxstep:
464 a1 *= maxstep / maxsteplength
465 logger.debug('Rescaled a1 to fulfill maxstep criterion')
467 self.a_start = a1
469 logger.debug(
470 'phi_start = %e, phi_prime_start = %e',
471 self.func_start,
472 self.phi_prime_start,
473 )
474 logger.debug(
475 'func_start = %s, self.func_old = %s',
476 self.func_start,
477 self.func_old,
478 )
479 logger.debug('a1 = %e, a_max = %e, a_min = %e', a1, a_max, self.a_min)
481 return a1