Coverage for ase / io / pickletrajectory.py: 76.87%
268 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3import collections
4import errno
5import os
6import pickle
7import sys
8import warnings
10import numpy as np
12from ase.atoms import Atoms
13from ase.calculators.calculator import PropertyNotImplementedError
14from ase.calculators.singlepoint import SinglePointCalculator
15from ase.constraints import FixAtoms
16from ase.parallel import barrier, world
19class PickleTrajectory:
20 """Reads/writes Atoms objects into a .traj file."""
21 # Per default, write these quantities
22 write_energy = True
23 write_forces = True
24 write_stress = True
25 write_charges = True
26 write_magmoms = True
27 write_momenta = True
28 write_info = True
30 def __init__(self, filename, mode='r', atoms=None, master=None,
31 backup=True, _warn=True):
32 """A PickleTrajectory can be created in read, write or append mode.
34 Parameters
35 ----------
37 filename:
38 The name of the parameter file. Should end in .traj.
40 mode='r':
41 The mode.
43 'r' is read mode, the file should already exist, and
44 no atoms argument should be specified.
46 'w' is write mode. If the file already exists, it is
47 renamed by appending .bak to the file name. The atoms
48 argument specifies the Atoms object to be written to the
49 file, if not given it must instead be given as an argument
50 to the write() method.
52 'a' is append mode. It acts a write mode, except that
53 data is appended to a preexisting file.
55 atoms=None:
56 The Atoms object to be written in write or append mode.
58 master=None:
59 Controls which process does the actual writing. The
60 default is that process number 0 does this. If this
61 argument is given, processes where it is True will write.
63 backup=True:
64 Use backup=False to disable renaming of an existing file.
65 """
67 if _warn:
68 msg = 'Please stop using old trajectory files!'
69 if mode == 'r':
70 msg += ('\nConvert to the new future-proof format like this:\n'
71 '\n $ python3 -m ase.io.trajectory ' +
72 filename + '\n')
74 raise RuntimeError(msg)
76 self.numbers = None
77 self.pbc = None
78 self.sanitycheck = True
79 self.pre_observers = [] # Callback functions before write
80 self.post_observers = [] # Callback functions after write
82 # Counter used to determine when callbacks are called:
83 self.write_counter = 0
85 self.offsets = []
86 if master is None:
87 master = (world.rank == 0)
88 self.master = master
89 self.backup = backup
90 self.set_atoms(atoms)
91 self.open(filename, mode)
93 def open(self, filename, mode):
94 """Opens the file.
96 For internal use only.
97 """
98 self.fd = filename
99 if mode == 'r':
100 if isinstance(filename, str):
101 self.fd = open(filename, 'rb')
102 self.read_header()
103 elif mode == 'a':
104 exists = True
105 if isinstance(filename, str):
106 exists = os.path.isfile(filename)
107 if exists:
108 exists = os.path.getsize(filename) > 0
109 if exists:
110 self.fd = open(filename, 'rb')
111 self.read_header()
112 self.fd.close()
113 barrier()
114 if self.master:
115 self.fd = open(filename, 'ab+')
116 else:
117 self.fd = open(os.devnull, 'ab+')
118 elif mode == 'w':
119 if self.master:
120 if isinstance(filename, str):
121 if self.backup and os.path.isfile(filename):
122 try:
123 os.rename(filename, filename + '.bak')
124 except OSError as e:
125 # this must run on Win only! Not atomic!
126 if e.errno != errno.EEXIST:
127 raise
128 os.unlink(filename + '.bak')
129 os.rename(filename, filename + '.bak')
130 self.fd = open(filename, 'wb')
131 else:
132 self.fd = open(os.devnull, 'wb')
133 else:
134 raise ValueError('mode must be "r", "w" or "a".')
136 def set_atoms(self, atoms=None):
137 """Associate an Atoms object with the trajectory.
139 Mostly for internal use.
140 """
141 if atoms is not None and not hasattr(atoms, 'get_positions'):
142 raise TypeError('"atoms" argument is not an Atoms object.')
143 self.atoms = atoms
145 def read_header(self):
146 if hasattr(self.fd, 'name'):
147 if os.path.isfile(self.fd.name):
148 if os.path.getsize(self.fd.name) == 0:
149 return
150 self.fd.seek(0)
151 try:
152 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory':
153 raise OSError('This is not a trajectory file!')
154 d = pickle.load(self.fd)
155 except EOFError:
156 raise EOFError('Bad trajectory file.')
158 self.pbc = d['pbc']
159 self.numbers = d['numbers']
160 self.tags = d.get('tags')
161 self.masses = d.get('masses')
162 self.constraints = dict2constraints(d)
163 self.offsets.append(self.fd.tell())
165 def write(self, atoms=None):
166 if atoms is None:
167 atoms = self.atoms
169 for image in atoms.iterimages():
170 self._write_atoms(image)
172 def _write_atoms(self, atoms):
173 """Write the atoms to the file.
175 If the atoms argument is not given, the atoms object specified
176 when creating the trajectory object is used.
177 """
178 self._call_observers(self.pre_observers)
180 if len(self.offsets) == 0:
181 self.write_header(atoms)
182 else:
183 if (atoms.pbc != self.pbc).any():
184 raise ValueError('Bad periodic boundary conditions!')
185 elif self.sanitycheck and len(atoms) != len(self.numbers):
186 raise ValueError('Bad number of atoms!')
187 elif self.sanitycheck and (atoms.numbers != self.numbers).any():
188 raise ValueError('Bad atomic numbers!')
190 if atoms.has('momenta'):
191 momenta = atoms.get_momenta()
192 else:
193 momenta = None
195 d = {'positions': atoms.get_positions(),
196 'cell': atoms.get_cell(),
197 'momenta': momenta}
199 if atoms.calc is not None:
200 if self.write_energy:
201 d['energy'] = atoms.get_potential_energy()
202 if self.write_forces:
203 assert self.write_energy
204 try:
205 d['forces'] = atoms.get_forces(apply_constraint=False)
206 except PropertyNotImplementedError:
207 pass
208 if self.write_stress:
209 assert self.write_energy
210 try:
211 d['stress'] = atoms.get_stress()
212 except PropertyNotImplementedError:
213 pass
214 if self.write_charges:
215 try:
216 d['charges'] = atoms.get_charges()
217 except PropertyNotImplementedError:
218 pass
219 if self.write_magmoms:
220 try:
221 magmoms = atoms.get_magnetic_moments()
222 if any(np.asarray(magmoms).flat):
223 d['magmoms'] = magmoms
224 except (PropertyNotImplementedError, AttributeError):
225 pass
227 if 'magmoms' not in d and atoms.has('initial_magmoms'):
228 d['magmoms'] = atoms.get_initial_magnetic_moments()
229 if 'charges' not in d and atoms.has('initial_charges'):
230 charges = atoms.get_initial_charges()
231 if (charges != 0).any():
232 d['charges'] = charges
234 if self.write_info:
235 d['info'] = stringnify_info(atoms.info)
237 if self.master:
238 pickle.dump(d, self.fd, protocol=2)
239 self.fd.flush()
240 self.offsets.append(self.fd.tell())
241 self._call_observers(self.post_observers)
242 self.write_counter += 1
244 def write_header(self, atoms):
245 self.fd.write(b'PickleTrajectory')
246 if atoms.has('tags'):
247 tags = atoms.get_tags()
248 else:
249 tags = None
250 if atoms.has('masses'):
251 masses = atoms.get_masses()
252 else:
253 masses = None
254 d = {'version': 3,
255 'pbc': atoms.get_pbc(),
256 'numbers': atoms.get_atomic_numbers(),
257 'tags': tags,
258 'masses': masses,
259 'constraints': [], # backwards compatibility
260 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)}
261 pickle.dump(d, self.fd, protocol=2)
262 self.header_written = True
263 self.offsets.append(self.fd.tell())
265 # Atomic numbers and periodic boundary conditions are only
266 # written once - in the header. Store them here so that we can
267 # check that they are the same for all images:
268 self.numbers = atoms.get_atomic_numbers()
269 self.pbc = atoms.get_pbc()
271 def close(self):
272 """Close the trajectory file."""
273 self.fd.close()
275 def __getitem__(self, i=-1):
276 if isinstance(i, slice):
277 return [self[j] for j in range(*i.indices(len(self)))]
279 N = len(self.offsets)
280 if 0 <= i < N:
281 self.fd.seek(self.offsets[i])
282 try:
283 d = pickle.load(self.fd, encoding='bytes')
284 d = {k.decode() if isinstance(k, bytes) else k: v
285 for k, v in d.items()}
286 except EOFError:
287 raise IndexError
288 if i == N - 1:
289 self.offsets.append(self.fd.tell())
290 charges = d.get('charges')
291 magmoms = d.get('magmoms')
292 try:
293 constraints = [c.copy() for c in self.constraints]
294 except Exception:
295 constraints = []
296 warnings.warn('Constraints did not unpickle correctly.')
297 atoms = Atoms(positions=d['positions'],
298 numbers=self.numbers,
299 cell=d['cell'],
300 momenta=d['momenta'],
301 magmoms=magmoms,
302 charges=charges,
303 tags=self.tags,
304 masses=self.masses,
305 pbc=self.pbc,
306 info=unstringnify_info(d.get('info', {})),
307 constraint=constraints)
308 if 'energy' in d:
309 calc = SinglePointCalculator(
310 atoms,
311 energy=d.get('energy', None),
312 forces=d.get('forces', None),
313 stress=d.get('stress', None),
314 magmoms=magmoms)
315 atoms.calc = calc
316 return atoms
318 if i >= N:
319 for j in range(N - 1, i + 1):
320 atoms = self[j]
321 return atoms
323 i = len(self) + i
324 if i < 0:
325 raise IndexError('Trajectory index out of range.')
326 return self[i]
328 def __len__(self):
329 if len(self.offsets) == 0:
330 return 0
331 N = len(self.offsets) - 1
332 while True:
333 self.fd.seek(self.offsets[N])
334 try:
335 pickle.load(self.fd)
336 except EOFError:
337 return N
338 self.offsets.append(self.fd.tell())
339 N += 1
341 def pre_write_attach(self, function, interval=1, *args, **kwargs):
342 """Attach a function to be called before writing begins.
344 function: The function or callable object to be called.
346 interval: How often the function is called. Default: every time (1).
348 All other arguments are stored, and passed to the function.
349 """
350 if not isinstance(function, collections.abc.Callable):
351 raise ValueError('Callback object must be callable.')
352 self.pre_observers.append((function, interval, args, kwargs))
354 def post_write_attach(self, function, interval=1, *args, **kwargs):
355 """Attach a function to be called after writing ends.
357 function: The function or callable object to be called.
359 interval: How often the function is called. Default: every time (1).
361 All other arguments are stored, and passed to the function.
362 """
363 if not isinstance(function, collections.abc.Callable):
364 raise ValueError('Callback object must be callable.')
365 self.post_observers.append((function, interval, args, kwargs))
367 def _call_observers(self, obs):
368 """Call pre/post write observers."""
369 for function, interval, args, kwargs in obs:
370 if self.write_counter % interval == 0:
371 function(*args, **kwargs)
373 def __enter__(self):
374 return self
376 def __exit__(self, *args):
377 self.close()
380def stringnify_info(info):
381 """Return a stringnified version of the dict *info* that is
382 ensured to be picklable. Items with non-string keys or
383 unpicklable values are dropped and a warning is issued."""
384 stringnified = {}
385 for k, v in info.items():
386 if not isinstance(k, str):
387 warnings.warn('Non-string info-dict key is not stored in ' +
388 'trajectory: ' + repr(k), UserWarning)
389 continue
390 try:
391 # Should highest protocol be used here for efficiency?
392 # Protocol 2 seems not to raise an exception when one
393 # tries to pickle a file object, so by using that, we
394 # might end up with file objects in inconsistent states.
395 s = pickle.dumps(v, protocol=0)
396 except pickle.PicklingError:
397 warnings.warn('Skipping not picklable info-dict item: ' +
398 f'"{k}" ({sys.exc_info()[1]})', UserWarning)
399 else:
400 stringnified[k] = s
401 return stringnified
404def unstringnify_info(stringnified):
405 """Convert the dict *stringnified* to a dict with unstringnified
406 objects and return it. Objects that cannot be unpickled will be
407 skipped and a warning will be issued."""
408 info = {}
409 for k, s in stringnified.items():
410 try:
411 v = pickle.loads(s)
412 except pickle.UnpicklingError:
413 warnings.warn('Skipping not unpicklable info-dict item: ' +
414 f'"{k}" ({sys.exc_info()[1]})', UserWarning)
415 else:
416 info[k] = v
417 return info
420def dict2constraints(d):
421 """Convert dict unpickled from trajectory file to list of constraints."""
423 version = d.get('version', 1)
425 if version == 1:
426 return d['constraints']
427 elif version in (2, 3):
428 try:
429 constraints = pickle.loads(d['constraints_string'])
430 for c in constraints:
431 if isinstance(c, FixAtoms) and c.index.dtype == bool:
432 # Special handling of old pickles:
433 c.index = np.arange(len(c.index))[c.index]
434 return constraints
435 except (AttributeError, KeyError, EOFError, ImportError, TypeError):
436 warnings.warn('Could not unpickle constraints!')
437 return []
438 else:
439 return []