Coverage for ase / db / web.py: 72.64%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3"""Helper functions for Flask WSGI-app.""" 

4 

5from ase.db.core import Database 

6from ase.db.table import Table, all_columns 

7 

8 

9class Session: 

10 """Seesion object. 

11 

12 Stores stuff that the jinja2 templetes (like templates/table.html) 

13 need to show. Example from table.html:: 

14 

15 Displaying rows {{ s.row1 }}-{{ s.row2 }} out of {{ s.nrows }} 

16 

17 where *s* is the session object. 

18 """ 

19 next_id = 1 

20 sessions: dict[int, 'Session'] = {} 

21 

22 def __init__(self, project_name: str): 

23 self.id = Session.next_id 

24 Session.next_id += 1 

25 

26 Session.sessions[self.id] = self 

27 if len(Session.sessions) > 2000: 

28 # Forget old sessions: 

29 for id in sorted(Session.sessions)[:400]: 

30 del Session.sessions[id] 

31 

32 self.columns: list[str] | None = None 

33 self.nrows: int | None = None 

34 self.nrows_total: int | None = None 

35 self.page = 0 

36 self.limit = 25 

37 self.sort = '' 

38 self.query = '' 

39 self.project_name = project_name 

40 

41 def __str__(self) -> str: 

42 return str(self.__dict__) 

43 

44 @staticmethod 

45 def get(id: int) -> 'Session': 

46 return Session.sessions[id] 

47 

48 def update(self, 

49 what: str, 

50 x: str, 

51 args: dict[str, str], 

52 project) -> None: 

53 

54 if self.columns is None: 

55 self.columns = list(project.default_columns) 

56 

57 if what == 'query': 

58 self.query = project.handle_query(args) 

59 self.nrows = None 

60 self.page = 0 

61 

62 elif what == 'sort': 

63 if x == self.sort: 

64 self.sort = '-' + x 

65 elif '-' + x == self.sort: 

66 self.sort = 'id' 

67 else: 

68 self.sort = x 

69 self.page = 0 

70 

71 elif what == 'limit': 

72 self.limit = int(x) 

73 self.page = 0 

74 

75 elif what == 'page': 

76 self.page = int(x) 

77 

78 elif what == 'toggle': 

79 column = x 

80 if column == 'reset': 

81 self.columns = list(project.default_columns) 

82 else: 

83 if column in self.columns: 

84 self.columns.remove(column) 

85 if column == self.sort.lstrip('-'): 

86 self.sort = 'id' 

87 self.page = 0 

88 else: 

89 self.columns.append(column) 

90 

91 @property 

92 def row1(self) -> int: 

93 return self.page * self.limit + 1 

94 

95 @property 

96 def row2(self) -> int: 

97 assert self.nrows is not None 

98 return min((self.page + 1) * self.limit, self.nrows) 

99 

100 def paginate(self) -> list[tuple[int, str]]: 

101 """Helper function for pagination stuff.""" 

102 assert self.nrows is not None 

103 npages = (self.nrows + self.limit - 1) // self.limit 

104 p1 = min(5, npages) 

105 p2 = max(self.page - 4, p1) 

106 p3 = min(self.page + 5, npages) 

107 p4 = max(npages - 4, p3) 

108 pgs = list(range(p1)) 

109 if p1 < p2: 

110 pgs.append(-1) 

111 pgs += list(range(p2, p3)) 

112 if p3 < p4: 

113 pgs.append(-1) 

114 pgs += list(range(p4, npages)) 

115 pages = [(self.page - 1, 'previous')] 

116 for p in pgs: 

117 if p == -1: 

118 pages.append((-1, '...')) 

119 elif p == self.page: 

120 pages.append((-1, str(p + 1))) 

121 else: 

122 pages.append((p, str(p + 1))) 

123 nxt = min(self.page + 1, npages - 1) 

124 if nxt == self.page: 

125 nxt = -1 

126 pages.append((nxt, 'next')) 

127 return pages 

128 

129 def create_table(self, 

130 db: Database, 

131 uid_key: str, 

132 keys: list[str]) -> Table: 

133 query = self.query 

134 

135 if self.nrows_total is None: 

136 self.nrows_total = db.count() 

137 

138 if self.nrows is None: 

139 try: 

140 self.nrows = db.count(query) 

141 except (ValueError, KeyError) as e: 

142 error = ', '.join(['Bad query'] + list(e.args)) 

143 from flask import flash 

144 flash(error) 

145 query = 'id=0' # this will return no rows 

146 self.nrows = 0 

147 

148 table = Table(db, uid_key) 

149 table.select(query, self.columns, self.sort, 

150 self.limit, offset=self.page * self.limit, 

151 show_empty_columns=True) 

152 table.format() 

153 assert self.columns is not None 

154 table.addcolumns = sorted(column for column in 

155 [*all_columns, *keys] 

156 if column not in self.columns) 

157 return table