Coverage for ase / db / cli.py: 62.90%
248 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 15:52 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-21 15:52 +0000
1# fmt: off
3import json
4import sys
5from collections import defaultdict
6from collections.abc import Iterable, Iterator
7from contextlib import contextmanager
8from pathlib import Path
10import ase.io
11from ase.db import connect
12from ase.db.core import convert_str_to_int_float_bool_or_str
13from ase.db.row import row2dct
14from ase.db.table import Table, all_columns
15from ase.utils import plural
18def count_keys(db, query):
19 keys = defaultdict(int)
20 for row in db.select(query):
21 for key in row._keys:
22 keys[key] += 1
24 n = max(len(key) for key in keys) + 1
25 for key, number in keys.items():
26 print('{:{}} {}'.format(key + ':', n, number))
27 return
30def show_values(args, db, query):
31 keys = args.show_values.split(',')
32 values = {key: defaultdict(int) for key in keys}
33 numbers = set()
34 for row in db.select(query):
35 kvp = row.key_value_pairs
36 for key in keys:
37 value = kvp.get(key)
38 if value is not None:
39 values[key][value] += 1
40 if not isinstance(value, str):
41 numbers.add(key)
43 n = max(len(key) for key in keys) + 1
44 for key in keys:
45 vals = values[key]
46 if key in numbers:
47 print('{:{}} [{}..{}]'
48 .format(key + ':', n, min(vals), max(vals)))
49 else:
50 print('{:{}} {}'
51 .format(key + ':', n,
52 ', '.join(f'{v}({n})'
53 for v, n in vals.items())))
56def insert_into(args, db, query, out, add_key_value_pairs):
57 if args.limit == -1:
58 args.limit = 0
60 progressbar = no_progressbar
61 length = None
63 if args.progress_bar:
64 # Try to import the one from click.
65 # People using ase.db will most likely have flask installed
66 # and therfore also click.
67 try:
68 from click import progressbar
69 except ImportError:
70 pass
71 else:
72 length = db.count(query)
74 nkvp = 0
75 nrows = 0
76 with connect(args.insert_into,
77 use_lock_file=not args.no_lock_file) as db2:
78 with progressbar(db.select(query,
79 sort=args.sort,
80 limit=args.limit,
81 offset=args.offset),
82 length=length) as rows:
83 for row in rows:
84 kvp = row.get('key_value_pairs', {})
85 nkvp -= len(kvp)
86 kvp.update(add_key_value_pairs)
87 nkvp += len(kvp)
88 if args.strip_data:
89 db2.write(row.toatoms(), **kvp)
90 else:
91 db2.write(row, data=row.get('data'), **kvp)
92 nrows += 1
94 out('Added %s (%s updated)' %
95 (plural(nkvp, 'key-value pair'),
96 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
97 out(f'Inserted {plural(nrows, "row")}')
100def main(args):
101 verbosity = 1 - args.quiet + args.verbose
102 query = ','.join(args.query)
104 if args.sort.endswith('-'):
105 # Allow using "key-" instead of "-key" for reverse sorting
106 args.sort = '-' + args.sort[:-1]
108 if query.isdigit():
109 query = int(query)
111 add_key_value_pairs = {}
112 if args.add_key_value_pairs:
113 for pair in args.add_key_value_pairs.split(','):
114 key, value = pair.split('=')
115 add_key_value_pairs[key] = \
116 convert_str_to_int_float_bool_or_str(value)
118 if args.delete_keys:
119 delete_keys = args.delete_keys.split(',')
120 else:
121 delete_keys = []
123 db = connect(args.database, use_lock_file=not args.no_lock_file)
125 def out(*args):
126 if verbosity > 0:
127 print(*args)
129 if args.analyse:
130 db.analyse()
131 return
133 if args.show_keys:
134 count_keys(db, query)
135 return
137 if args.show_values:
138 show_values(args, db, query)
139 return
141 if args.add_from_file:
142 filename = args.add_from_file
143 configs = ase.io.read(filename)
144 if not isinstance(configs, list):
145 configs = [configs]
146 for atoms in configs:
147 db.write(atoms, key_value_pairs=add_key_value_pairs)
148 out('Added ' + plural(len(configs), 'row'))
149 return
151 if args.count:
152 n = db.count(query)
153 print(f'{plural(n, "row")}')
154 return
156 if args.insert_into:
157 insert_into(args, db, query, out, add_key_value_pairs)
158 return
160 if args.limit == -1:
161 args.limit = 20
163 if args.explain:
164 for row in db.select(query, explain=True,
165 verbosity=verbosity,
166 limit=args.limit, offset=args.offset):
167 print(row['explain'])
168 return
170 if args.show_metadata:
171 print(json.dumps(db.metadata, sort_keys=True, indent=4))
172 return
174 if args.set_metadata:
175 with open(args.set_metadata) as fd:
176 db.metadata = json.load(fd)
177 return
179 if add_key_value_pairs or delete_keys:
180 ids = [row['id'] for row in db.select(query)]
181 M = 0
182 N = 0
183 with db:
184 for id in ids:
185 m, n = db.update(id, delete_keys=delete_keys,
186 **add_key_value_pairs)
187 M += m
188 N += n
189 out('Added %s (%s updated)' %
190 (plural(M, 'key-value pair'),
191 plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
192 out('Removed', plural(N, 'key-value pair'))
194 return
196 if args.delete:
197 ids = [row['id'] for row in db.select(query, include_data=False)]
198 if ids and not args.yes:
199 msg = f'Delete {plural(len(ids), "row")}? (yes/No): '
200 if input(msg).lower() != 'yes':
201 return
202 db.delete(ids)
203 out(f'Deleted {plural(len(ids), "row")}')
204 return
206 if args.plot:
207 if ':' in args.plot:
208 tags, keys = args.plot.split(':')
209 tags = tags.split(',')
210 else:
211 tags = []
212 keys = args.plot
213 keys = keys.split(',')
214 plots = defaultdict(list)
215 X = {}
216 labels = []
217 for row in db.select(query, sort=args.sort, include_data=False):
218 name = ','.join(str(row[tag]) for tag in tags)
219 x = row.get(keys[0])
220 if x is not None:
221 if isinstance(x, str):
222 if x not in X:
223 X[x] = len(X)
224 labels.append(x)
225 x = X[x]
226 plots[name].append([x] + [row.get(key) for key in keys[1:]])
227 import matplotlib.pyplot as plt
228 for name, plot in plots.items():
229 xyy = list(zip(*plot))
230 x = xyy[0]
231 for y, key in zip(xyy[1:], keys[1:]):
232 plt.plot(x, y, label=name + ':' + key)
233 if X:
234 plt.xticks(range(len(labels)), labels, rotation=90)
235 plt.legend()
236 plt.show()
237 return
239 if args.json:
240 row = db.get(query)
241 db2 = connect(sys.stdout, 'json', use_lock_file=False)
242 kvp = row.get('key_value_pairs', {})
243 db2.write(row, data=row.get('data'), **kvp)
244 return
246 if args.long:
247 row = db.get(query)
248 print(row2str(row))
249 return
251 if args.open_web_browser:
252 try:
253 import flask # noqa
254 except ImportError:
255 print('Please install Flask: python3 -m pip install flask')
256 return
257 check_jsmol()
258 import ase.db.app as app
259 app.DBApp().run_db(db)
260 return
262 columns = list(all_columns)
263 c = args.columns
264 if c and c.startswith('++'):
265 keys = set()
266 for row in db.select(query,
267 limit=args.limit, offset=args.offset,
268 include_data=False):
269 keys.update(row._keys)
270 columns.extend(keys)
271 if c[2:3] == ',':
272 c = c[3:]
273 else:
274 c = ''
275 if c:
276 if c[0] == '+':
277 c = c[1:]
278 elif c[0] != '-':
279 columns = []
280 for col in c.split(','):
281 if col[0] == '-':
282 columns.remove(col[1:])
283 else:
284 columns.append(col.lstrip('+'))
286 table = Table(db, verbosity=verbosity, cut=args.cut)
287 table.select(query, columns, args.sort, args.limit, args.offset)
288 if args.csv:
289 table.write_csv()
290 else:
291 table.write(query)
294def row2str(row) -> str:
295 t = row2dct(row, key_descriptions={})
296 S = [t['formula'] + ':',
297 'Unit cell in Ang:',
298 'axis|periodic| x| y| z|' +
299 ' length| angle']
300 c = 1
301 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' +
302 '{3:>10}|{4:>10}')
303 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']):
304 S.append(fmt.format(c, [' no', 'yes'][int(p)], axis, L, A))
305 c += 1
306 S.append('')
308 if 'stress' in t:
309 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:',
310 ' {}\n'.format(t['stress'])]
312 if 'dipole' in t:
313 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole']))
315 if 'constraints' in t:
316 S.append('Constraints: {}\n'.format(t['constraints']))
318 if 'data' in t:
319 S.append('Data: {}\n'.format(t['data']))
321 width0 = max(max(len(row[0]) for row in t['table']), 3)
322 width1 = max(max(len(row[1]) for row in t['table']), 11)
323 S.append('{:{}} | {:{}} | Value'
324 .format('Key', width0, 'Description', width1))
325 for key, desc, value in t['table']:
326 S.append('{:{}} | {:{}} | {}'
327 .format(key, width0, desc, width1, value))
328 return '\n'.join(S)
331@contextmanager
332def no_progressbar(iterable: Iterable,
333 length: int | None = None) -> Iterator[Iterable]:
334 """A do-nothing implementation."""
335 yield iterable
338def check_jsmol():
339 static = Path(__file__).parent / 'static'
340 if not (static / 'jsmol/JSmol.min.js').is_file():
341 print(f"""
342 WARNING:
343 You don't have jsmol on your system.
345 Download Jmol-*-binary.tar.gz from
346 https://sourceforge.net/projects/jmol/files/Jmol/,
347 extract jsmol.zip, unzip it and create a soft-link:
349 $ tar -xf Jmol-*-binary.tar.gz
350 $ unzip jmol-*/jsmol.zip
351 $ ln -s $PWD/jsmol {static}/jsmol
352 """,
353 file=sys.stderr)