Coverage for /builds/ase/ase/ase/calculators/loggingcalc.py: 74.05%

131 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3""" 

4Provides LoggingCalculator class to wrap a Calculator and record 

5number of enery and force calls 

6""" 

7 

8import json 

9import logging 

10import time 

11from typing import Any, Dict 

12 

13import numpy as np 

14 

15from ase.calculators.calculator import Calculator, all_properties 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class LoggingCalculator(Calculator): 

21 """Calculator wrapper to record and plot history of energy and function 

22 evaluations 

23 """ 

24 implemented_properties = all_properties 

25 default_parameters: Dict[str, Any] = {} 

26 name = 'LoggingCalculator' 

27 

28 property_to_method_name = { 

29 'energy': 'get_potential_energy', 

30 'energies': 'get_potential_energies', 

31 'forces': 'get_forces', 

32 'stress': 'get_stress', 

33 'stresses': 'get_stresses'} 

34 

35 def __init__(self, calculator, jsonfile=None, dumpjson=False): 

36 Calculator.__init__(self) 

37 self.calculator = calculator 

38 self.fmax = {} 

39 self.walltime = {} 

40 self.energy_evals = {} 

41 self.energy_count = {} 

42 self.set_label('(none)') 

43 if jsonfile is not None: 

44 self.read_json(jsonfile) 

45 self.dumpjson = dumpjson 

46 

47 def calculate(self, atoms, properties, system_changes): 

48 Calculator.calculate(self, atoms, properties, system_changes) 

49 

50 if isinstance(self.calculator, Calculator): 

51 results = [self.calculator.get_property(prop, atoms) 

52 for prop in properties] 

53 else: 

54 results = [] 

55 for prop in properties: 

56 method_name = self.property_to_method_name[prop] 

57 method = getattr(self.calculator, method_name) 

58 results.append(method(atoms)) 

59 

60 if 'energy' in properties or 'energies' in properties: 

61 self.energy_evals.setdefault(self.label, 0) 

62 self.energy_evals[self.label] += 1 

63 try: 

64 energy = results[properties.index('energy')] 

65 except IndexError: 

66 energy = sum(results[properties.index('energies')]) 

67 logger.info('energy call count=%d energy=%.3f', 

68 self.energy_evals[self.label], energy) 

69 self.results = dict(zip(properties, results)) 

70 

71 if 'forces' in self.results: 

72 fmax = self.fmax.setdefault(self.label, []) 

73 walltime = self.walltime.setdefault(self.label, []) 

74 forces = self.results['forces'].copy() 

75 energy_count = self.energy_count.setdefault(self.label, []) 

76 energy_evals = self.energy_evals.setdefault(self.label, 0) 

77 energy_count.append(energy_evals) 

78 for constraint in atoms.constraints: 

79 constraint.adjust_forces(atoms, forces) 

80 fmax.append(abs(forces).max()) 

81 walltime.append(time.time()) 

82 logger.info('force call fmax=%.3f', fmax[-1]) 

83 

84 if self.dumpjson: 

85 self.write_json('dump.json') 

86 

87 def write_json(self, filename): 

88 with open(filename, 'w') as fd: 

89 json.dump({'fmax': self.fmax, 

90 'walltime': self.walltime, 

91 'energy_evals': self.energy_evals, 

92 'energy_count': self.energy_count}, fd) 

93 

94 def read_json(self, filename, append=False, label=None): 

95 with open(filename) as fd: 

96 dct = json.load(fd) 

97 

98 labels = dct['fmax'].keys() 

99 if label is not None and len(labels) == 1: 

100 for key in ('fmax', 'walltime', 'energy_evals', 'energy_count'): 

101 dct[key][label] = dct[key][labels[0]] 

102 del dct[key][labels[0]] 

103 if not append: 

104 self.fmax = {} 

105 self.walltime = {} 

106 self.energy_evals = {} 

107 self.energy_count = {} 

108 self.fmax.update(dct['fmax']) 

109 self.walltime.update(dct['walltime']) 

110 self.energy_evals.update(dct['energy_evals']) 

111 self.energy_count.update(dct['energy_count']) 

112 

113 def tabulate(self): 

114 fmt1 = '%-10s %10s %10s %8s' 

115 title = fmt1 % ('Label', '# Force', '# Energy', 'Walltime/s') 

116 print(title) 

117 print('-' * len(title)) 

118 fmt2 = '%-10s %10d %10d %8.2f' 

119 for label in sorted(self.fmax.keys()): 

120 print(fmt2 % (label, len(self.fmax[label]), 

121 len(self.energy_count[label]), 

122 self.walltime[label][-1] - self.walltime[label][0])) 

123 

124 def plot(self, fmaxlim=(1e-2, 1e2), forces=True, energy=True, 

125 walltime=True, 

126 markers=None, labels=None, **kwargs): 

127 import matplotlib.pyplot as plt 

128 

129 if markers is None: 

130 markers = [c + s for c in ['r', 'g', 'b', 'c', 'm', 'y', 'k'] 

131 for s in ['.-', '.--']] 

132 nsub = sum([forces, energy, walltime]) 

133 nplot = 0 

134 

135 if labels is not None: 

136 fmax_values = [v for (k, v) in sorted(zip(self.fmax.keys(), 

137 self.fmax.values()))] 

138 self.fmax = dict(zip(labels, fmax_values)) 

139 

140 energy_count_values = [v for (k, v) in 

141 sorted(zip(self.energy_count.keys(), 

142 self.energy_count.values()))] 

143 self.energy_count = dict(zip(labels, energy_count_values)) 

144 

145 walltime_values = [v for (k, v) in 

146 sorted(zip(self.walltime.keys(), 

147 self.walltime.values()))] 

148 self.walltime = dict(zip(labels, walltime_values)) 

149 

150 if forces: 

151 nplot += 1 

152 plt.subplot(nsub, 1, nplot) 

153 for label, color in zip(sorted(self.fmax.keys()), markers): 

154 fmax = np.array(self.fmax[label]) 

155 idx = np.arange(len(fmax)) 

156 plt.semilogy(idx, fmax, color, label=label, **kwargs) 

157 

158 plt.xlabel('Number of force evaluations') 

159 plt.ylabel('Maximum force / eV/A') 

160 plt.ylim(*fmaxlim) 

161 plt.legend() 

162 

163 if energy: 

164 nplot += 1 

165 plt.subplot(nsub, 1, nplot) 

166 for label, color in zip(sorted(self.energy_count.keys()), markers): 

167 energy_count = np.array(self.energy_count[label]) 

168 fmax = np.array(self.fmax[label]) 

169 plt.semilogy(energy_count, fmax, color, label=label, **kwargs) 

170 

171 plt.xlabel('Number of energy evaluations') 

172 plt.ylabel('Maximum force / eV/A') 

173 plt.ylim(*fmaxlim) 

174 plt.legend() 

175 

176 if walltime: 

177 nplot += 1 

178 plt.subplot(nsub, 1, nplot) 

179 for label, color in zip(sorted(self.walltime.keys()), markers): 

180 walltime = np.array(self.walltime[label]) 

181 fmax = np.array(self.fmax[label]) 

182 walltime -= walltime[0] 

183 plt.semilogy(walltime, fmax, color, label=label, **kwargs) 

184 

185 plt.xlabel('Walltime / s') 

186 plt.ylabel('Maximum force / eV/A') 

187 plt.ylim(*fmaxlim) 

188 plt.legend() 

189 

190 plt.subplots_adjust(hspace=0.33)