Coverage for ase / db / cli.py: 62.90%

248 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-21 15:52 +0000

1# fmt: off 

2 

3import json 

4import sys 

5from collections import defaultdict 

6from collections.abc import Iterable, Iterator 

7from contextlib import contextmanager 

8from pathlib import Path 

9 

10import ase.io 

11from ase.db import connect 

12from ase.db.core import convert_str_to_int_float_bool_or_str 

13from ase.db.row import row2dct 

14from ase.db.table import Table, all_columns 

15from ase.utils import plural 

16 

17 

18def count_keys(db, query): 

19 keys = defaultdict(int) 

20 for row in db.select(query): 

21 for key in row._keys: 

22 keys[key] += 1 

23 

24 n = max(len(key) for key in keys) + 1 

25 for key, number in keys.items(): 

26 print('{:{}} {}'.format(key + ':', n, number)) 

27 return 

28 

29 

30def show_values(args, db, query): 

31 keys = args.show_values.split(',') 

32 values = {key: defaultdict(int) for key in keys} 

33 numbers = set() 

34 for row in db.select(query): 

35 kvp = row.key_value_pairs 

36 for key in keys: 

37 value = kvp.get(key) 

38 if value is not None: 

39 values[key][value] += 1 

40 if not isinstance(value, str): 

41 numbers.add(key) 

42 

43 n = max(len(key) for key in keys) + 1 

44 for key in keys: 

45 vals = values[key] 

46 if key in numbers: 

47 print('{:{}} [{}..{}]' 

48 .format(key + ':', n, min(vals), max(vals))) 

49 else: 

50 print('{:{}} {}' 

51 .format(key + ':', n, 

52 ', '.join(f'{v}({n})' 

53 for v, n in vals.items()))) 

54 

55 

56def insert_into(args, db, query, out, add_key_value_pairs): 

57 if args.limit == -1: 

58 args.limit = 0 

59 

60 progressbar = no_progressbar 

61 length = None 

62 

63 if args.progress_bar: 

64 # Try to import the one from click. 

65 # People using ase.db will most likely have flask installed 

66 # and therfore also click. 

67 try: 

68 from click import progressbar 

69 except ImportError: 

70 pass 

71 else: 

72 length = db.count(query) 

73 

74 nkvp = 0 

75 nrows = 0 

76 with connect(args.insert_into, 

77 use_lock_file=not args.no_lock_file) as db2: 

78 with progressbar(db.select(query, 

79 sort=args.sort, 

80 limit=args.limit, 

81 offset=args.offset), 

82 length=length) as rows: 

83 for row in rows: 

84 kvp = row.get('key_value_pairs', {}) 

85 nkvp -= len(kvp) 

86 kvp.update(add_key_value_pairs) 

87 nkvp += len(kvp) 

88 if args.strip_data: 

89 db2.write(row.toatoms(), **kvp) 

90 else: 

91 db2.write(row, data=row.get('data'), **kvp) 

92 nrows += 1 

93 

94 out('Added %s (%s updated)' % 

95 (plural(nkvp, 'key-value pair'), 

96 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair'))) 

97 out(f'Inserted {plural(nrows, "row")}') 

98 

99 

100def main(args): 

101 verbosity = 1 - args.quiet + args.verbose 

102 query = ','.join(args.query) 

103 

104 if args.sort.endswith('-'): 

105 # Allow using "key-" instead of "-key" for reverse sorting 

106 args.sort = '-' + args.sort[:-1] 

107 

108 if query.isdigit(): 

109 query = int(query) 

110 

111 add_key_value_pairs = {} 

112 if args.add_key_value_pairs: 

113 for pair in args.add_key_value_pairs.split(','): 

114 key, value = pair.split('=') 

115 add_key_value_pairs[key] = \ 

116 convert_str_to_int_float_bool_or_str(value) 

117 

118 if args.delete_keys: 

119 delete_keys = args.delete_keys.split(',') 

120 else: 

121 delete_keys = [] 

122 

123 db = connect(args.database, use_lock_file=not args.no_lock_file) 

124 

125 def out(*args): 

126 if verbosity > 0: 

127 print(*args) 

128 

129 if args.analyse: 

130 db.analyse() 

131 return 

132 

133 if args.show_keys: 

134 count_keys(db, query) 

135 return 

136 

137 if args.show_values: 

138 show_values(args, db, query) 

139 return 

140 

141 if args.add_from_file: 

142 filename = args.add_from_file 

143 configs = ase.io.read(filename) 

144 if not isinstance(configs, list): 

145 configs = [configs] 

146 for atoms in configs: 

147 db.write(atoms, key_value_pairs=add_key_value_pairs) 

148 out('Added ' + plural(len(configs), 'row')) 

149 return 

150 

151 if args.count: 

152 n = db.count(query) 

153 print(f'{plural(n, "row")}') 

154 return 

155 

156 if args.insert_into: 

157 insert_into(args, db, query, out, add_key_value_pairs) 

158 return 

159 

160 if args.limit == -1: 

161 args.limit = 20 

162 

163 if args.explain: 

164 for row in db.select(query, explain=True, 

165 verbosity=verbosity, 

166 limit=args.limit, offset=args.offset): 

167 print(row['explain']) 

168 return 

169 

170 if args.show_metadata: 

171 print(json.dumps(db.metadata, sort_keys=True, indent=4)) 

172 return 

173 

174 if args.set_metadata: 

175 with open(args.set_metadata) as fd: 

176 db.metadata = json.load(fd) 

177 return 

178 

179 if add_key_value_pairs or delete_keys: 

180 ids = [row['id'] for row in db.select(query)] 

181 M = 0 

182 N = 0 

183 with db: 

184 for id in ids: 

185 m, n = db.update(id, delete_keys=delete_keys, 

186 **add_key_value_pairs) 

187 M += m 

188 N += n 

189 out('Added %s (%s updated)' % 

190 (plural(M, 'key-value pair'), 

191 plural(len(add_key_value_pairs) * len(ids) - M, 'pair'))) 

192 out('Removed', plural(N, 'key-value pair')) 

193 

194 return 

195 

196 if args.delete: 

197 ids = [row['id'] for row in db.select(query, include_data=False)] 

198 if ids and not args.yes: 

199 msg = f'Delete {plural(len(ids), "row")}? (yes/No): ' 

200 if input(msg).lower() != 'yes': 

201 return 

202 db.delete(ids) 

203 out(f'Deleted {plural(len(ids), "row")}') 

204 return 

205 

206 if args.plot: 

207 if ':' in args.plot: 

208 tags, keys = args.plot.split(':') 

209 tags = tags.split(',') 

210 else: 

211 tags = [] 

212 keys = args.plot 

213 keys = keys.split(',') 

214 plots = defaultdict(list) 

215 X = {} 

216 labels = [] 

217 for row in db.select(query, sort=args.sort, include_data=False): 

218 name = ','.join(str(row[tag]) for tag in tags) 

219 x = row.get(keys[0]) 

220 if x is not None: 

221 if isinstance(x, str): 

222 if x not in X: 

223 X[x] = len(X) 

224 labels.append(x) 

225 x = X[x] 

226 plots[name].append([x] + [row.get(key) for key in keys[1:]]) 

227 import matplotlib.pyplot as plt 

228 for name, plot in plots.items(): 

229 xyy = list(zip(*plot)) 

230 x = xyy[0] 

231 for y, key in zip(xyy[1:], keys[1:]): 

232 plt.plot(x, y, label=name + ':' + key) 

233 if X: 

234 plt.xticks(range(len(labels)), labels, rotation=90) 

235 plt.legend() 

236 plt.show() 

237 return 

238 

239 if args.json: 

240 row = db.get(query) 

241 db2 = connect(sys.stdout, 'json', use_lock_file=False) 

242 kvp = row.get('key_value_pairs', {}) 

243 db2.write(row, data=row.get('data'), **kvp) 

244 return 

245 

246 if args.long: 

247 row = db.get(query) 

248 print(row2str(row)) 

249 return 

250 

251 if args.open_web_browser: 

252 try: 

253 import flask # noqa 

254 except ImportError: 

255 print('Please install Flask: python3 -m pip install flask') 

256 return 

257 check_jsmol() 

258 import ase.db.app as app 

259 app.DBApp().run_db(db) 

260 return 

261 

262 columns = list(all_columns) 

263 c = args.columns 

264 if c and c.startswith('++'): 

265 keys = set() 

266 for row in db.select(query, 

267 limit=args.limit, offset=args.offset, 

268 include_data=False): 

269 keys.update(row._keys) 

270 columns.extend(keys) 

271 if c[2:3] == ',': 

272 c = c[3:] 

273 else: 

274 c = '' 

275 if c: 

276 if c[0] == '+': 

277 c = c[1:] 

278 elif c[0] != '-': 

279 columns = [] 

280 for col in c.split(','): 

281 if col[0] == '-': 

282 columns.remove(col[1:]) 

283 else: 

284 columns.append(col.lstrip('+')) 

285 

286 table = Table(db, verbosity=verbosity, cut=args.cut) 

287 table.select(query, columns, args.sort, args.limit, args.offset) 

288 if args.csv: 

289 table.write_csv() 

290 else: 

291 table.write(query) 

292 

293 

294def row2str(row) -> str: 

295 t = row2dct(row, key_descriptions={}) 

296 S = [t['formula'] + ':', 

297 'Unit cell in Ang:', 

298 'axis|periodic| x| y| z|' + 

299 ' length| angle'] 

300 c = 1 

301 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' + 

302 '{3:>10}|{4:>10}') 

303 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']): 

304 S.append(fmt.format(c, [' no', 'yes'][int(p)], axis, L, A)) 

305 c += 1 

306 S.append('') 

307 

308 if 'stress' in t: 

309 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:', 

310 ' {}\n'.format(t['stress'])] 

311 

312 if 'dipole' in t: 

313 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole'])) 

314 

315 if 'constraints' in t: 

316 S.append('Constraints: {}\n'.format(t['constraints'])) 

317 

318 if 'data' in t: 

319 S.append('Data: {}\n'.format(t['data'])) 

320 

321 width0 = max(max(len(row[0]) for row in t['table']), 3) 

322 width1 = max(max(len(row[1]) for row in t['table']), 11) 

323 S.append('{:{}} | {:{}} | Value' 

324 .format('Key', width0, 'Description', width1)) 

325 for key, desc, value in t['table']: 

326 S.append('{:{}} | {:{}} | {}' 

327 .format(key, width0, desc, width1, value)) 

328 return '\n'.join(S) 

329 

330 

331@contextmanager 

332def no_progressbar(iterable: Iterable, 

333 length: int | None = None) -> Iterator[Iterable]: 

334 """A do-nothing implementation.""" 

335 yield iterable 

336 

337 

338def check_jsmol(): 

339 static = Path(__file__).parent / 'static' 

340 if not (static / 'jsmol/JSmol.min.js').is_file(): 

341 print(f""" 

342 WARNING: 

343 You don't have jsmol on your system. 

344 

345 Download Jmol-*-binary.tar.gz from 

346 https://sourceforge.net/projects/jmol/files/Jmol/, 

347 extract jsmol.zip, unzip it and create a soft-link: 

348 

349 $ tar -xf Jmol-*-binary.tar.gz 

350 $ unzip jmol-*/jsmol.zip 

351 $ ln -s $PWD/jsmol {static}/jsmol 

352 """, 

353 file=sys.stderr)