Coverage for /builds/ase/ase/ase/db/jsondb.py: 91.43%

175 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import os 

4import sys 

5from contextlib import ExitStack 

6 

7import numpy as np 

8 

9from ase.db.core import Database, lock, now, ops 

10from ase.db.row import AtomsRow 

11from ase.io.jsonio import decode, encode 

12from ase.parallel import parallel_function, world 

13 

14 

15class JSONDatabase(Database): 

16 def __enter__(self): 

17 return self 

18 

19 def __exit__(self, exc_type, exc_value, tb): 

20 pass 

21 

22 def _write(self, atoms, key_value_pairs, data, id): 

23 Database._write(self, atoms, key_value_pairs, data) 

24 

25 bigdct = {} 

26 ids = [] 

27 nextid = 1 

28 

29 if (isinstance(self.filename, str) and 

30 os.path.isfile(self.filename)): 

31 try: 

32 bigdct, ids, nextid = self._read_json() 

33 except (SyntaxError, ValueError): 

34 pass 

35 

36 mtime = now() 

37 

38 if isinstance(atoms, AtomsRow): 

39 row = atoms 

40 else: 

41 row = AtomsRow(atoms) 

42 row.ctime = mtime 

43 row.user = os.getenv('USER') 

44 

45 dct = {} 

46 for key in row.__dict__: 

47 if key[0] == '_' or key in row._keys or key == 'id': 

48 continue 

49 dct[key] = row[key] 

50 

51 dct['mtime'] = mtime 

52 

53 if key_value_pairs: 

54 dct['key_value_pairs'] = key_value_pairs 

55 

56 if data: 

57 dct['data'] = data 

58 

59 constraints = row.get('constraints') 

60 if constraints: 

61 dct['constraints'] = constraints 

62 

63 if id is None: 

64 id = nextid 

65 ids.append(id) 

66 nextid += 1 

67 else: 

68 assert id in bigdct 

69 

70 bigdct[id] = dct 

71 self._write_json(bigdct, ids, nextid) 

72 return id 

73 

74 def _read_json(self): 

75 if isinstance(self.filename, str): 

76 with open(self.filename) as fd: 

77 bigdct = decode(fd.read()) 

78 else: 

79 bigdct = decode(self.filename.read()) 

80 if self.filename is not sys.stdin: 

81 self.filename.seek(0) 

82 

83 if not isinstance(bigdct, dict) or ('ids' not in bigdct and 1 not in 

84 bigdct): 

85 from ase.io.formats import UnknownFileTypeError 

86 raise UnknownFileTypeError('Does not resemble ASE JSON database') 

87 

88 ids = bigdct.get('ids') 

89 if ids is None: 

90 # Allow for missing "ids" and "nextid": 

91 assert 1 in bigdct 

92 return bigdct, [1], 2 

93 if not isinstance(ids, list): 

94 ids = ids.tolist() 

95 return bigdct, ids, bigdct['nextid'] 

96 

97 def _write_json(self, bigdct, ids, nextid): 

98 if world.rank > 0: 

99 return 

100 

101 with ExitStack() as stack: 

102 if isinstance(self.filename, str): 

103 fd = stack.enter_context(open(self.filename, 'w')) 

104 else: 

105 fd = self.filename 

106 print('{', end='', file=fd) 

107 for id in ids: 

108 dct = bigdct[id] 

109 txt = ',\n '.join(f'"{key}": {encode(dct[key])}' 

110 for key in sorted(dct.keys())) 

111 print(f'"{id}": {{\n {txt}}},', file=fd) 

112 if self._metadata is not None: 

113 print(f'"metadata": {encode(self.metadata)},', file=fd) 

114 print(f'"ids": {ids},', file=fd) 

115 print(f'"nextid": {nextid}}}', file=fd) 

116 

117 @parallel_function 

118 @lock 

119 def delete(self, ids): 

120 bigdct, myids, nextid = self._read_json() 

121 for id in ids: 

122 del bigdct[id] 

123 myids.remove(id) 

124 self._write_json(bigdct, myids, nextid) 

125 

126 def _get_row(self, id): 

127 bigdct, ids, _nextid = self._read_json() 

128 if id is None: 

129 assert len(ids) == 1 

130 id = ids[0] 

131 dct = bigdct[id] 

132 dct['id'] = id 

133 return AtomsRow(dct) 

134 

135 def _select(self, keys, cmps, explain=False, verbosity=0, 

136 limit=None, offset=0, sort=None, include_data=True, 

137 columns='all'): 

138 if explain: 

139 yield {'explain': (0, 0, 0, 'scan table')} 

140 return 

141 

142 if sort: 

143 if sort[0] == '-': 

144 reverse = True 

145 sort = sort[1:] 

146 else: 

147 reverse = False 

148 

149 def f(row): 

150 return row.get(sort, missing) 

151 

152 rows = [] 

153 missing = [] 

154 for row in self._select(keys, cmps): 

155 key = row.get(sort) 

156 if key is None: 

157 missing.append((0, row)) 

158 else: 

159 rows.append((key, row)) 

160 

161 rows.sort(reverse=reverse, key=lambda x: x[0]) 

162 rows += missing 

163 

164 if limit: 

165 rows = rows[offset:offset + limit] 

166 for key, row in rows: 

167 yield row 

168 return 

169 

170 try: 

171 bigdct, ids, _nextid = self._read_json() 

172 except OSError: 

173 return 

174 

175 if not limit: 

176 limit = -offset - 1 

177 

178 cmps = [(key, ops[op], val) for key, op, val in cmps] 

179 n = 0 

180 for id in ids: 

181 if n - offset == limit: 

182 return 

183 dct = bigdct[id] 

184 if not include_data: 

185 dct.pop('data', None) 

186 row = AtomsRow(dct) 

187 row.id = id 

188 for key in keys: 

189 if key not in row: 

190 break 

191 else: 

192 for key, op, val in cmps: 

193 if isinstance(key, int): 

194 value = np.equal(row.numbers, key).sum() 

195 else: 

196 value = row.get(key) 

197 if key == 'pbc': 

198 assert op in [ops['='], ops['!=']] 

199 value = ''.join('FT'[x] for x in value) 

200 if value is None or not op(value, val): 

201 break 

202 else: 

203 if n >= offset: 

204 yield row 

205 n += 1 

206 

207 @property 

208 def metadata(self): 

209 if self._metadata is None: 

210 bigdct, _myids, _nextid = self._read_json() 

211 self._metadata = bigdct.get('metadata', {}) 

212 return self._metadata.copy() 

213 

214 @metadata.setter 

215 def metadata(self, dct): 

216 bigdct, ids, nextid = self._read_json() 

217 self._metadata = dct 

218 self._write_json(bigdct, ids, nextid) 

219 

220 def get_all_key_names(self): 

221 keys = set() 

222 bigdct, ids, _nextid = self._read_json() 

223 for id in ids: 

224 dct = bigdct[id] 

225 kvp = dct.get('key_value_pairs') 

226 if kvp: 

227 keys.update(kvp) 

228 return keys