Coverage for /builds/ase/ase/ase/db/core.py: 82.91%
392 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1import functools
2import json
3import numbers
4import operator
5import os
6import re
7import warnings
8from time import time
9from typing import Any, Dict, List
11import numpy as np
13from ase.atoms import Atoms
14from ase.calculators.calculator import all_changes, all_properties
15from ase.data import atomic_numbers
16from ase.db.row import AtomsRow
17from ase.formula import Formula
18from ase.io.jsonio import create_ase_object
19from ase.parallel import DummyMPI, parallel_function, parallel_generator, world
20from ase.utils import Lock, PurePath
22T2000 = 946681200.0 # January 1. 2000
23YEAR = 31557600.0 # 365.25 days
26@functools.total_ordering
27class KeyDescription:
28 _subscript = re.compile(r'`(.)_(.)`')
29 _superscript = re.compile(r'`(.*)\^\{?(.*?)\}?`')
31 def __init__(self, key, shortdesc=None, longdesc=None, unit=''):
32 self.key = key
34 if shortdesc is None:
35 shortdesc = key
37 if longdesc is None:
38 longdesc = shortdesc
40 self.shortdesc = shortdesc
41 self.longdesc = longdesc
43 # Somewhat arbitrary that we do this conversion. Can we avoid that?
44 # Previously done in create_key_descriptions().
45 unit = self._subscript.sub(r'\1<sub>\2</sub>', unit)
46 unit = self._superscript.sub(r'\1<sup>\2</sup>', unit)
47 unit = unit.replace(r'\text{', '').replace('}', '')
49 self.unit = unit
51 def __repr__(self):
52 cls = type(self).__name__
53 return (
54 f'{cls}({self.key!r}, {self.shortdesc!r}, {self.longdesc!r}, '
55 f'unit={self.unit!r})'
56 )
58 # The templates like to sort key descriptions by shortdesc.
59 def __eq__(self, other):
60 return self.shortdesc == getattr(other, 'shortdesc', None)
62 def __lt__(self, other):
63 return self.shortdesc < getattr(other, 'shortdesc', self.shortdesc)
66def get_key_descriptions():
67 KD = KeyDescription
68 return {
69 keydesc.key: keydesc
70 for keydesc in [
71 KD('id', 'ID', 'Uniqe row ID'),
72 KD('age', 'Age', 'Time since creation'),
73 KD('formula', 'Formula', 'Chemical formula'),
74 KD('pbc', 'PBC', 'Periodic boundary conditions'),
75 KD('user', 'Username'),
76 KD('calculator', 'Calculator', 'ASE-calculator name'),
77 KD('energy', 'Energy', 'Total energy', unit='eV'),
78 KD('natoms', 'Number of atoms'),
79 KD('fmax', 'Maximum force', unit='eV/Å'),
80 KD(
81 'smax',
82 'Maximum stress',
83 'Maximum stress on unit cell',
84 unit='eV/ų',
85 ),
86 KD('charge', 'Charge', 'Net charge in unit cell', unit='|e|'),
87 KD('mass', 'Mass', 'Sum of atomic masses in unit cell', unit='au'),
88 KD('magmom', 'Magnetic moment', unit='μ_B'),
89 KD('unique_id', 'Unique ID', 'Random (unique) ID'),
90 KD('volume', 'Volume', 'Volume of unit cell', unit='ų'),
91 ]
92 }
95def now():
96 """Return time since January 1. 2000 in years."""
97 return (time() - T2000) / YEAR
100seconds = {
101 's': 1,
102 'm': 60,
103 'h': 3600,
104 'd': 86400,
105 'w': 604800,
106 'M': 2629800,
107 'y': YEAR,
108}
110longwords = {
111 's': 'second',
112 'm': 'minute',
113 'h': 'hour',
114 'd': 'day',
115 'w': 'week',
116 'M': 'month',
117 'y': 'year',
118}
120ops = {
121 '<': operator.lt,
122 '<=': operator.le,
123 '=': operator.eq,
124 '>=': operator.ge,
125 '>': operator.gt,
126 '!=': operator.ne,
127}
129invop = {'<': '>=', '<=': '>', '>=': '<', '>': '<=', '=': '!=', '!=': '='}
131word = re.compile('[_a-zA-Z][_0-9a-zA-Z]*$')
133reserved_keys = set(
134 all_properties
135 + all_changes
136 + list(atomic_numbers)
137 + [
138 'id',
139 'unique_id',
140 'ctime',
141 'mtime',
142 'user',
143 'fmax',
144 'smax',
145 'momenta',
146 'constraints',
147 'natoms',
148 'formula',
149 'age',
150 'calculator',
151 'calculator_parameters',
152 'key_value_pairs',
153 'data',
154 ]
155)
157numeric_keys = {'id', 'energy', 'magmom', 'charge', 'natoms'}
160def check(key_value_pairs):
161 for key, value in key_value_pairs.items():
162 if key == 'external_tables':
163 # Checks for external_tables are not
164 # performed
165 continue
167 if not word.match(key) or key in reserved_keys:
168 raise ValueError(f'Bad key: {key}')
169 try:
170 Formula(key, strict=True)
171 except ValueError:
172 pass
173 else:
174 warnings.warn(
175 'It is best not to use keys ({0}) that are also a '
176 'chemical formula. If you do a "db.select({0!r})",'
177 'you will not find rows with your key. Instead, you wil get '
178 'rows containing the atoms in the formula!'.format(key)
179 )
180 if not isinstance(value, (numbers.Real, str, np.bool_)):
181 raise ValueError(f'Bad value for {key!r}: {value}')
182 if isinstance(value, str):
183 for t in [bool, int, float]:
184 if str_represents(value, t):
185 raise ValueError(
186 'Value '
187 + value
188 + ' is put in as string '
189 + 'but can be interpreted as '
190 + f'{t.__name__}! Please convert '
191 + f'to {t.__name__} before '
192 + 'writing to the database OR change '
193 + 'to a different string.'
194 )
197def str_represents(value, t=int):
198 new_value = convert_str_to_int_float_bool_or_str(value)
199 return isinstance(new_value, t)
202def connect(
203 name,
204 type='extract_from_name',
205 create_indices=True,
206 use_lock_file=True,
207 append=True,
208 serial=False,
209 **db_kwargs,
210):
211 """Create connection to database.
213 name: str
214 Filename or address of database.
215 type: str
216 One of 'json', 'db', 'postgresql', 'mysql', 'aselmdb'
217 (JSON, SQLite, PostgreSQL, MYSQL, ASELMDB).
218 Default is 'extract_from_name', which will guess the type
219 from the name.
220 use_lock_file: bool
221 You can turn this off if you know what you are doing ...
222 append: bool
223 Use append=False to start a new database.
224 db_kwargs: dict
225 Optional extra kwargs to pass on to the underlying db
226 """
228 if isinstance(name, PurePath):
229 name = str(name)
231 if type == 'extract_from_name':
232 if name is None:
233 type = None
234 elif not isinstance(name, str):
235 type = 'json'
236 elif name.startswith('postgresql://') or name.startswith('postgres://'):
237 type = 'postgresql'
238 elif name.startswith('mysql://') or name.startswith('mariadb://'):
239 type = 'mysql'
240 else:
241 type = os.path.splitext(name)[1][1:]
242 if type == '':
243 raise ValueError('No file extension or database type given')
245 if type is None:
246 return Database(**db_kwargs)
248 if not append and world.rank == 0:
249 if isinstance(name, str) and os.path.isfile(name):
250 os.remove(name)
252 if type not in ['postgresql', 'mysql'] and isinstance(name, str):
253 name = os.path.abspath(name)
255 if type == 'json':
256 from ase.db.jsondb import JSONDatabase
258 return JSONDatabase(
259 name, use_lock_file=use_lock_file, serial=serial, **db_kwargs
260 )
261 if type == 'db':
262 from ase.db.sqlite import SQLite3Database
264 return SQLite3Database(
265 name, create_indices, use_lock_file, serial=serial, **db_kwargs
266 )
267 if type == 'postgresql':
268 from ase_db_backends.postgresql import PostgreSQLDatabase
270 return PostgreSQLDatabase(name, **db_kwargs)
272 if type == 'mysql':
273 from ase_db_backends.mysql import MySQLDatabase
275 return MySQLDatabase(name, **db_kwargs)
277 if type == 'aselmdb':
278 from ase_db_backends.aselmdb import LMDBDatabase
280 return LMDBDatabase(name, **db_kwargs)
282 raise ValueError('Unknown database type: ' + type)
285def lock(method):
286 """Decorator for using a lock-file."""
288 @functools.wraps(method)
289 def new_method(self, *args, **kwargs):
290 if self.lock is None:
291 return method(self, *args, **kwargs)
292 else:
293 with self.lock:
294 return method(self, *args, **kwargs)
296 return new_method
299def convert_str_to_int_float_bool_or_str(value):
300 """Safe eval()"""
301 try:
302 return int(value)
303 except ValueError:
304 try:
305 value = float(value)
306 except ValueError:
307 value = {'True': True, 'False': False}.get(value, value)
308 return value
311def parse_selection(selection, **kwargs):
312 if selection is None or selection == '':
313 expressions = []
314 elif isinstance(selection, int):
315 expressions = [('id', '=', selection)]
316 elif isinstance(selection, list):
317 expressions = selection
318 else:
319 expressions = [w.strip() for w in selection.split(',')]
320 keys = []
321 comparisons = []
322 for expression in expressions:
323 if isinstance(expression, (list, tuple)):
324 comparisons.append(expression)
325 continue
326 if expression.count('<') == 2:
327 value, expression = expression.split('<', 1)
328 if expression[0] == '=':
329 op = '>='
330 expression = expression[1:]
331 else:
332 op = '>'
333 key = expression.split('<', 1)[0]
334 comparisons.append((key, op, value))
335 for op in ['!=', '<=', '>=', '<', '>', '=']:
336 if op in expression:
337 break
338 else: # no break
339 if expression in atomic_numbers:
340 comparisons.append((expression, '>', 0))
341 else:
342 try:
343 count = Formula(expression).count()
344 except ValueError:
345 keys.append(expression)
346 else:
347 comparisons.extend(
348 (symbol, '>', n - 1) for symbol, n in count.items()
349 )
350 continue
351 key, value = expression.split(op)
352 comparisons.append((key, op, value))
354 cmps = []
355 for key, value in kwargs.items():
356 comparisons.append((key, '=', value))
358 for key, op, value in comparisons:
359 if key == 'age':
360 key = 'ctime'
361 op = invop[op]
362 value = now() - time_string_to_float(value)
363 elif key == 'formula':
364 if op != '=':
365 raise ValueError('Use fomula=...')
366 f = Formula(value)
367 count = f.count()
368 cmps.extend(
369 (atomic_numbers[symbol], '=', n) for symbol, n in count.items()
370 )
371 key = 'natoms'
372 value = len(f)
373 elif key in atomic_numbers:
374 key = atomic_numbers[key]
375 value = int(value)
376 elif isinstance(value, str):
377 if key != 'unique_id':
378 value = convert_str_to_int_float_bool_or_str(value)
379 if key in numeric_keys and not isinstance(value, (int, float)):
380 msg = 'Wrong type for "{}{}{}" - must be a number'
381 raise ValueError(msg.format(key, op, value))
382 cmps.append((key, op, value))
384 return keys, cmps
387class Database:
388 """Base class for all databases."""
390 def __init__(
391 self,
392 filename: str = None,
393 create_indices: bool = True,
394 use_lock_file: bool = False,
395 serial: bool = False,
396 ):
397 """Database object.
399 serial: bool
400 Let someone else handle parallelization. Default behavior is
401 to interact with the database on the master only and then
402 distribute results to all slaves.
403 """
404 if isinstance(filename, str):
405 filename = os.path.expanduser(filename)
406 self.filename = filename
407 self.create_indices = create_indices
408 self.lock = None
409 if use_lock_file and isinstance(filename, str):
410 self.lock = Lock(filename + '.lock', world=DummyMPI())
411 self.serial = serial
413 # Decription of columns and other stuff:
414 self._metadata = None
416 @property
417 def metadata(self) -> Dict[str, Any]:
418 raise NotImplementedError
420 @parallel_function
421 @lock
422 def write(self, atoms, key_value_pairs={}, data={}, id=None, **kwargs):
423 """Write atoms to database with key-value pairs.
425 atoms: Atoms object
426 Write atomic numbers, positions, unit cell and boundary
427 conditions. If a calculator is attached, write also already
428 calculated properties such as the energy and forces.
429 key_value_pairs: dict
430 Dictionary of key-value pairs. Values must be strings or numbers.
431 data: dict
432 Extra stuff (not for searching).
433 id: int
434 Overwrite existing row.
436 Key-value pairs can also be set using keyword arguments::
438 connection.write(atoms, name='ABC', frequency=42.0)
440 Returns integer id of the new row.
441 """
443 if atoms is None:
444 atoms = Atoms()
446 kvp = dict(key_value_pairs) # modify a copy
447 kvp.update(kwargs)
449 id = self._write(atoms, kvp, data, id)
450 return id
452 def _write(self, atoms, key_value_pairs, data, id=None):
453 check(key_value_pairs)
454 return 1
456 @parallel_function
457 @lock
458 def reserve(self, **key_value_pairs):
459 """Write empty row if not already present.
461 Usage::
463 id = conn.reserve(key1=value1, key2=value2, ...)
465 Write an empty row with the given key-value pairs and
466 return the integer id. If such a row already exists, don't write
467 anything and return None.
468 """
470 for _ in self._select(
471 [], [(key, '=', value) for key, value in key_value_pairs.items()]
472 ):
473 return None
475 atoms = Atoms()
477 calc_name = key_value_pairs.pop('calculator', None)
479 if calc_name:
480 # Allow use of calculator key
481 assert calc_name.lower() == calc_name
483 # Fake calculator class:
484 class Fake:
485 name = calc_name
487 def todict(self):
488 return {}
490 def check_state(self, atoms):
491 return ['positions']
493 atoms.calc = Fake()
495 id = self._write(atoms, key_value_pairs, {}, None)
497 return id
499 def __delitem__(self, id):
500 self.delete([id])
502 def get_atoms(
503 self, selection=None, add_additional_information=False, **kwargs
504 ):
505 """Get Atoms object.
507 selection: int, str or list
508 See the select() method.
509 add_additional_information: bool
510 Put key-value pairs and data into Atoms.info dictionary.
512 In addition, one can use keyword arguments to select specific
513 key-value pairs.
514 """
516 row = self.get(selection, **kwargs)
517 return row.toatoms(add_additional_information)
519 def __getitem__(self, selection):
520 return self.get(selection)
522 def get(self, selection=None, **kwargs):
523 """Select a single row and return it as a dictionary.
525 selection: int, str or list
526 See the select() method.
527 """
528 rows = list(self.select(selection, limit=2, **kwargs))
529 if not rows:
530 raise KeyError('no match')
531 assert len(rows) == 1, 'more than one row matched'
532 return rows[0]
534 @parallel_generator
535 def select(
536 self,
537 selection=None,
538 filter=None,
539 explain=False,
540 verbosity=1,
541 limit=None,
542 offset=0,
543 sort=None,
544 include_data=True,
545 columns='all',
546 **kwargs,
547 ):
548 """Select rows.
550 Return AtomsRow iterator with results. Selection is done
551 using key-value pairs and the special keys:
553 formula, age, user, calculator, natoms, energy, magmom
554 and/or charge.
556 selection: int, str or list
557 Can be:
559 * an integer id
560 * a string like 'key=value', where '=' can also be one of
561 '<=', '<', '>', '>=' or '!='.
562 * a string like 'key'
563 * comma separated strings like 'key1<value1,key2=value2,key'
564 * list of strings or tuples: [('charge', '=', 1)].
565 filter: function
566 A function that takes as input a row and returns True or False.
567 explain: bool
568 Explain query plan.
569 verbosity: int
570 Possible values: 0, 1 or 2.
571 limit: int or None
572 Limit selection.
573 offset: int
574 Offset into selected rows.
575 sort: str
576 Sort rows after key. Prepend with minus sign for a decending sort.
577 include_data: bool
578 Use include_data=False to skip reading data from rows.
579 columns: 'all' or list of str
580 Specify which columns from the SQL table to include.
581 For example, if only the row id and the energy is needed,
582 queries can be speeded up by setting columns=['id', 'energy'].
583 """
585 if sort:
586 if sort == 'age':
587 sort = '-ctime'
588 elif sort == '-age':
589 sort = 'ctime'
590 elif sort.lstrip('-') == 'user':
591 sort += 'name'
593 keys, cmps = parse_selection(selection, **kwargs)
594 for row in self._select(
595 keys,
596 cmps,
597 explain=explain,
598 verbosity=verbosity,
599 limit=limit,
600 offset=offset,
601 sort=sort,
602 include_data=include_data,
603 columns=columns,
604 ):
605 if filter is None or filter(row):
606 yield row
608 def count(self, selection=None, **kwargs):
609 """Count rows.
611 See the select() method for the selection syntax. Use db.count() or
612 len(db) to count all rows.
613 """
614 n = 0
615 for _ in self.select(selection, **kwargs):
616 n += 1
617 return n
619 def __len__(self):
620 return self.count()
622 @parallel_function
623 @lock
624 def update(
625 self, id, atoms=None, delete_keys=[], data=None, **add_key_value_pairs
626 ):
627 """Update and/or delete key-value pairs of row(s).
629 id: int
630 ID of row to update.
631 atoms: Atoms object
632 Optionally update the Atoms data (positions, cell, ...).
633 data: dict
634 Data dict to be added to the existing data.
635 delete_keys: list of str
636 Keys to remove.
638 Use keyword arguments to add new key-value pairs.
640 Returns number of key-value pairs added and removed.
641 """
643 if not isinstance(id, numbers.Integral):
644 if isinstance(id, list):
645 err = (
646 'First argument must be an int and not a list.\n'
647 'Do something like this instead:\n\n'
648 'with db:\n'
649 ' for id in ids:\n'
650 ' db.update(id, ...)'
651 )
652 raise ValueError(err)
653 raise TypeError('id must be an int')
655 check(add_key_value_pairs)
657 row = self._get_row(id)
658 kvp = row.key_value_pairs
660 n = len(kvp)
661 for key in delete_keys:
662 kvp.pop(key, None)
663 n -= len(kvp)
664 m = -len(kvp)
665 kvp.update(add_key_value_pairs)
666 m += len(kvp)
668 moredata = data
669 data = row.get('data', {})
670 if moredata:
671 data.update(moredata)
672 if not data:
673 data = None
675 if atoms:
676 oldrow = row
677 row = AtomsRow(atoms)
678 # Copy over data, kvp, ctime, user and id
679 row._data = oldrow._data
680 row.__dict__.update(kvp)
681 row._keys = list(kvp)
682 row.ctime = oldrow.ctime
683 row.user = oldrow.user
684 row.id = id
686 if atoms or os.path.splitext(self.filename)[1] == '.json':
687 self._write(row, kvp, data, row.id)
688 else:
689 self._update(row.id, kvp, data)
690 return m, n
692 def delete(self, ids):
693 """Delete rows."""
694 raise NotImplementedError
697def time_string_to_float(s):
698 if isinstance(s, (float, int)):
699 return s
700 s = s.replace(' ', '')
701 if '+' in s:
702 return sum(time_string_to_float(x) for x in s.split('+'))
703 if s[-2].isalpha() and s[-1] == 's':
704 s = s[:-1]
705 i = 1
706 while s[i].isdigit():
707 i += 1
708 return seconds[s[i:]] * int(s[:i]) / YEAR
711def float_to_time_string(t, long=False):
712 t *= YEAR
713 for s in 'yMwdhms':
714 x = t / seconds[s]
715 if x > 5:
716 break
717 if long:
718 return f'{x:.3f} {longwords[s]}s'
719 else:
720 return f'{round(x):.0f}{s}'
723def object_to_bytes(obj: Any) -> bytes:
724 """Serialize Python object to bytes."""
725 parts = [b'12345678']
726 obj = o2b(obj, parts)
727 offset = sum(len(part) for part in parts)
728 x = np.array(offset, np.int64)
729 if not np.little_endian:
730 x.byteswap(True)
731 parts[0] = x.tobytes()
732 parts.append(json.dumps(obj, separators=(',', ':')).encode())
733 return b''.join(parts)
736def bytes_to_object(b: bytes) -> Any:
737 """Deserialize bytes to Python object."""
738 x = np.frombuffer(b[:8], np.int64)
739 if not np.little_endian:
740 x = x.byteswap()
741 offset = x.item()
742 obj = json.loads(b[offset:].decode())
743 return b2o(obj, b)
746def o2b(obj: Any, parts: List[bytes]):
747 if isinstance(obj, (int, float, bool, str, type(None))):
748 return obj
749 if isinstance(obj, dict):
750 return {key: o2b(value, parts) for key, value in obj.items()}
751 if isinstance(obj, (list, tuple)):
752 return [o2b(value, parts) for value in obj]
753 if isinstance(obj, np.ndarray):
754 assert obj.dtype != object, (
755 'Cannot convert ndarray of type "object" to bytes.'
756 )
757 offset = sum(len(part) for part in parts)
758 if not np.little_endian:
759 obj = obj.byteswap()
760 parts.append(obj.tobytes())
761 return {'__ndarray__': [obj.shape, obj.dtype.name, offset]}
762 if isinstance(obj, complex):
763 return {'__complex__': [obj.real, obj.imag]}
764 objtype = obj.ase_objtype
765 if objtype:
766 dct = o2b(obj.todict(), parts)
767 dct['__ase_objtype__'] = objtype
768 return dct
769 raise ValueError(
770 'Objects of type {type} not allowed'.format(type=type(obj))
771 )
774def b2o(obj: Any, b: bytes) -> Any:
775 if isinstance(obj, (int, float, bool, str, type(None))):
776 return obj
778 if isinstance(obj, list):
779 return [b2o(value, b) for value in obj]
781 assert isinstance(obj, dict)
783 x = obj.get('__complex__')
784 if x is not None:
785 return complex(*x)
787 x = obj.get('__ndarray__')
788 if x is not None:
789 shape, name, offset = x
790 dtype = np.dtype(name)
791 size = dtype.itemsize * np.prod(shape).astype(int)
792 a = np.frombuffer(b[offset : offset + size], dtype)
793 a.shape = shape
794 if not np.little_endian:
795 a = a.byteswap()
796 return a
798 dct = {key: b2o(value, b) for key, value in obj.items()}
799 objtype = dct.pop('__ase_objtype__', None)
800 if objtype is None:
801 return dct
802 return create_ase_object(objtype, dct)