Coverage for ase / db / sqlite.py: 88.71%
549 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-04 10:20 +0000
1"""SQLite3 backend.
3Versions:
51) Added 3 more columns.
62) Changed "user" to "username".
73) Now adding keys to keyword table and added an "information" table containing
8 a version number.
94) Got rid of keywords.
105) Add fmax, smax, mass, volume, charge
116) Use REAL for magmom and drop possibility for non-collinear spin
127) Volume can be None
138) Added name='metadata' row to "information" table
149) Row data is now stored in binary format.
15"""
17import json
18import numbers
19import os
20import sqlite3
21import sys
22from contextlib import contextmanager
24import numpy as np
26import ase.io.jsonio
27from ase.calculators.calculator import all_properties
28from ase.data import atomic_numbers
29from ase.db.core import (
30 Database,
31 bytes_to_object,
32 invop,
33 lock,
34 now,
35 object_to_bytes,
36 ops,
37 parse_selection,
38)
39from ase.db.row import AtomsRow
40from ase.parallel import parallel_function
42VERSION = 9
44init_statements = [
45 """CREATE TABLE systems (
46 id INTEGER PRIMARY KEY AUTOINCREMENT, -- ID's, timestamps and user name
47 unique_id TEXT UNIQUE,
48 ctime REAL,
49 mtime REAL,
50 username TEXT,
51 numbers BLOB, -- stuff that defines an Atoms object
52 positions BLOB,
53 cell BLOB,
54 pbc INTEGER,
55 initial_magmoms BLOB,
56 initial_charges BLOB,
57 masses BLOB,
58 tags BLOB,
59 momenta BLOB,
60 constraints TEXT, -- constraints and calculator
61 calculator TEXT,
62 calculator_parameters TEXT,
63 energy REAL, -- calculated properties
64 free_energy REAL,
65 forces BLOB,
66 stress BLOB,
67 dipole BLOB,
68 magmoms BLOB,
69 magmom REAL,
70 charges BLOB,
71 key_value_pairs TEXT, -- key-value pairs and data as json
72 data BLOB,
73 natoms INTEGER, -- stuff for making queries faster
74 fmax REAL,
75 smax REAL,
76 volume REAL,
77 mass REAL,
78 charge REAL)""",
79 """CREATE TABLE species (
80 Z INTEGER,
81 n INTEGER,
82 id INTEGER,
83 FOREIGN KEY (id) REFERENCES systems(id))""",
84 """CREATE TABLE keys (
85 key TEXT,
86 id INTEGER,
87 FOREIGN KEY (id) REFERENCES systems(id))""",
88 """CREATE TABLE text_key_values (
89 key TEXT,
90 value TEXT,
91 id INTEGER,
92 FOREIGN KEY (id) REFERENCES systems(id))""",
93 """CREATE TABLE number_key_values (
94 key TEXT,
95 value REAL,
96 id INTEGER,
97 FOREIGN KEY (id) REFERENCES systems(id))""",
98 """CREATE TABLE information (
99 name TEXT,
100 value TEXT)""",
101 f"INSERT INTO information VALUES ('version', '{VERSION}')",
102]
104index_statements = [
105 'CREATE INDEX unique_id_index ON systems(unique_id)',
106 'CREATE INDEX ctime_index ON systems(ctime)',
107 'CREATE INDEX username_index ON systems(username)',
108 'CREATE INDEX calculator_index ON systems(calculator)',
109 'CREATE INDEX species_index ON species(Z)',
110 'CREATE INDEX key_index ON keys(key)',
111 'CREATE INDEX text_index ON text_key_values(key)',
112 'CREATE INDEX number_index ON number_key_values(key)',
113]
115all_tables = [
116 'systems',
117 'species',
118 'keys',
119 'text_key_values',
120 'number_key_values',
121]
124def float_if_not_none(x):
125 """Convert numpy.float64 to float - old db-interfaces need that."""
126 if x is not None:
127 return float(x)
130class SQLite3Database(Database):
131 type = 'db'
132 initialized = False
133 _allow_reading_old_format = False
134 default = 'NULL' # used for autoincrement id
135 connection = None
136 version = None
137 columnnames = [
138 line.split()[0].lstrip() for line in init_statements[0].splitlines()[1:]
139 ]
141 def encode(self, obj, binary=False):
142 if binary:
143 return object_to_bytes(obj)
144 return ase.io.jsonio.encode(obj)
146 def decode(self, txt, lazy=False):
147 if lazy:
148 return txt
149 if isinstance(txt, str):
150 return ase.io.jsonio.decode(txt)
151 return bytes_to_object(txt)
153 def blob(self, array):
154 """Convert array to blob/buffer object."""
156 if array is None:
157 return None
158 if len(array) == 0:
159 array = np.zeros(0)
160 if array.dtype == np.int64:
161 array = array.astype(np.int32)
162 if not np.little_endian:
163 array = array.byteswap()
164 return memoryview(np.ascontiguousarray(array))
166 def deblob(self, buf, dtype=float, shape=None):
167 """Convert blob/buffer object to ndarray of correct dtype and shape.
169 (without creating an extra view)."""
170 if buf is None:
171 return None
172 if len(buf) == 0:
173 array = np.zeros(0, dtype)
174 else:
175 array = np.frombuffer(buf, dtype)
176 if not np.little_endian:
177 array = array.byteswap()
178 if shape is not None:
179 array.shape = shape
180 return array
182 def _connect(self):
183 return sqlite3.connect(self.filename, timeout=20)
185 def __enter__(self):
186 assert self.connection is None
187 self.change_count = 0
188 self.connection = self._connect()
189 return self
191 def __exit__(self, exc_type, exc_value, tb):
192 if exc_type is None:
193 self.connection.commit()
194 else:
195 self.connection.rollback()
196 self.connection.close()
197 self.connection = None
199 @contextmanager
200 def managed_connection(self, commit_frequency=5000):
201 from contextlib import ExitStack
203 with ExitStack() as stack:
204 con = self.connection or stack.enter_context(self._connect())
205 self._initialize(con)
206 yield con
208 if self.connection is None:
209 con.commit()
210 else:
211 self.change_count += 1
212 if self.change_count % commit_frequency == 0:
213 con.commit()
215 def _initialize(self, con):
216 if self.initialized:
217 return
219 self._metadata = {}
221 cur = con.execute(
222 "SELECT COUNT(*) FROM sqlite_master WHERE name='systems'"
223 )
225 if cur.fetchone()[0] == 0:
226 for statement in init_statements:
227 con.execute(statement)
228 if self.create_indices:
229 for statement in index_statements:
230 con.execute(statement)
231 con.commit()
232 self.version = VERSION
233 else:
234 cur = con.execute(
235 "SELECT COUNT(*) FROM sqlite_master WHERE name='user_index'"
236 )
237 if cur.fetchone()[0] == 1:
238 # Old version with "user" instead of "username" column
239 self.version = 1
240 else:
241 try:
242 cur = con.execute(
243 "SELECT value FROM information WHERE name='version'"
244 )
245 except sqlite3.OperationalError:
246 self.version = 2
247 else:
248 self.version = int(cur.fetchone()[0])
250 cur = con.execute(
251 "SELECT value FROM information WHERE name='metadata'"
252 )
253 results = cur.fetchall()
254 if results:
255 self._metadata = json.loads(results[0][0])
257 if self.version > VERSION:
258 raise OSError(
259 f'Can not read new ase.db format (version {self.version}). '
260 'Please update to latest ASE.'
261 )
262 if self.version < 5 and not self._allow_reading_old_format:
263 raise OSError(
264 'Please convert to new format. '
265 f'Use: python -m ase.db.convert {self.filename}'
266 )
268 self.initialized = True
270 def _write(self, atoms, key_value_pairs, data, id):
271 ext_tables = key_value_pairs.pop('external_tables', {})
272 Database._write(self, atoms, key_value_pairs, data)
274 mtime = now()
276 encode = self.encode
277 blob = self.blob
279 if not isinstance(atoms, AtomsRow):
280 row = AtomsRow(atoms)
281 row.ctime = mtime
282 row.user = os.getenv('USER')
283 else:
284 row = atoms
285 # Extract the external tables from AtomsRow
286 names = self._get_external_table_names()
287 for name in names:
288 new_table = row.get(name, {})
289 if new_table:
290 ext_tables[name] = new_table
292 if not id and not key_value_pairs and not ext_tables:
293 key_value_pairs = row.key_value_pairs
295 for k, v in ext_tables.items():
296 dtype = self._guess_type(v)
297 self._create_table_if_not_exists(k, dtype)
299 constraints = row._constraints
300 if constraints:
301 if isinstance(constraints, list):
302 constraints = encode(constraints)
303 else:
304 constraints = None
306 values = (
307 row.unique_id,
308 row.ctime,
309 mtime,
310 row.user,
311 blob(row.numbers),
312 blob(row.positions),
313 blob(row.cell),
314 int(np.dot(row.pbc, [1, 2, 4])),
315 blob(row.get('initial_magmoms')),
316 blob(row.get('initial_charges')),
317 blob(row.get('masses')),
318 blob(row.get('tags')),
319 blob(row.get('momenta')),
320 constraints,
321 )
323 if 'calculator' in row:
324 values += (row.calculator, encode(row.calculator_parameters))
325 else:
326 values += (None, None)
328 if not data:
329 data = row._data
331 with self.managed_connection() as con:
332 if not isinstance(data, (str, bytes)):
333 data = encode(data, binary=self.version >= 9)
335 values += (
336 float_if_not_none(row.get('energy')),
337 float_if_not_none(row.get('free_energy')),
338 blob(row.get('forces')),
339 blob(row.get('stress')),
340 blob(row.get('dipole')),
341 blob(row.get('magmoms')),
342 row.get('magmom'),
343 blob(row.get('charges')),
344 encode(key_value_pairs),
345 data,
346 len(row.numbers),
347 float_if_not_none(row.get('fmax')),
348 float_if_not_none(row.get('smax')),
349 float_if_not_none(row.get('volume')),
350 float(row.mass),
351 float(row.charge),
352 )
354 cur = con.cursor()
355 if id is None:
356 q = self.default + ', ' + ', '.join('?' * len(values))
357 cur.execute(f'INSERT INTO systems VALUES ({q})', values)
358 id = self.get_last_id(cur)
359 else:
360 self._delete(
361 cur,
362 [id],
363 ['keys', 'text_key_values', 'number_key_values', 'species'],
364 )
365 q = ', '.join(name + '=?' for name in self.columnnames[1:])
366 cur.execute(
367 f'UPDATE systems SET {q} WHERE id=?', values + (id,)
368 )
370 count = row.count_atoms()
371 if count:
372 species = [
373 (atomic_numbers[symbol], n, id)
374 for symbol, n in count.items()
375 ]
376 cur.executemany('INSERT INTO species VALUES (?, ?, ?)', species)
378 text_key_values = []
379 number_key_values = []
380 for key, value in key_value_pairs.items():
381 if isinstance(value, (numbers.Real, np.bool_)):
382 number_key_values.append([key, float(value), id])
383 else:
384 assert isinstance(value, str)
385 text_key_values.append([key, value, id])
387 cur.executemany(
388 'INSERT INTO text_key_values VALUES (?, ?, ?)', text_key_values
389 )
390 cur.executemany(
391 'INSERT INTO number_key_values VALUES (?, ?, ?)',
392 number_key_values,
393 )
394 cur.executemany(
395 'INSERT INTO keys VALUES (?, ?)',
396 [(key, id) for key in key_value_pairs],
397 )
399 # Insert entries in the valid tables
400 for tabname in ext_tables.keys():
401 entries = ext_tables[tabname]
402 entries['id'] = id
403 self._insert_in_external_table(
404 cur, name=tabname, entries=ext_tables[tabname]
405 )
407 return id
409 def _update(self, id, key_value_pairs, data=None):
410 """Update key_value_pairs and data for a single row"""
411 encode = self.encode
412 ext_tables = key_value_pairs.pop('external_tables', {})
414 for k, v in ext_tables.items():
415 dtype = self._guess_type(v)
416 self._create_table_if_not_exists(k, dtype)
418 mtime = now()
419 with self.managed_connection() as con:
420 cur = con.cursor()
421 cur.execute(
422 'UPDATE systems SET mtime=?, key_value_pairs=? WHERE id=?',
423 (mtime, encode(key_value_pairs), id),
424 )
425 if data:
426 if not isinstance(data, (str, bytes)):
427 data = encode(data, binary=self.version >= 9)
428 cur.execute('UPDATE systems set data=? where id=?', (data, id))
430 self._delete(
431 cur, [id], ['keys', 'text_key_values', 'number_key_values']
432 )
434 text_key_values = []
435 number_key_values = []
436 for key, value in key_value_pairs.items():
437 if isinstance(value, (numbers.Real, np.bool_)):
438 number_key_values.append([key, float(value), id])
439 else:
440 assert isinstance(value, str)
441 text_key_values.append([key, value, id])
443 cur.executemany(
444 'INSERT INTO text_key_values VALUES (?, ?, ?)', text_key_values
445 )
446 cur.executemany(
447 'INSERT INTO number_key_values VALUES (?, ?, ?)',
448 number_key_values,
449 )
450 cur.executemany(
451 'INSERT INTO keys VALUES (?, ?)',
452 [(key, id) for key in key_value_pairs],
453 )
455 # Insert entries in the valid tables
456 for tabname in ext_tables.keys():
457 entries = ext_tables[tabname]
458 entries['id'] = id
459 self._insert_in_external_table(
460 cur, name=tabname, entries=ext_tables[tabname]
461 )
463 return id
465 def get_last_id(self, cur):
466 cur.execute("SELECT seq FROM sqlite_sequence WHERE name='systems'")
467 result = cur.fetchone()
468 if result is not None:
469 id = result[0]
470 return id
471 else:
472 return 0
474 def _get_row(self, id):
475 with self.managed_connection() as con:
476 cur = con.cursor()
477 if id is None:
478 cur.execute('SELECT COUNT(*) FROM systems')
479 assert cur.fetchone()[0] == 1
480 cur.execute('SELECT * FROM systems')
481 else:
482 cur.execute('SELECT * FROM systems WHERE id=?', (id,))
483 values = cur.fetchone()
485 return self._convert_tuple_to_row(values)
487 def _convert_tuple_to_row(self, values):
488 deblob = self.deblob
489 decode = self.decode
491 values = self._old2new(values)
492 dct = {
493 'id': values[0],
494 'unique_id': values[1],
495 'ctime': values[2],
496 'mtime': values[3],
497 'user': values[4],
498 'numbers': deblob(values[5], np.int32),
499 'positions': deblob(values[6], shape=(-1, 3)),
500 'cell': deblob(values[7], shape=(3, 3)),
501 }
503 if values[8] is not None:
504 dct['pbc'] = (values[8] & np.array([1, 2, 4])).astype(bool)
505 if values[9] is not None:
506 dct['initial_magmoms'] = deblob(values[9])
507 if values[10] is not None:
508 dct['initial_charges'] = deblob(values[10])
509 if values[11] is not None:
510 dct['masses'] = deblob(values[11])
511 if values[12] is not None:
512 dct['tags'] = deblob(values[12], np.int32)
513 if values[13] is not None:
514 dct['momenta'] = deblob(values[13], shape=(-1, 3))
515 if values[14] is not None:
516 dct['constraints'] = values[14]
517 if values[15] is not None:
518 dct['calculator'] = values[15]
519 if values[16] is not None:
520 dct['calculator_parameters'] = decode(values[16])
521 if values[17] is not None:
522 dct['energy'] = values[17]
523 if values[18] is not None:
524 dct['free_energy'] = values[18]
525 if values[19] is not None:
526 dct['forces'] = deblob(values[19], shape=(-1, 3))
527 if values[20] is not None:
528 dct['stress'] = deblob(values[20])
529 if values[21] is not None:
530 dct['dipole'] = deblob(values[21])
531 if values[22] is not None:
532 dct['magmoms'] = deblob(values[22])
533 if values[23] is not None:
534 dct['magmom'] = values[23]
535 if values[24] is not None:
536 dct['charges'] = deblob(values[24])
537 if values[25] != '{}':
538 dct['key_value_pairs'] = decode(values[25])
539 if len(values) >= 27 and values[26] != 'null':
540 dct['data'] = decode(values[26], lazy=True)
542 # Now we need to update with info from the external tables
543 external_tab = self._get_external_table_names()
544 tables = {}
545 for tab in external_tab:
546 row = self._read_external_table(tab, dct['id'])
547 tables[tab] = row
549 dct.update(tables)
550 return AtomsRow(dct)
552 def _old2new(self, values):
553 if self.type == 'postgresql':
554 assert self.version >= 8, 'Your db-version is too old!'
555 assert self.version >= 4, 'Your db-file is too old!'
556 if self.version < 5:
557 pass # should be ok for reading by convert.py script
558 if self.version < 6:
559 m = values[23]
560 if m is not None and not isinstance(m, float):
561 magmom = float(self.deblob(m, shape=()))
562 values = values[:23] + (magmom,) + values[24:]
563 return values
565 def create_select_statement(
566 self,
567 keys,
568 cmps,
569 sort=None,
570 order=None,
571 sort_table=None,
572 what='systems.*',
573 ):
574 tables = ['systems']
575 where = []
576 args = []
577 for key in keys:
578 if key == 'forces':
579 where.append('systems.fmax IS NOT NULL')
580 elif key == 'strain':
581 where.append('systems.smax IS NOT NULL')
582 elif key in ['energy', 'fmax', 'smax', 'constraints', 'calculator']:
583 where.append(f'systems.{key} IS NOT NULL')
584 else:
585 if '-' not in key:
586 q = 'systems.id in (select id from keys where key=?)'
587 else:
588 key = key.replace('-', '')
589 q = 'systems.id not in (select id from keys where key=?)'
590 where.append(q)
591 args.append(key)
593 # Special handling of "H=0" and "H<2" type of selections:
594 bad = {}
595 for key, op, value in cmps:
596 if isinstance(key, int):
597 bad[key] = bad.get(key, True) and ops[op](0, value)
599 for key, op, value in cmps:
600 if key in [
601 'id',
602 'energy',
603 'magmom',
604 'ctime',
605 'user',
606 'calculator',
607 'natoms',
608 'pbc',
609 'unique_id',
610 'fmax',
611 'smax',
612 'volume',
613 'mass',
614 'charge',
615 ]:
616 if key == 'user':
617 key = 'username'
618 elif key == 'pbc':
619 assert op in ['=', '!=']
620 value = int(np.dot([x == 'T' for x in value], [1, 2, 4]))
621 elif key == 'magmom':
622 assert self.version >= 6, 'Update your db-file'
623 where.append(f'systems.{key}{op}?')
624 args.append(value)
625 elif isinstance(key, int):
626 if self.type == 'postgresql':
627 where.append(
628 'cardinality(array_positions('
629 + f'numbers::int[], ?)){op}?'
630 )
631 args += [key, value]
632 else:
633 if bad[key]:
634 where.append(
635 'systems.id not in (select id from species '
636 + f'where Z=? and n{invop[op]}?)'
637 )
638 args += [key, value]
639 else:
640 where.append(
641 'systems.id in (select id from species '
642 + f'where Z=? and n{op}?)'
643 )
644 args += [key, value]
646 elif self.type == 'postgresql':
647 jsonop = '->'
648 if isinstance(value, str):
649 jsonop = '->>'
650 elif isinstance(value, bool):
651 jsonop = '->>'
652 value = str(value).lower()
653 where.append(
654 "systems.key_value_pairs {} '{}'{}?".format(jsonop, key, op)
655 )
656 args.append(str(value))
658 elif isinstance(value, str):
659 where.append(
660 'systems.id in (select id from text_key_values '
661 + f'where key=? and value{op}?)'
662 )
663 args += [key, value]
664 else:
665 where.append(
666 'systems.id in (select id from number_key_values '
667 + f'where key=? and value{op}?)'
668 )
669 args += [key, float(value)]
671 if sort:
672 if sort_table != 'systems':
673 tables.append(f'{sort_table} AS sort_table')
674 where.append('systems.id=sort_table.id AND sort_table.key=?')
675 args.append(sort)
676 sort_table = 'sort_table'
677 sort = 'value'
679 sql = f'SELECT {what} FROM\n ' + ', '.join(tables)
680 if where:
681 sql += '\n WHERE\n ' + ' AND\n '.join(where)
682 if sort:
683 # XXX use "?" instead of "{}"
684 sql += '\nORDER BY {0}.{1} IS NULL, {0}.{1} {2}'.format(
685 sort_table, sort, order
686 )
688 return sql, args
690 def _select(
691 self,
692 keys,
693 cmps,
694 explain=False,
695 verbosity=0,
696 limit=None,
697 offset=0,
698 sort=None,
699 include_data=True,
700 columns='all',
701 ):
702 values = np.array([None for _ in range(27)])
703 values[25] = '{}'
704 values[26] = 'null'
706 if columns == 'all':
707 columnindex = list(range(26))
708 else:
709 columnindex = [
710 c for c in range(26) if self.columnnames[c] in columns
711 ]
712 if include_data:
713 columnindex.append(26)
715 if sort:
716 if sort[0] == '-':
717 order = 'DESC'
718 sort = sort[1:]
719 else:
720 order = 'ASC'
721 if sort in [
722 'id',
723 'energy',
724 'username',
725 'calculator',
726 'ctime',
727 'mtime',
728 'magmom',
729 'pbc',
730 'fmax',
731 'smax',
732 'volume',
733 'mass',
734 'charge',
735 'natoms',
736 ]:
737 sort_table = 'systems'
738 else:
739 for dct in self._select(
740 keys + [sort],
741 cmps=[],
742 limit=1,
743 include_data=False,
744 columns=['key_value_pairs'],
745 ):
746 if isinstance(dct['key_value_pairs'][sort], str):
747 sort_table = 'text_key_values'
748 else:
749 sort_table = 'number_key_values'
750 break
751 else:
752 # No rows. Just pick a table:
753 sort_table = 'number_key_values'
755 else:
756 order = None
757 sort_table = None
759 what = ', '.join(
760 'systems.' + name
761 for name in np.array(self.columnnames)[np.array(columnindex)]
762 )
764 sql, args = self.create_select_statement(
765 keys, cmps, sort, order, sort_table, what
766 )
768 if explain:
769 sql = 'EXPLAIN QUERY PLAN ' + sql
771 if limit:
772 sql += f'\nLIMIT {limit}'
774 if offset:
775 sql += self.get_offset_string(offset, limit=limit)
777 if verbosity == 2:
778 print(sql, args)
780 with self.managed_connection() as con:
781 cur = con.cursor()
782 cur.execute(sql, args)
783 if explain:
784 for row in cur.fetchall():
785 yield {'explain': row}
786 else:
787 n = 0
788 for shortvalues in cur.fetchall():
789 values[columnindex] = shortvalues
790 yield self._convert_tuple_to_row(tuple(values))
791 n += 1
793 if sort and sort_table != 'systems':
794 # Yield rows without sort key last:
795 if limit is not None:
796 if n == limit:
797 return
798 limit -= n
799 for row in self._select(
800 keys + ['-' + sort],
801 cmps,
802 limit=limit,
803 offset=offset,
804 include_data=include_data,
805 columns=columns,
806 ):
807 yield row
809 def get_offset_string(self, offset, limit=None):
810 sql = ''
811 if not limit:
812 # In sqlite you cannot have offset without limit, so we
813 # set it to -1 meaning no limit
814 sql += '\nLIMIT -1'
815 sql += f'\nOFFSET {offset}'
816 return sql
818 @parallel_function
819 def count(self, selection=None, **kwargs):
820 keys, cmps = parse_selection(selection, **kwargs)
821 sql, args = self.create_select_statement(keys, cmps, what='COUNT(*)')
823 with self.managed_connection() as con:
824 cur = con.cursor()
825 cur.execute(sql, args)
826 return cur.fetchone()[0]
828 def analyse(self):
829 with self.managed_connection() as con:
830 con.execute('ANALYZE')
832 @parallel_function
833 @lock
834 def delete(self, ids):
835 if len(ids) == 0:
836 return
837 table_names = self._get_external_table_names() + all_tables[::-1]
838 with self.managed_connection() as con:
839 self._delete(con.cursor(), ids, tables=table_names)
840 self.vacuum()
842 def _delete(self, cur, ids, tables=None):
843 tables = tables or all_tables[::-1]
844 for table in tables:
845 cur.execute(
846 'DELETE FROM {} WHERE id in ({});'.format(
847 table, ', '.join([str(id) for id in ids])
848 )
849 )
851 def vacuum(self):
852 if self.type != 'db':
853 return
855 with self.managed_connection() as con:
856 con.commit()
857 con.cursor().execute('VACUUM')
859 @property
860 def metadata(self):
861 if self._metadata is None:
862 assert self.connection is not None
863 self._initialize(self.connection)
864 return self._metadata.copy()
866 @metadata.setter
867 def metadata(self, dct):
868 self._metadata = dct
869 md = json.dumps(dct)
870 with self.managed_connection() as con:
871 cur = con.cursor()
872 cur.execute(
873 "SELECT COUNT(*) FROM information WHERE name='metadata'"
874 )
876 if cur.fetchone()[0]:
877 cur.execute(
878 "UPDATE information SET value=? WHERE name='metadata'", [md]
879 )
880 else:
881 cur.execute(
882 'INSERT INTO information VALUES (?, ?)', ('metadata', md)
883 )
885 def _get_external_table_names(self, db_con=None):
886 """Return a list with the external table names."""
887 sql = "SELECT value FROM information WHERE name='external_table_name'"
888 with self.managed_connection() as con:
889 cur = con.cursor()
890 cur.execute(sql)
891 ext_tab_names = [x[0] for x in cur.fetchall()]
892 return ext_tab_names
894 def _external_table_exists(self, name):
895 """Return True if an external table name exists."""
896 return name in self._get_external_table_names()
898 def _create_table_if_not_exists(self, name, dtype):
899 """Create a new table if it does not exits.
901 Arguments
902 ==========
903 name: str
904 Name of the new table
905 dtype: str
906 Datatype of the value field (typically REAL, INTEGER, TEXT etc.)
907 """
909 taken_names = set(all_tables + all_properties + self.columnnames)
910 if name in taken_names:
911 raise ValueError(f'External table can not be any of {taken_names}')
913 if self._external_table_exists(name):
914 return
916 sql = f'CREATE TABLE IF NOT EXISTS {name} '
917 sql += f'(key TEXT, value {dtype}, id INTEGER, '
918 sql += 'FOREIGN KEY (id) REFERENCES systems(id))'
919 sql2 = 'INSERT INTO information VALUES (?, ?)'
920 with self.managed_connection() as con:
921 cur = con.cursor()
922 cur.execute(sql)
923 # Insert an entry saying that there is a new external table
924 # present and an entry with the datatype
925 cur.execute(sql2, ('external_table_name', name))
926 cur.execute(sql2, (name + '_dtype', dtype))
928 def delete_external_table(self, name):
929 """Delete an external table."""
930 if not self._external_table_exists(name):
931 return
933 with self.managed_connection() as con:
934 cur = con.cursor()
936 sql = f'DROP TABLE {name}'
937 cur.execute(sql)
939 sql = 'DELETE FROM information WHERE value=?'
940 cur.execute(sql, (name,))
941 sql = 'DELETE FROM information WHERE name=?'
942 cur.execute(sql, (name + '_dtype',))
944 def _convert_to_recognized_types(self, value):
945 """Convert Numpy types to python types."""
946 if np.issubdtype(type(value), np.integer):
947 return int(value)
948 elif np.issubdtype(type(value), np.floating):
949 return float(value)
950 return value
952 def _insert_in_external_table(self, cursor, name=None, entries=None):
953 """Insert into external table"""
954 if name is None or entries is None:
955 # There is nothing to do
956 return
958 id = entries.pop('id')
959 dtype = self._guess_type(entries)
960 expected_dtype = self._get_value_type_of_table(cursor, name)
961 if dtype != expected_dtype:
962 raise ValueError(
963 f'The provided data type for table {name} is {dtype}, while '
964 f'it is initialized to be of type {expected_dtype}'
965 )
967 # First we check if entries already exists
968 cursor.execute(f'SELECT key FROM {name} WHERE id=?', (id,))
969 updates = []
970 for item in cursor.fetchall():
971 value = entries.pop(item[0], None)
972 if value is not None:
973 updates.append(
974 (value, id, self._convert_to_recognized_types(item[0]))
975 )
977 # Update entry if key and ID already exists
978 sql = f'UPDATE {name} SET value=? WHERE id=? AND key=?'
979 cursor.executemany(sql, updates)
981 # Insert the ones that does not already exist
982 inserts = [
983 (k, self._convert_to_recognized_types(v), id)
984 for k, v in entries.items()
985 ]
986 sql = f'INSERT INTO {name} VALUES (?, ?, ?)'
987 cursor.executemany(sql, inserts)
989 def _guess_type(self, entries):
990 """Guess the type based on the first entry."""
991 values = [v for _, v in entries.items()]
993 # Check if all datatypes are the same
994 all_types = [type(v) for v in values]
995 if any(t != all_types[0] for t in all_types):
996 typenames = [t.__name__ for t in all_types]
997 raise ValueError(
998 f'Inconsistent datatypes in the table. given types: {typenames}'
999 )
1001 val = values[0]
1002 if isinstance(val, int) or np.issubdtype(type(val), np.integer):
1003 return 'INTEGER'
1004 if isinstance(val, float) or np.issubdtype(type(val), np.floating):
1005 return 'REAL'
1006 if isinstance(val, str):
1007 return 'TEXT'
1008 raise ValueError('Unknown datatype!')
1010 def _get_value_type_of_table(self, cursor, tab_name):
1011 """Return the expected value name."""
1012 sql = 'SELECT value FROM information WHERE name=?'
1013 cursor.execute(sql, (tab_name + '_dtype',))
1014 return cursor.fetchone()[0]
1016 def _read_external_table(self, name, id):
1017 """Read row from external table."""
1019 with self.managed_connection() as con:
1020 cur = con.cursor()
1021 cur.execute(f'SELECT * FROM {name} WHERE id=?', (id,))
1022 items = cur.fetchall()
1023 dictionary = {item[0]: item[1] for item in items}
1025 return dictionary
1027 def get_all_key_names(self):
1028 """Create set of all key names."""
1029 with self.managed_connection() as con:
1030 cur = con.cursor()
1031 cur.execute('SELECT DISTINCT key FROM keys;')
1032 all_keys = {row[0] for row in cur.fetchall()}
1033 return all_keys
1036if __name__ == '__main__':
1037 from ase.db import connect
1039 con = connect(sys.argv[1])
1040 con._initialize(con._connect())
1041 print('Version:', con.version)