Coverage for ase / optimize / test / test.py: 21.49%

121 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import argparse 

4import traceback 

5from math import pi 

6from time import time 

7 

8import numpy as np 

9 

10import ase.db 

11import ase.optimize 

12from ase import Atoms 

13from ase.calculators.emt import EMT 

14from ase.io import Trajectory 

15 

16all_optimizers = ase.optimize.__all__ + ['PreconLBFGS', 'PreconFIRE', 

17 'SciPyFminCG', 'SciPyFminBFGS'] 

18all_optimizers.remove('QuasiNewton') 

19all_optimizers.remove('RestartError') 

20 

21 

22def get_optimizer(name): 

23 # types: (str) -> ase.optimize.Optimizer 

24 if name.startswith('Precon'): 

25 import ase.optimize.precon as precon 

26 return getattr(precon, name) 

27 if name.startswith('SciPy'): 

28 import ase.optimize.sciopt as sciopt 

29 return getattr(sciopt, name) 

30 return getattr(ase.optimize, name) 

31 

32 

33class Wrapper: 

34 """Atoms-object wrapper that can count number of moves.""" 

35 

36 def __init__( 

37 self, 

38 atoms: Atoms, 

39 gridspacing: float = 0.2, 

40 eggbox: float = 0.0, 

41 ) -> None: 

42 self.t0 = time() 

43 self.texcl = 0.0 

44 self.nsteps = 0 

45 self.atoms = atoms 

46 self.ready = False 

47 self.pos: np.ndarray | None = None 

48 self.eggbox = eggbox 

49 

50 self.x = None 

51 if eggbox: 

52 # Find small unit cell for grid-points 

53 h = [] 

54 for axis in atoms.get_cell(complete=True): 

55 L = np.linalg.norm(axis) 

56 n = int(L / gridspacing) 

57 h.append(axis / n) 

58 self.x = np.linalg.inv(h) 

59 

60 def get_potential_energy(self, force_consistent=False): 

61 t1 = time() 

62 e = self.atoms.get_potential_energy(force_consistent) 

63 

64 if self.eggbox: 

65 # Add egg-box error: 

66 s = np.dot(self.atoms.positions, self.x) 

67 e += np.cos(2 * pi * s).sum() * self.eggbox / 6 

68 

69 t2 = time() 

70 self.texcl += t2 - t1 

71 if not self.ready: 

72 self.nsteps += 1 

73 self.ready = True 

74 return e 

75 

76 def get_forces(self): 

77 t1 = time() 

78 f = self.atoms.get_forces() 

79 

80 if self.eggbox: 

81 # Add egg-box error: 

82 s = np.dot(self.atoms.positions, self.x) 

83 f += np.dot(np.sin(2 * pi * s), 

84 self.x.T) * (2 * pi * self.eggbox / 6) 

85 

86 t2 = time() 

87 self.texcl += t2 - t1 

88 if not self.ready: 

89 self.nsteps += 1 

90 self.ready = True 

91 return f 

92 

93 def set_positions(self, pos): 

94 if self.pos is not None and abs(pos - self.pos).max() > 1e-15: 

95 self.ready = False 

96 if self.nsteps == 200: 

97 raise RuntimeError('Did not converge!') 

98 

99 self.pos = pos 

100 self.atoms.set_positions(pos) 

101 

102 def get_positions(self): 

103 return self.atoms.get_positions() 

104 

105 def get_calculator(self): 

106 return self.atoms.calc 

107 

108 def __len__(self): 

109 return len(self.atoms) 

110 

111 def __ase_optimizable__(self): 

112 from ase.optimize.optimize import OptimizableAtoms 

113 return OptimizableAtoms(self) 

114 

115 

116def run_test(atoms, optimizer, tag, fmax=0.02, eggbox=0.0): 

117 """Optimize atoms with optimizer.""" 

118 wrapper = Wrapper(atoms, eggbox=eggbox) 

119 relax = optimizer(wrapper, logfile=tag + '.log') 

120 relax.attach(Trajectory(tag + '.traj', 'w', atoms=atoms)) 

121 

122 tincl = -time() 

123 error = '' 

124 

125 try: 

126 relax.run(fmax=fmax, steps=10000000) 

127 except Exception as x: 

128 wrapper.nsteps = float('inf') 

129 error = f'{x.__class__.__name__}: {x}' 

130 tb = traceback.format_exc() 

131 

132 with open(tag + '.err', 'w') as fd: 

133 fd.write(f'{error}\n{tb}\n') 

134 

135 tincl += time() 

136 

137 return error, wrapper.nsteps, wrapper.texcl, tincl 

138 

139 

140def test_optimizer(systems, optimizer, calculator, prefix='', db=None, 

141 eggbox=0.0): 

142 """Test optimizer on systems.""" 

143 

144 for name, atoms in systems: 

145 if db is not None: 

146 optname = optimizer.__name__ 

147 id = db.reserve(optimizer=optname, name=name) 

148 if id is None: 

149 continue 

150 atoms = atoms.copy() 

151 tag = f'{prefix}{optname}-{name}' 

152 atoms.calc = calculator(txt=tag + '.txt') 

153 error, nsteps, texcl, tincl = run_test(atoms, optimizer, tag, 

154 eggbox=eggbox) 

155 

156 if db is not None: 

157 db.write(atoms, 

158 id=id, 

159 optimizer=optname, 

160 name=name, 

161 error=error, 

162 n=nsteps, 

163 t=texcl, 

164 T=tincl, 

165 eggbox=eggbox) 

166 

167 

168def main(): 

169 parser = argparse.ArgumentParser( 

170 description='Test ASE optimizers') 

171 

172 parser.add_argument('systems', help='File containing test systems.') 

173 parser.add_argument('optimizer', nargs='*', 

174 help='Optimizer name(s). Choose from: {}. ' 

175 .format(', '.join(all_optimizers)) + 

176 'Default is all optimizers.') 

177 parser.add_argument('-e', '--egg-box', type=float, default=0.0, 

178 help='Fake egg-box error in eV.') 

179 

180 args = parser.parse_args() 

181 

182 systems = [(row.name, row.toatoms()) 

183 for row in ase.db.connect(args.systems).select()] 

184 

185 db = ase.db.connect('results.db') 

186 

187 if not args.optimizer: 

188 args.optimizer = all_optimizers 

189 

190 for opt in args.optimizer: 

191 print(opt) 

192 optimizer = get_optimizer(opt) 

193 test_optimizer(systems, optimizer, EMT, db=db, eggbox=args.egg_box) 

194 

195 

196if __name__ == '__main__': 

197 main()