Coverage for ase / db / sqlite.py: 88.71%

549 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1"""SQLite3 backend. 

2 

3Versions: 

4 

51) Added 3 more columns. 

62) Changed "user" to "username". 

73) Now adding keys to keyword table and added an "information" table containing 

8 a version number. 

94) Got rid of keywords. 

105) Add fmax, smax, mass, volume, charge 

116) Use REAL for magmom and drop possibility for non-collinear spin 

127) Volume can be None 

138) Added name='metadata' row to "information" table 

149) Row data is now stored in binary format. 

15""" 

16 

17import json 

18import numbers 

19import os 

20import sqlite3 

21import sys 

22from contextlib import contextmanager 

23 

24import numpy as np 

25 

26import ase.io.jsonio 

27from ase.calculators.calculator import all_properties 

28from ase.data import atomic_numbers 

29from ase.db.core import ( 

30 Database, 

31 bytes_to_object, 

32 invop, 

33 lock, 

34 now, 

35 object_to_bytes, 

36 ops, 

37 parse_selection, 

38) 

39from ase.db.row import AtomsRow 

40from ase.parallel import parallel_function 

41 

42VERSION = 9 

43 

44init_statements = [ 

45 """CREATE TABLE systems ( 

46 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name 

47 unique_id TEXT UNIQUE, 

48 ctime REAL, 

49 mtime REAL, 

50 username TEXT, 

51 numbers BLOB, -- stuff that defines an Atoms object 

52 positions BLOB, 

53 cell BLOB, 

54 pbc INTEGER, 

55 initial_magmoms BLOB, 

56 initial_charges BLOB, 

57 masses BLOB, 

58 tags BLOB, 

59 momenta BLOB, 

60 constraints TEXT, -- constraints and calculator 

61 calculator TEXT, 

62 calculator_parameters TEXT, 

63 energy REAL, -- calculated properties 

64 free_energy REAL, 

65 forces BLOB, 

66 stress BLOB, 

67 dipole BLOB, 

68 magmoms BLOB, 

69 magmom REAL, 

70 charges BLOB, 

71 key_value_pairs TEXT, -- key-value pairs and data as json 

72 data BLOB, 

73 natoms INTEGER, -- stuff for making queries faster 

74 fmax REAL, 

75 smax REAL, 

76 volume REAL, 

77 mass REAL, 

78 charge REAL)""", 

79 """CREATE TABLE species ( 

80 Z INTEGER, 

81 n INTEGER, 

82 id INTEGER, 

83 FOREIGN KEY (id) REFERENCES systems(id))""", 

84 """CREATE TABLE keys ( 

85 key TEXT, 

86 id INTEGER, 

87 FOREIGN KEY (id) REFERENCES systems(id))""", 

88 """CREATE TABLE text_key_values ( 

89 key TEXT, 

90 value TEXT, 

91 id INTEGER, 

92 FOREIGN KEY (id) REFERENCES systems(id))""", 

93 """CREATE TABLE number_key_values ( 

94 key TEXT, 

95 value REAL, 

96 id INTEGER, 

97 FOREIGN KEY (id) REFERENCES systems(id))""", 

98 """CREATE TABLE information ( 

99 name TEXT, 

100 value TEXT)""", 

101 f"INSERT INTO information VALUES ('version', '{VERSION}')", 

102] 

103 

104index_statements = [ 

105 'CREATE INDEX unique_id_index ON systems(unique_id)', 

106 'CREATE INDEX ctime_index ON systems(ctime)', 

107 'CREATE INDEX username_index ON systems(username)', 

108 'CREATE INDEX calculator_index ON systems(calculator)', 

109 'CREATE INDEX species_index ON species(Z)', 

110 'CREATE INDEX key_index ON keys(key)', 

111 'CREATE INDEX text_index ON text_key_values(key)', 

112 'CREATE INDEX number_index ON number_key_values(key)', 

113] 

114 

115all_tables = [ 

116 'systems', 

117 'species', 

118 'keys', 

119 'text_key_values', 

120 'number_key_values', 

121] 

122 

123 

124def float_if_not_none(x): 

125 """Convert numpy.float64 to float - old db-interfaces need that.""" 

126 if x is not None: 

127 return float(x) 

128 

129 

130class SQLite3Database(Database): 

131 type = 'db' 

132 initialized = False 

133 _allow_reading_old_format = False 

134 default = 'NULL' # used for autoincrement id 

135 connection = None 

136 version = None 

137 columnnames = [ 

138 line.split()[0].lstrip() for line in init_statements[0].splitlines()[1:] 

139 ] 

140 

141 def encode(self, obj, binary=False): 

142 if binary: 

143 return object_to_bytes(obj) 

144 return ase.io.jsonio.encode(obj) 

145 

146 def decode(self, txt, lazy=False): 

147 if lazy: 

148 return txt 

149 if isinstance(txt, str): 

150 return ase.io.jsonio.decode(txt) 

151 return bytes_to_object(txt) 

152 

153 def blob(self, array): 

154 """Convert array to blob/buffer object.""" 

155 

156 if array is None: 

157 return None 

158 if len(array) == 0: 

159 array = np.zeros(0) 

160 if array.dtype == np.int64: 

161 array = array.astype(np.int32) 

162 if not np.little_endian: 

163 array = array.byteswap() 

164 return memoryview(np.ascontiguousarray(array)) 

165 

166 def deblob(self, buf, dtype=float, shape=None): 

167 """Convert blob/buffer object to ndarray of correct dtype and shape. 

168 

169 (without creating an extra view).""" 

170 if buf is None: 

171 return None 

172 if len(buf) == 0: 

173 array = np.zeros(0, dtype) 

174 else: 

175 array = np.frombuffer(buf, dtype) 

176 if not np.little_endian: 

177 array = array.byteswap() 

178 if shape is not None: 

179 array.shape = shape 

180 return array 

181 

182 def _connect(self): 

183 return sqlite3.connect(self.filename, timeout=20) 

184 

185 def __enter__(self): 

186 assert self.connection is None 

187 self.change_count = 0 

188 self.connection = self._connect() 

189 return self 

190 

191 def __exit__(self, exc_type, exc_value, tb): 

192 if exc_type is None: 

193 self.connection.commit() 

194 else: 

195 self.connection.rollback() 

196 self.connection.close() 

197 self.connection = None 

198 

199 @contextmanager 

200 def managed_connection(self, commit_frequency=5000): 

201 from contextlib import ExitStack 

202 

203 with ExitStack() as stack: 

204 con = self.connection or stack.enter_context(self._connect()) 

205 self._initialize(con) 

206 yield con 

207 

208 if self.connection is None: 

209 con.commit() 

210 else: 

211 self.change_count += 1 

212 if self.change_count % commit_frequency == 0: 

213 con.commit() 

214 

215 def _initialize(self, con): 

216 if self.initialized: 

217 return 

218 

219 self._metadata = {} 

220 

221 cur = con.execute( 

222 "SELECT COUNT(*) FROM sqlite_master WHERE name='systems'" 

223 ) 

224 

225 if cur.fetchone()[0] == 0: 

226 for statement in init_statements: 

227 con.execute(statement) 

228 if self.create_indices: 

229 for statement in index_statements: 

230 con.execute(statement) 

231 con.commit() 

232 self.version = VERSION 

233 else: 

234 cur = con.execute( 

235 "SELECT COUNT(*) FROM sqlite_master WHERE name='user_index'" 

236 ) 

237 if cur.fetchone()[0] == 1: 

238 # Old version with "user" instead of "username" column 

239 self.version = 1 

240 else: 

241 try: 

242 cur = con.execute( 

243 "SELECT value FROM information WHERE name='version'" 

244 ) 

245 except sqlite3.OperationalError: 

246 self.version = 2 

247 else: 

248 self.version = int(cur.fetchone()[0]) 

249 

250 cur = con.execute( 

251 "SELECT value FROM information WHERE name='metadata'" 

252 ) 

253 results = cur.fetchall() 

254 if results: 

255 self._metadata = json.loads(results[0][0]) 

256 

257 if self.version > VERSION: 

258 raise OSError( 

259 f'Can not read new ase.db format (version {self.version}). ' 

260 'Please update to latest ASE.' 

261 ) 

262 if self.version < 5 and not self._allow_reading_old_format: 

263 raise OSError( 

264 'Please convert to new format. ' 

265 f'Use: python -m ase.db.convert {self.filename}' 

266 ) 

267 

268 self.initialized = True 

269 

270 def _write(self, atoms, key_value_pairs, data, id): 

271 ext_tables = key_value_pairs.pop('external_tables', {}) 

272 Database._write(self, atoms, key_value_pairs, data) 

273 

274 mtime = now() 

275 

276 encode = self.encode 

277 blob = self.blob 

278 

279 if not isinstance(atoms, AtomsRow): 

280 row = AtomsRow(atoms) 

281 row.ctime = mtime 

282 row.user = os.getenv('USER') 

283 else: 

284 row = atoms 

285 # Extract the external tables from AtomsRow 

286 names = self._get_external_table_names() 

287 for name in names: 

288 new_table = row.get(name, {}) 

289 if new_table: 

290 ext_tables[name] = new_table 

291 

292 if not id and not key_value_pairs and not ext_tables: 

293 key_value_pairs = row.key_value_pairs 

294 

295 for k, v in ext_tables.items(): 

296 dtype = self._guess_type(v) 

297 self._create_table_if_not_exists(k, dtype) 

298 

299 constraints = row._constraints 

300 if constraints: 

301 if isinstance(constraints, list): 

302 constraints = encode(constraints) 

303 else: 

304 constraints = None 

305 

306 values = ( 

307 row.unique_id, 

308 row.ctime, 

309 mtime, 

310 row.user, 

311 blob(row.numbers), 

312 blob(row.positions), 

313 blob(row.cell), 

314 int(np.dot(row.pbc, [1, 2, 4])), 

315 blob(row.get('initial_magmoms')), 

316 blob(row.get('initial_charges')), 

317 blob(row.get('masses')), 

318 blob(row.get('tags')), 

319 blob(row.get('momenta')), 

320 constraints, 

321 ) 

322 

323 if 'calculator' in row: 

324 values += (row.calculator, encode(row.calculator_parameters)) 

325 else: 

326 values += (None, None) 

327 

328 if not data: 

329 data = row._data 

330 

331 with self.managed_connection() as con: 

332 if not isinstance(data, (str, bytes)): 

333 data = encode(data, binary=self.version >= 9) 

334 

335 values += ( 

336 float_if_not_none(row.get('energy')), 

337 float_if_not_none(row.get('free_energy')), 

338 blob(row.get('forces')), 

339 blob(row.get('stress')), 

340 blob(row.get('dipole')), 

341 blob(row.get('magmoms')), 

342 row.get('magmom'), 

343 blob(row.get('charges')), 

344 encode(key_value_pairs), 

345 data, 

346 len(row.numbers), 

347 float_if_not_none(row.get('fmax')), 

348 float_if_not_none(row.get('smax')), 

349 float_if_not_none(row.get('volume')), 

350 float(row.mass), 

351 float(row.charge), 

352 ) 

353 

354 cur = con.cursor() 

355 if id is None: 

356 q = self.default + ', ' + ', '.join('?' * len(values)) 

357 cur.execute(f'INSERT INTO systems VALUES ({q})', values) 

358 id = self.get_last_id(cur) 

359 else: 

360 self._delete( 

361 cur, 

362 [id], 

363 ['keys', 'text_key_values', 'number_key_values', 'species'], 

364 ) 

365 q = ', '.join(name + '=?' for name in self.columnnames[1:]) 

366 cur.execute( 

367 f'UPDATE systems SET {q} WHERE id=?', values + (id,) 

368 ) 

369 

370 count = row.count_atoms() 

371 if count: 

372 species = [ 

373 (atomic_numbers[symbol], n, id) 

374 for symbol, n in count.items() 

375 ] 

376 cur.executemany('INSERT INTO species VALUES (?, ?, ?)', species) 

377 

378 text_key_values = [] 

379 number_key_values = [] 

380 for key, value in key_value_pairs.items(): 

381 if isinstance(value, (numbers.Real, np.bool_)): 

382 number_key_values.append([key, float(value), id]) 

383 else: 

384 assert isinstance(value, str) 

385 text_key_values.append([key, value, id]) 

386 

387 cur.executemany( 

388 'INSERT INTO text_key_values VALUES (?, ?, ?)', text_key_values 

389 ) 

390 cur.executemany( 

391 'INSERT INTO number_key_values VALUES (?, ?, ?)', 

392 number_key_values, 

393 ) 

394 cur.executemany( 

395 'INSERT INTO keys VALUES (?, ?)', 

396 [(key, id) for key in key_value_pairs], 

397 ) 

398 

399 # Insert entries in the valid tables 

400 for tabname in ext_tables.keys(): 

401 entries = ext_tables[tabname] 

402 entries['id'] = id 

403 self._insert_in_external_table( 

404 cur, name=tabname, entries=ext_tables[tabname] 

405 ) 

406 

407 return id 

408 

409 def _update(self, id, key_value_pairs, data=None): 

410 """Update key_value_pairs and data for a single row""" 

411 encode = self.encode 

412 ext_tables = key_value_pairs.pop('external_tables', {}) 

413 

414 for k, v in ext_tables.items(): 

415 dtype = self._guess_type(v) 

416 self._create_table_if_not_exists(k, dtype) 

417 

418 mtime = now() 

419 with self.managed_connection() as con: 

420 cur = con.cursor() 

421 cur.execute( 

422 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?', 

423 (mtime, encode(key_value_pairs), id), 

424 ) 

425 if data: 

426 if not isinstance(data, (str, bytes)): 

427 data = encode(data, binary=self.version >= 9) 

428 cur.execute('UPDATE systems set data=? where id=?', (data, id)) 

429 

430 self._delete( 

431 cur, [id], ['keys', 'text_key_values', 'number_key_values'] 

432 ) 

433 

434 text_key_values = [] 

435 number_key_values = [] 

436 for key, value in key_value_pairs.items(): 

437 if isinstance(value, (numbers.Real, np.bool_)): 

438 number_key_values.append([key, float(value), id]) 

439 else: 

440 assert isinstance(value, str) 

441 text_key_values.append([key, value, id]) 

442 

443 cur.executemany( 

444 'INSERT INTO text_key_values VALUES (?, ?, ?)', text_key_values 

445 ) 

446 cur.executemany( 

447 'INSERT INTO number_key_values VALUES (?, ?, ?)', 

448 number_key_values, 

449 ) 

450 cur.executemany( 

451 'INSERT INTO keys VALUES (?, ?)', 

452 [(key, id) for key in key_value_pairs], 

453 ) 

454 

455 # Insert entries in the valid tables 

456 for tabname in ext_tables.keys(): 

457 entries = ext_tables[tabname] 

458 entries['id'] = id 

459 self._insert_in_external_table( 

460 cur, name=tabname, entries=ext_tables[tabname] 

461 ) 

462 

463 return id 

464 

465 def get_last_id(self, cur): 

466 cur.execute("SELECT seq FROM sqlite_sequence WHERE name='systems'") 

467 result = cur.fetchone() 

468 if result is not None: 

469 id = result[0] 

470 return id 

471 else: 

472 return 0 

473 

474 def _get_row(self, id): 

475 with self.managed_connection() as con: 

476 cur = con.cursor() 

477 if id is None: 

478 cur.execute('SELECT COUNT(*) FROM systems') 

479 assert cur.fetchone()[0] == 1 

480 cur.execute('SELECT * FROM systems') 

481 else: 

482 cur.execute('SELECT * FROM systems WHERE id=?', (id,)) 

483 values = cur.fetchone() 

484 

485 return self._convert_tuple_to_row(values) 

486 

487 def _convert_tuple_to_row(self, values): 

488 deblob = self.deblob 

489 decode = self.decode 

490 

491 values = self._old2new(values) 

492 dct = { 

493 'id': values[0], 

494 'unique_id': values[1], 

495 'ctime': values[2], 

496 'mtime': values[3], 

497 'user': values[4], 

498 'numbers': deblob(values[5], np.int32), 

499 'positions': deblob(values[6], shape=(-1, 3)), 

500 'cell': deblob(values[7], shape=(3, 3)), 

501 } 

502 

503 if values[8] is not None: 

504 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool) 

505 if values[9] is not None: 

506 dct['initial_magmoms'] = deblob(values[9]) 

507 if values[10] is not None: 

508 dct['initial_charges'] = deblob(values[10]) 

509 if values[11] is not None: 

510 dct['masses'] = deblob(values[11]) 

511 if values[12] is not None: 

512 dct['tags'] = deblob(values[12], np.int32) 

513 if values[13] is not None: 

514 dct['momenta'] = deblob(values[13], shape=(-1, 3)) 

515 if values[14] is not None: 

516 dct['constraints'] = values[14] 

517 if values[15] is not None: 

518 dct['calculator'] = values[15] 

519 if values[16] is not None: 

520 dct['calculator_parameters'] = decode(values[16]) 

521 if values[17] is not None: 

522 dct['energy'] = values[17] 

523 if values[18] is not None: 

524 dct['free_energy'] = values[18] 

525 if values[19] is not None: 

526 dct['forces'] = deblob(values[19], shape=(-1, 3)) 

527 if values[20] is not None: 

528 dct['stress'] = deblob(values[20]) 

529 if values[21] is not None: 

530 dct['dipole'] = deblob(values[21]) 

531 if values[22] is not None: 

532 dct['magmoms'] = deblob(values[22]) 

533 if values[23] is not None: 

534 dct['magmom'] = values[23] 

535 if values[24] is not None: 

536 dct['charges'] = deblob(values[24]) 

537 if values[25] != '{}': 

538 dct['key_value_pairs'] = decode(values[25]) 

539 if len(values) >= 27 and values[26] != 'null': 

540 dct['data'] = decode(values[26], lazy=True) 

541 

542 # Now we need to update with info from the external tables 

543 external_tab = self._get_external_table_names() 

544 tables = {} 

545 for tab in external_tab: 

546 row = self._read_external_table(tab, dct['id']) 

547 tables[tab] = row 

548 

549 dct.update(tables) 

550 return AtomsRow(dct) 

551 

552 def _old2new(self, values): 

553 if self.type == 'postgresql': 

554 assert self.version >= 8, 'Your db-version is too old!' 

555 assert self.version >= 4, 'Your db-file is too old!' 

556 if self.version < 5: 

557 pass # should be ok for reading by convert.py script 

558 if self.version < 6: 

559 m = values[23] 

560 if m is not None and not isinstance(m, float): 

561 magmom = float(self.deblob(m, shape=())) 

562 values = values[:23] + (magmom,) + values[24:] 

563 return values 

564 

565 def create_select_statement( 

566 self, 

567 keys, 

568 cmps, 

569 sort=None, 

570 order=None, 

571 sort_table=None, 

572 what='systems.*', 

573 ): 

574 tables = ['systems'] 

575 where = [] 

576 args = [] 

577 for key in keys: 

578 if key == 'forces': 

579 where.append('systems.fmax IS NOT NULL') 

580 elif key == 'strain': 

581 where.append('systems.smax IS NOT NULL') 

582 elif key in ['energy', 'fmax', 'smax', 'constraints', 'calculator']: 

583 where.append(f'systems.{key} IS NOT NULL') 

584 else: 

585 if '-' not in key: 

586 q = 'systems.id in (select id from keys where key=?)' 

587 else: 

588 key = key.replace('-', '') 

589 q = 'systems.id not in (select id from keys where key=?)' 

590 where.append(q) 

591 args.append(key) 

592 

593 # Special handling of "H=0" and "H<2" type of selections: 

594 bad = {} 

595 for key, op, value in cmps: 

596 if isinstance(key, int): 

597 bad[key] = bad.get(key, True) and ops[op](0, value) 

598 

599 for key, op, value in cmps: 

600 if key in [ 

601 'id', 

602 'energy', 

603 'magmom', 

604 'ctime', 

605 'user', 

606 'calculator', 

607 'natoms', 

608 'pbc', 

609 'unique_id', 

610 'fmax', 

611 'smax', 

612 'volume', 

613 'mass', 

614 'charge', 

615 ]: 

616 if key == 'user': 

617 key = 'username' 

618 elif key == 'pbc': 

619 assert op in ['=', '!='] 

620 value = int(np.dot([x == 'T' for x in value], [1, 2, 4])) 

621 elif key == 'magmom': 

622 assert self.version >= 6, 'Update your db-file' 

623 where.append(f'systems.{key}{op}?') 

624 args.append(value) 

625 elif isinstance(key, int): 

626 if self.type == 'postgresql': 

627 where.append( 

628 'cardinality(array_positions(' 

629 + f'numbers::int[], ?)){op}?' 

630 ) 

631 args += [key, value] 

632 else: 

633 if bad[key]: 

634 where.append( 

635 'systems.id not in (select id from species ' 

636 + f'where Z=? and n{invop[op]}?)' 

637 ) 

638 args += [key, value] 

639 else: 

640 where.append( 

641 'systems.id in (select id from species ' 

642 + f'where Z=? and n{op}?)' 

643 ) 

644 args += [key, value] 

645 

646 elif self.type == 'postgresql': 

647 jsonop = '->' 

648 if isinstance(value, str): 

649 jsonop = '->>' 

650 elif isinstance(value, bool): 

651 jsonop = '->>' 

652 value = str(value).lower() 

653 where.append( 

654 "systems.key_value_pairs {} '{}'{}?".format(jsonop, key, op) 

655 ) 

656 args.append(str(value)) 

657 

658 elif isinstance(value, str): 

659 where.append( 

660 'systems.id in (select id from text_key_values ' 

661 + f'where key=? and value{op}?)' 

662 ) 

663 args += [key, value] 

664 else: 

665 where.append( 

666 'systems.id in (select id from number_key_values ' 

667 + f'where key=? and value{op}?)' 

668 ) 

669 args += [key, float(value)] 

670 

671 if sort: 

672 if sort_table != 'systems': 

673 tables.append(f'{sort_table} AS sort_table') 

674 where.append('systems.id=sort_table.id AND sort_table.key=?') 

675 args.append(sort) 

676 sort_table = 'sort_table' 

677 sort = 'value' 

678 

679 sql = f'SELECT {what} FROM\n ' + ', '.join(tables) 

680 if where: 

681 sql += '\n WHERE\n ' + ' AND\n '.join(where) 

682 if sort: 

683 # XXX use "?" instead of "{}" 

684 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format( 

685 sort_table, sort, order 

686 ) 

687 

688 return sql, args 

689 

690 def _select( 

691 self, 

692 keys, 

693 cmps, 

694 explain=False, 

695 verbosity=0, 

696 limit=None, 

697 offset=0, 

698 sort=None, 

699 include_data=True, 

700 columns='all', 

701 ): 

702 values = np.array([None for _ in range(27)]) 

703 values[25] = '{}' 

704 values[26] = 'null' 

705 

706 if columns == 'all': 

707 columnindex = list(range(26)) 

708 else: 

709 columnindex = [ 

710 c for c in range(26) if self.columnnames[c] in columns 

711 ] 

712 if include_data: 

713 columnindex.append(26) 

714 

715 if sort: 

716 if sort[0] == '-': 

717 order = 'DESC' 

718 sort = sort[1:] 

719 else: 

720 order = 'ASC' 

721 if sort in [ 

722 'id', 

723 'energy', 

724 'username', 

725 'calculator', 

726 'ctime', 

727 'mtime', 

728 'magmom', 

729 'pbc', 

730 'fmax', 

731 'smax', 

732 'volume', 

733 'mass', 

734 'charge', 

735 'natoms', 

736 ]: 

737 sort_table = 'systems' 

738 else: 

739 for dct in self._select( 

740 keys + [sort], 

741 cmps=[], 

742 limit=1, 

743 include_data=False, 

744 columns=['key_value_pairs'], 

745 ): 

746 if isinstance(dct['key_value_pairs'][sort], str): 

747 sort_table = 'text_key_values' 

748 else: 

749 sort_table = 'number_key_values' 

750 break 

751 else: 

752 # No rows. Just pick a table: 

753 sort_table = 'number_key_values' 

754 

755 else: 

756 order = None 

757 sort_table = None 

758 

759 what = ', '.join( 

760 'systems.' + name 

761 for name in np.array(self.columnnames)[np.array(columnindex)] 

762 ) 

763 

764 sql, args = self.create_select_statement( 

765 keys, cmps, sort, order, sort_table, what 

766 ) 

767 

768 if explain: 

769 sql = 'EXPLAIN QUERY PLAN ' + sql 

770 

771 if limit: 

772 sql += f'\nLIMIT {limit}' 

773 

774 if offset: 

775 sql += self.get_offset_string(offset, limit=limit) 

776 

777 if verbosity == 2: 

778 print(sql, args) 

779 

780 with self.managed_connection() as con: 

781 cur = con.cursor() 

782 cur.execute(sql, args) 

783 if explain: 

784 for row in cur.fetchall(): 

785 yield {'explain': row} 

786 else: 

787 n = 0 

788 for shortvalues in cur.fetchall(): 

789 values[columnindex] = shortvalues 

790 yield self._convert_tuple_to_row(tuple(values)) 

791 n += 1 

792 

793 if sort and sort_table != 'systems': 

794 # Yield rows without sort key last: 

795 if limit is not None: 

796 if n == limit: 

797 return 

798 limit -= n 

799 for row in self._select( 

800 keys + ['-' + sort], 

801 cmps, 

802 limit=limit, 

803 offset=offset, 

804 include_data=include_data, 

805 columns=columns, 

806 ): 

807 yield row 

808 

809 def get_offset_string(self, offset, limit=None): 

810 sql = '' 

811 if not limit: 

812 # In sqlite you cannot have offset without limit, so we 

813 # set it to -1 meaning no limit 

814 sql += '\nLIMIT -1' 

815 sql += f'\nOFFSET {offset}' 

816 return sql 

817 

818 @parallel_function 

819 def count(self, selection=None, **kwargs): 

820 keys, cmps = parse_selection(selection, **kwargs) 

821 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)') 

822 

823 with self.managed_connection() as con: 

824 cur = con.cursor() 

825 cur.execute(sql, args) 

826 return cur.fetchone()[0] 

827 

828 def analyse(self): 

829 with self.managed_connection() as con: 

830 con.execute('ANALYZE') 

831 

832 @parallel_function 

833 @lock 

834 def delete(self, ids): 

835 if len(ids) == 0: 

836 return 

837 table_names = self._get_external_table_names() + all_tables[::-1] 

838 with self.managed_connection() as con: 

839 self._delete(con.cursor(), ids, tables=table_names) 

840 self.vacuum() 

841 

842 def _delete(self, cur, ids, tables=None): 

843 tables = tables or all_tables[::-1] 

844 for table in tables: 

845 cur.execute( 

846 'DELETE FROM {} WHERE id in ({});'.format( 

847 table, ', '.join([str(id) for id in ids]) 

848 ) 

849 ) 

850 

851 def vacuum(self): 

852 if self.type != 'db': 

853 return 

854 

855 with self.managed_connection() as con: 

856 con.commit() 

857 con.cursor().execute('VACUUM') 

858 

859 @property 

860 def metadata(self): 

861 if self._metadata is None: 

862 assert self.connection is not None 

863 self._initialize(self.connection) 

864 return self._metadata.copy() 

865 

866 @metadata.setter 

867 def metadata(self, dct): 

868 self._metadata = dct 

869 md = json.dumps(dct) 

870 with self.managed_connection() as con: 

871 cur = con.cursor() 

872 cur.execute( 

873 "SELECT COUNT(*) FROM information WHERE name='metadata'" 

874 ) 

875 

876 if cur.fetchone()[0]: 

877 cur.execute( 

878 "UPDATE information SET value=? WHERE name='metadata'", [md] 

879 ) 

880 else: 

881 cur.execute( 

882 'INSERT INTO information VALUES (?, ?)', ('metadata', md) 

883 ) 

884 

885 def _get_external_table_names(self, db_con=None): 

886 """Return a list with the external table names.""" 

887 sql = "SELECT value FROM information WHERE name='external_table_name'" 

888 with self.managed_connection() as con: 

889 cur = con.cursor() 

890 cur.execute(sql) 

891 ext_tab_names = [x[0] for x in cur.fetchall()] 

892 return ext_tab_names 

893 

894 def _external_table_exists(self, name): 

895 """Return True if an external table name exists.""" 

896 return name in self._get_external_table_names() 

897 

898 def _create_table_if_not_exists(self, name, dtype): 

899 """Create a new table if it does not exits. 

900 

901 Arguments 

902 ========== 

903 name: str 

904 Name of the new table 

905 dtype: str 

906 Datatype of the value field (typically REAL, INTEGER, TEXT etc.) 

907 """ 

908 

909 taken_names = set(all_tables + all_properties + self.columnnames) 

910 if name in taken_names: 

911 raise ValueError(f'External table can not be any of {taken_names}') 

912 

913 if self._external_table_exists(name): 

914 return 

915 

916 sql = f'CREATE TABLE IF NOT EXISTS {name} ' 

917 sql += f'(key TEXT, value {dtype}, id INTEGER, ' 

918 sql += 'FOREIGN KEY (id) REFERENCES systems(id))' 

919 sql2 = 'INSERT INTO information VALUES (?, ?)' 

920 with self.managed_connection() as con: 

921 cur = con.cursor() 

922 cur.execute(sql) 

923 # Insert an entry saying that there is a new external table 

924 # present and an entry with the datatype 

925 cur.execute(sql2, ('external_table_name', name)) 

926 cur.execute(sql2, (name + '_dtype', dtype)) 

927 

928 def delete_external_table(self, name): 

929 """Delete an external table.""" 

930 if not self._external_table_exists(name): 

931 return 

932 

933 with self.managed_connection() as con: 

934 cur = con.cursor() 

935 

936 sql = f'DROP TABLE {name}' 

937 cur.execute(sql) 

938 

939 sql = 'DELETE FROM information WHERE value=?' 

940 cur.execute(sql, (name,)) 

941 sql = 'DELETE FROM information WHERE name=?' 

942 cur.execute(sql, (name + '_dtype',)) 

943 

944 def _convert_to_recognized_types(self, value): 

945 """Convert Numpy types to python types.""" 

946 if np.issubdtype(type(value), np.integer): 

947 return int(value) 

948 elif np.issubdtype(type(value), np.floating): 

949 return float(value) 

950 return value 

951 

952 def _insert_in_external_table(self, cursor, name=None, entries=None): 

953 """Insert into external table""" 

954 if name is None or entries is None: 

955 # There is nothing to do 

956 return 

957 

958 id = entries.pop('id') 

959 dtype = self._guess_type(entries) 

960 expected_dtype = self._get_value_type_of_table(cursor, name) 

961 if dtype != expected_dtype: 

962 raise ValueError( 

963 f'The provided data type for table {name} is {dtype}, while ' 

964 f'it is initialized to be of type {expected_dtype}' 

965 ) 

966 

967 # First we check if entries already exists 

968 cursor.execute(f'SELECT key FROM {name} WHERE id=?', (id,)) 

969 updates = [] 

970 for item in cursor.fetchall(): 

971 value = entries.pop(item[0], None) 

972 if value is not None: 

973 updates.append( 

974 (value, id, self._convert_to_recognized_types(item[0])) 

975 ) 

976 

977 # Update entry if key and ID already exists 

978 sql = f'UPDATE {name} SET value=? WHERE id=? AND key=?' 

979 cursor.executemany(sql, updates) 

980 

981 # Insert the ones that does not already exist 

982 inserts = [ 

983 (k, self._convert_to_recognized_types(v), id) 

984 for k, v in entries.items() 

985 ] 

986 sql = f'INSERT INTO {name} VALUES (?, ?, ?)' 

987 cursor.executemany(sql, inserts) 

988 

989 def _guess_type(self, entries): 

990 """Guess the type based on the first entry.""" 

991 values = [v for _, v in entries.items()] 

992 

993 # Check if all datatypes are the same 

994 all_types = [type(v) for v in values] 

995 if any(t != all_types[0] for t in all_types): 

996 typenames = [t.__name__ for t in all_types] 

997 raise ValueError( 

998 f'Inconsistent datatypes in the table. given types: {typenames}' 

999 ) 

1000 

1001 val = values[0] 

1002 if isinstance(val, int) or np.issubdtype(type(val), np.integer): 

1003 return 'INTEGER' 

1004 if isinstance(val, float) or np.issubdtype(type(val), np.floating): 

1005 return 'REAL' 

1006 if isinstance(val, str): 

1007 return 'TEXT' 

1008 raise ValueError('Unknown datatype!') 

1009 

1010 def _get_value_type_of_table(self, cursor, tab_name): 

1011 """Return the expected value name.""" 

1012 sql = 'SELECT value FROM information WHERE name=?' 

1013 cursor.execute(sql, (tab_name + '_dtype',)) 

1014 return cursor.fetchone()[0] 

1015 

1016 def _read_external_table(self, name, id): 

1017 """Read row from external table.""" 

1018 

1019 with self.managed_connection() as con: 

1020 cur = con.cursor() 

1021 cur.execute(f'SELECT * FROM {name} WHERE id=?', (id,)) 

1022 items = cur.fetchall() 

1023 dictionary = {item[0]: item[1] for item in items} 

1024 

1025 return dictionary 

1026 

1027 def get_all_key_names(self): 

1028 """Create set of all key names.""" 

1029 with self.managed_connection() as con: 

1030 cur = con.cursor() 

1031 cur.execute('SELECT DISTINCT key FROM keys;') 

1032 all_keys = {row[0] for row in cur.fetchall()} 

1033 return all_keys 

1034 

1035 

1036if __name__ == '__main__': 

1037 from ase.db import connect 

1038 

1039 con = connect(sys.argv[1]) 

1040 con._initialize(con._connect()) 

1041 print('Version:', con.version)