Coverage for /builds/ase/ase/ase/io/pickletrajectory.py: 76.87%
268 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import collections
4import errno
5import os
6import pickle
7import sys
8import warnings
10import numpy as np
12from ase.atoms import Atoms
13from ase.calculators.calculator import PropertyNotImplementedError
14from ase.calculators.singlepoint import SinglePointCalculator
15from ase.constraints import FixAtoms
16from ase.parallel import barrier, world
19class PickleTrajectory:
20 """Reads/writes Atoms objects into a .traj file."""
21 # Per default, write these quantities
22 write_energy = True
23 write_forces = True
24 write_stress = True
25 write_charges = True
26 write_magmoms = True
27 write_momenta = True
28 write_info = True
30 def __init__(self, filename, mode='r', atoms=None, master=None,
31 backup=True, _warn=True):
32 """A PickleTrajectory can be created in read, write or append mode.
34 Parameters:
36 filename:
37 The name of the parameter file. Should end in .traj.
39 mode='r':
40 The mode.
42 'r' is read mode, the file should already exist, and
43 no atoms argument should be specified.
45 'w' is write mode. If the file already exists, it is
46 renamed by appending .bak to the file name. The atoms
47 argument specifies the Atoms object to be written to the
48 file, if not given it must instead be given as an argument
49 to the write() method.
51 'a' is append mode. It acts a write mode, except that
52 data is appended to a preexisting file.
54 atoms=None:
55 The Atoms object to be written in write or append mode.
57 master=None:
58 Controls which process does the actual writing. The
59 default is that process number 0 does this. If this
60 argument is given, processes where it is True will write.
62 backup=True:
63 Use backup=False to disable renaming of an existing file.
64 """
66 if _warn:
67 msg = 'Please stop using old trajectory files!'
68 if mode == 'r':
69 msg += ('\nConvert to the new future-proof format like this:\n'
70 '\n $ python3 -m ase.io.trajectory ' +
71 filename + '\n')
73 raise RuntimeError(msg)
75 self.numbers = None
76 self.pbc = None
77 self.sanitycheck = True
78 self.pre_observers = [] # Callback functions before write
79 self.post_observers = [] # Callback functions after write
81 # Counter used to determine when callbacks are called:
82 self.write_counter = 0
84 self.offsets = []
85 if master is None:
86 master = (world.rank == 0)
87 self.master = master
88 self.backup = backup
89 self.set_atoms(atoms)
90 self.open(filename, mode)
92 def open(self, filename, mode):
93 """Opens the file.
95 For internal use only.
96 """
97 self.fd = filename
98 if mode == 'r':
99 if isinstance(filename, str):
100 self.fd = open(filename, 'rb')
101 self.read_header()
102 elif mode == 'a':
103 exists = True
104 if isinstance(filename, str):
105 exists = os.path.isfile(filename)
106 if exists:
107 exists = os.path.getsize(filename) > 0
108 if exists:
109 self.fd = open(filename, 'rb')
110 self.read_header()
111 self.fd.close()
112 barrier()
113 if self.master:
114 self.fd = open(filename, 'ab+')
115 else:
116 self.fd = open(os.devnull, 'ab+')
117 elif mode == 'w':
118 if self.master:
119 if isinstance(filename, str):
120 if self.backup and os.path.isfile(filename):
121 try:
122 os.rename(filename, filename + '.bak')
123 except OSError as e:
124 # this must run on Win only! Not atomic!
125 if e.errno != errno.EEXIST:
126 raise
127 os.unlink(filename + '.bak')
128 os.rename(filename, filename + '.bak')
129 self.fd = open(filename, 'wb')
130 else:
131 self.fd = open(os.devnull, 'wb')
132 else:
133 raise ValueError('mode must be "r", "w" or "a".')
135 def set_atoms(self, atoms=None):
136 """Associate an Atoms object with the trajectory.
138 Mostly for internal use.
139 """
140 if atoms is not None and not hasattr(atoms, 'get_positions'):
141 raise TypeError('"atoms" argument is not an Atoms object.')
142 self.atoms = atoms
144 def read_header(self):
145 if hasattr(self.fd, 'name'):
146 if os.path.isfile(self.fd.name):
147 if os.path.getsize(self.fd.name) == 0:
148 return
149 self.fd.seek(0)
150 try:
151 if self.fd.read(len('PickleTrajectory')) != b'PickleTrajectory':
152 raise OSError('This is not a trajectory file!')
153 d = pickle.load(self.fd)
154 except EOFError:
155 raise EOFError('Bad trajectory file.')
157 self.pbc = d['pbc']
158 self.numbers = d['numbers']
159 self.tags = d.get('tags')
160 self.masses = d.get('masses')
161 self.constraints = dict2constraints(d)
162 self.offsets.append(self.fd.tell())
164 def write(self, atoms=None):
165 if atoms is None:
166 atoms = self.atoms
168 for image in atoms.iterimages():
169 self._write_atoms(image)
171 def _write_atoms(self, atoms):
172 """Write the atoms to the file.
174 If the atoms argument is not given, the atoms object specified
175 when creating the trajectory object is used.
176 """
177 self._call_observers(self.pre_observers)
179 if len(self.offsets) == 0:
180 self.write_header(atoms)
181 else:
182 if (atoms.pbc != self.pbc).any():
183 raise ValueError('Bad periodic boundary conditions!')
184 elif self.sanitycheck and len(atoms) != len(self.numbers):
185 raise ValueError('Bad number of atoms!')
186 elif self.sanitycheck and (atoms.numbers != self.numbers).any():
187 raise ValueError('Bad atomic numbers!')
189 if atoms.has('momenta'):
190 momenta = atoms.get_momenta()
191 else:
192 momenta = None
194 d = {'positions': atoms.get_positions(),
195 'cell': atoms.get_cell(),
196 'momenta': momenta}
198 if atoms.calc is not None:
199 if self.write_energy:
200 d['energy'] = atoms.get_potential_energy()
201 if self.write_forces:
202 assert self.write_energy
203 try:
204 d['forces'] = atoms.get_forces(apply_constraint=False)
205 except PropertyNotImplementedError:
206 pass
207 if self.write_stress:
208 assert self.write_energy
209 try:
210 d['stress'] = atoms.get_stress()
211 except PropertyNotImplementedError:
212 pass
213 if self.write_charges:
214 try:
215 d['charges'] = atoms.get_charges()
216 except PropertyNotImplementedError:
217 pass
218 if self.write_magmoms:
219 try:
220 magmoms = atoms.get_magnetic_moments()
221 if any(np.asarray(magmoms).flat):
222 d['magmoms'] = magmoms
223 except (PropertyNotImplementedError, AttributeError):
224 pass
226 if 'magmoms' not in d and atoms.has('initial_magmoms'):
227 d['magmoms'] = atoms.get_initial_magnetic_moments()
228 if 'charges' not in d and atoms.has('initial_charges'):
229 charges = atoms.get_initial_charges()
230 if (charges != 0).any():
231 d['charges'] = charges
233 if self.write_info:
234 d['info'] = stringnify_info(atoms.info)
236 if self.master:
237 pickle.dump(d, self.fd, protocol=2)
238 self.fd.flush()
239 self.offsets.append(self.fd.tell())
240 self._call_observers(self.post_observers)
241 self.write_counter += 1
243 def write_header(self, atoms):
244 self.fd.write(b'PickleTrajectory')
245 if atoms.has('tags'):
246 tags = atoms.get_tags()
247 else:
248 tags = None
249 if atoms.has('masses'):
250 masses = atoms.get_masses()
251 else:
252 masses = None
253 d = {'version': 3,
254 'pbc': atoms.get_pbc(),
255 'numbers': atoms.get_atomic_numbers(),
256 'tags': tags,
257 'masses': masses,
258 'constraints': [], # backwards compatibility
259 'constraints_string': pickle.dumps(atoms.constraints, protocol=0)}
260 pickle.dump(d, self.fd, protocol=2)
261 self.header_written = True
262 self.offsets.append(self.fd.tell())
264 # Atomic numbers and periodic boundary conditions are only
265 # written once - in the header. Store them here so that we can
266 # check that they are the same for all images:
267 self.numbers = atoms.get_atomic_numbers()
268 self.pbc = atoms.get_pbc()
270 def close(self):
271 """Close the trajectory file."""
272 self.fd.close()
274 def __getitem__(self, i=-1):
275 if isinstance(i, slice):
276 return [self[j] for j in range(*i.indices(len(self)))]
278 N = len(self.offsets)
279 if 0 <= i < N:
280 self.fd.seek(self.offsets[i])
281 try:
282 d = pickle.load(self.fd, encoding='bytes')
283 d = {k.decode() if isinstance(k, bytes) else k: v
284 for k, v in d.items()}
285 except EOFError:
286 raise IndexError
287 if i == N - 1:
288 self.offsets.append(self.fd.tell())
289 charges = d.get('charges')
290 magmoms = d.get('magmoms')
291 try:
292 constraints = [c.copy() for c in self.constraints]
293 except Exception:
294 constraints = []
295 warnings.warn('Constraints did not unpickle correctly.')
296 atoms = Atoms(positions=d['positions'],
297 numbers=self.numbers,
298 cell=d['cell'],
299 momenta=d['momenta'],
300 magmoms=magmoms,
301 charges=charges,
302 tags=self.tags,
303 masses=self.masses,
304 pbc=self.pbc,
305 info=unstringnify_info(d.get('info', {})),
306 constraint=constraints)
307 if 'energy' in d:
308 calc = SinglePointCalculator(
309 atoms,
310 energy=d.get('energy', None),
311 forces=d.get('forces', None),
312 stress=d.get('stress', None),
313 magmoms=magmoms)
314 atoms.calc = calc
315 return atoms
317 if i >= N:
318 for j in range(N - 1, i + 1):
319 atoms = self[j]
320 return atoms
322 i = len(self) + i
323 if i < 0:
324 raise IndexError('Trajectory index out of range.')
325 return self[i]
327 def __len__(self):
328 if len(self.offsets) == 0:
329 return 0
330 N = len(self.offsets) - 1
331 while True:
332 self.fd.seek(self.offsets[N])
333 try:
334 pickle.load(self.fd)
335 except EOFError:
336 return N
337 self.offsets.append(self.fd.tell())
338 N += 1
340 def pre_write_attach(self, function, interval=1, *args, **kwargs):
341 """Attach a function to be called before writing begins.
343 function: The function or callable object to be called.
345 interval: How often the function is called. Default: every time (1).
347 All other arguments are stored, and passed to the function.
348 """
349 if not isinstance(function, collections.abc.Callable):
350 raise ValueError('Callback object must be callable.')
351 self.pre_observers.append((function, interval, args, kwargs))
353 def post_write_attach(self, function, interval=1, *args, **kwargs):
354 """Attach a function to be called after writing ends.
356 function: The function or callable object to be called.
358 interval: How often the function is called. Default: every time (1).
360 All other arguments are stored, and passed to the function.
361 """
362 if not isinstance(function, collections.abc.Callable):
363 raise ValueError('Callback object must be callable.')
364 self.post_observers.append((function, interval, args, kwargs))
366 def _call_observers(self, obs):
367 """Call pre/post write observers."""
368 for function, interval, args, kwargs in obs:
369 if self.write_counter % interval == 0:
370 function(*args, **kwargs)
372 def __enter__(self):
373 return self
375 def __exit__(self, *args):
376 self.close()
379def stringnify_info(info):
380 """Return a stringnified version of the dict *info* that is
381 ensured to be picklable. Items with non-string keys or
382 unpicklable values are dropped and a warning is issued."""
383 stringnified = {}
384 for k, v in info.items():
385 if not isinstance(k, str):
386 warnings.warn('Non-string info-dict key is not stored in ' +
387 'trajectory: ' + repr(k), UserWarning)
388 continue
389 try:
390 # Should highest protocol be used here for efficiency?
391 # Protocol 2 seems not to raise an exception when one
392 # tries to pickle a file object, so by using that, we
393 # might end up with file objects in inconsistent states.
394 s = pickle.dumps(v, protocol=0)
395 except pickle.PicklingError:
396 warnings.warn('Skipping not picklable info-dict item: ' +
397 f'"{k}" ({sys.exc_info()[1]})', UserWarning)
398 else:
399 stringnified[k] = s
400 return stringnified
403def unstringnify_info(stringnified):
404 """Convert the dict *stringnified* to a dict with unstringnified
405 objects and return it. Objects that cannot be unpickled will be
406 skipped and a warning will be issued."""
407 info = {}
408 for k, s in stringnified.items():
409 try:
410 v = pickle.loads(s)
411 except pickle.UnpicklingError:
412 warnings.warn('Skipping not unpicklable info-dict item: ' +
413 f'"{k}" ({sys.exc_info()[1]})', UserWarning)
414 else:
415 info[k] = v
416 return info
419def dict2constraints(d):
420 """Convert dict unpickled from trajectory file to list of constraints."""
422 version = d.get('version', 1)
424 if version == 1:
425 return d['constraints']
426 elif version in (2, 3):
427 try:
428 constraints = pickle.loads(d['constraints_string'])
429 for c in constraints:
430 if isinstance(c, FixAtoms) and c.index.dtype == bool:
431 # Special handling of old pickles:
432 c.index = np.arange(len(c.index))[c.index]
433 return constraints
434 except (AttributeError, KeyError, EOFError, ImportError, TypeError):
435 warnings.warn('Could not unpickle constraints!')
436 return []
437 else:
438 return []