Coverage for /builds/ase/ase/ase/db/sqlite.py: 88.33%

557 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3"""SQLite3 backend. 

4 

5Versions: 

6 

71) Added 3 more columns. 

82) Changed "user" to "username". 

93) Now adding keys to keyword table and added an "information" table containing 

10 a version number. 

114) Got rid of keywords. 

125) Add fmax, smax, mass, volume, charge 

136) Use REAL for magmom and drop possibility for non-collinear spin 

147) Volume can be None 

158) Added name='metadata' row to "information" table 

169) Row data is now stored in binary format. 

17""" 

18 

19import json 

20import numbers 

21import os 

22import sqlite3 

23import sys 

24from contextlib import contextmanager 

25 

26import numpy as np 

27 

28import ase.io.jsonio 

29from ase.calculators.calculator import all_properties 

30from ase.data import atomic_numbers 

31from ase.db.core import ( 

32 Database, 

33 bytes_to_object, 

34 invop, 

35 lock, 

36 now, 

37 object_to_bytes, 

38 ops, 

39 parse_selection, 

40) 

41from ase.db.row import AtomsRow 

42from ase.parallel import parallel_function 

43 

44VERSION = 9 

45 

46init_statements = [ 

47 """CREATE TABLE systems ( 

48 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name 

49 unique_id TEXT UNIQUE, 

50 ctime REAL, 

51 mtime REAL, 

52 username TEXT, 

53 numbers BLOB, -- stuff that defines an Atoms object 

54 positions BLOB, 

55 cell BLOB, 

56 pbc INTEGER, 

57 initial_magmoms BLOB, 

58 initial_charges BLOB, 

59 masses BLOB, 

60 tags BLOB, 

61 momenta BLOB, 

62 constraints TEXT, -- constraints and calculator 

63 calculator TEXT, 

64 calculator_parameters TEXT, 

65 energy REAL, -- calculated properties 

66 free_energy REAL, 

67 forces BLOB, 

68 stress BLOB, 

69 dipole BLOB, 

70 magmoms BLOB, 

71 magmom REAL, 

72 charges BLOB, 

73 key_value_pairs TEXT, -- key-value pairs and data as json 

74 data BLOB, 

75 natoms INTEGER, -- stuff for making queries faster 

76 fmax REAL, 

77 smax REAL, 

78 volume REAL, 

79 mass REAL, 

80 charge REAL)""", 

81 

82 """CREATE TABLE species ( 

83 Z INTEGER, 

84 n INTEGER, 

85 id INTEGER, 

86 FOREIGN KEY (id) REFERENCES systems(id))""", 

87 

88 """CREATE TABLE keys ( 

89 key TEXT, 

90 id INTEGER, 

91 FOREIGN KEY (id) REFERENCES systems(id))""", 

92 

93 """CREATE TABLE text_key_values ( 

94 key TEXT, 

95 value TEXT, 

96 id INTEGER, 

97 FOREIGN KEY (id) REFERENCES systems(id))""", 

98 

99 """CREATE TABLE number_key_values ( 

100 key TEXT, 

101 value REAL, 

102 id INTEGER, 

103 FOREIGN KEY (id) REFERENCES systems(id))""", 

104 

105 """CREATE TABLE information ( 

106 name TEXT, 

107 value TEXT)""", 

108 

109 f"INSERT INTO information VALUES ('version', '{VERSION}')"] 

110 

111index_statements = [ 

112 'CREATE INDEX unique_id_index ON systems(unique_id)', 

113 'CREATE INDEX ctime_index ON systems(ctime)', 

114 'CREATE INDEX username_index ON systems(username)', 

115 'CREATE INDEX calculator_index ON systems(calculator)', 

116 'CREATE INDEX species_index ON species(Z)', 

117 'CREATE INDEX key_index ON keys(key)', 

118 'CREATE INDEX text_index ON text_key_values(key)', 

119 'CREATE INDEX number_index ON number_key_values(key)'] 

120 

121all_tables = ['systems', 'species', 'keys', 

122 'text_key_values', 'number_key_values'] 

123 

124 

125def float_if_not_none(x): 

126 """Convert numpy.float64 to float - old db-interfaces need that.""" 

127 if x is not None: 

128 return float(x) 

129 

130 

131class SQLite3Database(Database): 

132 type = 'db' 

133 initialized = False 

134 _allow_reading_old_format = False 

135 default = 'NULL' # used for autoincrement id 

136 connection = None 

137 version = None 

138 columnnames = [line.split()[0].lstrip() 

139 for line in init_statements[0].splitlines()[1:]] 

140 

141 def encode(self, obj, binary=False): 

142 if binary: 

143 return object_to_bytes(obj) 

144 return ase.io.jsonio.encode(obj) 

145 

146 def decode(self, txt, lazy=False): 

147 if lazy: 

148 return txt 

149 if isinstance(txt, str): 

150 return ase.io.jsonio.decode(txt) 

151 return bytes_to_object(txt) 

152 

153 def blob(self, array): 

154 """Convert array to blob/buffer object.""" 

155 

156 if array is None: 

157 return None 

158 if len(array) == 0: 

159 array = np.zeros(0) 

160 if array.dtype == np.int64: 

161 array = array.astype(np.int32) 

162 if not np.little_endian: 

163 array = array.byteswap() 

164 return memoryview(np.ascontiguousarray(array)) 

165 

166 def deblob(self, buf, dtype=float, shape=None): 

167 """Convert blob/buffer object to ndarray of correct dtype and shape. 

168 

169 (without creating an extra view).""" 

170 if buf is None: 

171 return None 

172 if len(buf) == 0: 

173 array = np.zeros(0, dtype) 

174 else: 

175 array = np.frombuffer(buf, dtype) 

176 if not np.little_endian: 

177 array = array.byteswap() 

178 if shape is not None: 

179 array.shape = shape 

180 return array 

181 

182 def _connect(self): 

183 return sqlite3.connect(self.filename, timeout=20) 

184 

185 def __enter__(self): 

186 assert self.connection is None 

187 self.change_count = 0 

188 self.connection = self._connect() 

189 return self 

190 

191 def __exit__(self, exc_type, exc_value, tb): 

192 if exc_type is None: 

193 self.connection.commit() 

194 else: 

195 self.connection.rollback() 

196 self.connection.close() 

197 self.connection = None 

198 

199 @contextmanager 

200 def managed_connection(self, commit_frequency=5000): 

201 try: 

202 con = self.connection or self._connect() 

203 self._initialize(con) 

204 yield con 

205 except ValueError as exc: 

206 if self.connection is None: 

207 con.close() 

208 raise exc 

209 else: 

210 if self.connection is None: 

211 con.commit() 

212 con.close() 

213 else: 

214 self.change_count += 1 

215 if self.change_count % commit_frequency == 0: 

216 con.commit() 

217 

218 def _initialize(self, con): 

219 if self.initialized: 

220 return 

221 

222 self._metadata = {} 

223 

224 cur = con.execute( 

225 'SELECT COUNT(*) FROM sqlite_master WHERE name="systems"') 

226 

227 if cur.fetchone()[0] == 0: 

228 for statement in init_statements: 

229 con.execute(statement) 

230 if self.create_indices: 

231 for statement in index_statements: 

232 con.execute(statement) 

233 con.commit() 

234 self.version = VERSION 

235 else: 

236 cur = con.execute( 

237 'SELECT COUNT(*) FROM sqlite_master WHERE name="user_index"') 

238 if cur.fetchone()[0] == 1: 

239 # Old version with "user" instead of "username" column 

240 self.version = 1 

241 else: 

242 try: 

243 cur = con.execute( 

244 'SELECT value FROM information WHERE name="version"') 

245 except sqlite3.OperationalError: 

246 self.version = 2 

247 else: 

248 self.version = int(cur.fetchone()[0]) 

249 

250 cur = con.execute( 

251 'SELECT value FROM information WHERE name="metadata"') 

252 results = cur.fetchall() 

253 if results: 

254 self._metadata = json.loads(results[0][0]) 

255 

256 if self.version > VERSION: 

257 raise OSError('Can not read new ase.db format ' 

258 '(version {}). Please update to latest ASE.' 

259 .format(self.version)) 

260 if self.version < 5 and not self._allow_reading_old_format: 

261 raise OSError('Please convert to new format. ' + 

262 'Use: python -m ase.db.convert ' + self.filename) 

263 

264 self.initialized = True 

265 

266 def _write(self, atoms, key_value_pairs, data, id): 

267 ext_tables = key_value_pairs.pop("external_tables", {}) 

268 Database._write(self, atoms, key_value_pairs, data) 

269 

270 mtime = now() 

271 

272 encode = self.encode 

273 blob = self.blob 

274 

275 if not isinstance(atoms, AtomsRow): 

276 row = AtomsRow(atoms) 

277 row.ctime = mtime 

278 row.user = os.getenv('USER') 

279 else: 

280 row = atoms 

281 # Extract the external tables from AtomsRow 

282 names = self._get_external_table_names() 

283 for name in names: 

284 new_table = row.get(name, {}) 

285 if new_table: 

286 ext_tables[name] = new_table 

287 

288 if not id and not key_value_pairs and not ext_tables: 

289 key_value_pairs = row.key_value_pairs 

290 

291 for k, v in ext_tables.items(): 

292 dtype = self._guess_type(v) 

293 self._create_table_if_not_exists(k, dtype) 

294 

295 constraints = row._constraints 

296 if constraints: 

297 if isinstance(constraints, list): 

298 constraints = encode(constraints) 

299 else: 

300 constraints = None 

301 

302 values = (row.unique_id, 

303 row.ctime, 

304 mtime, 

305 row.user, 

306 blob(row.numbers), 

307 blob(row.positions), 

308 blob(row.cell), 

309 int(np.dot(row.pbc, [1, 2, 4])), 

310 blob(row.get('initial_magmoms')), 

311 blob(row.get('initial_charges')), 

312 blob(row.get('masses')), 

313 blob(row.get('tags')), 

314 blob(row.get('momenta')), 

315 constraints) 

316 

317 if 'calculator' in row: 

318 values += (row.calculator, encode(row.calculator_parameters)) 

319 else: 

320 values += (None, None) 

321 

322 if not data: 

323 data = row._data 

324 

325 with self.managed_connection() as con: 

326 if not isinstance(data, (str, bytes)): 

327 data = encode(data, binary=self.version >= 9) 

328 

329 values += (float_if_not_none(row.get('energy')), 

330 float_if_not_none(row.get('free_energy')), 

331 blob(row.get('forces')), 

332 blob(row.get('stress')), 

333 blob(row.get('dipole')), 

334 blob(row.get('magmoms')), 

335 row.get('magmom'), 

336 blob(row.get('charges')), 

337 encode(key_value_pairs), 

338 data, 

339 len(row.numbers), 

340 float_if_not_none(row.get('fmax')), 

341 float_if_not_none(row.get('smax')), 

342 float_if_not_none(row.get('volume')), 

343 float(row.mass), 

344 float(row.charge)) 

345 

346 cur = con.cursor() 

347 if id is None: 

348 q = self.default + ', ' + ', '.join('?' * len(values)) 

349 cur.execute(f'INSERT INTO systems VALUES ({q})', 

350 values) 

351 id = self.get_last_id(cur) 

352 else: 

353 self._delete(cur, [id], ['keys', 'text_key_values', 

354 'number_key_values', 'species']) 

355 q = ', '.join(name + '=?' for name in self.columnnames[1:]) 

356 cur.execute(f'UPDATE systems SET {q} WHERE id=?', 

357 values + (id,)) 

358 

359 count = row.count_atoms() 

360 if count: 

361 species = [(atomic_numbers[symbol], n, id) 

362 for symbol, n in count.items()] 

363 cur.executemany('INSERT INTO species VALUES (?, ?, ?)', 

364 species) 

365 

366 text_key_values = [] 

367 number_key_values = [] 

368 for key, value in key_value_pairs.items(): 

369 if isinstance(value, (numbers.Real, np.bool_)): 

370 number_key_values.append([key, float(value), id]) 

371 else: 

372 assert isinstance(value, str) 

373 text_key_values.append([key, value, id]) 

374 

375 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)', 

376 text_key_values) 

377 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)', 

378 number_key_values) 

379 cur.executemany('INSERT INTO keys VALUES (?, ?)', 

380 [(key, id) for key in key_value_pairs]) 

381 

382 # Insert entries in the valid tables 

383 for tabname in ext_tables.keys(): 

384 entries = ext_tables[tabname] 

385 entries['id'] = id 

386 self._insert_in_external_table( 

387 cur, name=tabname, entries=ext_tables[tabname]) 

388 

389 return id 

390 

391 def _update(self, id, key_value_pairs, data=None): 

392 """Update key_value_pairs and data for a single row """ 

393 encode = self.encode 

394 ext_tables = key_value_pairs.pop('external_tables', {}) 

395 

396 for k, v in ext_tables.items(): 

397 dtype = self._guess_type(v) 

398 self._create_table_if_not_exists(k, dtype) 

399 

400 mtime = now() 

401 with self.managed_connection() as con: 

402 cur = con.cursor() 

403 cur.execute( 

404 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?', 

405 (mtime, encode(key_value_pairs), id)) 

406 if data: 

407 if not isinstance(data, (str, bytes)): 

408 data = encode(data, binary=self.version >= 9) 

409 cur.execute('UPDATE systems set data=? where id=?', (data, id)) 

410 

411 self._delete(cur, [id], ['keys', 'text_key_values', 

412 'number_key_values']) 

413 

414 text_key_values = [] 

415 number_key_values = [] 

416 for key, value in key_value_pairs.items(): 

417 if isinstance(value, (numbers.Real, np.bool_)): 

418 number_key_values.append([key, float(value), id]) 

419 else: 

420 assert isinstance(value, str) 

421 text_key_values.append([key, value, id]) 

422 

423 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)', 

424 text_key_values) 

425 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)', 

426 number_key_values) 

427 cur.executemany('INSERT INTO keys VALUES (?, ?)', 

428 [(key, id) for key in key_value_pairs]) 

429 

430 # Insert entries in the valid tables 

431 for tabname in ext_tables.keys(): 

432 entries = ext_tables[tabname] 

433 entries['id'] = id 

434 self._insert_in_external_table( 

435 cur, name=tabname, entries=ext_tables[tabname]) 

436 

437 return id 

438 

439 def get_last_id(self, cur): 

440 cur.execute('SELECT seq FROM sqlite_sequence WHERE name="systems"') 

441 result = cur.fetchone() 

442 if result is not None: 

443 id = result[0] 

444 return id 

445 else: 

446 return 0 

447 

448 def _get_row(self, id): 

449 with self.managed_connection() as con: 

450 cur = con.cursor() 

451 if id is None: 

452 cur.execute('SELECT COUNT(*) FROM systems') 

453 assert cur.fetchone()[0] == 1 

454 cur.execute('SELECT * FROM systems') 

455 else: 

456 cur.execute('SELECT * FROM systems WHERE id=?', (id,)) 

457 values = cur.fetchone() 

458 

459 return self._convert_tuple_to_row(values) 

460 

461 def _convert_tuple_to_row(self, values): 

462 deblob = self.deblob 

463 decode = self.decode 

464 

465 values = self._old2new(values) 

466 dct = {'id': values[0], 

467 'unique_id': values[1], 

468 'ctime': values[2], 

469 'mtime': values[3], 

470 'user': values[4], 

471 'numbers': deblob(values[5], np.int32), 

472 'positions': deblob(values[6], shape=(-1, 3)), 

473 'cell': deblob(values[7], shape=(3, 3))} 

474 

475 if values[8] is not None: 

476 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool) 

477 if values[9] is not None: 

478 dct['initial_magmoms'] = deblob(values[9]) 

479 if values[10] is not None: 

480 dct['initial_charges'] = deblob(values[10]) 

481 if values[11] is not None: 

482 dct['masses'] = deblob(values[11]) 

483 if values[12] is not None: 

484 dct['tags'] = deblob(values[12], np.int32) 

485 if values[13] is not None: 

486 dct['momenta'] = deblob(values[13], shape=(-1, 3)) 

487 if values[14] is not None: 

488 dct['constraints'] = values[14] 

489 if values[15] is not None: 

490 dct['calculator'] = values[15] 

491 if values[16] is not None: 

492 dct['calculator_parameters'] = decode(values[16]) 

493 if values[17] is not None: 

494 dct['energy'] = values[17] 

495 if values[18] is not None: 

496 dct['free_energy'] = values[18] 

497 if values[19] is not None: 

498 dct['forces'] = deblob(values[19], shape=(-1, 3)) 

499 if values[20] is not None: 

500 dct['stress'] = deblob(values[20]) 

501 if values[21] is not None: 

502 dct['dipole'] = deblob(values[21]) 

503 if values[22] is not None: 

504 dct['magmoms'] = deblob(values[22]) 

505 if values[23] is not None: 

506 dct['magmom'] = values[23] 

507 if values[24] is not None: 

508 dct['charges'] = deblob(values[24]) 

509 if values[25] != '{}': 

510 dct['key_value_pairs'] = decode(values[25]) 

511 if len(values) >= 27 and values[26] != 'null': 

512 dct['data'] = decode(values[26], lazy=True) 

513 

514 # Now we need to update with info from the external tables 

515 external_tab = self._get_external_table_names() 

516 tables = {} 

517 for tab in external_tab: 

518 row = self._read_external_table(tab, dct["id"]) 

519 tables[tab] = row 

520 

521 dct.update(tables) 

522 return AtomsRow(dct) 

523 

524 def _old2new(self, values): 

525 if self.type == 'postgresql': 

526 assert self.version >= 8, 'Your db-version is too old!' 

527 assert self.version >= 4, 'Your db-file is too old!' 

528 if self.version < 5: 

529 pass # should be ok for reading by convert.py script 

530 if self.version < 6: 

531 m = values[23] 

532 if m is not None and not isinstance(m, float): 

533 magmom = float(self.deblob(m, shape=())) 

534 values = values[:23] + (magmom,) + values[24:] 

535 return values 

536 

537 def create_select_statement(self, keys, cmps, 

538 sort=None, order=None, sort_table=None, 

539 what='systems.*'): 

540 tables = ['systems'] 

541 where = [] 

542 args = [] 

543 for key in keys: 

544 if key == 'forces': 

545 where.append('systems.fmax IS NOT NULL') 

546 elif key == 'strain': 

547 where.append('systems.smax IS NOT NULL') 

548 elif key in ['energy', 'fmax', 'smax', 

549 'constraints', 'calculator']: 

550 where.append(f'systems.{key} IS NOT NULL') 

551 else: 

552 if '-' not in key: 

553 q = 'systems.id in (select id from keys where key=?)' 

554 else: 

555 key = key.replace('-', '') 

556 q = 'systems.id not in (select id from keys where key=?)' 

557 where.append(q) 

558 args.append(key) 

559 

560 # Special handling of "H=0" and "H<2" type of selections: 

561 bad = {} 

562 for key, op, value in cmps: 

563 if isinstance(key, int): 

564 bad[key] = bad.get(key, True) and ops[op](0, value) 

565 

566 for key, op, value in cmps: 

567 if key in ['id', 'energy', 'magmom', 'ctime', 'user', 

568 'calculator', 'natoms', 'pbc', 'unique_id', 

569 'fmax', 'smax', 'volume', 'mass', 'charge']: 

570 if key == 'user': 

571 key = 'username' 

572 elif key == 'pbc': 

573 assert op in ['=', '!='] 

574 value = int(np.dot([x == 'T' for x in value], [1, 2, 4])) 

575 elif key == 'magmom': 

576 assert self.version >= 6, 'Update your db-file' 

577 where.append(f'systems.{key}{op}?') 

578 args.append(value) 

579 elif isinstance(key, int): 

580 if self.type == 'postgresql': 

581 where.append( 

582 'cardinality(array_positions(' + 

583 f'numbers::int[], ?)){op}?') 

584 args += [key, value] 

585 else: 

586 if bad[key]: 

587 where.append( 

588 'systems.id not in (select id from species ' + 

589 f'where Z=? and n{invop[op]}?)') 

590 args += [key, value] 

591 else: 

592 where.append('systems.id in (select id from species ' + 

593 f'where Z=? and n{op}?)') 

594 args += [key, value] 

595 

596 elif self.type == 'postgresql': 

597 jsonop = '->' 

598 if isinstance(value, str): 

599 jsonop = '->>' 

600 elif isinstance(value, bool): 

601 jsonop = '->>' 

602 value = str(value).lower() 

603 where.append("systems.key_value_pairs {} '{}'{}?" 

604 .format(jsonop, key, op)) 

605 args.append(str(value)) 

606 

607 elif isinstance(value, str): 

608 where.append('systems.id in (select id from text_key_values ' + 

609 f'where key=? and value{op}?)') 

610 args += [key, value] 

611 else: 

612 where.append( 

613 'systems.id in (select id from number_key_values ' + 

614 f'where key=? and value{op}?)') 

615 args += [key, float(value)] 

616 

617 if sort: 

618 if sort_table != 'systems': 

619 tables.append(f'{sort_table} AS sort_table') 

620 where.append('systems.id=sort_table.id AND ' 

621 'sort_table.key=?') 

622 args.append(sort) 

623 sort_table = 'sort_table' 

624 sort = 'value' 

625 

626 sql = f'SELECT {what} FROM\n ' + ', '.join(tables) 

627 if where: 

628 sql += '\n WHERE\n ' + ' AND\n '.join(where) 

629 if sort: 

630 # XXX use "?" instead of "{}" 

631 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format( 

632 sort_table, sort, order) 

633 

634 return sql, args 

635 

636 def _select(self, keys, cmps, explain=False, verbosity=0, 

637 limit=None, offset=0, sort=None, include_data=True, 

638 columns='all'): 

639 

640 values = np.array([None for _ in range(27)]) 

641 values[25] = '{}' 

642 values[26] = 'null' 

643 

644 if columns == 'all': 

645 columnindex = list(range(26)) 

646 else: 

647 columnindex = [c for c in range(26) 

648 if self.columnnames[c] in columns] 

649 if include_data: 

650 columnindex.append(26) 

651 

652 if sort: 

653 if sort[0] == '-': 

654 order = 'DESC' 

655 sort = sort[1:] 

656 else: 

657 order = 'ASC' 

658 if sort in ['id', 'energy', 'username', 'calculator', 

659 'ctime', 'mtime', 'magmom', 'pbc', 

660 'fmax', 'smax', 'volume', 'mass', 'charge', 'natoms']: 

661 sort_table = 'systems' 

662 else: 

663 for dct in self._select(keys + [sort], cmps=[], limit=1, 

664 include_data=False, 

665 columns=['key_value_pairs']): 

666 if isinstance(dct['key_value_pairs'][sort], str): 

667 sort_table = 'text_key_values' 

668 else: 

669 sort_table = 'number_key_values' 

670 break 

671 else: 

672 # No rows. Just pick a table: 

673 sort_table = 'number_key_values' 

674 

675 else: 

676 order = None 

677 sort_table = None 

678 

679 what = ', '.join('systems.' + name 

680 for name in 

681 np.array(self.columnnames)[np.array(columnindex)]) 

682 

683 sql, args = self.create_select_statement(keys, cmps, sort, order, 

684 sort_table, what) 

685 

686 if explain: 

687 sql = 'EXPLAIN QUERY PLAN ' + sql 

688 

689 if limit: 

690 sql += f'\nLIMIT {limit}' 

691 

692 if offset: 

693 sql += self.get_offset_string(offset, limit=limit) 

694 

695 if verbosity == 2: 

696 print(sql, args) 

697 

698 with self.managed_connection() as con: 

699 cur = con.cursor() 

700 cur.execute(sql, args) 

701 if explain: 

702 for row in cur.fetchall(): 

703 yield {'explain': row} 

704 else: 

705 n = 0 

706 for shortvalues in cur.fetchall(): 

707 values[columnindex] = shortvalues 

708 yield self._convert_tuple_to_row(tuple(values)) 

709 n += 1 

710 

711 if sort and sort_table != 'systems': 

712 # Yield rows without sort key last: 

713 if limit is not None: 

714 if n == limit: 

715 return 

716 limit -= n 

717 for row in self._select(keys + ['-' + sort], cmps, 

718 limit=limit, offset=offset, 

719 include_data=include_data, 

720 columns=columns): 

721 yield row 

722 

723 def get_offset_string(self, offset, limit=None): 

724 sql = '' 

725 if not limit: 

726 # In sqlite you cannot have offset without limit, so we 

727 # set it to -1 meaning no limit 

728 sql += '\nLIMIT -1' 

729 sql += f'\nOFFSET {offset}' 

730 return sql 

731 

732 @parallel_function 

733 def count(self, selection=None, **kwargs): 

734 keys, cmps = parse_selection(selection, **kwargs) 

735 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)') 

736 

737 with self.managed_connection() as con: 

738 cur = con.cursor() 

739 cur.execute(sql, args) 

740 return cur.fetchone()[0] 

741 

742 def analyse(self): 

743 with self.managed_connection() as con: 

744 con.execute('ANALYZE') 

745 

746 @parallel_function 

747 @lock 

748 def delete(self, ids): 

749 if len(ids) == 0: 

750 return 

751 table_names = self._get_external_table_names() + all_tables[::-1] 

752 with self.managed_connection() as con: 

753 self._delete(con.cursor(), ids, 

754 tables=table_names) 

755 self.vacuum() 

756 

757 def _delete(self, cur, ids, tables=None): 

758 tables = tables or all_tables[::-1] 

759 for table in tables: 

760 cur.execute('DELETE FROM {} WHERE id in ({});'. 

761 format(table, ', '.join([str(id) for id in ids]))) 

762 

763 def vacuum(self): 

764 if self.type != 'db': 

765 return 

766 

767 with self.managed_connection() as con: 

768 con.commit() 

769 con.cursor().execute("VACUUM") 

770 

771 @property 

772 def metadata(self): 

773 if self._metadata is None: 

774 self._initialize(self._connect()) 

775 return self._metadata.copy() 

776 

777 @metadata.setter 

778 def metadata(self, dct): 

779 self._metadata = dct 

780 md = json.dumps(dct) 

781 with self.managed_connection() as con: 

782 cur = con.cursor() 

783 cur.execute( 

784 "SELECT COUNT(*) FROM information WHERE name='metadata'") 

785 

786 if cur.fetchone()[0]: 

787 cur.execute( 

788 "UPDATE information SET value=? WHERE name='metadata'", 

789 [md]) 

790 else: 

791 cur.execute('INSERT INTO information VALUES (?, ?)', 

792 ('metadata', md)) 

793 

794 def _get_external_table_names(self, db_con=None): 

795 """Return a list with the external table names.""" 

796 sql = "SELECT value FROM information WHERE name='external_table_name'" 

797 with self.managed_connection() as con: 

798 cur = con.cursor() 

799 cur.execute(sql) 

800 ext_tab_names = [x[0] for x in cur.fetchall()] 

801 return ext_tab_names 

802 

803 def _external_table_exists(self, name): 

804 """Return True if an external table name exists.""" 

805 return name in self._get_external_table_names() 

806 

807 def _create_table_if_not_exists(self, name, dtype): 

808 """Create a new table if it does not exits. 

809 

810 Arguments 

811 ========== 

812 name: str 

813 Name of the new table 

814 dtype: str 

815 Datatype of the value field (typically REAL, INTEGER, TEXT etc.) 

816 """ 

817 

818 taken_names = set(all_tables + all_properties + self.columnnames) 

819 if name in taken_names: 

820 raise ValueError("External table can not be any of {}" 

821 "".format(taken_names)) 

822 

823 if self._external_table_exists(name): 

824 return 

825 

826 sql = f"CREATE TABLE IF NOT EXISTS {name} " 

827 sql += f"(key TEXT, value {dtype}, id INTEGER, " 

828 sql += "FOREIGN KEY (id) REFERENCES systems(id))" 

829 sql2 = "INSERT INTO information VALUES (?, ?)" 

830 with self.managed_connection() as con: 

831 cur = con.cursor() 

832 cur.execute(sql) 

833 # Insert an entry saying that there is a new external table 

834 # present and an entry with the datatype 

835 cur.execute(sql2, ("external_table_name", name)) 

836 cur.execute(sql2, (name + "_dtype", dtype)) 

837 

838 def delete_external_table(self, name): 

839 """Delete an external table.""" 

840 if not self._external_table_exists(name): 

841 return 

842 

843 with self.managed_connection() as con: 

844 cur = con.cursor() 

845 

846 sql = f"DROP TABLE {name}" 

847 cur.execute(sql) 

848 

849 sql = "DELETE FROM information WHERE value=?" 

850 cur.execute(sql, (name,)) 

851 sql = "DELETE FROM information WHERE name=?" 

852 cur.execute(sql, (name + "_dtype",)) 

853 

854 def _convert_to_recognized_types(self, value): 

855 """Convert Numpy types to python types.""" 

856 if np.issubdtype(type(value), np.integer): 

857 return int(value) 

858 elif np.issubdtype(type(value), np.floating): 

859 return float(value) 

860 return value 

861 

862 def _insert_in_external_table(self, cursor, name=None, entries=None): 

863 """Insert into external table""" 

864 if name is None or entries is None: 

865 # There is nothing to do 

866 return 

867 

868 id = entries.pop("id") 

869 dtype = self._guess_type(entries) 

870 expected_dtype = self._get_value_type_of_table(cursor, name) 

871 if dtype != expected_dtype: 

872 raise ValueError("The provided data type for table {} " 

873 "is {}, while it is initialized to " 

874 "be of type {}" 

875 "".format(name, dtype, expected_dtype)) 

876 

877 # First we check if entries already exists 

878 cursor.execute(f"SELECT key FROM {name} WHERE id=?", (id,)) 

879 updates = [] 

880 for item in cursor.fetchall(): 

881 value = entries.pop(item[0], None) 

882 if value is not None: 

883 updates.append( 

884 (value, id, self._convert_to_recognized_types(item[0]))) 

885 

886 # Update entry if key and ID already exists 

887 sql = f"UPDATE {name} SET value=? WHERE id=? AND key=?" 

888 cursor.executemany(sql, updates) 

889 

890 # Insert the ones that does not already exist 

891 inserts = [(k, self._convert_to_recognized_types(v), id) 

892 for k, v in entries.items()] 

893 sql = f"INSERT INTO {name} VALUES (?, ?, ?)" 

894 cursor.executemany(sql, inserts) 

895 

896 def _guess_type(self, entries): 

897 """Guess the type based on the first entry.""" 

898 values = [v for _, v in entries.items()] 

899 

900 # Check if all datatypes are the same 

901 all_types = [type(v) for v in values] 

902 if any(t != all_types[0] for t in all_types): 

903 typenames = [t.__name__ for t in all_types] 

904 raise ValueError("Inconsistent datatypes in the table. " 

905 "given types: {}".format(typenames)) 

906 

907 val = values[0] 

908 if isinstance(val, int) or np.issubdtype(type(val), np.integer): 

909 return "INTEGER" 

910 if isinstance(val, float) or np.issubdtype(type(val), np.floating): 

911 return "REAL" 

912 if isinstance(val, str): 

913 return "TEXT" 

914 raise ValueError("Unknown datatype!") 

915 

916 def _get_value_type_of_table(self, cursor, tab_name): 

917 """Return the expected value name.""" 

918 sql = "SELECT value FROM information WHERE name=?" 

919 cursor.execute(sql, (tab_name + "_dtype",)) 

920 return cursor.fetchone()[0] 

921 

922 def _read_external_table(self, name, id): 

923 """Read row from external table.""" 

924 

925 with self.managed_connection() as con: 

926 cur = con.cursor() 

927 cur.execute(f"SELECT * FROM {name} WHERE id=?", (id,)) 

928 items = cur.fetchall() 

929 dictionary = {item[0]: item[1] for item in items} 

930 

931 return dictionary 

932 

933 def get_all_key_names(self): 

934 """Create set of all key names.""" 

935 with self.managed_connection() as con: 

936 cur = con.cursor() 

937 cur.execute('SELECT DISTINCT key FROM keys;') 

938 all_keys = {row[0] for row in cur.fetchall()} 

939 return all_keys 

940 

941 

942if __name__ == '__main__': 

943 from ase.db import connect 

944 con = connect(sys.argv[1]) 

945 con._initialize(con._connect()) 

946 print('Version:', con.version)