Coverage for /builds/ase/ase/ase/db/jsondb.py: 91.43%
175 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import os
4import sys
5from contextlib import ExitStack
7import numpy as np
9from ase.db.core import Database, lock, now, ops
10from ase.db.row import AtomsRow
11from ase.io.jsonio import decode, encode
12from ase.parallel import parallel_function, world
15class JSONDatabase(Database):
16 def __enter__(self):
17 return self
19 def __exit__(self, exc_type, exc_value, tb):
20 pass
22 def _write(self, atoms, key_value_pairs, data, id):
23 Database._write(self, atoms, key_value_pairs, data)
25 bigdct = {}
26 ids = []
27 nextid = 1
29 if (isinstance(self.filename, str) and
30 os.path.isfile(self.filename)):
31 try:
32 bigdct, ids, nextid = self._read_json()
33 except (SyntaxError, ValueError):
34 pass
36 mtime = now()
38 if isinstance(atoms, AtomsRow):
39 row = atoms
40 else:
41 row = AtomsRow(atoms)
42 row.ctime = mtime
43 row.user = os.getenv('USER')
45 dct = {}
46 for key in row.__dict__:
47 if key[0] == '_' or key in row._keys or key == 'id':
48 continue
49 dct[key] = row[key]
51 dct['mtime'] = mtime
53 if key_value_pairs:
54 dct['key_value_pairs'] = key_value_pairs
56 if data:
57 dct['data'] = data
59 constraints = row.get('constraints')
60 if constraints:
61 dct['constraints'] = constraints
63 if id is None:
64 id = nextid
65 ids.append(id)
66 nextid += 1
67 else:
68 assert id in bigdct
70 bigdct[id] = dct
71 self._write_json(bigdct, ids, nextid)
72 return id
74 def _read_json(self):
75 if isinstance(self.filename, str):
76 with open(self.filename) as fd:
77 bigdct = decode(fd.read())
78 else:
79 bigdct = decode(self.filename.read())
80 if self.filename is not sys.stdin:
81 self.filename.seek(0)
83 if not isinstance(bigdct, dict) or ('ids' not in bigdct and 1 not in
84 bigdct):
85 from ase.io.formats import UnknownFileTypeError
86 raise UnknownFileTypeError('Does not resemble ASE JSON database')
88 ids = bigdct.get('ids')
89 if ids is None:
90 # Allow for missing "ids" and "nextid":
91 assert 1 in bigdct
92 return bigdct, [1], 2
93 if not isinstance(ids, list):
94 ids = ids.tolist()
95 return bigdct, ids, bigdct['nextid']
97 def _write_json(self, bigdct, ids, nextid):
98 if world.rank > 0:
99 return
101 with ExitStack() as stack:
102 if isinstance(self.filename, str):
103 fd = stack.enter_context(open(self.filename, 'w'))
104 else:
105 fd = self.filename
106 print('{', end='', file=fd)
107 for id in ids:
108 dct = bigdct[id]
109 txt = ',\n '.join(f'"{key}": {encode(dct[key])}'
110 for key in sorted(dct.keys()))
111 print(f'"{id}": {{\n {txt}}},', file=fd)
112 if self._metadata is not None:
113 print(f'"metadata": {encode(self.metadata)},', file=fd)
114 print(f'"ids": {ids},', file=fd)
115 print(f'"nextid": {nextid}}}', file=fd)
117 @parallel_function
118 @lock
119 def delete(self, ids):
120 bigdct, myids, nextid = self._read_json()
121 for id in ids:
122 del bigdct[id]
123 myids.remove(id)
124 self._write_json(bigdct, myids, nextid)
126 def _get_row(self, id):
127 bigdct, ids, _nextid = self._read_json()
128 if id is None:
129 assert len(ids) == 1
130 id = ids[0]
131 dct = bigdct[id]
132 dct['id'] = id
133 return AtomsRow(dct)
135 def _select(self, keys, cmps, explain=False, verbosity=0,
136 limit=None, offset=0, sort=None, include_data=True,
137 columns='all'):
138 if explain:
139 yield {'explain': (0, 0, 0, 'scan table')}
140 return
142 if sort:
143 if sort[0] == '-':
144 reverse = True
145 sort = sort[1:]
146 else:
147 reverse = False
149 def f(row):
150 return row.get(sort, missing)
152 rows = []
153 missing = []
154 for row in self._select(keys, cmps):
155 key = row.get(sort)
156 if key is None:
157 missing.append((0, row))
158 else:
159 rows.append((key, row))
161 rows.sort(reverse=reverse, key=lambda x: x[0])
162 rows += missing
164 if limit:
165 rows = rows[offset:offset + limit]
166 for key, row in rows:
167 yield row
168 return
170 try:
171 bigdct, ids, _nextid = self._read_json()
172 except OSError:
173 return
175 if not limit:
176 limit = -offset - 1
178 cmps = [(key, ops[op], val) for key, op, val in cmps]
179 n = 0
180 for id in ids:
181 if n - offset == limit:
182 return
183 dct = bigdct[id]
184 if not include_data:
185 dct.pop('data', None)
186 row = AtomsRow(dct)
187 row.id = id
188 for key in keys:
189 if key not in row:
190 break
191 else:
192 for key, op, val in cmps:
193 if isinstance(key, int):
194 value = np.equal(row.numbers, key).sum()
195 else:
196 value = row.get(key)
197 if key == 'pbc':
198 assert op in [ops['='], ops['!=']]
199 value = ''.join('FT'[x] for x in value)
200 if value is None or not op(value, val):
201 break
202 else:
203 if n >= offset:
204 yield row
205 n += 1
207 @property
208 def metadata(self):
209 if self._metadata is None:
210 bigdct, _myids, _nextid = self._read_json()
211 self._metadata = bigdct.get('metadata', {})
212 return self._metadata.copy()
214 @metadata.setter
215 def metadata(self, dct):
216 bigdct, ids, nextid = self._read_json()
217 self._metadata = dct
218 self._write_json(bigdct, ids, nextid)
220 def get_all_key_names(self):
221 keys = set()
222 bigdct, ids, _nextid = self._read_json()
223 for id in ids:
224 dct = bigdct[id]
225 kvp = dct.get('key_value_pairs')
226 if kvp:
227 keys.update(kvp)
228 return keys