Coverage for /builds/ase/ase/ase/db/cli.py: 62.30%
244 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import json
4import sys
5from collections import defaultdict
6from contextlib import contextmanager
7from pathlib import Path
8from typing import Iterable, Iterator
10import ase.io
11from ase.db import connect
12from ase.db.core import convert_str_to_int_float_bool_or_str
13from ase.db.row import row2dct
14from ase.db.table import Table, all_columns
15from ase.utils import plural
18def count_keys(db, query):
19 keys = defaultdict(int)
20 for row in db.select(query):
21 for key in row._keys:
22 keys[key] += 1
24 n = max(len(key) for key in keys) + 1
25 for key, number in keys.items():
26 print('{:{}} {}'.format(key + ':', n, number))
27 return
30def main(args):
31 verbosity = 1 - args.quiet + args.verbose
32 query = ','.join(args.query)
34 if args.sort.endswith('-'):
35 # Allow using "key-" instead of "-key" for reverse sorting
36 args.sort = '-' + args.sort[:-1]
38 if query.isdigit():
39 query = int(query)
41 add_key_value_pairs = {}
42 if args.add_key_value_pairs:
43 for pair in args.add_key_value_pairs.split(','):
44 key, value = pair.split('=')
45 add_key_value_pairs[key] = \
46 convert_str_to_int_float_bool_or_str(value)
48 if args.delete_keys:
49 delete_keys = args.delete_keys.split(',')
50 else:
51 delete_keys = []
53 db = connect(args.database, use_lock_file=not args.no_lock_file)
55 def out(*args):
56 if verbosity > 0:
57 print(*args)
59 if args.analyse:
60 db.analyse()
61 return
63 if args.show_keys:
64 count_keys(db, query)
65 return
67 if args.show_values:
68 keys = args.show_values.split(',')
69 values = {key: defaultdict(int) for key in keys}
70 numbers = set()
71 for row in db.select(query):
72 kvp = row.key_value_pairs
73 for key in keys:
74 value = kvp.get(key)
75 if value is not None:
76 values[key][value] += 1
77 if not isinstance(value, str):
78 numbers.add(key)
80 n = max(len(key) for key in keys) + 1
81 for key in keys:
82 vals = values[key]
83 if key in numbers:
84 print('{:{}} [{}..{}]'
85 .format(key + ':', n, min(vals), max(vals)))
86 else:
87 print('{:{}} {}'
88 .format(key + ':', n,
89 ', '.join(f'{v}({n})'
90 for v, n in vals.items())))
91 return
93 if args.add_from_file:
94 filename = args.add_from_file
95 configs = ase.io.read(filename)
96 if not isinstance(configs, list):
97 configs = [configs]
98 for atoms in configs:
99 db.write(atoms, key_value_pairs=add_key_value_pairs)
100 out('Added ' + plural(len(configs), 'row'))
101 return
103 if args.count:
104 n = db.count(query)
105 print(f'{plural(n, "row")}')
106 return
108 if args.insert_into:
109 if args.limit == -1:
110 args.limit = 0
112 progressbar = no_progressbar
113 length = None
115 if args.progress_bar:
116 # Try to import the one from click.
117 # People using ase.db will most likely have flask installed
118 # and therfore also click.
119 try:
120 from click import progressbar
121 except ImportError:
122 pass
123 else:
124 length = db.count(query)
126 nkvp = 0
127 nrows = 0
128 with connect(args.insert_into,
129 use_lock_file=not args.no_lock_file) as db2:
130 with progressbar(db.select(query,
131 sort=args.sort,
132 limit=args.limit,
133 offset=args.offset),
134 length=length) as rows:
135 for row in rows:
136 kvp = row.get('key_value_pairs', {})
137 nkvp -= len(kvp)
138 kvp.update(add_key_value_pairs)
139 nkvp += len(kvp)
140 if args.strip_data:
141 db2.write(row.toatoms(), **kvp)
142 else:
143 db2.write(row, data=row.get('data'), **kvp)
144 nrows += 1
146 out('Added %s (%s updated)' %
147 (plural(nkvp, 'key-value pair'),
148 plural(len(add_key_value_pairs) * nrows - nkvp, 'pair')))
149 out(f'Inserted {plural(nrows, "row")}')
150 return
152 if args.limit == -1:
153 args.limit = 20
155 if args.explain:
156 for row in db.select(query, explain=True,
157 verbosity=verbosity,
158 limit=args.limit, offset=args.offset):
159 print(row['explain'])
160 return
162 if args.show_metadata:
163 print(json.dumps(db.metadata, sort_keys=True, indent=4))
164 return
166 if args.set_metadata:
167 with open(args.set_metadata) as fd:
168 db.metadata = json.load(fd)
169 return
171 if add_key_value_pairs or delete_keys:
172 ids = [row['id'] for row in db.select(query)]
173 M = 0
174 N = 0
175 with db:
176 for id in ids:
177 m, n = db.update(id, delete_keys=delete_keys,
178 **add_key_value_pairs)
179 M += m
180 N += n
181 out('Added %s (%s updated)' %
182 (plural(M, 'key-value pair'),
183 plural(len(add_key_value_pairs) * len(ids) - M, 'pair')))
184 out('Removed', plural(N, 'key-value pair'))
186 return
188 if args.delete:
189 ids = [row['id'] for row in db.select(query, include_data=False)]
190 if ids and not args.yes:
191 msg = f'Delete {plural(len(ids), "row")}? (yes/No): '
192 if input(msg).lower() != 'yes':
193 return
194 db.delete(ids)
195 out(f'Deleted {plural(len(ids), "row")}')
196 return
198 if args.plot:
199 if ':' in args.plot:
200 tags, keys = args.plot.split(':')
201 tags = tags.split(',')
202 else:
203 tags = []
204 keys = args.plot
205 keys = keys.split(',')
206 plots = defaultdict(list)
207 X = {}
208 labels = []
209 for row in db.select(query, sort=args.sort, include_data=False):
210 name = ','.join(str(row[tag]) for tag in tags)
211 x = row.get(keys[0])
212 if x is not None:
213 if isinstance(x, str):
214 if x not in X:
215 X[x] = len(X)
216 labels.append(x)
217 x = X[x]
218 plots[name].append([x] + [row.get(key) for key in keys[1:]])
219 import matplotlib.pyplot as plt
220 for name, plot in plots.items():
221 xyy = list(zip(*plot))
222 x = xyy[0]
223 for y, key in zip(xyy[1:], keys[1:]):
224 plt.plot(x, y, label=name + ':' + key)
225 if X:
226 plt.xticks(range(len(labels)), labels, rotation=90)
227 plt.legend()
228 plt.show()
229 return
231 if args.json:
232 row = db.get(query)
233 db2 = connect(sys.stdout, 'json', use_lock_file=False)
234 kvp = row.get('key_value_pairs', {})
235 db2.write(row, data=row.get('data'), **kvp)
236 return
238 if args.long:
239 row = db.get(query)
240 print(row2str(row))
241 return
243 if args.open_web_browser:
244 try:
245 import flask # noqa
246 except ImportError:
247 print('Please install Flask: python3 -m pip install flask')
248 return
249 check_jsmol()
250 import ase.db.app as app
251 app.DBApp().run_db(db)
252 return
254 columns = list(all_columns)
255 c = args.columns
256 if c and c.startswith('++'):
257 keys = set()
258 for row in db.select(query,
259 limit=args.limit, offset=args.offset,
260 include_data=False):
261 keys.update(row._keys)
262 columns.extend(keys)
263 if c[2:3] == ',':
264 c = c[3:]
265 else:
266 c = ''
267 if c:
268 if c[0] == '+':
269 c = c[1:]
270 elif c[0] != '-':
271 columns = []
272 for col in c.split(','):
273 if col[0] == '-':
274 columns.remove(col[1:])
275 else:
276 columns.append(col.lstrip('+'))
278 table = Table(db, verbosity=verbosity, cut=args.cut)
279 table.select(query, columns, args.sort, args.limit, args.offset)
280 if args.csv:
281 table.write_csv()
282 else:
283 table.write(query)
286def row2str(row) -> str:
287 t = row2dct(row, key_descriptions={})
288 S = [t['formula'] + ':',
289 'Unit cell in Ang:',
290 'axis|periodic| x| y| z|' +
291 ' length| angle']
292 c = 1
293 fmt = (' {0}| {1}|{2[0]:>11}|{2[1]:>11}|{2[2]:>11}|' +
294 '{3:>10}|{4:>10}')
295 for p, axis, L, A in zip(row.pbc, t['cell'], t['lengths'], t['angles']):
296 S.append(fmt.format(c, [' no', 'yes'][int(p)], axis, L, A))
297 c += 1
298 S.append('')
300 if 'stress' in t:
301 S += ['Stress tensor (xx, yy, zz, zy, zx, yx) in eV/Ang^3:',
302 ' {}\n'.format(t['stress'])]
304 if 'dipole' in t:
305 S.append('Dipole moment in e*Ang: ({})\n'.format(t['dipole']))
307 if 'constraints' in t:
308 S.append('Constraints: {}\n'.format(t['constraints']))
310 if 'data' in t:
311 S.append('Data: {}\n'.format(t['data']))
313 width0 = max(max(len(row[0]) for row in t['table']), 3)
314 width1 = max(max(len(row[1]) for row in t['table']), 11)
315 S.append('{:{}} | {:{}} | Value'
316 .format('Key', width0, 'Description', width1))
317 for key, desc, value in t['table']:
318 S.append('{:{}} | {:{}} | {}'
319 .format(key, width0, desc, width1, value))
320 return '\n'.join(S)
323@contextmanager
324def no_progressbar(iterable: Iterable,
325 length: int = None) -> Iterator[Iterable]:
326 """A do-nothing implementation."""
327 yield iterable
330def check_jsmol():
331 static = Path(__file__).parent / 'static'
332 if not (static / 'jsmol/JSmol.min.js').is_file():
333 print(f"""
334 WARNING:
335 You don't have jsmol on your system.
337 Download Jmol-*-binary.tar.gz from
338 https://sourceforge.net/projects/jmol/files/Jmol/,
339 extract jsmol.zip, unzip it and create a soft-link:
341 $ tar -xf Jmol-*-binary.tar.gz
342 $ unzip jmol-*/jsmol.zip
343 $ ln -s $PWD/jsmol {static}/jsmol
344 """,
345 file=sys.stderr)