Coverage for /builds/ase/ase/ase/db/cli.py: 62.30%

244 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import json 

4import sys 

5from collections import defaultdict 

6from contextlib import contextmanager 

7from pathlib import Path 

8from typing import Iterable, Iterator 

9 

10import ase.io 

11from ase.db import connect 

12from ase.db.core import convert_str_to_int_float_bool_or_str 

13from ase.db.row import row2dct 

14from ase.db.table import Table, all_columns 

15from ase.utils import plural 

16 

17 

18def count_keys(db, query): 

19 keys = defaultdict(int) 

20 for row in db.select(query): 

21 for key in row._keys: 

22 keys[key] += 1 

23 

24 n = max(len(key) for key in keys) + 1 

25 for key, number in keys.items(): 

26 print('{:{}} {}'.format(key + ':', n, number)) 

27 return 

28 

29 

30def main(args): 

31 verbosity = 1 - args.quiet + args.verbose 

32 query = ','.join(args.query) 

33 

34 if args.sort.endswith('-'): 

35 # Allow using "key-" instead of "-key" for reverse sorting 

36 args.sort = '-' + args.sort[:-1] 

37 

38 if query.isdigit(): 

39 query = int(query) 

40 

41 add_key_value_pairs = {} 

42 if args.add_key_value_pairs: 

43 for pair in args.add_key_value_pairs.split(','): 

44 key, value = pair.split('=') 

45 add_key_value_pairs[key] = \ 

46 convert_str_to_int_float_bool_or_str(value) 

47 

48 if args.delete_keys: 

49 delete_keys = args.delete_keys.split(',') 

50 else: 

51 delete_keys = [] 

52 

53 db = connect(args.database, use_lock_file=not args.no_lock_file) 

54 

55 def out(*args): 

56 if verbosity > 0: 

57 print(*args) 

58 

59 if args.analyse: 

60 db.analyse() 

61 return 

62 

63 if args.show_keys: 

64 count_keys(db, query) 

65 return 

66 

67 if args.show_values: 

68 keys = args.show_values.split(',') 

69 values = {key: defaultdict(int) for key in keys} 

70 numbers = set() 

71 for row in db.select(query): 

72 kvp = row.key_value_pairs 

73 for key in keys: 

74 value = kvp.get(key) 

75 if value is not None: 

76 values[key][value] += 1 

77 if not isinstance(value, str): 

78 numbers.add(key) 

79 

80 n = max(len(key) for key in keys) + 1 

81 for key in keys: 

82 vals = values[key] 

83 if key in numbers: 

84 print('{:{}} [{}..{}]' 

85 .format(key + ':', n, min(vals), max(vals))) 

86 else: 

87 print('{:{}} {}' 

88 .format(key + ':', n, 

89 ', '.join(f'{v}({n})' 

90 for v, n in vals.items()))) 

91 return 

92 

93 if args.add_from_file: 

94 filename = args.add_from_file 

95 configs = ase.io.read(filename) 

96 if not isinstance(configs, list): 

97 configs = [configs] 

98 for atoms in configs: 

99 db.write(atoms, key_value_pairs=add_key_value_pairs) 

100 out('Added ' + plural(len(configs), 'row')) 

101 return 

102 

103 if args.count: 

104 n = db.count(query) 

105 print(f'{plural(n, "row")}') 

106 return 

107 

108 if args.insert_into: 

109 if args.limit == -1: 

110 args.limit = 0 

111 

112 progressbar = no_progressbar 

113 length = None 

114 

115 if args.progress_bar: 

116 # Try to import the one from click. 

117 # People using ase.db will most likely have flask installed 

118 # and therfore also click. 

119 try: 

120 from click import progressbar 

121 except ImportError: 

122 pass 

123 else: 

124 length = db.count(query) 

125 

126 nkvp = 0 

127 nrows = 0 

128 with connect(args.insert_into, 

129 use_lock_file=not args.no_lock_file) as db2: 

130 with progressbar(db.select(query, 

131 sort=args.sort, 

132 limit=args.limit, 

133 offset=args.offset), 

134 length=length) as rows: 

135 for row in rows: 

136 kvp = row.get('key_value_pairs', {}) 

137 nkvp -= len(kvp) 

138 kvp.update(add_key_value_pairs) 

139 nkvp += len(kvp) 

140 if args.strip_data: 

141 db2.write(row.toatoms(), **kvp) 

142 else: 

143 db2.write(row, data=row.get('data'), **kvp) 

144 nrows += 1 

145 

146 out('Added %s (%s updated)' % 

147 (plural(nkvp, 'key-value pair'), 

148 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair'))) 

149 out(f'Inserted {plural(nrows, "row")}') 

150 return 

151 

152 if args.limit == -1: 

153 args.limit = 20 

154 

155 if args.explain: 

156 for row in db.select(query, explain=True, 

157 verbosity=verbosity, 

158 limit=args.limit, offset=args.offset): 

159 print(row['explain']) 

160 return 

161 

162 if args.show_metadata: 

163 print(json.dumps(db.metadata, sort_keys=True, indent=4)) 

164 return 

165 

166 if args.set_metadata: 

167 with open(args.set_metadata) as fd: 

168 db.metadata = json.load(fd) 

169 return 

170 

171 if add_key_value_pairs or delete_keys: 

172 ids = [row['id'] for row in db.select(query)] 

173 M = 0 

174 N = 0 

175 with db: 

176 for id in ids: 

177 m, n = db.update(id, delete_keys=delete_keys, 

178 **add_key_value_pairs) 

179 M += m 

180 N += n 

181 out('Added %s (%s updated)' % 

182 (plural(M, 'key-value pair'), 

183 plural(len(add_key_value_pairs) * len(ids) - M, 'pair'))) 

184 out('Removed', plural(N, 'key-value pair')) 

185 

186 return 

187 

188 if args.delete: 

189 ids = [row['id'] for row in db.select(query, include_data=False)] 

190 if ids and not args.yes: 

191 msg = f'Delete {plural(len(ids), "row")}? (yes/No): ' 

192 if input(msg).lower() != 'yes': 

193 return 

194 db.delete(ids) 

195 out(f'Deleted {plural(len(ids), "row")}') 

196 return 

197 

198 if args.plot: 

199 if ':' in args.plot: 

200 tags, keys = args.plot.split(':') 

201 tags = tags.split(',') 

202 else: 

203 tags = [] 

204 keys = args.plot 

205 keys = keys.split(',') 

206 plots = defaultdict(list) 

207 X = {} 

208 labels = [] 

209 for row in db.select(query, sort=args.sort, include_data=False): 

210 name = ','.join(str(row[tag]) for tag in tags) 

211 x = row.get(keys[0]) 

212 if x is not None: 

213 if isinstance(x, str): 

214 if x not in X: 

215 X[x] = len(X) 

216 labels.append(x) 

217 x = X[x] 

218 plots[name].append([x] + [row.get(key) for key in keys[1:]]) 

219 import matplotlib.pyplot as plt 

220 for name, plot in plots.items(): 

221 xyy = list(zip(*plot)) 

222 x = xyy[0] 

223 for y, key in zip(xyy[1:], keys[1:]): 

224 plt.plot(x, y, label=name + ':' + key) 

225 if X: 

226 plt.xticks(range(len(labels)), labels, rotation=90) 

227 plt.legend() 

228 plt.show() 

229 return 

230 

231 if args.json: 

232 row = db.get(query) 

233 db2 = connect(sys.stdout, 'json', use_lock_file=False) 

234 kvp = row.get('key_value_pairs', {}) 

235 db2.write(row, data=row.get('data'), **kvp) 

236 return 

237 

238 if args.long: 

239 row = db.get(query) 

240 print(row2str(row)) 

241 return 

242 

243 if args.open_web_browser: 

244 try: 

245 import flask # noqa 

246 except ImportError: 

247 print('Please install Flask: python3 -m pip install flask') 

248 return 

249 check_jsmol() 

250 import ase.db.app as app 

251 app.DBApp().run_db(db) 

252 return 

253 

254 columns = list(all_columns) 

255 c = args.columns 

256 if c and c.startswith('++'): 

257 keys = set() 

258 for row in db.select(query, 

259 limit=args.limit, offset=args.offset, 

260 include_data=False): 

261 keys.update(row._keys) 

262 columns.extend(keys) 

263 if c[2:3] == ',': 

264 c = c[3:] 

265 else: 

266 c = '' 

267 if c: 

268 if c[0] == '+': 

269 c = c[1:] 

270 elif c[0] != '-': 

271 columns = [] 

272 for col in c.split(','): 

273 if col[0] == '-': 

274 columns.remove(col[1:]) 

275 else: 

276 columns.append(col.lstrip('+')) 

277 

278 table = Table(db, verbosity=verbosity, cut=args.cut) 

279 table.select(query, columns, args.sort, args.limit, args.offset) 

280 if args.csv: 

281 table.write_csv() 

282 else: 

283 table.write(query) 

284 

285 

286def row2str(row) -> str: 

287 t = row2dct(row, key_descriptions={}) 

288 S = [t['formula'] + ':', 

289 'Unit cell in Ang:', 

290 'axis|periodic| x| y| z|' + 

291 ' length| angle'] 

292 c = 1 

293 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' + 

294 '{3:>10}|{4:>10}') 

295 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']): 

296 S.append(fmt.format(c, [' no', 'yes'][int(p)], axis, L, A)) 

297 c += 1 

298 S.append('') 

299 

300 if 'stress' in t: 

301 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:', 

302 ' {}\n'.format(t['stress'])] 

303 

304 if 'dipole' in t: 

305 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole'])) 

306 

307 if 'constraints' in t: 

308 S.append('Constraints: {}\n'.format(t['constraints'])) 

309 

310 if 'data' in t: 

311 S.append('Data: {}\n'.format(t['data'])) 

312 

313 width0 = max(max(len(row[0]) for row in t['table']), 3) 

314 width1 = max(max(len(row[1]) for row in t['table']), 11) 

315 S.append('{:{}} | {:{}} | Value' 

316 .format('Key', width0, 'Description', width1)) 

317 for key, desc, value in t['table']: 

318 S.append('{:{}} | {:{}} | {}' 

319 .format(key, width0, desc, width1, value)) 

320 return '\n'.join(S) 

321 

322 

323@contextmanager 

324def no_progressbar(iterable: Iterable, 

325 length: int = None) -> Iterator[Iterable]: 

326 """A do-nothing implementation.""" 

327 yield iterable 

328 

329 

330def check_jsmol(): 

331 static = Path(__file__).parent / 'static' 

332 if not (static / 'jsmol/JSmol.min.js').is_file(): 

333 print(f""" 

334 WARNING: 

335 You don't have jsmol on your system. 

336 

337 Download Jmol-*-binary.tar.gz from 

338 https://sourceforge.net/projects/jmol/files/Jmol/, 

339 extract jsmol.zip, unzip it and create a soft-link: 

340 

341 $ tar -xf Jmol-*-binary.tar.gz 

342 $ unzip jmol-*/jsmol.zip 

343 $ ln -s $PWD/jsmol {static}/jsmol 

344 """, 

345 file=sys.stderr)