Coverage for ase / optimize / test / test.py: 22.13%
122 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
1# fmt: off
3import argparse
4import traceback
5from math import pi
6from time import time
7from typing import Union
9import numpy as np
11import ase.db
12import ase.optimize
13from ase import Atoms
14from ase.calculators.emt import EMT
15from ase.io import Trajectory
17all_optimizers = ase.optimize.__all__ + ['PreconLBFGS', 'PreconFIRE',
18 'SciPyFminCG', 'SciPyFminBFGS']
19all_optimizers.remove('QuasiNewton')
20all_optimizers.remove('RestartError')
23def get_optimizer(name):
24 # types: (str) -> ase.optimize.Optimizer
25 if name.startswith('Precon'):
26 import ase.optimize.precon as precon
27 return getattr(precon, name)
28 if name.startswith('SciPy'):
29 import ase.optimize.sciopt as sciopt
30 return getattr(sciopt, name)
31 return getattr(ase.optimize, name)
34class Wrapper:
35 """Atoms-object wrapper that can count number of moves."""
37 def __init__(
38 self,
39 atoms: Atoms,
40 gridspacing: float = 0.2,
41 eggbox: float = 0.0,
42 ) -> None:
43 self.t0 = time()
44 self.texcl = 0.0
45 self.nsteps = 0
46 self.atoms = atoms
47 self.ready = False
48 self.pos: Union[np.ndarray, None] = None
49 self.eggbox = eggbox
51 self.x = None
52 if eggbox:
53 # Find small unit cell for grid-points
54 h = []
55 for axis in atoms.get_cell(complete=True):
56 L = np.linalg.norm(axis)
57 n = int(L / gridspacing)
58 h.append(axis / n)
59 self.x = np.linalg.inv(h)
61 def get_potential_energy(self, force_consistent=False):
62 t1 = time()
63 e = self.atoms.get_potential_energy(force_consistent)
65 if self.eggbox:
66 # Add egg-box error:
67 s = np.dot(self.atoms.positions, self.x)
68 e += np.cos(2 * pi * s).sum() * self.eggbox / 6
70 t2 = time()
71 self.texcl += t2 - t1
72 if not self.ready:
73 self.nsteps += 1
74 self.ready = True
75 return e
77 def get_forces(self):
78 t1 = time()
79 f = self.atoms.get_forces()
81 if self.eggbox:
82 # Add egg-box error:
83 s = np.dot(self.atoms.positions, self.x)
84 f += np.dot(np.sin(2 * pi * s),
85 self.x.T) * (2 * pi * self.eggbox / 6)
87 t2 = time()
88 self.texcl += t2 - t1
89 if not self.ready:
90 self.nsteps += 1
91 self.ready = True
92 return f
94 def set_positions(self, pos):
95 if self.pos is not None and abs(pos - self.pos).max() > 1e-15:
96 self.ready = False
97 if self.nsteps == 200:
98 raise RuntimeError('Did not converge!')
100 self.pos = pos
101 self.atoms.set_positions(pos)
103 def get_positions(self):
104 return self.atoms.get_positions()
106 def get_calculator(self):
107 return self.atoms.calc
109 def __len__(self):
110 return len(self.atoms)
112 def __ase_optimizable__(self):
113 from ase.optimize.optimize import OptimizableAtoms
114 return OptimizableAtoms(self)
117def run_test(atoms, optimizer, tag, fmax=0.02, eggbox=0.0):
118 """Optimize atoms with optimizer."""
119 wrapper = Wrapper(atoms, eggbox=eggbox)
120 relax = optimizer(wrapper, logfile=tag + '.log')
121 relax.attach(Trajectory(tag + '.traj', 'w', atoms=atoms))
123 tincl = -time()
124 error = ''
126 try:
127 relax.run(fmax=fmax, steps=10000000)
128 except Exception as x:
129 wrapper.nsteps = float('inf')
130 error = f'{x.__class__.__name__}: {x}'
131 tb = traceback.format_exc()
133 with open(tag + '.err', 'w') as fd:
134 fd.write(f'{error}\n{tb}\n')
136 tincl += time()
138 return error, wrapper.nsteps, wrapper.texcl, tincl
141def test_optimizer(systems, optimizer, calculator, prefix='', db=None,
142 eggbox=0.0):
143 """Test optimizer on systems."""
145 for name, atoms in systems:
146 if db is not None:
147 optname = optimizer.__name__
148 id = db.reserve(optimizer=optname, name=name)
149 if id is None:
150 continue
151 atoms = atoms.copy()
152 tag = f'{prefix}{optname}-{name}'
153 atoms.calc = calculator(txt=tag + '.txt')
154 error, nsteps, texcl, tincl = run_test(atoms, optimizer, tag,
155 eggbox=eggbox)
157 if db is not None:
158 db.write(atoms,
159 id=id,
160 optimizer=optname,
161 name=name,
162 error=error,
163 n=nsteps,
164 t=texcl,
165 T=tincl,
166 eggbox=eggbox)
169def main():
170 parser = argparse.ArgumentParser(
171 description='Test ASE optimizers')
173 parser.add_argument('systems', help='File containing test systems.')
174 parser.add_argument('optimizer', nargs='*',
175 help='Optimizer name(s). Choose from: {}. '
176 .format(', '.join(all_optimizers)) +
177 'Default is all optimizers.')
178 parser.add_argument('-e', '--egg-box', type=float, default=0.0,
179 help='Fake egg-box error in eV.')
181 args = parser.parse_args()
183 systems = [(row.name, row.toatoms())
184 for row in ase.db.connect(args.systems).select()]
186 db = ase.db.connect('results.db')
188 if not args.optimizer:
189 args.optimizer = all_optimizers
191 for opt in args.optimizer:
192 print(opt)
193 optimizer = get_optimizer(opt)
194 test_optimizer(systems, optimizer, EMT, db=db, eggbox=args.egg_box)
197if __name__ == '__main__':
198 main()