Coverage for /builds/ase/ase/ase/db/table.py: 90.14%
142 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3from typing import List, Optional
5import numpy as np
7from ase.db.core import float_to_time_string, now
9all_columns = ('id', 'age', 'user', 'formula', 'calculator',
10 'energy', 'natoms', 'fmax', 'pbc', 'volume',
11 'charge', 'mass', 'smax', 'magmom')
14def get_sql_columns(columns):
15 """ Map the names of table columns to names of columns in
16 the SQL tables"""
17 sql_columns = list(columns)
18 if 'age' in columns:
19 sql_columns.remove('age')
20 sql_columns += ['mtime', 'ctime']
21 if 'user' in columns:
22 sql_columns[sql_columns.index('user')] = 'username'
23 if 'formula' in columns:
24 sql_columns[sql_columns.index('formula')] = 'numbers'
25 if 'fmax' in columns:
26 sql_columns[sql_columns.index('fmax')] = 'forces'
27 if 'smax' in columns:
28 sql_columns[sql_columns.index('smax')] = 'stress'
29 if 'volume' in columns:
30 sql_columns[sql_columns.index('volume')] = 'cell'
31 if 'mass' in columns:
32 sql_columns[sql_columns.index('mass')] = 'masses'
33 if 'charge' in columns:
34 sql_columns[sql_columns.index('charge')] = 'charges'
36 sql_columns.append('key_value_pairs')
37 sql_columns.append('constraints')
38 if 'id' not in sql_columns:
39 sql_columns.append('id')
41 return sql_columns
44def plural(n, word):
45 if n == 1:
46 return '1 ' + word
47 return '%d %ss' % (n, word)
50def cut(txt, length):
51 if len(txt) <= length or length == 0:
52 return txt
53 return txt[:length - 3] + '...'
56def cutlist(lst, length):
57 if len(lst) <= length or length == 0:
58 return lst
59 return lst[:9] + [f'... ({len(lst) - 9} more)']
62class Table:
63 def __init__(
64 self,
65 connection,
66 unique_key: str = 'id',
67 verbosity: int = 1,
68 cut: int = 35,
69 ):
70 self.connection = connection
71 self.verbosity = verbosity
72 self.cut = cut
73 self.rows: list[Row] = []
74 self.columns = None
75 self.id = None
76 self.right = None
77 self.keys = None
78 self.unique_key = unique_key
79 self.addcolumns: Optional[List[str]] = None
81 def select(self, query, columns, sort, limit, offset,
82 show_empty_columns=False):
83 """Query datatbase and create rows."""
84 sql_columns = get_sql_columns(columns)
85 self.limit = limit
86 self.offset = offset
87 self.rows = [Row(row, columns, self.unique_key)
88 for row in self.connection.select(
89 query, verbosity=self.verbosity,
90 limit=limit, offset=offset, sort=sort,
91 include_data=False, columns=sql_columns)]
93 self.columns = list(columns)
95 if not show_empty_columns:
96 delete = set(range(len(columns)))
97 for row in self.rows:
98 for n in delete.copy():
99 if row.values[n] is not None:
100 delete.remove(n)
101 delete = sorted(delete, reverse=True)
102 for row in self.rows:
103 for n in delete:
104 del row.values[n]
106 for n in delete:
107 del self.columns[n]
109 def format(self, subscript=None):
110 right = set() # right-adjust numbers
111 allkeys = set()
112 for row in self.rows:
113 numbers = row.format(self.columns, subscript)
114 right.update(numbers)
115 allkeys.update(row.dct.get('key_value_pairs', {}))
117 right.add('age')
118 self.right = [column in right for column in self.columns]
120 self.keys = sorted(allkeys)
122 def write(self, query=None):
123 self.format()
124 L = [[len(s) for s in row.strings]
125 for row in self.rows]
126 L.append([len(c) for c in self.columns])
127 N = np.max(L, axis=0)
129 fmt = '{:{align}{width}}'
130 if self.verbosity > 0:
131 print('|'.join(fmt.format(c, align='<>'[a], width=w)
132 for c, a, w in zip(self.columns, self.right, N)))
133 for row in self.rows:
134 print('|'.join(fmt.format(c, align='<>'[a], width=w)
135 for c, a, w in
136 zip(row.strings, self.right, N)))
138 if self.verbosity == 0:
139 return
141 nrows = len(self.rows)
143 if self.limit and nrows == self.limit:
144 n = self.connection.count(query)
145 print('Rows:', n, f'(showing first {self.limit})')
146 else:
147 print('Rows:', nrows)
149 if self.keys:
150 print('Keys:', ', '.join(cutlist(self.keys, self.cut)))
152 def write_csv(self):
153 if self.verbosity > 0:
154 print(', '.join(self.columns))
155 for row in self.rows:
156 print(', '.join(str(val) for val in row.values))
159class Row:
160 def __init__(self, dct, columns, unique_key='id'):
161 self.dct = dct
162 self.values = None
163 self.strings = None
164 self.set_columns(columns)
165 self.uid = getattr(dct, unique_key)
167 def set_columns(self, columns):
168 self.values = []
169 for c in columns:
170 if c == 'age':
171 value = float_to_time_string(now() - self.dct.ctime)
172 elif c == 'pbc':
173 value = ''.join('FT'[int(p)] for p in self.dct.pbc)
174 else:
175 value = getattr(self.dct, c, None)
176 self.values.append(value)
178 def format(self, columns, subscript=None):
179 self.strings = []
180 numbers = set()
181 for value, column in zip(self.values, columns):
182 if column == 'formula' and subscript:
183 value = subscript.sub(r'<sub>\1</sub>', value)
184 elif isinstance(value, dict):
185 value = str(value)
186 elif isinstance(value, list):
187 value = str(value)
188 elif isinstance(value, np.ndarray):
189 value = str(value.tolist())
190 elif isinstance(value, int):
191 value = str(value)
192 numbers.add(column)
193 elif isinstance(value, float):
194 numbers.add(column)
195 value = f'{value:.3f}'
196 elif value is None:
197 value = ''
198 self.strings.append(value)
200 return numbers