Coverage for /builds/ase/ase/ase/db/sqlite.py: 88.33%
557 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""SQLite3 backend.
5Versions:
71) Added 3 more columns.
82) Changed "user" to "username".
93) Now adding keys to keyword table and added an "information" table containing
10 a version number.
114) Got rid of keywords.
125) Add fmax, smax, mass, volume, charge
136) Use REAL for magmom and drop possibility for non-collinear spin
147) Volume can be None
158) Added name='metadata' row to "information" table
169) Row data is now stored in binary format.
17"""
19import json
20import numbers
21import os
22import sqlite3
23import sys
24from contextlib import contextmanager
26import numpy as np
28import ase.io.jsonio
29from ase.calculators.calculator import all_properties
30from ase.data import atomic_numbers
31from ase.db.core import (
32 Database,
33 bytes_to_object,
34 invop,
35 lock,
36 now,
37 object_to_bytes,
38 ops,
39 parse_selection,
40)
41from ase.db.row import AtomsRow
42from ase.parallel import parallel_function
44VERSION = 9
46init_statements = [
47 """CREATE TABLE systems (
48 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name
49 unique_id TEXT UNIQUE,
50 ctime REAL,
51 mtime REAL,
52 username TEXT,
53 numbers BLOB, -- stuff that defines an Atoms object
54 positions BLOB,
55 cell BLOB,
56 pbc INTEGER,
57 initial_magmoms BLOB,
58 initial_charges BLOB,
59 masses BLOB,
60 tags BLOB,
61 momenta BLOB,
62 constraints TEXT, -- constraints and calculator
63 calculator TEXT,
64 calculator_parameters TEXT,
65 energy REAL, -- calculated properties
66 free_energy REAL,
67 forces BLOB,
68 stress BLOB,
69 dipole BLOB,
70 magmoms BLOB,
71 magmom REAL,
72 charges BLOB,
73 key_value_pairs TEXT, -- key-value pairs and data as json
74 data BLOB,
75 natoms INTEGER, -- stuff for making queries faster
76 fmax REAL,
77 smax REAL,
78 volume REAL,
79 mass REAL,
80 charge REAL)""",
82 """CREATE TABLE species (
83 Z INTEGER,
84 n INTEGER,
85 id INTEGER,
86 FOREIGN KEY (id) REFERENCES systems(id))""",
88 """CREATE TABLE keys (
89 key TEXT,
90 id INTEGER,
91 FOREIGN KEY (id) REFERENCES systems(id))""",
93 """CREATE TABLE text_key_values (
94 key TEXT,
95 value TEXT,
96 id INTEGER,
97 FOREIGN KEY (id) REFERENCES systems(id))""",
99 """CREATE TABLE number_key_values (
100 key TEXT,
101 value REAL,
102 id INTEGER,
103 FOREIGN KEY (id) REFERENCES systems(id))""",
105 """CREATE TABLE information (
106 name TEXT,
107 value TEXT)""",
109 f"INSERT INTO information VALUES ('version', '{VERSION}')"]
111index_statements = [
112 'CREATE INDEX unique_id_index ON systems(unique_id)',
113 'CREATE INDEX ctime_index ON systems(ctime)',
114 'CREATE INDEX username_index ON systems(username)',
115 'CREATE INDEX calculator_index ON systems(calculator)',
116 'CREATE INDEX species_index ON species(Z)',
117 'CREATE INDEX key_index ON keys(key)',
118 'CREATE INDEX text_index ON text_key_values(key)',
119 'CREATE INDEX number_index ON number_key_values(key)']
121all_tables = ['systems', 'species', 'keys',
122 'text_key_values', 'number_key_values']
125def float_if_not_none(x):
126 """Convert numpy.float64 to float - old db-interfaces need that."""
127 if x is not None:
128 return float(x)
131class SQLite3Database(Database):
132 type = 'db'
133 initialized = False
134 _allow_reading_old_format = False
135 default = 'NULL' # used for autoincrement id
136 connection = None
137 version = None
138 columnnames = [line.split()[0].lstrip()
139 for line in init_statements[0].splitlines()[1:]]
141 def encode(self, obj, binary=False):
142 if binary:
143 return object_to_bytes(obj)
144 return ase.io.jsonio.encode(obj)
146 def decode(self, txt, lazy=False):
147 if lazy:
148 return txt
149 if isinstance(txt, str):
150 return ase.io.jsonio.decode(txt)
151 return bytes_to_object(txt)
153 def blob(self, array):
154 """Convert array to blob/buffer object."""
156 if array is None:
157 return None
158 if len(array) == 0:
159 array = np.zeros(0)
160 if array.dtype == np.int64:
161 array = array.astype(np.int32)
162 if not np.little_endian:
163 array = array.byteswap()
164 return memoryview(np.ascontiguousarray(array))
166 def deblob(self, buf, dtype=float, shape=None):
167 """Convert blob/buffer object to ndarray of correct dtype and shape.
169 (without creating an extra view)."""
170 if buf is None:
171 return None
172 if len(buf) == 0:
173 array = np.zeros(0, dtype)
174 else:
175 array = np.frombuffer(buf, dtype)
176 if not np.little_endian:
177 array = array.byteswap()
178 if shape is not None:
179 array.shape = shape
180 return array
182 def _connect(self):
183 return sqlite3.connect(self.filename, timeout=20)
185 def __enter__(self):
186 assert self.connection is None
187 self.change_count = 0
188 self.connection = self._connect()
189 return self
191 def __exit__(self, exc_type, exc_value, tb):
192 if exc_type is None:
193 self.connection.commit()
194 else:
195 self.connection.rollback()
196 self.connection.close()
197 self.connection = None
199 @contextmanager
200 def managed_connection(self, commit_frequency=5000):
201 try:
202 con = self.connection or self._connect()
203 self._initialize(con)
204 yield con
205 except ValueError as exc:
206 if self.connection is None:
207 con.close()
208 raise exc
209 else:
210 if self.connection is None:
211 con.commit()
212 con.close()
213 else:
214 self.change_count += 1
215 if self.change_count % commit_frequency == 0:
216 con.commit()
218 def _initialize(self, con):
219 if self.initialized:
220 return
222 self._metadata = {}
224 cur = con.execute(
225 'SELECT COUNT(*) FROM sqlite_master WHERE name="systems"')
227 if cur.fetchone()[0] == 0:
228 for statement in init_statements:
229 con.execute(statement)
230 if self.create_indices:
231 for statement in index_statements:
232 con.execute(statement)
233 con.commit()
234 self.version = VERSION
235 else:
236 cur = con.execute(
237 'SELECT COUNT(*) FROM sqlite_master WHERE name="user_index"')
238 if cur.fetchone()[0] == 1:
239 # Old version with "user" instead of "username" column
240 self.version = 1
241 else:
242 try:
243 cur = con.execute(
244 'SELECT value FROM information WHERE name="version"')
245 except sqlite3.OperationalError:
246 self.version = 2
247 else:
248 self.version = int(cur.fetchone()[0])
250 cur = con.execute(
251 'SELECT value FROM information WHERE name="metadata"')
252 results = cur.fetchall()
253 if results:
254 self._metadata = json.loads(results[0][0])
256 if self.version > VERSION:
257 raise OSError('Can not read new ase.db format '
258 '(version {}). Please update to latest ASE.'
259 .format(self.version))
260 if self.version < 5 and not self._allow_reading_old_format:
261 raise OSError('Please convert to new format. ' +
262 'Use: python -m ase.db.convert ' + self.filename)
264 self.initialized = True
266 def _write(self, atoms, key_value_pairs, data, id):
267 ext_tables = key_value_pairs.pop("external_tables", {})
268 Database._write(self, atoms, key_value_pairs, data)
270 mtime = now()
272 encode = self.encode
273 blob = self.blob
275 if not isinstance(atoms, AtomsRow):
276 row = AtomsRow(atoms)
277 row.ctime = mtime
278 row.user = os.getenv('USER')
279 else:
280 row = atoms
281 # Extract the external tables from AtomsRow
282 names = self._get_external_table_names()
283 for name in names:
284 new_table = row.get(name, {})
285 if new_table:
286 ext_tables[name] = new_table
288 if not id and not key_value_pairs and not ext_tables:
289 key_value_pairs = row.key_value_pairs
291 for k, v in ext_tables.items():
292 dtype = self._guess_type(v)
293 self._create_table_if_not_exists(k, dtype)
295 constraints = row._constraints
296 if constraints:
297 if isinstance(constraints, list):
298 constraints = encode(constraints)
299 else:
300 constraints = None
302 values = (row.unique_id,
303 row.ctime,
304 mtime,
305 row.user,
306 blob(row.numbers),
307 blob(row.positions),
308 blob(row.cell),
309 int(np.dot(row.pbc, [1, 2, 4])),
310 blob(row.get('initial_magmoms')),
311 blob(row.get('initial_charges')),
312 blob(row.get('masses')),
313 blob(row.get('tags')),
314 blob(row.get('momenta')),
315 constraints)
317 if 'calculator' in row:
318 values += (row.calculator, encode(row.calculator_parameters))
319 else:
320 values += (None, None)
322 if not data:
323 data = row._data
325 with self.managed_connection() as con:
326 if not isinstance(data, (str, bytes)):
327 data = encode(data, binary=self.version >= 9)
329 values += (float_if_not_none(row.get('energy')),
330 float_if_not_none(row.get('free_energy')),
331 blob(row.get('forces')),
332 blob(row.get('stress')),
333 blob(row.get('dipole')),
334 blob(row.get('magmoms')),
335 row.get('magmom'),
336 blob(row.get('charges')),
337 encode(key_value_pairs),
338 data,
339 len(row.numbers),
340 float_if_not_none(row.get('fmax')),
341 float_if_not_none(row.get('smax')),
342 float_if_not_none(row.get('volume')),
343 float(row.mass),
344 float(row.charge))
346 cur = con.cursor()
347 if id is None:
348 q = self.default + ', ' + ', '.join('?' * len(values))
349 cur.execute(f'INSERT INTO systems VALUES ({q})',
350 values)
351 id = self.get_last_id(cur)
352 else:
353 self._delete(cur, [id], ['keys', 'text_key_values',
354 'number_key_values', 'species'])
355 q = ', '.join(name + '=?' for name in self.columnnames[1:])
356 cur.execute(f'UPDATE systems SET {q} WHERE id=?',
357 values + (id,))
359 count = row.count_atoms()
360 if count:
361 species = [(atomic_numbers[symbol], n, id)
362 for symbol, n in count.items()]
363 cur.executemany('INSERT INTO species VALUES (?, ?, ?)',
364 species)
366 text_key_values = []
367 number_key_values = []
368 for key, value in key_value_pairs.items():
369 if isinstance(value, (numbers.Real, np.bool_)):
370 number_key_values.append([key, float(value), id])
371 else:
372 assert isinstance(value, str)
373 text_key_values.append([key, value, id])
375 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
376 text_key_values)
377 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
378 number_key_values)
379 cur.executemany('INSERT INTO keys VALUES (?, ?)',
380 [(key, id) for key in key_value_pairs])
382 # Insert entries in the valid tables
383 for tabname in ext_tables.keys():
384 entries = ext_tables[tabname]
385 entries['id'] = id
386 self._insert_in_external_table(
387 cur, name=tabname, entries=ext_tables[tabname])
389 return id
391 def _update(self, id, key_value_pairs, data=None):
392 """Update key_value_pairs and data for a single row """
393 encode = self.encode
394 ext_tables = key_value_pairs.pop('external_tables', {})
396 for k, v in ext_tables.items():
397 dtype = self._guess_type(v)
398 self._create_table_if_not_exists(k, dtype)
400 mtime = now()
401 with self.managed_connection() as con:
402 cur = con.cursor()
403 cur.execute(
404 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?',
405 (mtime, encode(key_value_pairs), id))
406 if data:
407 if not isinstance(data, (str, bytes)):
408 data = encode(data, binary=self.version >= 9)
409 cur.execute('UPDATE systems set data=? where id=?', (data, id))
411 self._delete(cur, [id], ['keys', 'text_key_values',
412 'number_key_values'])
414 text_key_values = []
415 number_key_values = []
416 for key, value in key_value_pairs.items():
417 if isinstance(value, (numbers.Real, np.bool_)):
418 number_key_values.append([key, float(value), id])
419 else:
420 assert isinstance(value, str)
421 text_key_values.append([key, value, id])
423 cur.executemany('INSERT INTO text_key_values VALUES (?, ?, ?)',
424 text_key_values)
425 cur.executemany('INSERT INTO number_key_values VALUES (?, ?, ?)',
426 number_key_values)
427 cur.executemany('INSERT INTO keys VALUES (?, ?)',
428 [(key, id) for key in key_value_pairs])
430 # Insert entries in the valid tables
431 for tabname in ext_tables.keys():
432 entries = ext_tables[tabname]
433 entries['id'] = id
434 self._insert_in_external_table(
435 cur, name=tabname, entries=ext_tables[tabname])
437 return id
439 def get_last_id(self, cur):
440 cur.execute('SELECT seq FROM sqlite_sequence WHERE name="systems"')
441 result = cur.fetchone()
442 if result is not None:
443 id = result[0]
444 return id
445 else:
446 return 0
448 def _get_row(self, id):
449 with self.managed_connection() as con:
450 cur = con.cursor()
451 if id is None:
452 cur.execute('SELECT COUNT(*) FROM systems')
453 assert cur.fetchone()[0] == 1
454 cur.execute('SELECT * FROM systems')
455 else:
456 cur.execute('SELECT * FROM systems WHERE id=?', (id,))
457 values = cur.fetchone()
459 return self._convert_tuple_to_row(values)
461 def _convert_tuple_to_row(self, values):
462 deblob = self.deblob
463 decode = self.decode
465 values = self._old2new(values)
466 dct = {'id': values[0],
467 'unique_id': values[1],
468 'ctime': values[2],
469 'mtime': values[3],
470 'user': values[4],
471 'numbers': deblob(values[5], np.int32),
472 'positions': deblob(values[6], shape=(-1, 3)),
473 'cell': deblob(values[7], shape=(3, 3))}
475 if values[8] is not None:
476 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool)
477 if values[9] is not None:
478 dct['initial_magmoms'] = deblob(values[9])
479 if values[10] is not None:
480 dct['initial_charges'] = deblob(values[10])
481 if values[11] is not None:
482 dct['masses'] = deblob(values[11])
483 if values[12] is not None:
484 dct['tags'] = deblob(values[12], np.int32)
485 if values[13] is not None:
486 dct['momenta'] = deblob(values[13], shape=(-1, 3))
487 if values[14] is not None:
488 dct['constraints'] = values[14]
489 if values[15] is not None:
490 dct['calculator'] = values[15]
491 if values[16] is not None:
492 dct['calculator_parameters'] = decode(values[16])
493 if values[17] is not None:
494 dct['energy'] = values[17]
495 if values[18] is not None:
496 dct['free_energy'] = values[18]
497 if values[19] is not None:
498 dct['forces'] = deblob(values[19], shape=(-1, 3))
499 if values[20] is not None:
500 dct['stress'] = deblob(values[20])
501 if values[21] is not None:
502 dct['dipole'] = deblob(values[21])
503 if values[22] is not None:
504 dct['magmoms'] = deblob(values[22])
505 if values[23] is not None:
506 dct['magmom'] = values[23]
507 if values[24] is not None:
508 dct['charges'] = deblob(values[24])
509 if values[25] != '{}':
510 dct['key_value_pairs'] = decode(values[25])
511 if len(values) >= 27 and values[26] != 'null':
512 dct['data'] = decode(values[26], lazy=True)
514 # Now we need to update with info from the external tables
515 external_tab = self._get_external_table_names()
516 tables = {}
517 for tab in external_tab:
518 row = self._read_external_table(tab, dct["id"])
519 tables[tab] = row
521 dct.update(tables)
522 return AtomsRow(dct)
524 def _old2new(self, values):
525 if self.type == 'postgresql':
526 assert self.version >= 8, 'Your db-version is too old!'
527 assert self.version >= 4, 'Your db-file is too old!'
528 if self.version < 5:
529 pass # should be ok for reading by convert.py script
530 if self.version < 6:
531 m = values[23]
532 if m is not None and not isinstance(m, float):
533 magmom = float(self.deblob(m, shape=()))
534 values = values[:23] + (magmom,) + values[24:]
535 return values
537 def create_select_statement(self, keys, cmps,
538 sort=None, order=None, sort_table=None,
539 what='systems.*'):
540 tables = ['systems']
541 where = []
542 args = []
543 for key in keys:
544 if key == 'forces':
545 where.append('systems.fmax IS NOT NULL')
546 elif key == 'strain':
547 where.append('systems.smax IS NOT NULL')
548 elif key in ['energy', 'fmax', 'smax',
549 'constraints', 'calculator']:
550 where.append(f'systems.{key} IS NOT NULL')
551 else:
552 if '-' not in key:
553 q = 'systems.id in (select id from keys where key=?)'
554 else:
555 key = key.replace('-', '')
556 q = 'systems.id not in (select id from keys where key=?)'
557 where.append(q)
558 args.append(key)
560 # Special handling of "H=0" and "H<2" type of selections:
561 bad = {}
562 for key, op, value in cmps:
563 if isinstance(key, int):
564 bad[key] = bad.get(key, True) and ops[op](0, value)
566 for key, op, value in cmps:
567 if key in ['id', 'energy', 'magmom', 'ctime', 'user',
568 'calculator', 'natoms', 'pbc', 'unique_id',
569 'fmax', 'smax', 'volume', 'mass', 'charge']:
570 if key == 'user':
571 key = 'username'
572 elif key == 'pbc':
573 assert op in ['=', '!=']
574 value = int(np.dot([x == 'T' for x in value], [1, 2, 4]))
575 elif key == 'magmom':
576 assert self.version >= 6, 'Update your db-file'
577 where.append(f'systems.{key}{op}?')
578 args.append(value)
579 elif isinstance(key, int):
580 if self.type == 'postgresql':
581 where.append(
582 'cardinality(array_positions(' +
583 f'numbers::int[], ?)){op}?')
584 args += [key, value]
585 else:
586 if bad[key]:
587 where.append(
588 'systems.id not in (select id from species ' +
589 f'where Z=? and n{invop[op]}?)')
590 args += [key, value]
591 else:
592 where.append('systems.id in (select id from species ' +
593 f'where Z=? and n{op}?)')
594 args += [key, value]
596 elif self.type == 'postgresql':
597 jsonop = '->'
598 if isinstance(value, str):
599 jsonop = '->>'
600 elif isinstance(value, bool):
601 jsonop = '->>'
602 value = str(value).lower()
603 where.append("systems.key_value_pairs {} '{}'{}?"
604 .format(jsonop, key, op))
605 args.append(str(value))
607 elif isinstance(value, str):
608 where.append('systems.id in (select id from text_key_values ' +
609 f'where key=? and value{op}?)')
610 args += [key, value]
611 else:
612 where.append(
613 'systems.id in (select id from number_key_values ' +
614 f'where key=? and value{op}?)')
615 args += [key, float(value)]
617 if sort:
618 if sort_table != 'systems':
619 tables.append(f'{sort_table} AS sort_table')
620 where.append('systems.id=sort_table.id AND '
621 'sort_table.key=?')
622 args.append(sort)
623 sort_table = 'sort_table'
624 sort = 'value'
626 sql = f'SELECT {what} FROM\n ' + ', '.join(tables)
627 if where:
628 sql += '\n WHERE\n ' + ' AND\n '.join(where)
629 if sort:
630 # XXX use "?" instead of "{}"
631 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format(
632 sort_table, sort, order)
634 return sql, args
636 def _select(self, keys, cmps, explain=False, verbosity=0,
637 limit=None, offset=0, sort=None, include_data=True,
638 columns='all'):
640 values = np.array([None for _ in range(27)])
641 values[25] = '{}'
642 values[26] = 'null'
644 if columns == 'all':
645 columnindex = list(range(26))
646 else:
647 columnindex = [c for c in range(26)
648 if self.columnnames[c] in columns]
649 if include_data:
650 columnindex.append(26)
652 if sort:
653 if sort[0] == '-':
654 order = 'DESC'
655 sort = sort[1:]
656 else:
657 order = 'ASC'
658 if sort in ['id', 'energy', 'username', 'calculator',
659 'ctime', 'mtime', 'magmom', 'pbc',
660 'fmax', 'smax', 'volume', 'mass', 'charge', 'natoms']:
661 sort_table = 'systems'
662 else:
663 for dct in self._select(keys + [sort], cmps=[], limit=1,
664 include_data=False,
665 columns=['key_value_pairs']):
666 if isinstance(dct['key_value_pairs'][sort], str):
667 sort_table = 'text_key_values'
668 else:
669 sort_table = 'number_key_values'
670 break
671 else:
672 # No rows. Just pick a table:
673 sort_table = 'number_key_values'
675 else:
676 order = None
677 sort_table = None
679 what = ', '.join('systems.' + name
680 for name in
681 np.array(self.columnnames)[np.array(columnindex)])
683 sql, args = self.create_select_statement(keys, cmps, sort, order,
684 sort_table, what)
686 if explain:
687 sql = 'EXPLAIN QUERY PLAN ' + sql
689 if limit:
690 sql += f'\nLIMIT {limit}'
692 if offset:
693 sql += self.get_offset_string(offset, limit=limit)
695 if verbosity == 2:
696 print(sql, args)
698 with self.managed_connection() as con:
699 cur = con.cursor()
700 cur.execute(sql, args)
701 if explain:
702 for row in cur.fetchall():
703 yield {'explain': row}
704 else:
705 n = 0
706 for shortvalues in cur.fetchall():
707 values[columnindex] = shortvalues
708 yield self._convert_tuple_to_row(tuple(values))
709 n += 1
711 if sort and sort_table != 'systems':
712 # Yield rows without sort key last:
713 if limit is not None:
714 if n == limit:
715 return
716 limit -= n
717 for row in self._select(keys + ['-' + sort], cmps,
718 limit=limit, offset=offset,
719 include_data=include_data,
720 columns=columns):
721 yield row
723 def get_offset_string(self, offset, limit=None):
724 sql = ''
725 if not limit:
726 # In sqlite you cannot have offset without limit, so we
727 # set it to -1 meaning no limit
728 sql += '\nLIMIT -1'
729 sql += f'\nOFFSET {offset}'
730 return sql
732 @parallel_function
733 def count(self, selection=None, **kwargs):
734 keys, cmps = parse_selection(selection, **kwargs)
735 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)')
737 with self.managed_connection() as con:
738 cur = con.cursor()
739 cur.execute(sql, args)
740 return cur.fetchone()[0]
742 def analyse(self):
743 with self.managed_connection() as con:
744 con.execute('ANALYZE')
746 @parallel_function
747 @lock
748 def delete(self, ids):
749 if len(ids) == 0:
750 return
751 table_names = self._get_external_table_names() + all_tables[::-1]
752 with self.managed_connection() as con:
753 self._delete(con.cursor(), ids,
754 tables=table_names)
755 self.vacuum()
757 def _delete(self, cur, ids, tables=None):
758 tables = tables or all_tables[::-1]
759 for table in tables:
760 cur.execute('DELETE FROM {} WHERE id in ({});'.
761 format(table, ', '.join([str(id) for id in ids])))
763 def vacuum(self):
764 if self.type != 'db':
765 return
767 with self.managed_connection() as con:
768 con.commit()
769 con.cursor().execute("VACUUM")
771 @property
772 def metadata(self):
773 if self._metadata is None:
774 self._initialize(self._connect())
775 return self._metadata.copy()
777 @metadata.setter
778 def metadata(self, dct):
779 self._metadata = dct
780 md = json.dumps(dct)
781 with self.managed_connection() as con:
782 cur = con.cursor()
783 cur.execute(
784 "SELECT COUNT(*) FROM information WHERE name='metadata'")
786 if cur.fetchone()[0]:
787 cur.execute(
788 "UPDATE information SET value=? WHERE name='metadata'",
789 [md])
790 else:
791 cur.execute('INSERT INTO information VALUES (?, ?)',
792 ('metadata', md))
794 def _get_external_table_names(self, db_con=None):
795 """Return a list with the external table names."""
796 sql = "SELECT value FROM information WHERE name='external_table_name'"
797 with self.managed_connection() as con:
798 cur = con.cursor()
799 cur.execute(sql)
800 ext_tab_names = [x[0] for x in cur.fetchall()]
801 return ext_tab_names
803 def _external_table_exists(self, name):
804 """Return True if an external table name exists."""
805 return name in self._get_external_table_names()
807 def _create_table_if_not_exists(self, name, dtype):
808 """Create a new table if it does not exits.
810 Arguments
811 ==========
812 name: str
813 Name of the new table
814 dtype: str
815 Datatype of the value field (typically REAL, INTEGER, TEXT etc.)
816 """
818 taken_names = set(all_tables + all_properties + self.columnnames)
819 if name in taken_names:
820 raise ValueError("External table can not be any of {}"
821 "".format(taken_names))
823 if self._external_table_exists(name):
824 return
826 sql = f"CREATE TABLE IF NOT EXISTS {name} "
827 sql += f"(key TEXT, value {dtype}, id INTEGER, "
828 sql += "FOREIGN KEY (id) REFERENCES systems(id))"
829 sql2 = "INSERT INTO information VALUES (?, ?)"
830 with self.managed_connection() as con:
831 cur = con.cursor()
832 cur.execute(sql)
833 # Insert an entry saying that there is a new external table
834 # present and an entry with the datatype
835 cur.execute(sql2, ("external_table_name", name))
836 cur.execute(sql2, (name + "_dtype", dtype))
838 def delete_external_table(self, name):
839 """Delete an external table."""
840 if not self._external_table_exists(name):
841 return
843 with self.managed_connection() as con:
844 cur = con.cursor()
846 sql = f"DROP TABLE {name}"
847 cur.execute(sql)
849 sql = "DELETE FROM information WHERE value=?"
850 cur.execute(sql, (name,))
851 sql = "DELETE FROM information WHERE name=?"
852 cur.execute(sql, (name + "_dtype",))
854 def _convert_to_recognized_types(self, value):
855 """Convert Numpy types to python types."""
856 if np.issubdtype(type(value), np.integer):
857 return int(value)
858 elif np.issubdtype(type(value), np.floating):
859 return float(value)
860 return value
862 def _insert_in_external_table(self, cursor, name=None, entries=None):
863 """Insert into external table"""
864 if name is None or entries is None:
865 # There is nothing to do
866 return
868 id = entries.pop("id")
869 dtype = self._guess_type(entries)
870 expected_dtype = self._get_value_type_of_table(cursor, name)
871 if dtype != expected_dtype:
872 raise ValueError("The provided data type for table {} "
873 "is {}, while it is initialized to "
874 "be of type {}"
875 "".format(name, dtype, expected_dtype))
877 # First we check if entries already exists
878 cursor.execute(f"SELECT key FROM {name} WHERE id=?", (id,))
879 updates = []
880 for item in cursor.fetchall():
881 value = entries.pop(item[0], None)
882 if value is not None:
883 updates.append(
884 (value, id, self._convert_to_recognized_types(item[0])))
886 # Update entry if key and ID already exists
887 sql = f"UPDATE {name} SET value=? WHERE id=? AND key=?"
888 cursor.executemany(sql, updates)
890 # Insert the ones that does not already exist
891 inserts = [(k, self._convert_to_recognized_types(v), id)
892 for k, v in entries.items()]
893 sql = f"INSERT INTO {name} VALUES (?, ?, ?)"
894 cursor.executemany(sql, inserts)
896 def _guess_type(self, entries):
897 """Guess the type based on the first entry."""
898 values = [v for _, v in entries.items()]
900 # Check if all datatypes are the same
901 all_types = [type(v) for v in values]
902 if any(t != all_types[0] for t in all_types):
903 typenames = [t.__name__ for t in all_types]
904 raise ValueError("Inconsistent datatypes in the table. "
905 "given types: {}".format(typenames))
907 val = values[0]
908 if isinstance(val, int) or np.issubdtype(type(val), np.integer):
909 return "INTEGER"
910 if isinstance(val, float) or np.issubdtype(type(val), np.floating):
911 return "REAL"
912 if isinstance(val, str):
913 return "TEXT"
914 raise ValueError("Unknown datatype!")
916 def _get_value_type_of_table(self, cursor, tab_name):
917 """Return the expected value name."""
918 sql = "SELECT value FROM information WHERE name=?"
919 cursor.execute(sql, (tab_name + "_dtype",))
920 return cursor.fetchone()[0]
922 def _read_external_table(self, name, id):
923 """Read row from external table."""
925 with self.managed_connection() as con:
926 cur = con.cursor()
927 cur.execute(f"SELECT * FROM {name} WHERE id=?", (id,))
928 items = cur.fetchall()
929 dictionary = {item[0]: item[1] for item in items}
931 return dictionary
933 def get_all_key_names(self):
934 """Create set of all key names."""
935 with self.managed_connection() as con:
936 cur = con.cursor()
937 cur.execute('SELECT DISTINCT key FROM keys;')
938 all_keys = {row[0] for row in cur.fetchall()}
939 return all_keys
942if __name__ == '__main__':
943 from ase.db import connect
944 con = connect(sys.argv[1])
945 con._initialize(con._connect())
946 print('Version:', con.version)