Coverage for /builds/ase/ase/ase/cli/template.py: 94.00%
200 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import string
5import numpy as np
7from ase.data import chemical_symbols
8from ase.io import string2index
9from ase.io.formats import parse_filename
11# default fields
14def field_specs_on_conditions(calculator_outputs, rank_order):
15 if calculator_outputs:
16 field_specs = ['i:0', 'el', 'd', 'rd', 'df', 'rdf']
17 else:
18 field_specs = ['i:0', 'el', 'dx', 'dy', 'dz', 'd', 'rd']
19 if rank_order is not None:
20 field_specs[0] = 'i:1'
21 if rank_order in field_specs:
22 for c, i in enumerate(field_specs):
23 if i == rank_order:
24 field_specs[c] = i + ':0:1'
25 else:
26 field_specs.append(rank_order + ':0:1')
27 else:
28 field_specs[0] = field_specs[0] + ':1'
29 return field_specs
32def summary_functions_on_conditions(has_calc):
33 if has_calc:
34 return [rmsd, energy_delta]
35 return [rmsd]
38def header_alias(h):
39 """Replace keyboard characters with Unicode symbols
40 for pretty printing"""
41 if h == 'i':
42 h = 'index'
43 elif h == 'an':
44 h = 'atomic #'
45 elif h == 't':
46 h = 'tag'
47 elif h == 'el':
48 h = 'element'
49 elif h[0] == 'd':
50 h = h.replace('d', 'Δ')
51 elif h[0] == 'r':
52 h = 'rank ' + header_alias(h[1:])
53 elif h[0] == 'a':
54 h = h.replace('a', '<')
55 h += '>'
56 return h
59def prec_round(a, prec=2):
60 """
61 To make hierarchical sorting different from non-hierarchical sorting
62 with floats.
63 """
64 if a == 0:
65 return a
66 else:
67 s = 1 if a > 0 else -1
68 m = np.log10(s * a) // 1
69 c = np.log10(s * a) % 1
70 return s * np.round(10**c, prec) * 10**m
73prec_round = np.vectorize(prec_round)
75# end most settings
77# this will sort alphabetically by chemical symbol
78num2sym = dict(zip(np.argsort(chemical_symbols), chemical_symbols))
79# to sort by atomic number, uncomment below
80# num2sym = dict(zip(range(len(chemical_symbols)), chemical_symbols))
81sym2num = {v: k for k, v in num2sym.items()}
83atoms_props = [
84 'dx',
85 'dy',
86 'dz',
87 'd',
88 't',
89 'an',
90 'i',
91 'el',
92 'p1',
93 'p2',
94 'p1x',
95 'p1y',
96 'p1z',
97 'p2x',
98 'p2y',
99 'p2z']
102def get_field_data(atoms1, atoms2, field):
103 if field[0] == 'r':
104 field = field[1:]
105 rank_order = True
106 else:
107 rank_order = False
109 if field in atoms_props:
110 if field == 't':
111 data = atoms1.get_tags()
112 elif field == 'an':
113 data = atoms1.numbers
114 elif field == 'el':
115 data = np.array([sym2num[sym] for sym in atoms1.symbols])
116 elif field == 'i':
117 data = np.arange(len(atoms1))
118 else:
119 if field.startswith('d'):
120 y = atoms2.positions - atoms1.positions
121 elif field.startswith('p'):
122 if field[1] == '1':
123 y = atoms1.positions
124 else:
125 y = atoms2.positions
127 if field.endswith('x'):
128 data = y[:, 0]
129 elif field.endswith('y'):
130 data = y[:, 1]
131 elif field.endswith('z'):
132 data = y[:, 2]
133 else:
134 data = np.linalg.norm(y, axis=1)
135 else:
136 if field[0] == 'd':
137 y = atoms2.get_forces() - atoms1.get_forces()
138 elif field[0] == 'a':
139 y = (atoms2.get_forces() + atoms1.get_forces()) / 2
140 else:
141 if field[1] == '1':
142 y = atoms1.get_forces()
143 else:
144 y = atoms2.get_forces()
146 if field.endswith('x'):
147 data = y[:, 0]
148 elif field.endswith('y'):
149 data = y[:, 1]
150 elif field.endswith('z'):
151 data = y[:, 2]
152 else:
153 data = np.linalg.norm(y, axis=1)
155 if rank_order:
156 return np.argsort(np.argsort(-data))
158 return data
161# Summary Functions
163def rmsd(atoms1, atoms2):
164 dpositions = atoms2.positions - atoms1.positions
165 return 'RMSD={:+.1E}'.format(
166 np.sqrt((np.linalg.norm(dpositions, axis=1)**2).mean()))
169def energy_delta(atoms1, atoms2):
170 E1 = atoms1.get_potential_energy()
171 E2 = atoms2.get_potential_energy()
172 return f'E1 = {E1:+.1E}, E2 = {E2:+.1E}, dE = {E2 - E1:+1.1E}'
175def parse_field_specs(field_specs):
176 fields = []
177 hier = []
178 scent = []
179 for fs in field_specs:
180 fhs = fs.split(':')
181 if len(fhs) == 3:
182 scent.append(int(fhs[2]))
183 hier.append(int(fhs[1]))
184 fields.append(fhs[0])
185 elif len(fhs) == 2:
186 scent.append(-1)
187 hier.append(int(fhs[1]))
188 fields.append(fhs[0])
189 elif len(fhs) == 1:
190 scent.append(-1)
191 hier.append(-1)
192 fields.append(fhs[0])
193 mxm = max(hier)
194 for c in range(len(hier)):
195 if hier[c] < 0:
196 mxm += 1
197 hier[c] = mxm
198 # reversed by convention of numpy lexsort
199 hier = np.argsort(hier)[::-1]
200 return fields, hier, np.array(scent)
202# Class definitions
205class MapFormatter(string.Formatter):
206 """String formatting method to map string
207 mapped to float data field
208 used for sorting back to string."""
210 def format_field(self, value, spec):
211 if spec.endswith('h'):
212 value = num2sym[int(value)]
213 spec = spec[:-1] + 's'
214 return super().format_field(value, spec)
217class TableFormat:
218 def __init__(self,
219 columnwidth=9,
220 precision=2,
221 representation='E',
222 toprule='=',
223 midrule='-',
224 bottomrule='='):
226 self.precision = precision
227 self.representation = representation
228 self.columnwidth = columnwidth
229 self.formatter = MapFormatter().format
230 self.toprule = toprule
231 self.midrule = midrule
232 self.bottomrule = bottomrule
234 self.fmt_class = {
235 'signed float': "{{: ^{}.{}{}}}".format(
236 self.columnwidth,
237 self.precision - 1,
238 self.representation),
239 'unsigned float': "{{:^{}.{}{}}}".format(
240 self.columnwidth,
241 self.precision - 1,
242 self.representation),
243 'int': "{{:^{}n}}".format(
244 self.columnwidth),
245 'str': "{{:^{}s}}".format(
246 self.columnwidth),
247 'conv': "{{:^{}h}}".format(
248 self.columnwidth)}
249 fmt = {}
250 signed_floats = [
251 'dx',
252 'dy',
253 'dz',
254 'dfx',
255 'dfy',
256 'dfz',
257 'afx',
258 'afy',
259 'afz',
260 'p1x',
261 'p2x',
262 'p1y',
263 'p2y',
264 'p1z',
265 'p2z',
266 'f1x',
267 'f2x',
268 'f1y',
269 'f2y',
270 'f1z',
271 'f2z']
272 for sf in signed_floats:
273 fmt[sf] = self.fmt_class['signed float']
274 unsigned_floats = ['d', 'df', 'af', 'p1', 'p2', 'f1', 'f2']
275 for usf in unsigned_floats:
276 fmt[usf] = self.fmt_class['unsigned float']
277 integers = ['i', 'an', 't'] + ['r' + sf for sf in signed_floats] + \
278 ['r' + usf for usf in unsigned_floats]
279 for i in integers:
280 fmt[i] = self.fmt_class['int']
281 fmt['el'] = self.fmt_class['conv']
283 self.fmt = fmt
286class Table:
287 def __init__(self,
288 field_specs,
289 summary_functions=[],
290 tableformat=None,
291 max_lines=None,
292 title='',
293 tablewidth=None):
295 self.max_lines = max_lines
296 self.summary_functions = summary_functions
297 self.field_specs = field_specs
299 self.fields, self.hier, self.scent = parse_field_specs(self.field_specs)
300 self.nfields = len(self.fields)
302 # formatting
303 if tableformat is None:
304 self.tableformat = TableFormat()
305 else:
306 self.tableformat = tableformat
308 if tablewidth is None:
309 self.tablewidth = self.tableformat.columnwidth * self.nfields
310 else:
311 self.tablewidth = tablewidth
313 self.title = title
315 def make(self, atoms1, atoms2, csv=False):
316 header = self.make_header(csv=csv)
317 body = self.make_body(atoms1, atoms2, csv=csv)
318 if self.max_lines is not None:
319 body = body[:self.max_lines]
320 summary = self.make_summary(atoms1, atoms2)
322 return '\n'.join([self.title,
323 self.tableformat.toprule * self.tablewidth,
324 header,
325 self.tableformat.midrule * self.tablewidth,
326 body,
327 self.tableformat.bottomrule * self.tablewidth,
328 summary])
330 def make_header(self, csv=False):
331 if csv:
332 return ','.join([header_alias(field) for field in self.fields])
334 fields = self.tableformat.fmt_class['str'] * self.nfields
335 headers = [header_alias(field) for field in self.fields]
337 return self.tableformat.formatter(fields, *headers)
339 def make_summary(self, atoms1, atoms2):
340 return '\n'.join([summary_function(atoms1, atoms2)
341 for summary_function in self.summary_functions])
343 def make_body(self, atoms1, atoms2, csv=False):
344 field_data = np.array([get_field_data(atoms1, atoms2, field)
345 for field in self.fields])
347 sorting_array = field_data * self.scent[:, np.newaxis]
348 sorting_array = sorting_array[self.hier]
349 sorting_array = prec_round(sorting_array, self.tableformat.precision)
351 field_data = field_data[:, np.lexsort(sorting_array)].transpose()
353 if csv:
354 rowformat = ','.join(
355 ['{:h}' if field == 'el' else '{{:.{}E}}'.format(
356 self.tableformat.precision) for field in self.fields])
357 else:
358 rowformat = ''.join([self.tableformat.fmt[field]
359 for field in self.fields])
360 body = [
361 self.tableformat.formatter(
362 rowformat,
363 *row) for row in field_data]
364 return '\n'.join(body)
367default_index = string2index(':')
370def slice_split(filename):
371 if '@' in filename:
372 filename, index = parse_filename(filename, None)
373 else:
374 filename, index = parse_filename(filename, default_index)
375 return filename, index