Coverage for /builds/ase/ase/ase/utils/filecache.py: 95.92%

196 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1import json 

2from collections.abc import Mapping, MutableMapping 

3from contextlib import contextmanager 

4from pathlib import Path 

5 

6from ase.io.jsonio import encode as encode_json 

7from ase.io.jsonio import read_json, write_json 

8from ase.io.ulm import InvalidULMFileError, NDArrayReader, Writer, ulmopen 

9from ase.parallel import world 

10from ase.utils import opencew 

11 

12 

13def missing(key): 

14 raise KeyError(key) 

15 

16 

17class Locked(Exception): 

18 pass 

19 

20 

21# Note: 

22# 

23# The communicator handling is a complete hack. 

24# We should entirely remove communicators from these objects. 

25# (Actually: opencew() should not know about communicators.) 

26# Then the caller is responsible for handling parallelism, 

27# which makes life simpler for both the caller and us! 

28# 

29# Also, things like clean()/__del__ are not correctly implemented 

30# in parallel. The reason why it currently "works" is that 

31# we don't call those functions from Vibrations etc., or they do so 

32# only for rank==0. 

33 

34 

35class JSONBackend: 

36 extension = '.json' 

37 DecodeError = json.decoder.JSONDecodeError 

38 

39 @staticmethod 

40 def open_for_writing(path, comm): 

41 return opencew(path, world=comm) 

42 

43 @staticmethod 

44 def read(fname): 

45 return read_json(fname, always_array=False) 

46 

47 @staticmethod 

48 def open_and_write(target, data, comm): 

49 if comm.rank == 0: 

50 write_json(target, data) 

51 

52 @staticmethod 

53 def write(fd, value): 

54 fd.write(encode_json(value).encode('utf-8')) 

55 

56 @classmethod 

57 def dump_cache(cls, path, dct, comm): 

58 return CombinedJSONCache.dump_cache(path, dct, comm) 

59 

60 @classmethod 

61 def create_multifile_cache(cls, directory, comm): 

62 return MultiFileJSONCache(directory, comm=comm) 

63 

64 

65class ULMBackend: 

66 extension = '.ulm' 

67 DecodeError = InvalidULMFileError 

68 

69 @staticmethod 

70 def open_for_writing(path, comm): 

71 fd = opencew(path, world=comm) 

72 if fd is not None: 

73 return Writer(fd, 'w', '') 

74 

75 @staticmethod 

76 def read(fname): 

77 with ulmopen(fname, 'r') as r: 

78 data = r._data['cache'] 

79 if isinstance(data, NDArrayReader): 

80 return data.read() 

81 return data 

82 

83 @staticmethod 

84 def open_and_write(target, data, comm): 

85 if comm.rank == 0: 

86 with ulmopen(target, 'w') as w: 

87 w.write('cache', data) 

88 

89 @staticmethod 

90 def write(fd, value): 

91 fd.write('cache', value) 

92 

93 @classmethod 

94 def dump_cache(cls, path, dct, comm): 

95 return CombinedULMCache.dump_cache(path, dct, comm) 

96 

97 @classmethod 

98 def create_multifile_cache(cls, directory, comm): 

99 return MultiFileULMCache(directory, comm=comm) 

100 

101 

102class CacheLock: 

103 def __init__(self, fd, key, backend): 

104 self.fd = fd 

105 self.key = key 

106 self.backend = backend 

107 

108 def save(self, value): 

109 try: 

110 self.backend.write(self.fd, value) 

111 except Exception as ex: 

112 raise RuntimeError(f'Failed to save {value} to cache') from ex 

113 finally: 

114 self.fd.close() 

115 

116 

117class _MultiFileCacheTemplate(MutableMapping): 

118 writable = True 

119 

120 def __init__(self, directory, comm=world): 

121 self.directory = Path(directory) 

122 self.comm = comm 

123 

124 def _filename(self, key): 

125 return self.directory / (f'cache.{key}' + self.backend.extension) 

126 

127 def _glob(self): 

128 return self.directory.glob('cache.*' + self.backend.extension) 

129 

130 def __iter__(self): 

131 for path in self._glob(): 

132 cache, key = path.stem.split('.', 1) 

133 if cache != 'cache': 

134 continue 

135 yield key 

136 

137 def __len__(self): 

138 # Very inefficient this, but not a big usecase. 

139 return len(list(self._glob())) 

140 

141 @contextmanager 

142 def lock(self, key): 

143 if self.comm.rank == 0: 

144 self.directory.mkdir(exist_ok=True, parents=True) 

145 path = self._filename(key) 

146 fd = self.backend.open_for_writing(path, self.comm) 

147 try: 

148 if fd is None: 

149 yield None 

150 else: 

151 yield CacheLock(fd, key, self.backend) 

152 finally: 

153 if fd is not None: 

154 fd.close() 

155 

156 def __setitem__(self, key, value): 

157 with self.lock(key) as handle: 

158 if handle is None: 

159 raise Locked(key) 

160 handle.save(value) 

161 

162 def __getitem__(self, key): 

163 path = self._filename(key) 

164 try: 

165 return self.backend.read(path) 

166 except FileNotFoundError: 

167 missing(key) 

168 except self.backend.DecodeError: 

169 # May be partially written, which typically means empty 

170 # because the file was locked with exclusive-write-open. 

171 # 

172 # Since we decide what keys we have based on which files exist, 

173 # we are obligated to return a value for this case too. 

174 # So we return None. 

175 return None 

176 

177 def __delitem__(self, key): 

178 try: 

179 self._filename(key).unlink() 

180 except FileNotFoundError: 

181 missing(key) 

182 

183 def combine(self): 

184 cache = self.backend.dump_cache( 

185 self.directory, dict(self), comm=self.comm 

186 ) 

187 assert set(cache) == set(self) 

188 self.clear() 

189 assert len(self) == 0 

190 return cache 

191 

192 def split(self): 

193 return self 

194 

195 def filecount(self): 

196 return len(self) 

197 

198 def strip_empties(self): 

199 empties = [key for key, value in self.items() if value is None] 

200 for key in empties: 

201 del self[key] 

202 return len(empties) 

203 

204 

205class _CombinedCacheTemplate(Mapping): 

206 writable = False 

207 

208 def __init__(self, directory, dct, comm=world): 

209 self.directory = Path(directory) 

210 self._dct = dict(dct) 

211 self.comm = comm 

212 

213 def filecount(self): 

214 return int(self._filename.is_file()) 

215 

216 @property 

217 def _filename(self): 

218 return self.directory / ('combined' + self.backend.extension) 

219 

220 def __len__(self): 

221 return len(self._dct) 

222 

223 def __iter__(self): 

224 return iter(self._dct) 

225 

226 def __getitem__(self, index): 

227 return self._dct[index] 

228 

229 def _dump(self): 

230 target = self._filename 

231 if target.exists(): 

232 raise RuntimeError(f'Already exists: {target}') 

233 self.directory.mkdir(exist_ok=True, parents=True) 

234 self.backend.open_and_write(target, self._dct, comm=self.comm) 

235 

236 @classmethod 

237 def dump_cache(cls, path, dct, comm=world): 

238 cache = cls(path, dct, comm=comm) 

239 cache._dump() 

240 return cache 

241 

242 @classmethod 

243 def load(cls, path, comm): 

244 # XXX Very hacky this one 

245 cache = cls(path, {}, comm=comm) 

246 dct = cls.backend.read(cache._filename) 

247 cache._dct.update(dct) 

248 return cache 

249 

250 def clear(self): 

251 self._filename.unlink() 

252 self._dct.clear() 

253 

254 def combine(self): 

255 return self 

256 

257 def split(self): 

258 cache = self.backend.create_multifile_cache( 

259 self.directory, comm=self.comm 

260 ) 

261 assert len(cache) == 0 

262 cache.update(self) 

263 assert set(cache) == set(self) 

264 self.clear() 

265 return cache 

266 

267 

268class MultiFileJSONCache(_MultiFileCacheTemplate): 

269 backend = JSONBackend() 

270 

271 

272class MultiFileULMCache(_MultiFileCacheTemplate): 

273 backend = ULMBackend() 

274 

275 

276class CombinedJSONCache(_CombinedCacheTemplate): 

277 backend = JSONBackend() 

278 

279 

280class CombinedULMCache(_CombinedCacheTemplate): 

281 backend = ULMBackend() 

282 

283 

284def get_json_cache(directory, comm=world): 

285 try: 

286 return CombinedJSONCache.load(directory, comm=comm) 

287 except FileNotFoundError: 

288 return MultiFileJSONCache(directory, comm=comm)