Coverage for /builds/ase/ase/ase/io/pickletrajectory.py: 76.87%

268 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3import collections 

4import errno 

5import os 

6import pickle 

7import sys 

8import warnings 

9 

10import numpy as np 

11 

12from ase.atoms import Atoms 

13from ase.calculators.calculator import PropertyNotImplementedError 

14from ase.calculators.singlepoint import SinglePointCalculator 

15from ase.constraints import FixAtoms 

16from ase.parallel import barrier, world 

17 

18 

19class PickleTrajectory: 

20 """Reads/writes Atoms objects into a .traj file.""" 

21 # Per default, write these quantities 

22 write_energy = True 

23 write_forces = True 

24 write_stress = True 

25 write_charges = True 

26 write_magmoms = True 

27 write_momenta = True 

28 write_info = True 

29 

30 def __init__(self, filename, mode='r', atoms=None, master=None, 

31 backup=True, _warn=True): 

32 """A PickleTrajectory can be created in read, write or append mode. 

33 

34 Parameters: 

35 

36 filename: 

37 The name of the parameter file. Should end in .traj. 

38 

39 mode='r': 

40 The mode. 

41 

42 'r' is read mode, the file should already exist, and 

43 no atoms argument should be specified. 

44 

45 'w' is write mode. If the file already exists, it is 

46 renamed by appending .bak to the file name. The atoms 

47 argument specifies the Atoms object to be written to the 

48 file, if not given it must instead be given as an argument 

49 to the write() method. 

50 

51 'a' is append mode. It acts a write mode, except that 

52 data is appended to a preexisting file. 

53 

54 atoms=None: 

55 The Atoms object to be written in write or append mode. 

56 

57 master=None: 

58 Controls which process does the actual writing. The 

59 default is that process number 0 does this. If this 

60 argument is given, processes where it is True will write. 

61 

62 backup=True: 

63 Use backup=False to disable renaming of an existing file. 

64 """ 

65 

66 if _warn: 

67 msg = 'Please stop using old trajectory files!' 

68 if mode == 'r': 

69 msg += ('\nConvert to the new future-proof format like this:\n' 

70 '\n $ python3 -m ase.io.trajectory ' + 

71 filename + '\n') 

72 

73 raise RuntimeError(msg) 

74 

75 self.numbers = None 

76 self.pbc = None 

77 self.sanitycheck = True 

78 self.pre_observers = [] # Callback functions before write 

79 self.post_observers = [] # Callback functions after write 

80 

81 # Counter used to determine when callbacks are called: 

82 self.write_counter = 0 

83 

84 self.offsets = [] 

85 if master is None: 

86 master = (world.rank == 0) 

87 self.master = master 

88 self.backup = backup 

89 self.set_atoms(atoms) 

90 self.open(filename, mode) 

91 

92 def open(self, filename, mode): 

93 """Opens the file. 

94 

95 For internal use only. 

96 """ 

97 self.fd = filename 

98 if mode == 'r': 

99 if isinstance(filename, str): 

100 self.fd = open(filename, 'rb') 

101 self.read_header() 

102 elif mode == 'a': 

103 exists = True 

104 if isinstance(filename, str): 

105 exists = os.path.isfile(filename) 

106 if exists: 

107 exists = os.path.getsize(filename) > 0 

108 if exists: 

109 self.fd = open(filename, 'rb') 

110 self.read_header() 

111 self.fd.close() 

112 barrier() 

113 if self.master: 

114 self.fd = open(filename, 'ab+') 

115 else: 

116 self.fd = open(os.devnull, 'ab+') 

117 elif mode == 'w': 

118 if self.master: 

119 if isinstance(filename, str): 

120 if self.backup and os.path.isfile(filename): 

121 try: 

122 os.rename(filename, filename + '.bak') 

123 except OSError as e: 

124 # this must run on Win only! Not atomic! 

125 if e.errno != errno.EEXIST: 

126 raise 

127 os.unlink(filename + '.bak') 

128 os.rename(filename, filename + '.bak') 

129 self.fd = open(filename, 'wb') 

130 else: 

131 self.fd = open(os.devnull, 'wb') 

132 else: 

133 raise ValueError('mode must be "r", "w" or "a".') 

134 

135 def set_atoms(self, atoms=None): 

136 """Associate an Atoms object with the trajectory. 

137 

138 Mostly for internal use. 

139 """ 

140 if atoms is not None and not hasattr(atoms, 'get_positions'): 

141 raise TypeError('"atoms" argument is not an Atoms object.') 

142 self.atoms = atoms 

143 

144 def read_header(self): 

145 if hasattr(self.fd, 'name'): 

146 if os.path.isfile(self.fd.name): 

147 if os.path.getsize(self.fd.name) == 0: 

148 return 

149 self.fd.seek(0) 

150 try: 

151 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory': 

152 raise OSError('This is not a trajectory file!') 

153 d = pickle.load(self.fd) 

154 except EOFError: 

155 raise EOFError('Bad trajectory file.') 

156 

157 self.pbc = d['pbc'] 

158 self.numbers = d['numbers'] 

159 self.tags = d.get('tags') 

160 self.masses = d.get('masses') 

161 self.constraints = dict2constraints(d) 

162 self.offsets.append(self.fd.tell()) 

163 

164 def write(self, atoms=None): 

165 if atoms is None: 

166 atoms = self.atoms 

167 

168 for image in atoms.iterimages(): 

169 self._write_atoms(image) 

170 

171 def _write_atoms(self, atoms): 

172 """Write the atoms to the file. 

173 

174 If the atoms argument is not given, the atoms object specified 

175 when creating the trajectory object is used. 

176 """ 

177 self._call_observers(self.pre_observers) 

178 

179 if len(self.offsets) == 0: 

180 self.write_header(atoms) 

181 else: 

182 if (atoms.pbc != self.pbc).any(): 

183 raise ValueError('Bad periodic boundary conditions!') 

184 elif self.sanitycheck and len(atoms) != len(self.numbers): 

185 raise ValueError('Bad number of atoms!') 

186 elif self.sanitycheck and (atoms.numbers != self.numbers).any(): 

187 raise ValueError('Bad atomic numbers!') 

188 

189 if atoms.has('momenta'): 

190 momenta = atoms.get_momenta() 

191 else: 

192 momenta = None 

193 

194 d = {'positions': atoms.get_positions(), 

195 'cell': atoms.get_cell(), 

196 'momenta': momenta} 

197 

198 if atoms.calc is not None: 

199 if self.write_energy: 

200 d['energy'] = atoms.get_potential_energy() 

201 if self.write_forces: 

202 assert self.write_energy 

203 try: 

204 d['forces'] = atoms.get_forces(apply_constraint=False) 

205 except PropertyNotImplementedError: 

206 pass 

207 if self.write_stress: 

208 assert self.write_energy 

209 try: 

210 d['stress'] = atoms.get_stress() 

211 except PropertyNotImplementedError: 

212 pass 

213 if self.write_charges: 

214 try: 

215 d['charges'] = atoms.get_charges() 

216 except PropertyNotImplementedError: 

217 pass 

218 if self.write_magmoms: 

219 try: 

220 magmoms = atoms.get_magnetic_moments() 

221 if any(np.asarray(magmoms).flat): 

222 d['magmoms'] = magmoms 

223 except (PropertyNotImplementedError, AttributeError): 

224 pass 

225 

226 if 'magmoms' not in d and atoms.has('initial_magmoms'): 

227 d['magmoms'] = atoms.get_initial_magnetic_moments() 

228 if 'charges' not in d and atoms.has('initial_charges'): 

229 charges = atoms.get_initial_charges() 

230 if (charges != 0).any(): 

231 d['charges'] = charges 

232 

233 if self.write_info: 

234 d['info'] = stringnify_info(atoms.info) 

235 

236 if self.master: 

237 pickle.dump(d, self.fd, protocol=2) 

238 self.fd.flush() 

239 self.offsets.append(self.fd.tell()) 

240 self._call_observers(self.post_observers) 

241 self.write_counter += 1 

242 

243 def write_header(self, atoms): 

244 self.fd.write(b'PickleTrajectory') 

245 if atoms.has('tags'): 

246 tags = atoms.get_tags() 

247 else: 

248 tags = None 

249 if atoms.has('masses'): 

250 masses = atoms.get_masses() 

251 else: 

252 masses = None 

253 d = {'version': 3, 

254 'pbc': atoms.get_pbc(), 

255 'numbers': atoms.get_atomic_numbers(), 

256 'tags': tags, 

257 'masses': masses, 

258 'constraints': [], # backwards compatibility 

259 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)} 

260 pickle.dump(d, self.fd, protocol=2) 

261 self.header_written = True 

262 self.offsets.append(self.fd.tell()) 

263 

264 # Atomic numbers and periodic boundary conditions are only 

265 # written once - in the header. Store them here so that we can 

266 # check that they are the same for all images: 

267 self.numbers = atoms.get_atomic_numbers() 

268 self.pbc = atoms.get_pbc() 

269 

270 def close(self): 

271 """Close the trajectory file.""" 

272 self.fd.close() 

273 

274 def __getitem__(self, i=-1): 

275 if isinstance(i, slice): 

276 return [self[j] for j in range(*i.indices(len(self)))] 

277 

278 N = len(self.offsets) 

279 if 0 <= i < N: 

280 self.fd.seek(self.offsets[i]) 

281 try: 

282 d = pickle.load(self.fd, encoding='bytes') 

283 d = {k.decode() if isinstance(k, bytes) else k: v 

284 for k, v in d.items()} 

285 except EOFError: 

286 raise IndexError 

287 if i == N - 1: 

288 self.offsets.append(self.fd.tell()) 

289 charges = d.get('charges') 

290 magmoms = d.get('magmoms') 

291 try: 

292 constraints = [c.copy() for c in self.constraints] 

293 except Exception: 

294 constraints = [] 

295 warnings.warn('Constraints did not unpickle correctly.') 

296 atoms = Atoms(positions=d['positions'], 

297 numbers=self.numbers, 

298 cell=d['cell'], 

299 momenta=d['momenta'], 

300 magmoms=magmoms, 

301 charges=charges, 

302 tags=self.tags, 

303 masses=self.masses, 

304 pbc=self.pbc, 

305 info=unstringnify_info(d.get('info', {})), 

306 constraint=constraints) 

307 if 'energy' in d: 

308 calc = SinglePointCalculator( 

309 atoms, 

310 energy=d.get('energy', None), 

311 forces=d.get('forces', None), 

312 stress=d.get('stress', None), 

313 magmoms=magmoms) 

314 atoms.calc = calc 

315 return atoms 

316 

317 if i >= N: 

318 for j in range(N - 1, i + 1): 

319 atoms = self[j] 

320 return atoms 

321 

322 i = len(self) + i 

323 if i < 0: 

324 raise IndexError('Trajectory index out of range.') 

325 return self[i] 

326 

327 def __len__(self): 

328 if len(self.offsets) == 0: 

329 return 0 

330 N = len(self.offsets) - 1 

331 while True: 

332 self.fd.seek(self.offsets[N]) 

333 try: 

334 pickle.load(self.fd) 

335 except EOFError: 

336 return N 

337 self.offsets.append(self.fd.tell()) 

338 N += 1 

339 

340 def pre_write_attach(self, function, interval=1, *args, **kwargs): 

341 """Attach a function to be called before writing begins. 

342 

343 function: The function or callable object to be called. 

344 

345 interval: How often the function is called. Default: every time (1). 

346 

347 All other arguments are stored, and passed to the function. 

348 """ 

349 if not isinstance(function, collections.abc.Callable): 

350 raise ValueError('Callback object must be callable.') 

351 self.pre_observers.append((function, interval, args, kwargs)) 

352 

353 def post_write_attach(self, function, interval=1, *args, **kwargs): 

354 """Attach a function to be called after writing ends. 

355 

356 function: The function or callable object to be called. 

357 

358 interval: How often the function is called. Default: every time (1). 

359 

360 All other arguments are stored, and passed to the function. 

361 """ 

362 if not isinstance(function, collections.abc.Callable): 

363 raise ValueError('Callback object must be callable.') 

364 self.post_observers.append((function, interval, args, kwargs)) 

365 

366 def _call_observers(self, obs): 

367 """Call pre/post write observers.""" 

368 for function, interval, args, kwargs in obs: 

369 if self.write_counter % interval == 0: 

370 function(*args, **kwargs) 

371 

372 def __enter__(self): 

373 return self 

374 

375 def __exit__(self, *args): 

376 self.close() 

377 

378 

379def stringnify_info(info): 

380 """Return a stringnified version of the dict *info* that is 

381 ensured to be picklable. Items with non-string keys or 

382 unpicklable values are dropped and a warning is issued.""" 

383 stringnified = {} 

384 for k, v in info.items(): 

385 if not isinstance(k, str): 

386 warnings.warn('Non-string info-dict key is not stored in ' + 

387 'trajectory: ' + repr(k), UserWarning) 

388 continue 

389 try: 

390 # Should highest protocol be used here for efficiency? 

391 # Protocol 2 seems not to raise an exception when one 

392 # tries to pickle a file object, so by using that, we 

393 # might end up with file objects in inconsistent states. 

394 s = pickle.dumps(v, protocol=0) 

395 except pickle.PicklingError: 

396 warnings.warn('Skipping not picklable info-dict item: ' + 

397 f'"{k}" ({sys.exc_info()[1]})', UserWarning) 

398 else: 

399 stringnified[k] = s 

400 return stringnified 

401 

402 

403def unstringnify_info(stringnified): 

404 """Convert the dict *stringnified* to a dict with unstringnified 

405 objects and return it. Objects that cannot be unpickled will be 

406 skipped and a warning will be issued.""" 

407 info = {} 

408 for k, s in stringnified.items(): 

409 try: 

410 v = pickle.loads(s) 

411 except pickle.UnpicklingError: 

412 warnings.warn('Skipping not unpicklable info-dict item: ' + 

413 f'"{k}" ({sys.exc_info()[1]})', UserWarning) 

414 else: 

415 info[k] = v 

416 return info 

417 

418 

419def dict2constraints(d): 

420 """Convert dict unpickled from trajectory file to list of constraints.""" 

421 

422 version = d.get('version', 1) 

423 

424 if version == 1: 

425 return d['constraints'] 

426 elif version in (2, 3): 

427 try: 

428 constraints = pickle.loads(d['constraints_string']) 

429 for c in constraints: 

430 if isinstance(c, FixAtoms) and c.index.dtype == bool: 

431 # Special handling of old pickles: 

432 c.index = np.arange(len(c.index))[c.index] 

433 return constraints 

434 except (AttributeError, KeyError, EOFError, ImportError, TypeError): 

435 warnings.warn('Could not unpickle constraints!') 

436 return [] 

437 else: 

438 return []