Coverage for ase / io / netcdftrajectory.py: 83.01%
359 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3"""
4netcdftrajectory - I/O trajectory files in the AMBER NetCDF convention
6More information on the AMBER NetCDF conventions can be found at
7http://ambermd.org/netcdf/. This module supports extensions to
8these conventions, such as writing of additional fields and writing to
9HDF5 (NetCDF-4) files.
11A netCDF4-python is required by this module:
13 netCDF4-python - https://github.com/Unidata/netcdf4-python
15NetCDF files can be directly visualized using the libAtoms flavor of
16AtomEye (http://www.libatoms.org/),
17VMD (http://www.ks.uiuc.edu/Research/vmd/)
18or Ovito (http://www.ovito.org/, starting with version 2.3).
19"""
22import collections
23import os
24import warnings
25from functools import reduce
27import numpy as np
29import ase
30from ase.data import atomic_masses
31from ase.geometry import cellpar_to_cell
34class NetCDFTrajectory:
35 """
36 Reads/writes Atoms objects into an AMBER-style .nc trajectory file.
37 """
39 # Default dimension names
40 _frame_dim = 'frame'
41 _spatial_dim = 'spatial'
42 _atom_dim = 'atom'
43 _cell_spatial_dim = 'cell_spatial'
44 _cell_angular_dim = 'cell_angular'
45 _label_dim = 'label'
46 _Voigt_dim = 'Voigt' # For stress/strain tensors
48 # Default field names. If it is a list, check for any of these names upon
49 # opening. Upon writing, use the first name.
50 _spatial_var = 'spatial'
51 _cell_spatial_var = 'cell_spatial'
52 _cell_angular_var = 'cell_angular'
53 _time_var = 'time'
54 _numbers_var = ['atom_types', 'type', 'Z']
55 _positions_var = 'coordinates'
56 _velocities_var = 'velocities'
57 _cell_origin_var = 'cell_origin'
58 _cell_lengths_var = 'cell_lengths'
59 _cell_angles_var = 'cell_angles'
61 _default_vars = reduce(lambda x, y: x + y,
62 [_numbers_var, [_positions_var], [_velocities_var],
63 [_cell_origin_var], [_cell_lengths_var],
64 [_cell_angles_var]])
66 def __init__(self, filename, mode='r', atoms=None, types_to_numbers=None,
67 double=True, netcdf_format='NETCDF3_CLASSIC', keep_open=True,
68 index_var='id', chunk_size=1000000):
69 """
70 A NetCDFTrajectory can be created in read, write or append mode.
72 Parameters
73 ----------
75 filename:
76 The name of the parameter file. Should end in .nc.
78 mode='r':
79 The mode.
81 'r' is read mode, the file should already exist, and no atoms
82 argument should be specified.
84 'w' is write mode. The atoms argument specifies the Atoms object
85 to be written to the file, if not given it must instead be given
86 as an argument to the write() method.
88 'a' is append mode. It acts a write mode, except that data is
89 appended to a preexisting file.
91 atoms=None:
92 The Atoms object to be written in write or append mode.
94 types_to_numbers=None:
95 Dictionary or list for conversion of atom types to atomic numbers
96 when reading a trajectory file.
98 double=True:
99 Create new variable in double precision.
101 netcdf_format='NETCDF3_CLASSIC':
102 Format string for the underlying NetCDF file format. Only relevant
103 if a new file is created. More information can be found at
104 https://www.unidata.ucar.edu/software/netcdf/docs/netcdf/File-Format.html
106 'NETCDF3_CLASSIC' is the original binary format.
108 'NETCDF3_64BIT' can be used to write larger files.
110 'NETCDF4_CLASSIC' is HDF5 with some NetCDF limitations.
112 'NETCDF4' is HDF5.
114 keep_open=True:
115 Keep the file open during consecutive read/write operations.
116 Set to false if you experience data corruption. This will close the
117 file after each read/write operation by comes with serious
118 performance penalty.
120 index_var='id':
121 Name of variable containing the atom indices. Atoms are reordered
122 by this index upon reading if this variable is present. Default
123 value is for LAMMPS output. None switches atom indices off.
125 chunk_size=1000000:
126 Maximum size of consecutive number of records (along the 'atom')
127 dimension read when reading from a NetCDF file. This is used to
128 reduce the memory footprint of a read operation on very large files.
129 """
130 self.nc = None
131 self.chunk_size = chunk_size
133 self.numbers = None
134 self.pre_observers = [] # Callback functions before write
135 self.post_observers = [] # Callback functions after write are called
137 self.has_header = False
138 self._set_atoms(atoms)
140 self.types_to_numbers = None
141 if isinstance(types_to_numbers, list):
142 types_to_numbers = {x: y for x, y in enumerate(types_to_numbers)}
143 if types_to_numbers is not None:
144 self.types_to_numbers = types_to_numbers
146 self.index_var = index_var
148 if self.index_var is not None:
149 self._default_vars += [self.index_var]
151 # 'l' should be a valid type according to the netcdf4-python
152 # documentation, but does not appear to work.
153 self.dtype_conv = {'l': 'i'}
154 if not double:
155 self.dtype_conv.update(dict(d='f'))
157 self.extra_per_frame_vars = []
158 self.extra_per_file_vars = []
159 # per frame atts are global quantities, not quantities stored for each
160 # atom
161 self.extra_per_frame_atts = []
163 self.mode = mode
164 self.netcdf_format = netcdf_format
166 if atoms:
167 self.n_atoms = len(atoms)
168 else:
169 self.n_atoms = None
171 self.filename = filename
172 if keep_open is None:
173 # Only netCDF4-python supports append to files
174 self.keep_open = self.mode == 'r'
175 else:
176 self.keep_open = keep_open
178 def __del__(self):
179 self.close()
181 def _open(self):
182 """
183 Opens the file.
185 For internal use only.
186 """
187 import netCDF4
188 if self.nc is not None:
189 return
190 if self.mode == 'a' and not os.path.exists(self.filename):
191 self.mode = 'w'
192 self.nc = netCDF4.Dataset(self.filename, self.mode,
193 format=self.netcdf_format)
195 self.frame = 0
196 if self.mode == 'r' or self.mode == 'a':
197 self._read_header()
198 self.frame = self._len()
200 def _set_atoms(self, atoms=None):
201 """
202 Associate an Atoms object with the trajectory.
204 For internal use only.
205 """
206 if atoms is not None and not hasattr(atoms, 'get_positions'):
207 raise TypeError('"atoms" argument is not an Atoms object.')
208 self.atoms = atoms
210 def _read_header(self):
211 if not self.n_atoms:
212 self.n_atoms = len(self.nc.dimensions[self._atom_dim])
214 for name, var in self.nc.variables.items():
215 # This can be unicode which confuses ASE
216 name = str(name)
217 # _default_vars is taken care of already
218 if name not in self._default_vars:
219 if len(var.dimensions) >= 2:
220 if var.dimensions[0] == self._frame_dim:
221 if var.dimensions[1] == self._atom_dim:
222 self.extra_per_frame_vars += [name]
223 else:
224 self.extra_per_frame_atts += [name]
226 elif len(var.dimensions) == 1:
227 if var.dimensions[0] == self._atom_dim:
228 self.extra_per_file_vars += [name]
229 elif var.dimensions[0] == self._frame_dim:
230 self.extra_per_frame_atts += [name]
232 self.has_header = True
234 def write(self, atoms=None, frame=None, arrays=None, time=None):
235 """
236 Write the atoms to the file.
238 If the atoms argument is not given, the atoms object specified
239 when creating the trajectory object is used.
240 """
241 self._open()
242 self._call_observers(self.pre_observers)
243 if atoms is None:
244 atoms = self.atoms
246 if hasattr(atoms, 'interpolate'):
247 # seems to be a NEB
248 neb = atoms
249 assert not neb.parallel
250 try:
251 neb.get_energies_and_forces(all=True)
252 except AttributeError:
253 pass
254 for image in neb.images:
255 self.write(image)
256 return
258 if not self.has_header:
259 self._define_file_structure(atoms)
260 else:
261 if len(atoms) != self.n_atoms:
262 raise ValueError('Bad number of atoms!')
264 if frame is None:
265 i = self.frame
266 else:
267 i = frame
269 # Number can be per file variable
270 numbers = self._get_variable(self._numbers_var)
271 if numbers.dimensions[0] == self._frame_dim:
272 numbers[i] = atoms.get_atomic_numbers()
273 else:
274 if np.any(numbers != atoms.get_atomic_numbers()):
275 raise ValueError('Atomic numbers do not match!')
276 self._get_variable(self._positions_var)[i] = atoms.get_positions()
277 if atoms.has('momenta'):
278 self._add_velocities()
279 self._get_variable(self._velocities_var)[i] = \
280 atoms.get_momenta() / atoms.get_masses().reshape(-1, 1)
281 a, b, c, alpha, beta, gamma = atoms.cell.cellpar()
282 if np.any(np.logical_not(atoms.pbc)):
283 warnings.warn('Atoms have nonperiodic directions. Cell lengths in '
284 'these directions are lost and will be '
285 'shrink-wrapped when reading the NetCDF file.')
286 cell_lengths = np.array([a, b, c]) * atoms.pbc
287 self._get_variable(self._cell_lengths_var)[i] = cell_lengths
288 self._get_variable(self._cell_angles_var)[i] = [alpha, beta, gamma]
289 self._get_variable(self._cell_origin_var)[i] = \
290 atoms.get_celldisp().reshape(3)
291 if arrays is not None:
292 for array in arrays:
293 data = atoms.get_array(array)
294 if array in self.extra_per_file_vars:
295 # This field exists but is per file data. Check that the
296 # data remains consistent.
297 if np.any(self._get_variable(array) != data):
298 raise ValueError('Trying to write Atoms object with '
299 'incompatible data for the {} '
300 'array.'.format(array))
301 else:
302 self._add_array(atoms, array, data.dtype, data.shape)
303 self._get_variable(array)[i] = data
304 if time is not None:
305 self._add_time()
306 self._get_variable(self._time_var)[i] = time
308 self.sync()
310 self._call_observers(self.post_observers)
311 self.frame += 1
312 self._close()
314 def write_arrays(self, atoms, frame, arrays):
315 self._open()
316 self._call_observers(self.pre_observers)
317 for array in arrays:
318 data = atoms.get_array(array)
319 if array in self.extra_per_file_vars:
320 # This field exists but is per file data. Check that the
321 # data remains consistent.
322 if np.any(self._get_variable(array) != data):
323 raise ValueError('Trying to write Atoms object with '
324 'incompatible data for the {} '
325 'array.'.format(array))
326 else:
327 self._add_array(atoms, array, data.dtype, data.shape)
328 self._get_variable(array)[frame] = data
329 self._call_observers(self.post_observers)
330 self._close()
332 def _define_file_structure(self, atoms):
333 self.nc.Conventions = 'AMBER'
334 self.nc.ConventionVersion = '1.0'
335 self.nc.program = 'ASE'
336 self.nc.programVersion = ase.__version__
337 self.nc.title = "MOL"
339 if self._frame_dim not in self.nc.dimensions:
340 self.nc.createDimension(self._frame_dim, None)
341 if self._spatial_dim not in self.nc.dimensions:
342 self.nc.createDimension(self._spatial_dim, 3)
343 if self._atom_dim not in self.nc.dimensions:
344 self.nc.createDimension(self._atom_dim, len(atoms))
345 if self._cell_spatial_dim not in self.nc.dimensions:
346 self.nc.createDimension(self._cell_spatial_dim, 3)
347 if self._cell_angular_dim not in self.nc.dimensions:
348 self.nc.createDimension(self._cell_angular_dim, 3)
349 if self._label_dim not in self.nc.dimensions:
350 self.nc.createDimension(self._label_dim, 5)
352 # Self-describing variables from AMBER convention
353 if not self._has_variable(self._spatial_var):
354 self.nc.createVariable(self._spatial_var, 'S1',
355 (self._spatial_dim,))
356 self.nc.variables[self._spatial_var][:] = ['x', 'y', 'z']
357 if not self._has_variable(self._cell_spatial_var):
358 self.nc.createVariable(self._cell_spatial_dim, 'S1',
359 (self._cell_spatial_dim,))
360 self.nc.variables[self._cell_spatial_var][:] = ['a', 'b', 'c']
361 if not self._has_variable(self._cell_angular_var):
362 self.nc.createVariable(self._cell_angular_var, 'S1',
363 (self._cell_angular_dim, self._label_dim,))
364 self.nc.variables[self._cell_angular_var][0] = [x for x in 'alpha']
365 self.nc.variables[self._cell_angular_var][1] = [x for x in 'beta ']
366 self.nc.variables[self._cell_angular_var][2] = [x for x in 'gamma']
368 if not self._has_variable(self._numbers_var):
369 self.nc.createVariable(self._numbers_var[0], 'i',
370 (self._frame_dim, self._atom_dim,))
371 if not self._has_variable(self._positions_var):
372 self.nc.createVariable(self._positions_var, 'f4',
373 (self._frame_dim, self._atom_dim,
374 self._spatial_dim))
375 self.nc.variables[self._positions_var].units = 'Angstrom'
376 self.nc.variables[self._positions_var].scale_factor = 1.
377 if not self._has_variable(self._cell_lengths_var):
378 self.nc.createVariable(self._cell_lengths_var, 'd',
379 (self._frame_dim, self._cell_spatial_dim))
380 self.nc.variables[self._cell_lengths_var].units = 'Angstrom'
381 self.nc.variables[self._cell_lengths_var].scale_factor = 1.
382 if not self._has_variable(self._cell_angles_var):
383 self.nc.createVariable(self._cell_angles_var, 'd',
384 (self._frame_dim, self._cell_angular_dim))
385 self.nc.variables[self._cell_angles_var].units = 'degree'
386 if not self._has_variable(self._cell_origin_var):
387 self.nc.createVariable(self._cell_origin_var, 'd',
388 (self._frame_dim, self._cell_spatial_dim))
389 self.nc.variables[self._cell_origin_var].units = 'Angstrom'
390 self.nc.variables[self._cell_origin_var].scale_factor = 1.
392 def _add_time(self):
393 if not self._has_variable(self._time_var):
394 self.nc.createVariable(self._time_var, 'f8', (self._frame_dim,))
396 def _add_velocities(self):
397 if not self._has_variable(self._velocities_var):
398 self.nc.createVariable(self._velocities_var, 'f4',
399 (self._frame_dim, self._atom_dim,
400 self._spatial_dim))
401 self.nc.variables[self._positions_var].units = \
402 'Angstrom/Femtosecond'
403 self.nc.variables[self._positions_var].scale_factor = 1.
405 def _add_array(self, atoms, array_name, type, shape):
406 if not self._has_variable(array_name):
407 dims = [self._frame_dim]
408 for i in shape:
409 if i == len(atoms):
410 dims += [self._atom_dim]
411 elif i == 3:
412 dims += [self._spatial_dim]
413 elif i == 6:
414 # This can only be stress/strain tensor in Voigt notation
415 if self._Voigt_dim not in self.nc.dimensions:
416 self.nc.createDimension(self._Voigt_dim, 6)
417 dims += [self._Voigt_dim]
418 else:
419 raise TypeError("Don't know how to dump array of shape {}"
420 " into NetCDF trajectory.".format(shape))
421 if hasattr(type, 'char'):
422 t = self.dtype_conv.get(type.char, type)
423 else:
424 t = type
425 self.nc.createVariable(array_name, t, dims)
427 def _get_variable(self, name, exc=True):
428 if isinstance(name, list):
429 for n in name:
430 if n in self.nc.variables:
431 return self.nc.variables[n]
432 if exc:
433 raise RuntimeError(
434 'None of the variables {} was found in the '
435 'NetCDF trajectory.'.format(', '.join(name)))
436 else:
437 if name in self.nc.variables:
438 return self.nc.variables[name]
439 if exc:
440 raise RuntimeError('Variables {} was found in the NetCDF '
441 'trajectory.'.format(name))
442 return None
444 def _has_variable(self, name):
445 if isinstance(name, list):
446 for n in name:
447 if n in self.nc.variables:
448 return True
449 return False
450 else:
451 return name in self.nc.variables
453 def _get_data(self, name, frame, index, exc=True):
454 var = self._get_variable(name, exc=exc)
455 if var is None:
456 return None
457 if var.dimensions[0] == self._frame_dim:
458 data = np.zeros(var.shape[1:], dtype=var.dtype)
459 s = var.shape[1]
460 if s < self.chunk_size:
461 data[index] = var[frame]
462 else:
463 # If this is a large data set, only read chunks from it to
464 # reduce memory footprint of the NetCDFTrajectory reader.
465 for i in range((s - 1) // self.chunk_size + 1):
466 sl = slice(i * self.chunk_size,
467 min((i + 1) * self.chunk_size, s))
468 data[index[sl]] = var[frame, sl]
469 else:
470 data = np.zeros(var.shape, dtype=var.dtype)
471 s = var.shape[0]
472 if s < self.chunk_size:
473 data[index] = var[...]
474 else:
475 # If this is a large data set, only read chunks from it to
476 # reduce memory footprint of the NetCDFTrajectory reader.
477 for i in range((s - 1) // self.chunk_size + 1):
478 sl = slice(i * self.chunk_size,
479 min((i + 1) * self.chunk_size, s))
480 data[index[sl]] = var[sl]
481 return data
483 def __enter__(self):
484 return self
486 def __exit__(self, *args):
487 self.close()
489 def close(self):
490 """Close the trajectory file."""
491 if self.nc is not None:
492 self.nc.close()
493 self.nc = None
495 def _close(self):
496 if not self.keep_open:
497 self.close()
498 if self.mode == 'w':
499 self.mode = 'a'
501 def sync(self):
502 self.nc.sync()
504 def __getitem__(self, i=-1):
505 self._open()
507 if isinstance(i, slice):
508 return [self[j] for j in range(*i.indices(self._len()))]
510 N = self._len()
511 if 0 <= i < N:
512 # Non-periodic boundaries have cell_length == 0.0
513 cell_lengths = \
514 np.array(self.nc.variables[self._cell_lengths_var][i][:])
515 pbc = np.abs(cell_lengths > 1e-6)
517 # Do we have a cell origin?
518 if self._has_variable(self._cell_origin_var):
519 origin = np.array(
520 self.nc.variables[self._cell_origin_var][i][:])
521 else:
522 origin = np.zeros([3], dtype=float)
524 # Do we have an index variable?
525 if (self.index_var is not None and
526 self._has_variable(self.index_var)):
527 index = np.array(self.nc.variables[self.index_var][i][:])
528 # The index variable can be non-consecutive, we here construct
529 # a consecutive one.
530 consecutive_index = np.zeros_like(index)
531 consecutive_index[np.argsort(index)] = np.arange(self.n_atoms)
532 else:
533 consecutive_index = np.arange(self.n_atoms)
535 # Read element numbers
536 self.numbers = self._get_data(self._numbers_var, i,
537 consecutive_index, exc=False)
538 if self.numbers is None:
539 self.numbers = np.ones(self.n_atoms, dtype=int)
540 if self.types_to_numbers is not None:
541 d = set(self.numbers).difference(self.types_to_numbers.keys())
542 if len(d) > 0:
543 self.types_to_numbers.update({num: num for num in d})
544 func = np.vectorize(self.types_to_numbers.get)
545 self.numbers = func(self.numbers)
546 self.masses = atomic_masses[self.numbers]
548 # Read positions
549 positions = self._get_data(self._positions_var, i,
550 consecutive_index)
552 # Determine cell size for non-periodic directions from shrink
553 # wrapped cell.
554 for dim in np.arange(3)[np.logical_not(pbc)]:
555 origin[dim] = positions[:, dim].min()
556 cell_lengths[dim] = positions[:, dim].max() - origin[dim]
558 # Construct cell shape from cell lengths and angles
559 cell = cellpar_to_cell(
560 list(cell_lengths) +
561 list(self.nc.variables[self._cell_angles_var][i])
562 )
564 # Compute momenta from velocities (if present)
565 momenta = self._get_data(self._velocities_var, i,
566 consecutive_index, exc=False)
567 if momenta is not None:
568 momenta *= self.masses.reshape(-1, 1)
570 info = {
571 name: np.array(self.nc.variables[name][i])
572 for name in self.extra_per_frame_atts
573 }
574 # Create atoms object
575 atoms = ase.Atoms(
576 positions=positions,
577 numbers=self.numbers,
578 cell=cell,
579 celldisp=origin,
580 momenta=momenta,
581 masses=self.masses,
582 pbc=pbc,
583 info=info
584 )
586 # Attach additional arrays found in the NetCDF file
587 for name in self.extra_per_frame_vars:
588 atoms.set_array(name, self._get_data(name, i,
589 consecutive_index))
590 for name in self.extra_per_file_vars:
591 atoms.set_array(name, self._get_data(name, i,
592 consecutive_index))
593 self._close()
594 return atoms
596 i = N + i
597 if i < 0 or i >= N:
598 self._close()
599 raise IndexError('Trajectory index out of range.')
600 return self[i]
602 def _len(self):
603 if self._frame_dim in self.nc.dimensions:
604 return int(self._get_variable(self._positions_var).shape[0])
605 else:
606 return 0
608 def __len__(self):
609 self._open()
610 n_frames = self._len()
611 self._close()
612 return n_frames
614 def pre_write_attach(self, function, interval=1, *args, **kwargs):
615 """
616 Attach a function to be called before writing begins.
618 function: The function or callable object to be called.
620 interval: How often the function is called. Default: every time (1).
622 All other arguments are stored, and passed to the function.
623 """
624 if not isinstance(function, collections.abc.Callable):
625 raise ValueError('Callback object must be callable.')
626 self.pre_observers.append((function, interval, args, kwargs))
628 def post_write_attach(self, function, interval=1, *args, **kwargs):
629 """
630 Attach a function to be called after writing ends.
632 function: The function or callable object to be called.
634 interval: How often the function is called. Default: every time (1).
636 All other arguments are stored, and passed to the function.
637 """
638 if not isinstance(function, collections.abc.Callable):
639 raise ValueError('Callback object must be callable.')
640 self.post_observers.append((function, interval, args, kwargs))
642 def _call_observers(self, obs):
643 """Call pre/post write observers."""
644 for function, interval, args, kwargs in obs:
645 if self.write_counter % interval == 0:
646 function(*args, **kwargs)
649def read_netcdftrajectory(filename, index=-1):
650 with NetCDFTrajectory(filename, mode='r') as traj:
651 return traj[index]
654def write_netcdftrajectory(filename, images):
655 if hasattr(images, 'get_positions'):
656 images = [images]
658 with NetCDFTrajectory(filename, mode='w') as traj:
659 for atoms in images:
660 traj.write(atoms)