Coverage for /builds/ase/ase/ase/cli/template.py: 94.00%

200 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import string 

4 

5import numpy as np 

6 

7from ase.data import chemical_symbols 

8from ase.io import string2index 

9from ase.io.formats import parse_filename 

10 

11# default fields 

12 

13 

14def field_specs_on_conditions(calculator_outputs, rank_order): 

15 if calculator_outputs: 

16 field_specs = ['i:0', 'el', 'd', 'rd', 'df', 'rdf'] 

17 else: 

18 field_specs = ['i:0', 'el', 'dx', 'dy', 'dz', 'd', 'rd'] 

19 if rank_order is not None: 

20 field_specs[0] = 'i:1' 

21 if rank_order in field_specs: 

22 for c, i in enumerate(field_specs): 

23 if i == rank_order: 

24 field_specs[c] = i + ':0:1' 

25 else: 

26 field_specs.append(rank_order + ':0:1') 

27 else: 

28 field_specs[0] = field_specs[0] + ':1' 

29 return field_specs 

30 

31 

32def summary_functions_on_conditions(has_calc): 

33 if has_calc: 

34 return [rmsd, energy_delta] 

35 return [rmsd] 

36 

37 

38def header_alias(h): 

39 """Replace keyboard characters with Unicode symbols 

40 for pretty printing""" 

41 if h == 'i': 

42 h = 'index' 

43 elif h == 'an': 

44 h = 'atomic #' 

45 elif h == 't': 

46 h = 'tag' 

47 elif h == 'el': 

48 h = 'element' 

49 elif h[0] == 'd': 

50 h = h.replace('d', 'Δ') 

51 elif h[0] == 'r': 

52 h = 'rank ' + header_alias(h[1:]) 

53 elif h[0] == 'a': 

54 h = h.replace('a', '<') 

55 h += '>' 

56 return h 

57 

58 

59def prec_round(a, prec=2): 

60 """ 

61 To make hierarchical sorting different from non-hierarchical sorting 

62 with floats. 

63 """ 

64 if a == 0: 

65 return a 

66 else: 

67 s = 1 if a > 0 else -1 

68 m = np.log10(s * a) // 1 

69 c = np.log10(s * a) % 1 

70 return s * np.round(10**c, prec) * 10**m 

71 

72 

73prec_round = np.vectorize(prec_round) 

74 

75# end most settings 

76 

77# this will sort alphabetically by chemical symbol 

78num2sym = dict(zip(np.argsort(chemical_symbols), chemical_symbols)) 

79# to sort by atomic number, uncomment below 

80# num2sym = dict(zip(range(len(chemical_symbols)), chemical_symbols)) 

81sym2num = {v: k for k, v in num2sym.items()} 

82 

83atoms_props = [ 

84 'dx', 

85 'dy', 

86 'dz', 

87 'd', 

88 't', 

89 'an', 

90 'i', 

91 'el', 

92 'p1', 

93 'p2', 

94 'p1x', 

95 'p1y', 

96 'p1z', 

97 'p2x', 

98 'p2y', 

99 'p2z'] 

100 

101 

102def get_field_data(atoms1, atoms2, field): 

103 if field[0] == 'r': 

104 field = field[1:] 

105 rank_order = True 

106 else: 

107 rank_order = False 

108 

109 if field in atoms_props: 

110 if field == 't': 

111 data = atoms1.get_tags() 

112 elif field == 'an': 

113 data = atoms1.numbers 

114 elif field == 'el': 

115 data = np.array([sym2num[sym] for sym in atoms1.symbols]) 

116 elif field == 'i': 

117 data = np.arange(len(atoms1)) 

118 else: 

119 if field.startswith('d'): 

120 y = atoms2.positions - atoms1.positions 

121 elif field.startswith('p'): 

122 if field[1] == '1': 

123 y = atoms1.positions 

124 else: 

125 y = atoms2.positions 

126 

127 if field.endswith('x'): 

128 data = y[:, 0] 

129 elif field.endswith('y'): 

130 data = y[:, 1] 

131 elif field.endswith('z'): 

132 data = y[:, 2] 

133 else: 

134 data = np.linalg.norm(y, axis=1) 

135 else: 

136 if field[0] == 'd': 

137 y = atoms2.get_forces() - atoms1.get_forces() 

138 elif field[0] == 'a': 

139 y = (atoms2.get_forces() + atoms1.get_forces()) / 2 

140 else: 

141 if field[1] == '1': 

142 y = atoms1.get_forces() 

143 else: 

144 y = atoms2.get_forces() 

145 

146 if field.endswith('x'): 

147 data = y[:, 0] 

148 elif field.endswith('y'): 

149 data = y[:, 1] 

150 elif field.endswith('z'): 

151 data = y[:, 2] 

152 else: 

153 data = np.linalg.norm(y, axis=1) 

154 

155 if rank_order: 

156 return np.argsort(np.argsort(-data)) 

157 

158 return data 

159 

160 

161# Summary Functions 

162 

163def rmsd(atoms1, atoms2): 

164 dpositions = atoms2.positions - atoms1.positions 

165 return 'RMSD={:+.1E}'.format( 

166 np.sqrt((np.linalg.norm(dpositions, axis=1)**2).mean())) 

167 

168 

169def energy_delta(atoms1, atoms2): 

170 E1 = atoms1.get_potential_energy() 

171 E2 = atoms2.get_potential_energy() 

172 return f'E1 = {E1:+.1E}, E2 = {E2:+.1E}, dE = {E2 - E1:+1.1E}' 

173 

174 

175def parse_field_specs(field_specs): 

176 fields = [] 

177 hier = [] 

178 scent = [] 

179 for fs in field_specs: 

180 fhs = fs.split(':') 

181 if len(fhs) == 3: 

182 scent.append(int(fhs[2])) 

183 hier.append(int(fhs[1])) 

184 fields.append(fhs[0]) 

185 elif len(fhs) == 2: 

186 scent.append(-1) 

187 hier.append(int(fhs[1])) 

188 fields.append(fhs[0]) 

189 elif len(fhs) == 1: 

190 scent.append(-1) 

191 hier.append(-1) 

192 fields.append(fhs[0]) 

193 mxm = max(hier) 

194 for c in range(len(hier)): 

195 if hier[c] < 0: 

196 mxm += 1 

197 hier[c] = mxm 

198 # reversed by convention of numpy lexsort 

199 hier = np.argsort(hier)[::-1] 

200 return fields, hier, np.array(scent) 

201 

202# Class definitions 

203 

204 

205class MapFormatter(string.Formatter): 

206 """String formatting method to map string 

207 mapped to float data field 

208 used for sorting back to string.""" 

209 

210 def format_field(self, value, spec): 

211 if spec.endswith('h'): 

212 value = num2sym[int(value)] 

213 spec = spec[:-1] + 's' 

214 return super().format_field(value, spec) 

215 

216 

217class TableFormat: 

218 def __init__(self, 

219 columnwidth=9, 

220 precision=2, 

221 representation='E', 

222 toprule='=', 

223 midrule='-', 

224 bottomrule='='): 

225 

226 self.precision = precision 

227 self.representation = representation 

228 self.columnwidth = columnwidth 

229 self.formatter = MapFormatter().format 

230 self.toprule = toprule 

231 self.midrule = midrule 

232 self.bottomrule = bottomrule 

233 

234 self.fmt_class = { 

235 'signed float': "{{: ^{}.{}{}}}".format( 

236 self.columnwidth, 

237 self.precision - 1, 

238 self.representation), 

239 'unsigned float': "{{:^{}.{}{}}}".format( 

240 self.columnwidth, 

241 self.precision - 1, 

242 self.representation), 

243 'int': "{{:^{}n}}".format( 

244 self.columnwidth), 

245 'str': "{{:^{}s}}".format( 

246 self.columnwidth), 

247 'conv': "{{:^{}h}}".format( 

248 self.columnwidth)} 

249 fmt = {} 

250 signed_floats = [ 

251 'dx', 

252 'dy', 

253 'dz', 

254 'dfx', 

255 'dfy', 

256 'dfz', 

257 'afx', 

258 'afy', 

259 'afz', 

260 'p1x', 

261 'p2x', 

262 'p1y', 

263 'p2y', 

264 'p1z', 

265 'p2z', 

266 'f1x', 

267 'f2x', 

268 'f1y', 

269 'f2y', 

270 'f1z', 

271 'f2z'] 

272 for sf in signed_floats: 

273 fmt[sf] = self.fmt_class['signed float'] 

274 unsigned_floats = ['d', 'df', 'af', 'p1', 'p2', 'f1', 'f2'] 

275 for usf in unsigned_floats: 

276 fmt[usf] = self.fmt_class['unsigned float'] 

277 integers = ['i', 'an', 't'] + ['r' + sf for sf in signed_floats] + \ 

278 ['r' + usf for usf in unsigned_floats] 

279 for i in integers: 

280 fmt[i] = self.fmt_class['int'] 

281 fmt['el'] = self.fmt_class['conv'] 

282 

283 self.fmt = fmt 

284 

285 

286class Table: 

287 def __init__(self, 

288 field_specs, 

289 summary_functions=[], 

290 tableformat=None, 

291 max_lines=None, 

292 title='', 

293 tablewidth=None): 

294 

295 self.max_lines = max_lines 

296 self.summary_functions = summary_functions 

297 self.field_specs = field_specs 

298 

299 self.fields, self.hier, self.scent = parse_field_specs(self.field_specs) 

300 self.nfields = len(self.fields) 

301 

302 # formatting 

303 if tableformat is None: 

304 self.tableformat = TableFormat() 

305 else: 

306 self.tableformat = tableformat 

307 

308 if tablewidth is None: 

309 self.tablewidth = self.tableformat.columnwidth * self.nfields 

310 else: 

311 self.tablewidth = tablewidth 

312 

313 self.title = title 

314 

315 def make(self, atoms1, atoms2, csv=False): 

316 header = self.make_header(csv=csv) 

317 body = self.make_body(atoms1, atoms2, csv=csv) 

318 if self.max_lines is not None: 

319 body = body[:self.max_lines] 

320 summary = self.make_summary(atoms1, atoms2) 

321 

322 return '\n'.join([self.title, 

323 self.tableformat.toprule * self.tablewidth, 

324 header, 

325 self.tableformat.midrule * self.tablewidth, 

326 body, 

327 self.tableformat.bottomrule * self.tablewidth, 

328 summary]) 

329 

330 def make_header(self, csv=False): 

331 if csv: 

332 return ','.join([header_alias(field) for field in self.fields]) 

333 

334 fields = self.tableformat.fmt_class['str'] * self.nfields 

335 headers = [header_alias(field) for field in self.fields] 

336 

337 return self.tableformat.formatter(fields, *headers) 

338 

339 def make_summary(self, atoms1, atoms2): 

340 return '\n'.join([summary_function(atoms1, atoms2) 

341 for summary_function in self.summary_functions]) 

342 

343 def make_body(self, atoms1, atoms2, csv=False): 

344 field_data = np.array([get_field_data(atoms1, atoms2, field) 

345 for field in self.fields]) 

346 

347 sorting_array = field_data * self.scent[:, np.newaxis] 

348 sorting_array = sorting_array[self.hier] 

349 sorting_array = prec_round(sorting_array, self.tableformat.precision) 

350 

351 field_data = field_data[:, np.lexsort(sorting_array)].transpose() 

352 

353 if csv: 

354 rowformat = ','.join( 

355 ['{:h}' if field == 'el' else '{{:.{}E}}'.format( 

356 self.tableformat.precision) for field in self.fields]) 

357 else: 

358 rowformat = ''.join([self.tableformat.fmt[field] 

359 for field in self.fields]) 

360 body = [ 

361 self.tableformat.formatter( 

362 rowformat, 

363 *row) for row in field_data] 

364 return '\n'.join(body) 

365 

366 

367default_index = string2index(':') 

368 

369 

370def slice_split(filename): 

371 if '@' in filename: 

372 filename, index = parse_filename(filename, None) 

373 else: 

374 filename, index = parse_filename(filename, default_index) 

375 return filename, index