Coverage for ase / db / table.py: 90.07%

141 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3 

4import numpy as np 

5 

6from ase.db.core import float_to_time_string, now 

7 

8all_columns = ('id', 'age', 'user', 'formula', 'calculator', 

9 'energy', 'natoms', 'fmax', 'pbc', 'volume', 

10 'charge', 'mass', 'smax', 'magmom') 

11 

12 

13def get_sql_columns(columns): 

14 """ Map the names of table columns to names of columns in 

15 the SQL tables""" 

16 sql_columns = list(columns) 

17 if 'age' in columns: 

18 sql_columns.remove('age') 

19 sql_columns += ['mtime', 'ctime'] 

20 if 'user' in columns: 

21 sql_columns[sql_columns.index('user')] = 'username' 

22 if 'formula' in columns: 

23 sql_columns[sql_columns.index('formula')] = 'numbers' 

24 if 'fmax' in columns: 

25 sql_columns[sql_columns.index('fmax')] = 'forces' 

26 if 'smax' in columns: 

27 sql_columns[sql_columns.index('smax')] = 'stress' 

28 if 'volume' in columns: 

29 sql_columns[sql_columns.index('volume')] = 'cell' 

30 if 'mass' in columns: 

31 sql_columns[sql_columns.index('mass')] = 'masses' 

32 if 'charge' in columns: 

33 sql_columns[sql_columns.index('charge')] = 'charges' 

34 

35 sql_columns.append('key_value_pairs') 

36 sql_columns.append('constraints') 

37 if 'id' not in sql_columns: 

38 sql_columns.append('id') 

39 

40 return sql_columns 

41 

42 

43def plural(n, word): 

44 if n == 1: 

45 return '1 ' + word 

46 return '%d %ss' % (n, word) 

47 

48 

49def cut(txt, length): 

50 if len(txt) <= length or length == 0: 

51 return txt 

52 return txt[:length - 3] + '...' 

53 

54 

55def cutlist(lst, length): 

56 if len(lst) <= length or length == 0: 

57 return lst 

58 return lst[:9] + [f'... ({len(lst) - 9} more)'] 

59 

60 

61class Table: 

62 def __init__( 

63 self, 

64 connection, 

65 unique_key: str = 'id', 

66 verbosity: int = 1, 

67 cut: int = 35, 

68 ): 

69 self.connection = connection 

70 self.verbosity = verbosity 

71 self.cut = cut 

72 self.rows: list[Row] = [] 

73 self.columns = None 

74 self.id = None 

75 self.right = None 

76 self.keys = None 

77 self.unique_key = unique_key 

78 self.addcolumns: list[str] | None = None 

79 

80 def select(self, query, columns, sort, limit, offset, 

81 show_empty_columns=False): 

82 """Query datatbase and create rows.""" 

83 sql_columns = get_sql_columns(columns) 

84 self.limit = limit 

85 self.offset = offset 

86 self.rows = [Row(row, columns, self.unique_key) 

87 for row in self.connection.select( 

88 query, verbosity=self.verbosity, 

89 limit=limit, offset=offset, sort=sort, 

90 include_data=False, columns=sql_columns)] 

91 

92 self.columns = list(columns) 

93 

94 if not show_empty_columns: 

95 delete = set(range(len(columns))) 

96 for row in self.rows: 

97 for n in delete.copy(): 

98 if row.values[n] is not None: 

99 delete.remove(n) 

100 delete = sorted(delete, reverse=True) 

101 for row in self.rows: 

102 for n in delete: 

103 del row.values[n] 

104 

105 for n in delete: 

106 del self.columns[n] 

107 

108 def format(self, subscript=None): 

109 right = set() # right-adjust numbers 

110 allkeys = set() 

111 for row in self.rows: 

112 numbers = row.format(self.columns, subscript) 

113 right.update(numbers) 

114 allkeys.update(row.dct.get('key_value_pairs', {})) 

115 

116 right.add('age') 

117 self.right = [column in right for column in self.columns] 

118 

119 self.keys = sorted(allkeys) 

120 

121 def write(self, query=None): 

122 self.format() 

123 L = [[len(s) for s in row.strings] 

124 for row in self.rows] 

125 L.append([len(c) for c in self.columns]) 

126 N = np.max(L, axis=0) 

127 

128 fmt = '{:{align}{width}}' 

129 if self.verbosity > 0: 

130 print('|'.join(fmt.format(c, align='<>'[a], width=w) 

131 for c, a, w in zip(self.columns, self.right, N))) 

132 for row in self.rows: 

133 print('|'.join(fmt.format(c, align='<>'[a], width=w) 

134 for c, a, w in 

135 zip(row.strings, self.right, N))) 

136 

137 if self.verbosity == 0: 

138 return 

139 

140 nrows = len(self.rows) 

141 

142 if self.limit and nrows == self.limit: 

143 n = self.connection.count(query) 

144 print('Rows:', n, f'(showing first {self.limit})') 

145 else: 

146 print('Rows:', nrows) 

147 

148 if self.keys: 

149 print('Keys:', ', '.join(cutlist(self.keys, self.cut))) 

150 

151 def write_csv(self): 

152 if self.verbosity > 0: 

153 print(', '.join(self.columns)) 

154 for row in self.rows: 

155 print(', '.join(str(val) for val in row.values)) 

156 

157 

158class Row: 

159 def __init__(self, dct, columns, unique_key='id'): 

160 self.dct = dct 

161 self.values = None 

162 self.strings = None 

163 self.set_columns(columns) 

164 self.uid = getattr(dct, unique_key) 

165 

166 def set_columns(self, columns): 

167 self.values = [] 

168 for c in columns: 

169 if c == 'age': 

170 value = float_to_time_string(now() - self.dct.ctime) 

171 elif c == 'pbc': 

172 value = ''.join('FT'[int(p)] for p in self.dct.pbc) 

173 else: 

174 value = getattr(self.dct, c, None) 

175 self.values.append(value) 

176 

177 def format(self, columns, subscript=None): 

178 self.strings = [] 

179 numbers = set() 

180 for value, column in zip(self.values, columns): 

181 if column == 'formula' and subscript: 

182 value = subscript.sub(r'<sub>\1</sub>', value) 

183 elif isinstance(value, dict): 

184 value = str(value) 

185 elif isinstance(value, list): 

186 value = str(value) 

187 elif isinstance(value, np.ndarray): 

188 value = str(value.tolist()) 

189 elif isinstance(value, int): 

190 value = str(value) 

191 numbers.add(column) 

192 elif isinstance(value, float): 

193 numbers.add(column) 

194 value = f'{value:.3f}' 

195 elif value is None: 

196 value = '' 

197 self.strings.append(value) 

198 

199 return numbers