Coverage for ase / io / pickletrajectory.py: 76.87%

268 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import collections 

4import errno 

5import os 

6import pickle 

7import sys 

8import warnings 

9 

10import numpy as np 

11 

12from ase.atoms import Atoms 

13from ase.calculators.calculator import PropertyNotImplementedError 

14from ase.calculators.singlepoint import SinglePointCalculator 

15from ase.constraints import FixAtoms 

16from ase.parallel import barrier, world 

17 

18 

19class PickleTrajectory: 

20 """Reads/writes Atoms objects into a .traj file.""" 

21 # Per default, write these quantities 

22 write_energy = True 

23 write_forces = True 

24 write_stress = True 

25 write_charges = True 

26 write_magmoms = True 

27 write_momenta = True 

28 write_info = True 

29 

30 def __init__(self, filename, mode='r', atoms=None, master=None, 

31 backup=True, _warn=True): 

32 """A PickleTrajectory can be created in read, write or append mode. 

33 

34 Parameters 

35 ---------- 

36 

37 filename: 

38 The name of the parameter file. Should end in .traj. 

39 

40 mode='r': 

41 The mode. 

42 

43 'r' is read mode, the file should already exist, and 

44 no atoms argument should be specified. 

45 

46 'w' is write mode. If the file already exists, it is 

47 renamed by appending .bak to the file name. The atoms 

48 argument specifies the Atoms object to be written to the 

49 file, if not given it must instead be given as an argument 

50 to the write() method. 

51 

52 'a' is append mode. It acts a write mode, except that 

53 data is appended to a preexisting file. 

54 

55 atoms=None: 

56 The Atoms object to be written in write or append mode. 

57 

58 master=None: 

59 Controls which process does the actual writing. The 

60 default is that process number 0 does this. If this 

61 argument is given, processes where it is True will write. 

62 

63 backup=True: 

64 Use backup=False to disable renaming of an existing file. 

65 """ 

66 

67 if _warn: 

68 msg = 'Please stop using old trajectory files!' 

69 if mode == 'r': 

70 msg += ('\nConvert to the new future-proof format like this:\n' 

71 '\n $ python3 -m ase.io.trajectory ' + 

72 filename + '\n') 

73 

74 raise RuntimeError(msg) 

75 

76 self.numbers = None 

77 self.pbc = None 

78 self.sanitycheck = True 

79 self.pre_observers = [] # Callback functions before write 

80 self.post_observers = [] # Callback functions after write 

81 

82 # Counter used to determine when callbacks are called: 

83 self.write_counter = 0 

84 

85 self.offsets = [] 

86 if master is None: 

87 master = (world.rank == 0) 

88 self.master = master 

89 self.backup = backup 

90 self.set_atoms(atoms) 

91 self.open(filename, mode) 

92 

93 def open(self, filename, mode): 

94 """Opens the file. 

95 

96 For internal use only. 

97 """ 

98 self.fd = filename 

99 if mode == 'r': 

100 if isinstance(filename, str): 

101 self.fd = open(filename, 'rb') 

102 self.read_header() 

103 elif mode == 'a': 

104 exists = True 

105 if isinstance(filename, str): 

106 exists = os.path.isfile(filename) 

107 if exists: 

108 exists = os.path.getsize(filename) > 0 

109 if exists: 

110 self.fd = open(filename, 'rb') 

111 self.read_header() 

112 self.fd.close() 

113 barrier() 

114 if self.master: 

115 self.fd = open(filename, 'ab+') 

116 else: 

117 self.fd = open(os.devnull, 'ab+') 

118 elif mode == 'w': 

119 if self.master: 

120 if isinstance(filename, str): 

121 if self.backup and os.path.isfile(filename): 

122 try: 

123 os.rename(filename, filename + '.bak') 

124 except OSError as e: 

125 # this must run on Win only! Not atomic! 

126 if e.errno != errno.EEXIST: 

127 raise 

128 os.unlink(filename + '.bak') 

129 os.rename(filename, filename + '.bak') 

130 self.fd = open(filename, 'wb') 

131 else: 

132 self.fd = open(os.devnull, 'wb') 

133 else: 

134 raise ValueError('mode must be "r", "w" or "a".') 

135 

136 def set_atoms(self, atoms=None): 

137 """Associate an Atoms object with the trajectory. 

138 

139 Mostly for internal use. 

140 """ 

141 if atoms is not None and not hasattr(atoms, 'get_positions'): 

142 raise TypeError('"atoms" argument is not an Atoms object.') 

143 self.atoms = atoms 

144 

145 def read_header(self): 

146 if hasattr(self.fd, 'name'): 

147 if os.path.isfile(self.fd.name): 

148 if os.path.getsize(self.fd.name) == 0: 

149 return 

150 self.fd.seek(0) 

151 try: 

152 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory': 

153 raise OSError('This is not a trajectory file!') 

154 d = pickle.load(self.fd) 

155 except EOFError: 

156 raise EOFError('Bad trajectory file.') 

157 

158 self.pbc = d['pbc'] 

159 self.numbers = d['numbers'] 

160 self.tags = d.get('tags') 

161 self.masses = d.get('masses') 

162 self.constraints = dict2constraints(d) 

163 self.offsets.append(self.fd.tell()) 

164 

165 def write(self, atoms=None): 

166 if atoms is None: 

167 atoms = self.atoms 

168 

169 for image in atoms.iterimages(): 

170 self._write_atoms(image) 

171 

172 def _write_atoms(self, atoms): 

173 """Write the atoms to the file. 

174 

175 If the atoms argument is not given, the atoms object specified 

176 when creating the trajectory object is used. 

177 """ 

178 self._call_observers(self.pre_observers) 

179 

180 if len(self.offsets) == 0: 

181 self.write_header(atoms) 

182 else: 

183 if (atoms.pbc != self.pbc).any(): 

184 raise ValueError('Bad periodic boundary conditions!') 

185 elif self.sanitycheck and len(atoms) != len(self.numbers): 

186 raise ValueError('Bad number of atoms!') 

187 elif self.sanitycheck and (atoms.numbers != self.numbers).any(): 

188 raise ValueError('Bad atomic numbers!') 

189 

190 if atoms.has('momenta'): 

191 momenta = atoms.get_momenta() 

192 else: 

193 momenta = None 

194 

195 d = {'positions': atoms.get_positions(), 

196 'cell': atoms.get_cell(), 

197 'momenta': momenta} 

198 

199 if atoms.calc is not None: 

200 if self.write_energy: 

201 d['energy'] = atoms.get_potential_energy() 

202 if self.write_forces: 

203 assert self.write_energy 

204 try: 

205 d['forces'] = atoms.get_forces(apply_constraint=False) 

206 except PropertyNotImplementedError: 

207 pass 

208 if self.write_stress: 

209 assert self.write_energy 

210 try: 

211 d['stress'] = atoms.get_stress() 

212 except PropertyNotImplementedError: 

213 pass 

214 if self.write_charges: 

215 try: 

216 d['charges'] = atoms.get_charges() 

217 except PropertyNotImplementedError: 

218 pass 

219 if self.write_magmoms: 

220 try: 

221 magmoms = atoms.get_magnetic_moments() 

222 if any(np.asarray(magmoms).flat): 

223 d['magmoms'] = magmoms 

224 except (PropertyNotImplementedError, AttributeError): 

225 pass 

226 

227 if 'magmoms' not in d and atoms.has('initial_magmoms'): 

228 d['magmoms'] = atoms.get_initial_magnetic_moments() 

229 if 'charges' not in d and atoms.has('initial_charges'): 

230 charges = atoms.get_initial_charges() 

231 if (charges != 0).any(): 

232 d['charges'] = charges 

233 

234 if self.write_info: 

235 d['info'] = stringnify_info(atoms.info) 

236 

237 if self.master: 

238 pickle.dump(d, self.fd, protocol=2) 

239 self.fd.flush() 

240 self.offsets.append(self.fd.tell()) 

241 self._call_observers(self.post_observers) 

242 self.write_counter += 1 

243 

244 def write_header(self, atoms): 

245 self.fd.write(b'PickleTrajectory') 

246 if atoms.has('tags'): 

247 tags = atoms.get_tags() 

248 else: 

249 tags = None 

250 if atoms.has('masses'): 

251 masses = atoms.get_masses() 

252 else: 

253 masses = None 

254 d = {'version': 3, 

255 'pbc': atoms.get_pbc(), 

256 'numbers': atoms.get_atomic_numbers(), 

257 'tags': tags, 

258 'masses': masses, 

259 'constraints': [], # backwards compatibility 

260 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)} 

261 pickle.dump(d, self.fd, protocol=2) 

262 self.header_written = True 

263 self.offsets.append(self.fd.tell()) 

264 

265 # Atomic numbers and periodic boundary conditions are only 

266 # written once - in the header. Store them here so that we can 

267 # check that they are the same for all images: 

268 self.numbers = atoms.get_atomic_numbers() 

269 self.pbc = atoms.get_pbc() 

270 

271 def close(self): 

272 """Close the trajectory file.""" 

273 self.fd.close() 

274 

275 def __getitem__(self, i=-1): 

276 if isinstance(i, slice): 

277 return [self[j] for j in range(*i.indices(len(self)))] 

278 

279 N = len(self.offsets) 

280 if 0 <= i < N: 

281 self.fd.seek(self.offsets[i]) 

282 try: 

283 d = pickle.load(self.fd, encoding='bytes') 

284 d = {k.decode() if isinstance(k, bytes) else k: v 

285 for k, v in d.items()} 

286 except EOFError: 

287 raise IndexError 

288 if i == N - 1: 

289 self.offsets.append(self.fd.tell()) 

290 charges = d.get('charges') 

291 magmoms = d.get('magmoms') 

292 try: 

293 constraints = [c.copy() for c in self.constraints] 

294 except Exception: 

295 constraints = [] 

296 warnings.warn('Constraints did not unpickle correctly.') 

297 atoms = Atoms(positions=d['positions'], 

298 numbers=self.numbers, 

299 cell=d['cell'], 

300 momenta=d['momenta'], 

301 magmoms=magmoms, 

302 charges=charges, 

303 tags=self.tags, 

304 masses=self.masses, 

305 pbc=self.pbc, 

306 info=unstringnify_info(d.get('info', {})), 

307 constraint=constraints) 

308 if 'energy' in d: 

309 calc = SinglePointCalculator( 

310 atoms, 

311 energy=d.get('energy', None), 

312 forces=d.get('forces', None), 

313 stress=d.get('stress', None), 

314 magmoms=magmoms) 

315 atoms.calc = calc 

316 return atoms 

317 

318 if i >= N: 

319 for j in range(N - 1, i + 1): 

320 atoms = self[j] 

321 return atoms 

322 

323 i = len(self) + i 

324 if i < 0: 

325 raise IndexError('Trajectory index out of range.') 

326 return self[i] 

327 

328 def __len__(self): 

329 if len(self.offsets) == 0: 

330 return 0 

331 N = len(self.offsets) - 1 

332 while True: 

333 self.fd.seek(self.offsets[N]) 

334 try: 

335 pickle.load(self.fd) 

336 except EOFError: 

337 return N 

338 self.offsets.append(self.fd.tell()) 

339 N += 1 

340 

341 def pre_write_attach(self, function, interval=1, *args, **kwargs): 

342 """Attach a function to be called before writing begins. 

343 

344 function: The function or callable object to be called. 

345 

346 interval: How often the function is called. Default: every time (1). 

347 

348 All other arguments are stored, and passed to the function. 

349 """ 

350 if not isinstance(function, collections.abc.Callable): 

351 raise ValueError('Callback object must be callable.') 

352 self.pre_observers.append((function, interval, args, kwargs)) 

353 

354 def post_write_attach(self, function, interval=1, *args, **kwargs): 

355 """Attach a function to be called after writing ends. 

356 

357 function: The function or callable object to be called. 

358 

359 interval: How often the function is called. Default: every time (1). 

360 

361 All other arguments are stored, and passed to the function. 

362 """ 

363 if not isinstance(function, collections.abc.Callable): 

364 raise ValueError('Callback object must be callable.') 

365 self.post_observers.append((function, interval, args, kwargs)) 

366 

367 def _call_observers(self, obs): 

368 """Call pre/post write observers.""" 

369 for function, interval, args, kwargs in obs: 

370 if self.write_counter % interval == 0: 

371 function(*args, **kwargs) 

372 

373 def __enter__(self): 

374 return self 

375 

376 def __exit__(self, *args): 

377 self.close() 

378 

379 

380def stringnify_info(info): 

381 """Return a stringnified version of the dict *info* that is 

382 ensured to be picklable. Items with non-string keys or 

383 unpicklable values are dropped and a warning is issued.""" 

384 stringnified = {} 

385 for k, v in info.items(): 

386 if not isinstance(k, str): 

387 warnings.warn('Non-string info-dict key is not stored in ' + 

388 'trajectory: ' + repr(k), UserWarning) 

389 continue 

390 try: 

391 # Should highest protocol be used here for efficiency? 

392 # Protocol 2 seems not to raise an exception when one 

393 # tries to pickle a file object, so by using that, we 

394 # might end up with file objects in inconsistent states. 

395 s = pickle.dumps(v, protocol=0) 

396 except pickle.PicklingError: 

397 warnings.warn('Skipping not picklable info-dict item: ' + 

398 f'"{k}" ({sys.exc_info()[1]})', UserWarning) 

399 else: 

400 stringnified[k] = s 

401 return stringnified 

402 

403 

404def unstringnify_info(stringnified): 

405 """Convert the dict *stringnified* to a dict with unstringnified 

406 objects and return it. Objects that cannot be unpickled will be 

407 skipped and a warning will be issued.""" 

408 info = {} 

409 for k, s in stringnified.items(): 

410 try: 

411 v = pickle.loads(s) 

412 except pickle.UnpicklingError: 

413 warnings.warn('Skipping not unpicklable info-dict item: ' + 

414 f'"{k}" ({sys.exc_info()[1]})', UserWarning) 

415 else: 

416 info[k] = v 

417 return info 

418 

419 

420def dict2constraints(d): 

421 """Convert dict unpickled from trajectory file to list of constraints.""" 

422 

423 version = d.get('version', 1) 

424 

425 if version == 1: 

426 return d['constraints'] 

427 elif version in (2, 3): 

428 try: 

429 constraints = pickle.loads(d['constraints_string']) 

430 for c in constraints: 

431 if isinstance(c, FixAtoms) and c.index.dtype == bool: 

432 # Special handling of old pickles: 

433 c.index = np.arange(len(c.index))[c.index] 

434 return constraints 

435 except (AttributeError, KeyError, EOFError, ImportError, TypeError): 

436 warnings.warn('Could not unpickle constraints!') 

437 return [] 

438 else: 

439 return []