Coverage for ase / optimize / test / test.py: 21.49%
121 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3import argparse
4import traceback
5from math import pi
6from time import time
8import numpy as np
10import ase.db
11import ase.optimize
12from ase import Atoms
13from ase.calculators.emt import EMT
14from ase.io import Trajectory
16all_optimizers = ase.optimize.__all__ + ['PreconLBFGS', 'PreconFIRE',
17 'SciPyFminCG', 'SciPyFminBFGS']
18all_optimizers.remove('QuasiNewton')
19all_optimizers.remove('RestartError')
22def get_optimizer(name):
23 # types: (str) -> ase.optimize.Optimizer
24 if name.startswith('Precon'):
25 import ase.optimize.precon as precon
26 return getattr(precon, name)
27 if name.startswith('SciPy'):
28 import ase.optimize.sciopt as sciopt
29 return getattr(sciopt, name)
30 return getattr(ase.optimize, name)
33class Wrapper:
34 """Atoms-object wrapper that can count number of moves."""
36 def __init__(
37 self,
38 atoms: Atoms,
39 gridspacing: float = 0.2,
40 eggbox: float = 0.0,
41 ) -> None:
42 self.t0 = time()
43 self.texcl = 0.0
44 self.nsteps = 0
45 self.atoms = atoms
46 self.ready = False
47 self.pos: np.ndarray | None = None
48 self.eggbox = eggbox
50 self.x = None
51 if eggbox:
52 # Find small unit cell for grid-points
53 h = []
54 for axis in atoms.get_cell(complete=True):
55 L = np.linalg.norm(axis)
56 n = int(L / gridspacing)
57 h.append(axis / n)
58 self.x = np.linalg.inv(h)
60 def get_potential_energy(self, force_consistent=False):
61 t1 = time()
62 e = self.atoms.get_potential_energy(force_consistent)
64 if self.eggbox:
65 # Add egg-box error:
66 s = np.dot(self.atoms.positions, self.x)
67 e += np.cos(2 * pi * s).sum() * self.eggbox / 6
69 t2 = time()
70 self.texcl += t2 - t1
71 if not self.ready:
72 self.nsteps += 1
73 self.ready = True
74 return e
76 def get_forces(self):
77 t1 = time()
78 f = self.atoms.get_forces()
80 if self.eggbox:
81 # Add egg-box error:
82 s = np.dot(self.atoms.positions, self.x)
83 f += np.dot(np.sin(2 * pi * s),
84 self.x.T) * (2 * pi * self.eggbox / 6)
86 t2 = time()
87 self.texcl += t2 - t1
88 if not self.ready:
89 self.nsteps += 1
90 self.ready = True
91 return f
93 def set_positions(self, pos):
94 if self.pos is not None and abs(pos - self.pos).max() > 1e-15:
95 self.ready = False
96 if self.nsteps == 200:
97 raise RuntimeError('Did not converge!')
99 self.pos = pos
100 self.atoms.set_positions(pos)
102 def get_positions(self):
103 return self.atoms.get_positions()
105 def get_calculator(self):
106 return self.atoms.calc
108 def __len__(self):
109 return len(self.atoms)
111 def __ase_optimizable__(self):
112 from ase.optimize.optimize import OptimizableAtoms
113 return OptimizableAtoms(self)
116def run_test(atoms, optimizer, tag, fmax=0.02, eggbox=0.0):
117 """Optimize atoms with optimizer."""
118 wrapper = Wrapper(atoms, eggbox=eggbox)
119 relax = optimizer(wrapper, logfile=tag + '.log')
120 relax.attach(Trajectory(tag + '.traj', 'w', atoms=atoms))
122 tincl = -time()
123 error = ''
125 try:
126 relax.run(fmax=fmax, steps=10000000)
127 except Exception as x:
128 wrapper.nsteps = float('inf')
129 error = f'{x.__class__.__name__}: {x}'
130 tb = traceback.format_exc()
132 with open(tag + '.err', 'w') as fd:
133 fd.write(f'{error}\n{tb}\n')
135 tincl += time()
137 return error, wrapper.nsteps, wrapper.texcl, tincl
140def test_optimizer(systems, optimizer, calculator, prefix='', db=None,
141 eggbox=0.0):
142 """Test optimizer on systems."""
144 for name, atoms in systems:
145 if db is not None:
146 optname = optimizer.__name__
147 id = db.reserve(optimizer=optname, name=name)
148 if id is None:
149 continue
150 atoms = atoms.copy()
151 tag = f'{prefix}{optname}-{name}'
152 atoms.calc = calculator(txt=tag + '.txt')
153 error, nsteps, texcl, tincl = run_test(atoms, optimizer, tag,
154 eggbox=eggbox)
156 if db is not None:
157 db.write(atoms,
158 id=id,
159 optimizer=optname,
160 name=name,
161 error=error,
162 n=nsteps,
163 t=texcl,
164 T=tincl,
165 eggbox=eggbox)
168def main():
169 parser = argparse.ArgumentParser(
170 description='Test ASE optimizers')
172 parser.add_argument('systems', help='File containing test systems.')
173 parser.add_argument('optimizer', nargs='*',
174 help='Optimizer name(s). Choose from: {}. '
175 .format(', '.join(all_optimizers)) +
176 'Default is all optimizers.')
177 parser.add_argument('-e', '--egg-box', type=float, default=0.0,
178 help='Fake egg-box error in eV.')
180 args = parser.parse_args()
182 systems = [(row.name, row.toatoms())
183 for row in ase.db.connect(args.systems).select()]
185 db = ase.db.connect('results.db')
187 if not args.optimizer:
188 args.optimizer = all_optimizers
190 for opt in args.optimizer:
191 print(opt)
192 optimizer = get_optimizer(opt)
193 test_optimizer(systems, optimizer, EMT, db=db, eggbox=args.egg_box)
196if __name__ == '__main__':
197 main()