Coverage for /builds/ase/ase/ase/db/table.py: 90.14%

142 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3from typing import List, Optional 

4 

5import numpy as np 

6 

7from ase.db.core import float_to_time_string, now 

8 

9all_columns = ('id', 'age', 'user', 'formula', 'calculator', 

10 'energy', 'natoms', 'fmax', 'pbc', 'volume', 

11 'charge', 'mass', 'smax', 'magmom') 

12 

13 

14def get_sql_columns(columns): 

15 """ Map the names of table columns to names of columns in 

16 the SQL tables""" 

17 sql_columns = list(columns) 

18 if 'age' in columns: 

19 sql_columns.remove('age') 

20 sql_columns += ['mtime', 'ctime'] 

21 if 'user' in columns: 

22 sql_columns[sql_columns.index('user')] = 'username' 

23 if 'formula' in columns: 

24 sql_columns[sql_columns.index('formula')] = 'numbers' 

25 if 'fmax' in columns: 

26 sql_columns[sql_columns.index('fmax')] = 'forces' 

27 if 'smax' in columns: 

28 sql_columns[sql_columns.index('smax')] = 'stress' 

29 if 'volume' in columns: 

30 sql_columns[sql_columns.index('volume')] = 'cell' 

31 if 'mass' in columns: 

32 sql_columns[sql_columns.index('mass')] = 'masses' 

33 if 'charge' in columns: 

34 sql_columns[sql_columns.index('charge')] = 'charges' 

35 

36 sql_columns.append('key_value_pairs') 

37 sql_columns.append('constraints') 

38 if 'id' not in sql_columns: 

39 sql_columns.append('id') 

40 

41 return sql_columns 

42 

43 

44def plural(n, word): 

45 if n == 1: 

46 return '1 ' + word 

47 return '%d %ss' % (n, word) 

48 

49 

50def cut(txt, length): 

51 if len(txt) <= length or length == 0: 

52 return txt 

53 return txt[:length - 3] + '...' 

54 

55 

56def cutlist(lst, length): 

57 if len(lst) <= length or length == 0: 

58 return lst 

59 return lst[:9] + [f'... ({len(lst) - 9} more)'] 

60 

61 

62class Table: 

63 def __init__( 

64 self, 

65 connection, 

66 unique_key: str = 'id', 

67 verbosity: int = 1, 

68 cut: int = 35, 

69 ): 

70 self.connection = connection 

71 self.verbosity = verbosity 

72 self.cut = cut 

73 self.rows: list[Row] = [] 

74 self.columns = None 

75 self.id = None 

76 self.right = None 

77 self.keys = None 

78 self.unique_key = unique_key 

79 self.addcolumns: Optional[List[str]] = None 

80 

81 def select(self, query, columns, sort, limit, offset, 

82 show_empty_columns=False): 

83 """Query datatbase and create rows.""" 

84 sql_columns = get_sql_columns(columns) 

85 self.limit = limit 

86 self.offset = offset 

87 self.rows = [Row(row, columns, self.unique_key) 

88 for row in self.connection.select( 

89 query, verbosity=self.verbosity, 

90 limit=limit, offset=offset, sort=sort, 

91 include_data=False, columns=sql_columns)] 

92 

93 self.columns = list(columns) 

94 

95 if not show_empty_columns: 

96 delete = set(range(len(columns))) 

97 for row in self.rows: 

98 for n in delete.copy(): 

99 if row.values[n] is not None: 

100 delete.remove(n) 

101 delete = sorted(delete, reverse=True) 

102 for row in self.rows: 

103 for n in delete: 

104 del row.values[n] 

105 

106 for n in delete: 

107 del self.columns[n] 

108 

109 def format(self, subscript=None): 

110 right = set() # right-adjust numbers 

111 allkeys = set() 

112 for row in self.rows: 

113 numbers = row.format(self.columns, subscript) 

114 right.update(numbers) 

115 allkeys.update(row.dct.get('key_value_pairs', {})) 

116 

117 right.add('age') 

118 self.right = [column in right for column in self.columns] 

119 

120 self.keys = sorted(allkeys) 

121 

122 def write(self, query=None): 

123 self.format() 

124 L = [[len(s) for s in row.strings] 

125 for row in self.rows] 

126 L.append([len(c) for c in self.columns]) 

127 N = np.max(L, axis=0) 

128 

129 fmt = '{:{align}{width}}' 

130 if self.verbosity > 0: 

131 print('|'.join(fmt.format(c, align='<>'[a], width=w) 

132 for c, a, w in zip(self.columns, self.right, N))) 

133 for row in self.rows: 

134 print('|'.join(fmt.format(c, align='<>'[a], width=w) 

135 for c, a, w in 

136 zip(row.strings, self.right, N))) 

137 

138 if self.verbosity == 0: 

139 return 

140 

141 nrows = len(self.rows) 

142 

143 if self.limit and nrows == self.limit: 

144 n = self.connection.count(query) 

145 print('Rows:', n, f'(showing first {self.limit})') 

146 else: 

147 print('Rows:', nrows) 

148 

149 if self.keys: 

150 print('Keys:', ', '.join(cutlist(self.keys, self.cut))) 

151 

152 def write_csv(self): 

153 if self.verbosity > 0: 

154 print(', '.join(self.columns)) 

155 for row in self.rows: 

156 print(', '.join(str(val) for val in row.values)) 

157 

158 

159class Row: 

160 def __init__(self, dct, columns, unique_key='id'): 

161 self.dct = dct 

162 self.values = None 

163 self.strings = None 

164 self.set_columns(columns) 

165 self.uid = getattr(dct, unique_key) 

166 

167 def set_columns(self, columns): 

168 self.values = [] 

169 for c in columns: 

170 if c == 'age': 

171 value = float_to_time_string(now() - self.dct.ctime) 

172 elif c == 'pbc': 

173 value = ''.join('FT'[int(p)] for p in self.dct.pbc) 

174 else: 

175 value = getattr(self.dct, c, None) 

176 self.values.append(value) 

177 

178 def format(self, columns, subscript=None): 

179 self.strings = [] 

180 numbers = set() 

181 for value, column in zip(self.values, columns): 

182 if column == 'formula' and subscript: 

183 value = subscript.sub(r'<sub>\1</sub>', value) 

184 elif isinstance(value, dict): 

185 value = str(value) 

186 elif isinstance(value, list): 

187 value = str(value) 

188 elif isinstance(value, np.ndarray): 

189 value = str(value.tolist()) 

190 elif isinstance(value, int): 

191 value = str(value) 

192 numbers.add(column) 

193 elif isinstance(value, float): 

194 numbers.add(column) 

195 value = f'{value:.3f}' 

196 elif value is None: 

197 value = '' 

198 self.strings.append(value) 

199 

200 return numbers