Coverage for ase / optimize / test / test.py: 22.13%

122 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1# fmt: off 

2 

3import argparse 

4import traceback 

5from math import pi 

6from time import time 

7from typing import Union 

8 

9import numpy as np 

10 

11import ase.db 

12import ase.optimize 

13from ase import Atoms 

14from ase.calculators.emt import EMT 

15from ase.io import Trajectory 

16 

17all_optimizers = ase.optimize.__all__ + ['PreconLBFGS', 'PreconFIRE', 

18 'SciPyFminCG', 'SciPyFminBFGS'] 

19all_optimizers.remove('QuasiNewton') 

20all_optimizers.remove('RestartError') 

21 

22 

23def get_optimizer(name): 

24 # types: (str) -> ase.optimize.Optimizer 

25 if name.startswith('Precon'): 

26 import ase.optimize.precon as precon 

27 return getattr(precon, name) 

28 if name.startswith('SciPy'): 

29 import ase.optimize.sciopt as sciopt 

30 return getattr(sciopt, name) 

31 return getattr(ase.optimize, name) 

32 

33 

34class Wrapper: 

35 """Atoms-object wrapper that can count number of moves.""" 

36 

37 def __init__( 

38 self, 

39 atoms: Atoms, 

40 gridspacing: float = 0.2, 

41 eggbox: float = 0.0, 

42 ) -> None: 

43 self.t0 = time() 

44 self.texcl = 0.0 

45 self.nsteps = 0 

46 self.atoms = atoms 

47 self.ready = False 

48 self.pos: Union[np.ndarray, None] = None 

49 self.eggbox = eggbox 

50 

51 self.x = None 

52 if eggbox: 

53 # Find small unit cell for grid-points 

54 h = [] 

55 for axis in atoms.get_cell(complete=True): 

56 L = np.linalg.norm(axis) 

57 n = int(L / gridspacing) 

58 h.append(axis / n) 

59 self.x = np.linalg.inv(h) 

60 

61 def get_potential_energy(self, force_consistent=False): 

62 t1 = time() 

63 e = self.atoms.get_potential_energy(force_consistent) 

64 

65 if self.eggbox: 

66 # Add egg-box error: 

67 s = np.dot(self.atoms.positions, self.x) 

68 e += np.cos(2 * pi * s).sum() * self.eggbox / 6 

69 

70 t2 = time() 

71 self.texcl += t2 - t1 

72 if not self.ready: 

73 self.nsteps += 1 

74 self.ready = True 

75 return e 

76 

77 def get_forces(self): 

78 t1 = time() 

79 f = self.atoms.get_forces() 

80 

81 if self.eggbox: 

82 # Add egg-box error: 

83 s = np.dot(self.atoms.positions, self.x) 

84 f += np.dot(np.sin(2 * pi * s), 

85 self.x.T) * (2 * pi * self.eggbox / 6) 

86 

87 t2 = time() 

88 self.texcl += t2 - t1 

89 if not self.ready: 

90 self.nsteps += 1 

91 self.ready = True 

92 return f 

93 

94 def set_positions(self, pos): 

95 if self.pos is not None and abs(pos - self.pos).max() > 1e-15: 

96 self.ready = False 

97 if self.nsteps == 200: 

98 raise RuntimeError('Did not converge!') 

99 

100 self.pos = pos 

101 self.atoms.set_positions(pos) 

102 

103 def get_positions(self): 

104 return self.atoms.get_positions() 

105 

106 def get_calculator(self): 

107 return self.atoms.calc 

108 

109 def __len__(self): 

110 return len(self.atoms) 

111 

112 def __ase_optimizable__(self): 

113 from ase.optimize.optimize import OptimizableAtoms 

114 return OptimizableAtoms(self) 

115 

116 

117def run_test(atoms, optimizer, tag, fmax=0.02, eggbox=0.0): 

118 """Optimize atoms with optimizer.""" 

119 wrapper = Wrapper(atoms, eggbox=eggbox) 

120 relax = optimizer(wrapper, logfile=tag + '.log') 

121 relax.attach(Trajectory(tag + '.traj', 'w', atoms=atoms)) 

122 

123 tincl = -time() 

124 error = '' 

125 

126 try: 

127 relax.run(fmax=fmax, steps=10000000) 

128 except Exception as x: 

129 wrapper.nsteps = float('inf') 

130 error = f'{x.__class__.__name__}: {x}' 

131 tb = traceback.format_exc() 

132 

133 with open(tag + '.err', 'w') as fd: 

134 fd.write(f'{error}\n{tb}\n') 

135 

136 tincl += time() 

137 

138 return error, wrapper.nsteps, wrapper.texcl, tincl 

139 

140 

141def test_optimizer(systems, optimizer, calculator, prefix='', db=None, 

142 eggbox=0.0): 

143 """Test optimizer on systems.""" 

144 

145 for name, atoms in systems: 

146 if db is not None: 

147 optname = optimizer.__name__ 

148 id = db.reserve(optimizer=optname, name=name) 

149 if id is None: 

150 continue 

151 atoms = atoms.copy() 

152 tag = f'{prefix}{optname}-{name}' 

153 atoms.calc = calculator(txt=tag + '.txt') 

154 error, nsteps, texcl, tincl = run_test(atoms, optimizer, tag, 

155 eggbox=eggbox) 

156 

157 if db is not None: 

158 db.write(atoms, 

159 id=id, 

160 optimizer=optname, 

161 name=name, 

162 error=error, 

163 n=nsteps, 

164 t=texcl, 

165 T=tincl, 

166 eggbox=eggbox) 

167 

168 

169def main(): 

170 parser = argparse.ArgumentParser( 

171 description='Test ASE optimizers') 

172 

173 parser.add_argument('systems', help='File containing test systems.') 

174 parser.add_argument('optimizer', nargs='*', 

175 help='Optimizer name(s). Choose from: {}. ' 

176 .format(', '.join(all_optimizers)) + 

177 'Default is all optimizers.') 

178 parser.add_argument('-e', '--egg-box', type=float, default=0.0, 

179 help='Fake egg-box error in eV.') 

180 

181 args = parser.parse_args() 

182 

183 systems = [(row.name, row.toatoms()) 

184 for row in ase.db.connect(args.systems).select()] 

185 

186 db = ase.db.connect('results.db') 

187 

188 if not args.optimizer: 

189 args.optimizer = all_optimizers 

190 

191 for opt in args.optimizer: 

192 print(opt) 

193 optimizer = get_optimizer(opt) 

194 test_optimizer(systems, optimizer, EMT, db=db, eggbox=args.egg_box) 

195 

196 

197if __name__ == '__main__': 

198 main()