Coverage for ase / db / table.py: 90.07%
141 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
4import numpy as np
6from ase.db.core import float_to_time_string, now
8all_columns = ('id', 'age', 'user', 'formula', 'calculator',
9 'energy', 'natoms', 'fmax', 'pbc', 'volume',
10 'charge', 'mass', 'smax', 'magmom')
13def get_sql_columns(columns):
14 """ Map the names of table columns to names of columns in
15 the SQL tables"""
16 sql_columns = list(columns)
17 if 'age' in columns:
18 sql_columns.remove('age')
19 sql_columns += ['mtime', 'ctime']
20 if 'user' in columns:
21 sql_columns[sql_columns.index('user')] = 'username'
22 if 'formula' in columns:
23 sql_columns[sql_columns.index('formula')] = 'numbers'
24 if 'fmax' in columns:
25 sql_columns[sql_columns.index('fmax')] = 'forces'
26 if 'smax' in columns:
27 sql_columns[sql_columns.index('smax')] = 'stress'
28 if 'volume' in columns:
29 sql_columns[sql_columns.index('volume')] = 'cell'
30 if 'mass' in columns:
31 sql_columns[sql_columns.index('mass')] = 'masses'
32 if 'charge' in columns:
33 sql_columns[sql_columns.index('charge')] = 'charges'
35 sql_columns.append('key_value_pairs')
36 sql_columns.append('constraints')
37 if 'id' not in sql_columns:
38 sql_columns.append('id')
40 return sql_columns
43def plural(n, word):
44 if n == 1:
45 return '1 ' + word
46 return '%d %ss' % (n, word)
49def cut(txt, length):
50 if len(txt) <= length or length == 0:
51 return txt
52 return txt[:length - 3] + '...'
55def cutlist(lst, length):
56 if len(lst) <= length or length == 0:
57 return lst
58 return lst[:9] + [f'... ({len(lst) - 9} more)']
61class Table:
62 def __init__(
63 self,
64 connection,
65 unique_key: str = 'id',
66 verbosity: int = 1,
67 cut: int = 35,
68 ):
69 self.connection = connection
70 self.verbosity = verbosity
71 self.cut = cut
72 self.rows: list[Row] = []
73 self.columns = None
74 self.id = None
75 self.right = None
76 self.keys = None
77 self.unique_key = unique_key
78 self.addcolumns: list[str] | None = None
80 def select(self, query, columns, sort, limit, offset,
81 show_empty_columns=False):
82 """Query datatbase and create rows."""
83 sql_columns = get_sql_columns(columns)
84 self.limit = limit
85 self.offset = offset
86 self.rows = [Row(row, columns, self.unique_key)
87 for row in self.connection.select(
88 query, verbosity=self.verbosity,
89 limit=limit, offset=offset, sort=sort,
90 include_data=False, columns=sql_columns)]
92 self.columns = list(columns)
94 if not show_empty_columns:
95 delete = set(range(len(columns)))
96 for row in self.rows:
97 for n in delete.copy():
98 if row.values[n] is not None:
99 delete.remove(n)
100 delete = sorted(delete, reverse=True)
101 for row in self.rows:
102 for n in delete:
103 del row.values[n]
105 for n in delete:
106 del self.columns[n]
108 def format(self, subscript=None):
109 right = set() # right-adjust numbers
110 allkeys = set()
111 for row in self.rows:
112 numbers = row.format(self.columns, subscript)
113 right.update(numbers)
114 allkeys.update(row.dct.get('key_value_pairs', {}))
116 right.add('age')
117 self.right = [column in right for column in self.columns]
119 self.keys = sorted(allkeys)
121 def write(self, query=None):
122 self.format()
123 L = [[len(s) for s in row.strings]
124 for row in self.rows]
125 L.append([len(c) for c in self.columns])
126 N = np.max(L, axis=0)
128 fmt = '{:{align}{width}}'
129 if self.verbosity > 0:
130 print('|'.join(fmt.format(c, align='<>'[a], width=w)
131 for c, a, w in zip(self.columns, self.right, N)))
132 for row in self.rows:
133 print('|'.join(fmt.format(c, align='<>'[a], width=w)
134 for c, a, w in
135 zip(row.strings, self.right, N)))
137 if self.verbosity == 0:
138 return
140 nrows = len(self.rows)
142 if self.limit and nrows == self.limit:
143 n = self.connection.count(query)
144 print('Rows:', n, f'(showing first {self.limit})')
145 else:
146 print('Rows:', nrows)
148 if self.keys:
149 print('Keys:', ', '.join(cutlist(self.keys, self.cut)))
151 def write_csv(self):
152 if self.verbosity > 0:
153 print(', '.join(self.columns))
154 for row in self.rows:
155 print(', '.join(str(val) for val in row.values))
158class Row:
159 def __init__(self, dct, columns, unique_key='id'):
160 self.dct = dct
161 self.values = None
162 self.strings = None
163 self.set_columns(columns)
164 self.uid = getattr(dct, unique_key)
166 def set_columns(self, columns):
167 self.values = []
168 for c in columns:
169 if c == 'age':
170 value = float_to_time_string(now() - self.dct.ctime)
171 elif c == 'pbc':
172 value = ''.join('FT'[int(p)] for p in self.dct.pbc)
173 else:
174 value = getattr(self.dct, c, None)
175 self.values.append(value)
177 def format(self, columns, subscript=None):
178 self.strings = []
179 numbers = set()
180 for value, column in zip(self.values, columns):
181 if column == 'formula' and subscript:
182 value = subscript.sub(r'<sub>\1</sub>', value)
183 elif isinstance(value, dict):
184 value = str(value)
185 elif isinstance(value, list):
186 value = str(value)
187 elif isinstance(value, np.ndarray):
188 value = str(value.tolist())
189 elif isinstance(value, int):
190 value = str(value)
191 numbers.add(column)
192 elif isinstance(value, float):
193 numbers.add(column)
194 value = f'{value:.3f}'
195 elif value is None:
196 value = ''
197 self.strings.append(value)
199 return numbers