Coverage for /builds/ase/ase/ase/calculators/loggingcalc.py: 74.05%
131 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""
4Provides LoggingCalculator class to wrap a Calculator and record
5number of enery and force calls
6"""
8import json
9import logging
10import time
11from typing import Any, Dict
13import numpy as np
15from ase.calculators.calculator import Calculator, all_properties
17logger = logging.getLogger(__name__)
20class LoggingCalculator(Calculator):
21 """Calculator wrapper to record and plot history of energy and function
22 evaluations
23 """
24 implemented_properties = all_properties
25 default_parameters: Dict[str, Any] = {}
26 name = 'LoggingCalculator'
28 property_to_method_name = {
29 'energy': 'get_potential_energy',
30 'energies': 'get_potential_energies',
31 'forces': 'get_forces',
32 'stress': 'get_stress',
33 'stresses': 'get_stresses'}
35 def __init__(self, calculator, jsonfile=None, dumpjson=False):
36 Calculator.__init__(self)
37 self.calculator = calculator
38 self.fmax = {}
39 self.walltime = {}
40 self.energy_evals = {}
41 self.energy_count = {}
42 self.set_label('(none)')
43 if jsonfile is not None:
44 self.read_json(jsonfile)
45 self.dumpjson = dumpjson
47 def calculate(self, atoms, properties, system_changes):
48 Calculator.calculate(self, atoms, properties, system_changes)
50 if isinstance(self.calculator, Calculator):
51 results = [self.calculator.get_property(prop, atoms)
52 for prop in properties]
53 else:
54 results = []
55 for prop in properties:
56 method_name = self.property_to_method_name[prop]
57 method = getattr(self.calculator, method_name)
58 results.append(method(atoms))
60 if 'energy' in properties or 'energies' in properties:
61 self.energy_evals.setdefault(self.label, 0)
62 self.energy_evals[self.label] += 1
63 try:
64 energy = results[properties.index('energy')]
65 except IndexError:
66 energy = sum(results[properties.index('energies')])
67 logger.info('energy call count=%d energy=%.3f',
68 self.energy_evals[self.label], energy)
69 self.results = dict(zip(properties, results))
71 if 'forces' in self.results:
72 fmax = self.fmax.setdefault(self.label, [])
73 walltime = self.walltime.setdefault(self.label, [])
74 forces = self.results['forces'].copy()
75 energy_count = self.energy_count.setdefault(self.label, [])
76 energy_evals = self.energy_evals.setdefault(self.label, 0)
77 energy_count.append(energy_evals)
78 for constraint in atoms.constraints:
79 constraint.adjust_forces(atoms, forces)
80 fmax.append(abs(forces).max())
81 walltime.append(time.time())
82 logger.info('force call fmax=%.3f', fmax[-1])
84 if self.dumpjson:
85 self.write_json('dump.json')
87 def write_json(self, filename):
88 with open(filename, 'w') as fd:
89 json.dump({'fmax': self.fmax,
90 'walltime': self.walltime,
91 'energy_evals': self.energy_evals,
92 'energy_count': self.energy_count}, fd)
94 def read_json(self, filename, append=False, label=None):
95 with open(filename) as fd:
96 dct = json.load(fd)
98 labels = dct['fmax'].keys()
99 if label is not None and len(labels) == 1:
100 for key in ('fmax', 'walltime', 'energy_evals', 'energy_count'):
101 dct[key][label] = dct[key][labels[0]]
102 del dct[key][labels[0]]
103 if not append:
104 self.fmax = {}
105 self.walltime = {}
106 self.energy_evals = {}
107 self.energy_count = {}
108 self.fmax.update(dct['fmax'])
109 self.walltime.update(dct['walltime'])
110 self.energy_evals.update(dct['energy_evals'])
111 self.energy_count.update(dct['energy_count'])
113 def tabulate(self):
114 fmt1 = '%-10s %10s %10s %8s'
115 title = fmt1 % ('Label', '# Force', '# Energy', 'Walltime/s')
116 print(title)
117 print('-' * len(title))
118 fmt2 = '%-10s %10d %10d %8.2f'
119 for label in sorted(self.fmax.keys()):
120 print(fmt2 % (label, len(self.fmax[label]),
121 len(self.energy_count[label]),
122 self.walltime[label][-1] - self.walltime[label][0]))
124 def plot(self, fmaxlim=(1e-2, 1e2), forces=True, energy=True,
125 walltime=True,
126 markers=None, labels=None, **kwargs):
127 import matplotlib.pyplot as plt
129 if markers is None:
130 markers = [c + s for c in ['r', 'g', 'b', 'c', 'm', 'y', 'k']
131 for s in ['.-', '.--']]
132 nsub = sum([forces, energy, walltime])
133 nplot = 0
135 if labels is not None:
136 fmax_values = [v for (k, v) in sorted(zip(self.fmax.keys(),
137 self.fmax.values()))]
138 self.fmax = dict(zip(labels, fmax_values))
140 energy_count_values = [v for (k, v) in
141 sorted(zip(self.energy_count.keys(),
142 self.energy_count.values()))]
143 self.energy_count = dict(zip(labels, energy_count_values))
145 walltime_values = [v for (k, v) in
146 sorted(zip(self.walltime.keys(),
147 self.walltime.values()))]
148 self.walltime = dict(zip(labels, walltime_values))
150 if forces:
151 nplot += 1
152 plt.subplot(nsub, 1, nplot)
153 for label, color in zip(sorted(self.fmax.keys()), markers):
154 fmax = np.array(self.fmax[label])
155 idx = np.arange(len(fmax))
156 plt.semilogy(idx, fmax, color, label=label, **kwargs)
158 plt.xlabel('Number of force evaluations')
159 plt.ylabel('Maximum force / eV/A')
160 plt.ylim(*fmaxlim)
161 plt.legend()
163 if energy:
164 nplot += 1
165 plt.subplot(nsub, 1, nplot)
166 for label, color in zip(sorted(self.energy_count.keys()), markers):
167 energy_count = np.array(self.energy_count[label])
168 fmax = np.array(self.fmax[label])
169 plt.semilogy(energy_count, fmax, color, label=label, **kwargs)
171 plt.xlabel('Number of energy evaluations')
172 plt.ylabel('Maximum force / eV/A')
173 plt.ylim(*fmaxlim)
174 plt.legend()
176 if walltime:
177 nplot += 1
178 plt.subplot(nsub, 1, nplot)
179 for label, color in zip(sorted(self.walltime.keys()), markers):
180 walltime = np.array(self.walltime[label])
181 fmax = np.array(self.fmax[label])
182 walltime -= walltime[0]
183 plt.semilogy(walltime, fmax, color, label=label, **kwargs)
185 plt.xlabel('Walltime / s')
186 plt.ylabel('Maximum force / eV/A')
187 plt.ylim(*fmaxlim)
188 plt.legend()
190 plt.subplots_adjust(hspace=0.33)