Coverage for /builds/ase/ase/ase/io/netcdftrajectory.py: 83.01%
359 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""
4netcdftrajectory - I/O trajectory files in the AMBER NetCDF convention
6More information on the AMBER NetCDF conventions can be found at
7http://ambermd.org/netcdf/. This module supports extensions to
8these conventions, such as writing of additional fields and writing to
9HDF5 (NetCDF-4) files.
11A netCDF4-python is required by this module:
13 netCDF4-python - https://github.com/Unidata/netcdf4-python
15NetCDF files can be directly visualized using the libAtoms flavor of
16AtomEye (http://www.libatoms.org/),
17VMD (http://www.ks.uiuc.edu/Research/vmd/)
18or Ovito (http://www.ovito.org/, starting with version 2.3).
19"""
22import collections
23import os
24import warnings
25from functools import reduce
27import numpy as np
29import ase
30from ase.data import atomic_masses
31from ase.geometry import cellpar_to_cell
34class NetCDFTrajectory:
35 """
36 Reads/writes Atoms objects into an AMBER-style .nc trajectory file.
37 """
39 # Default dimension names
40 _frame_dim = 'frame'
41 _spatial_dim = 'spatial'
42 _atom_dim = 'atom'
43 _cell_spatial_dim = 'cell_spatial'
44 _cell_angular_dim = 'cell_angular'
45 _label_dim = 'label'
46 _Voigt_dim = 'Voigt' # For stress/strain tensors
48 # Default field names. If it is a list, check for any of these names upon
49 # opening. Upon writing, use the first name.
50 _spatial_var = 'spatial'
51 _cell_spatial_var = 'cell_spatial'
52 _cell_angular_var = 'cell_angular'
53 _time_var = 'time'
54 _numbers_var = ['atom_types', 'type', 'Z']
55 _positions_var = 'coordinates'
56 _velocities_var = 'velocities'
57 _cell_origin_var = 'cell_origin'
58 _cell_lengths_var = 'cell_lengths'
59 _cell_angles_var = 'cell_angles'
61 _default_vars = reduce(lambda x, y: x + y,
62 [_numbers_var, [_positions_var], [_velocities_var],
63 [_cell_origin_var], [_cell_lengths_var],
64 [_cell_angles_var]])
66 def __init__(self, filename, mode='r', atoms=None, types_to_numbers=None,
67 double=True, netcdf_format='NETCDF3_CLASSIC', keep_open=True,
68 index_var='id', chunk_size=1000000):
69 """
70 A NetCDFTrajectory can be created in read, write or append mode.
72 Parameters:
74 filename:
75 The name of the parameter file. Should end in .nc.
77 mode='r':
78 The mode.
80 'r' is read mode, the file should already exist, and no atoms
81 argument should be specified.
83 'w' is write mode. The atoms argument specifies the Atoms object
84 to be written to the file, if not given it must instead be given
85 as an argument to the write() method.
87 'a' is append mode. It acts a write mode, except that data is
88 appended to a preexisting file.
90 atoms=None:
91 The Atoms object to be written in write or append mode.
93 types_to_numbers=None:
94 Dictionary or list for conversion of atom types to atomic numbers
95 when reading a trajectory file.
97 double=True:
98 Create new variable in double precision.
100 netcdf_format='NETCDF3_CLASSIC':
101 Format string for the underlying NetCDF file format. Only relevant
102 if a new file is created. More information can be found at
103 https://www.unidata.ucar.edu/software/netcdf/docs/netcdf/File-Format.html
105 'NETCDF3_CLASSIC' is the original binary format.
107 'NETCDF3_64BIT' can be used to write larger files.
109 'NETCDF4_CLASSIC' is HDF5 with some NetCDF limitations.
111 'NETCDF4' is HDF5.
113 keep_open=True:
114 Keep the file open during consecutive read/write operations.
115 Set to false if you experience data corruption. This will close the
116 file after each read/write operation by comes with serious
117 performance penalty.
119 index_var='id':
120 Name of variable containing the atom indices. Atoms are reordered
121 by this index upon reading if this variable is present. Default
122 value is for LAMMPS output. None switches atom indices off.
124 chunk_size=1000000:
125 Maximum size of consecutive number of records (along the 'atom')
126 dimension read when reading from a NetCDF file. This is used to
127 reduce the memory footprint of a read operation on very large files.
128 """
129 self.nc = None
130 self.chunk_size = chunk_size
132 self.numbers = None
133 self.pre_observers = [] # Callback functions before write
134 self.post_observers = [] # Callback functions after write are called
136 self.has_header = False
137 self._set_atoms(atoms)
139 self.types_to_numbers = None
140 if isinstance(types_to_numbers, list):
141 types_to_numbers = {x: y for x, y in enumerate(types_to_numbers)}
142 if types_to_numbers is not None:
143 self.types_to_numbers = types_to_numbers
145 self.index_var = index_var
147 if self.index_var is not None:
148 self._default_vars += [self.index_var]
150 # 'l' should be a valid type according to the netcdf4-python
151 # documentation, but does not appear to work.
152 self.dtype_conv = {'l': 'i'}
153 if not double:
154 self.dtype_conv.update(dict(d='f'))
156 self.extra_per_frame_vars = []
157 self.extra_per_file_vars = []
158 # per frame atts are global quantities, not quantities stored for each
159 # atom
160 self.extra_per_frame_atts = []
162 self.mode = mode
163 self.netcdf_format = netcdf_format
165 if atoms:
166 self.n_atoms = len(atoms)
167 else:
168 self.n_atoms = None
170 self.filename = filename
171 if keep_open is None:
172 # Only netCDF4-python supports append to files
173 self.keep_open = self.mode == 'r'
174 else:
175 self.keep_open = keep_open
177 def __del__(self):
178 self.close()
180 def _open(self):
181 """
182 Opens the file.
184 For internal use only.
185 """
186 import netCDF4
187 if self.nc is not None:
188 return
189 if self.mode == 'a' and not os.path.exists(self.filename):
190 self.mode = 'w'
191 self.nc = netCDF4.Dataset(self.filename, self.mode,
192 format=self.netcdf_format)
194 self.frame = 0
195 if self.mode == 'r' or self.mode == 'a':
196 self._read_header()
197 self.frame = self._len()
199 def _set_atoms(self, atoms=None):
200 """
201 Associate an Atoms object with the trajectory.
203 For internal use only.
204 """
205 if atoms is not None and not hasattr(atoms, 'get_positions'):
206 raise TypeError('"atoms" argument is not an Atoms object.')
207 self.atoms = atoms
209 def _read_header(self):
210 if not self.n_atoms:
211 self.n_atoms = len(self.nc.dimensions[self._atom_dim])
213 for name, var in self.nc.variables.items():
214 # This can be unicode which confuses ASE
215 name = str(name)
216 # _default_vars is taken care of already
217 if name not in self._default_vars:
218 if len(var.dimensions) >= 2:
219 if var.dimensions[0] == self._frame_dim:
220 if var.dimensions[1] == self._atom_dim:
221 self.extra_per_frame_vars += [name]
222 else:
223 self.extra_per_frame_atts += [name]
225 elif len(var.dimensions) == 1:
226 if var.dimensions[0] == self._atom_dim:
227 self.extra_per_file_vars += [name]
228 elif var.dimensions[0] == self._frame_dim:
229 self.extra_per_frame_atts += [name]
231 self.has_header = True
233 def write(self, atoms=None, frame=None, arrays=None, time=None):
234 """
235 Write the atoms to the file.
237 If the atoms argument is not given, the atoms object specified
238 when creating the trajectory object is used.
239 """
240 self._open()
241 self._call_observers(self.pre_observers)
242 if atoms is None:
243 atoms = self.atoms
245 if hasattr(atoms, 'interpolate'):
246 # seems to be a NEB
247 neb = atoms
248 assert not neb.parallel
249 try:
250 neb.get_energies_and_forces(all=True)
251 except AttributeError:
252 pass
253 for image in neb.images:
254 self.write(image)
255 return
257 if not self.has_header:
258 self._define_file_structure(atoms)
259 else:
260 if len(atoms) != self.n_atoms:
261 raise ValueError('Bad number of atoms!')
263 if frame is None:
264 i = self.frame
265 else:
266 i = frame
268 # Number can be per file variable
269 numbers = self._get_variable(self._numbers_var)
270 if numbers.dimensions[0] == self._frame_dim:
271 numbers[i] = atoms.get_atomic_numbers()
272 else:
273 if np.any(numbers != atoms.get_atomic_numbers()):
274 raise ValueError('Atomic numbers do not match!')
275 self._get_variable(self._positions_var)[i] = atoms.get_positions()
276 if atoms.has('momenta'):
277 self._add_velocities()
278 self._get_variable(self._velocities_var)[i] = \
279 atoms.get_momenta() / atoms.get_masses().reshape(-1, 1)
280 a, b, c, alpha, beta, gamma = atoms.cell.cellpar()
281 if np.any(np.logical_not(atoms.pbc)):
282 warnings.warn('Atoms have nonperiodic directions. Cell lengths in '
283 'these directions are lost and will be '
284 'shrink-wrapped when reading the NetCDF file.')
285 cell_lengths = np.array([a, b, c]) * atoms.pbc
286 self._get_variable(self._cell_lengths_var)[i] = cell_lengths
287 self._get_variable(self._cell_angles_var)[i] = [alpha, beta, gamma]
288 self._get_variable(self._cell_origin_var)[i] = \
289 atoms.get_celldisp().reshape(3)
290 if arrays is not None:
291 for array in arrays:
292 data = atoms.get_array(array)
293 if array in self.extra_per_file_vars:
294 # This field exists but is per file data. Check that the
295 # data remains consistent.
296 if np.any(self._get_variable(array) != data):
297 raise ValueError('Trying to write Atoms object with '
298 'incompatible data for the {} '
299 'array.'.format(array))
300 else:
301 self._add_array(atoms, array, data.dtype, data.shape)
302 self._get_variable(array)[i] = data
303 if time is not None:
304 self._add_time()
305 self._get_variable(self._time_var)[i] = time
307 self.sync()
309 self._call_observers(self.post_observers)
310 self.frame += 1
311 self._close()
313 def write_arrays(self, atoms, frame, arrays):
314 self._open()
315 self._call_observers(self.pre_observers)
316 for array in arrays:
317 data = atoms.get_array(array)
318 if array in self.extra_per_file_vars:
319 # This field exists but is per file data. Check that the
320 # data remains consistent.
321 if np.any(self._get_variable(array) != data):
322 raise ValueError('Trying to write Atoms object with '
323 'incompatible data for the {} '
324 'array.'.format(array))
325 else:
326 self._add_array(atoms, array, data.dtype, data.shape)
327 self._get_variable(array)[frame] = data
328 self._call_observers(self.post_observers)
329 self._close()
331 def _define_file_structure(self, atoms):
332 self.nc.Conventions = 'AMBER'
333 self.nc.ConventionVersion = '1.0'
334 self.nc.program = 'ASE'
335 self.nc.programVersion = ase.__version__
336 self.nc.title = "MOL"
338 if self._frame_dim not in self.nc.dimensions:
339 self.nc.createDimension(self._frame_dim, None)
340 if self._spatial_dim not in self.nc.dimensions:
341 self.nc.createDimension(self._spatial_dim, 3)
342 if self._atom_dim not in self.nc.dimensions:
343 self.nc.createDimension(self._atom_dim, len(atoms))
344 if self._cell_spatial_dim not in self.nc.dimensions:
345 self.nc.createDimension(self._cell_spatial_dim, 3)
346 if self._cell_angular_dim not in self.nc.dimensions:
347 self.nc.createDimension(self._cell_angular_dim, 3)
348 if self._label_dim not in self.nc.dimensions:
349 self.nc.createDimension(self._label_dim, 5)
351 # Self-describing variables from AMBER convention
352 if not self._has_variable(self._spatial_var):
353 self.nc.createVariable(self._spatial_var, 'S1',
354 (self._spatial_dim,))
355 self.nc.variables[self._spatial_var][:] = ['x', 'y', 'z']
356 if not self._has_variable(self._cell_spatial_var):
357 self.nc.createVariable(self._cell_spatial_dim, 'S1',
358 (self._cell_spatial_dim,))
359 self.nc.variables[self._cell_spatial_var][:] = ['a', 'b', 'c']
360 if not self._has_variable(self._cell_angular_var):
361 self.nc.createVariable(self._cell_angular_var, 'S1',
362 (self._cell_angular_dim, self._label_dim,))
363 self.nc.variables[self._cell_angular_var][0] = [x for x in 'alpha']
364 self.nc.variables[self._cell_angular_var][1] = [x for x in 'beta ']
365 self.nc.variables[self._cell_angular_var][2] = [x for x in 'gamma']
367 if not self._has_variable(self._numbers_var):
368 self.nc.createVariable(self._numbers_var[0], 'i',
369 (self._frame_dim, self._atom_dim,))
370 if not self._has_variable(self._positions_var):
371 self.nc.createVariable(self._positions_var, 'f4',
372 (self._frame_dim, self._atom_dim,
373 self._spatial_dim))
374 self.nc.variables[self._positions_var].units = 'Angstrom'
375 self.nc.variables[self._positions_var].scale_factor = 1.
376 if not self._has_variable(self._cell_lengths_var):
377 self.nc.createVariable(self._cell_lengths_var, 'd',
378 (self._frame_dim, self._cell_spatial_dim))
379 self.nc.variables[self._cell_lengths_var].units = 'Angstrom'
380 self.nc.variables[self._cell_lengths_var].scale_factor = 1.
381 if not self._has_variable(self._cell_angles_var):
382 self.nc.createVariable(self._cell_angles_var, 'd',
383 (self._frame_dim, self._cell_angular_dim))
384 self.nc.variables[self._cell_angles_var].units = 'degree'
385 if not self._has_variable(self._cell_origin_var):
386 self.nc.createVariable(self._cell_origin_var, 'd',
387 (self._frame_dim, self._cell_spatial_dim))
388 self.nc.variables[self._cell_origin_var].units = 'Angstrom'
389 self.nc.variables[self._cell_origin_var].scale_factor = 1.
391 def _add_time(self):
392 if not self._has_variable(self._time_var):
393 self.nc.createVariable(self._time_var, 'f8', (self._frame_dim,))
395 def _add_velocities(self):
396 if not self._has_variable(self._velocities_var):
397 self.nc.createVariable(self._velocities_var, 'f4',
398 (self._frame_dim, self._atom_dim,
399 self._spatial_dim))
400 self.nc.variables[self._positions_var].units = \
401 'Angstrom/Femtosecond'
402 self.nc.variables[self._positions_var].scale_factor = 1.
404 def _add_array(self, atoms, array_name, type, shape):
405 if not self._has_variable(array_name):
406 dims = [self._frame_dim]
407 for i in shape:
408 if i == len(atoms):
409 dims += [self._atom_dim]
410 elif i == 3:
411 dims += [self._spatial_dim]
412 elif i == 6:
413 # This can only be stress/strain tensor in Voigt notation
414 if self._Voigt_dim not in self.nc.dimensions:
415 self.nc.createDimension(self._Voigt_dim, 6)
416 dims += [self._Voigt_dim]
417 else:
418 raise TypeError("Don't know how to dump array of shape {}"
419 " into NetCDF trajectory.".format(shape))
420 if hasattr(type, 'char'):
421 t = self.dtype_conv.get(type.char, type)
422 else:
423 t = type
424 self.nc.createVariable(array_name, t, dims)
426 def _get_variable(self, name, exc=True):
427 if isinstance(name, list):
428 for n in name:
429 if n in self.nc.variables:
430 return self.nc.variables[n]
431 if exc:
432 raise RuntimeError(
433 'None of the variables {} was found in the '
434 'NetCDF trajectory.'.format(', '.join(name)))
435 else:
436 if name in self.nc.variables:
437 return self.nc.variables[name]
438 if exc:
439 raise RuntimeError('Variables {} was found in the NetCDF '
440 'trajectory.'.format(name))
441 return None
443 def _has_variable(self, name):
444 if isinstance(name, list):
445 for n in name:
446 if n in self.nc.variables:
447 return True
448 return False
449 else:
450 return name in self.nc.variables
452 def _get_data(self, name, frame, index, exc=True):
453 var = self._get_variable(name, exc=exc)
454 if var is None:
455 return None
456 if var.dimensions[0] == self._frame_dim:
457 data = np.zeros(var.shape[1:], dtype=var.dtype)
458 s = var.shape[1]
459 if s < self.chunk_size:
460 data[index] = var[frame]
461 else:
462 # If this is a large data set, only read chunks from it to
463 # reduce memory footprint of the NetCDFTrajectory reader.
464 for i in range((s - 1) // self.chunk_size + 1):
465 sl = slice(i * self.chunk_size,
466 min((i + 1) * self.chunk_size, s))
467 data[index[sl]] = var[frame, sl]
468 else:
469 data = np.zeros(var.shape, dtype=var.dtype)
470 s = var.shape[0]
471 if s < self.chunk_size:
472 data[index] = var[...]
473 else:
474 # If this is a large data set, only read chunks from it to
475 # reduce memory footprint of the NetCDFTrajectory reader.
476 for i in range((s - 1) // self.chunk_size + 1):
477 sl = slice(i * self.chunk_size,
478 min((i + 1) * self.chunk_size, s))
479 data[index[sl]] = var[sl]
480 return data
482 def __enter__(self):
483 return self
485 def __exit__(self, *args):
486 self.close()
488 def close(self):
489 """Close the trajectory file."""
490 if self.nc is not None:
491 self.nc.close()
492 self.nc = None
494 def _close(self):
495 if not self.keep_open:
496 self.close()
497 if self.mode == 'w':
498 self.mode = 'a'
500 def sync(self):
501 self.nc.sync()
503 def __getitem__(self, i=-1):
504 self._open()
506 if isinstance(i, slice):
507 return [self[j] for j in range(*i.indices(self._len()))]
509 N = self._len()
510 if 0 <= i < N:
511 # Non-periodic boundaries have cell_length == 0.0
512 cell_lengths = \
513 np.array(self.nc.variables[self._cell_lengths_var][i][:])
514 pbc = np.abs(cell_lengths > 1e-6)
516 # Do we have a cell origin?
517 if self._has_variable(self._cell_origin_var):
518 origin = np.array(
519 self.nc.variables[self._cell_origin_var][i][:])
520 else:
521 origin = np.zeros([3], dtype=float)
523 # Do we have an index variable?
524 if (self.index_var is not None and
525 self._has_variable(self.index_var)):
526 index = np.array(self.nc.variables[self.index_var][i][:])
527 # The index variable can be non-consecutive, we here construct
528 # a consecutive one.
529 consecutive_index = np.zeros_like(index)
530 consecutive_index[np.argsort(index)] = np.arange(self.n_atoms)
531 else:
532 consecutive_index = np.arange(self.n_atoms)
534 # Read element numbers
535 self.numbers = self._get_data(self._numbers_var, i,
536 consecutive_index, exc=False)
537 if self.numbers is None:
538 self.numbers = np.ones(self.n_atoms, dtype=int)
539 if self.types_to_numbers is not None:
540 d = set(self.numbers).difference(self.types_to_numbers.keys())
541 if len(d) > 0:
542 self.types_to_numbers.update({num: num for num in d})
543 func = np.vectorize(self.types_to_numbers.get)
544 self.numbers = func(self.numbers)
545 self.masses = atomic_masses[self.numbers]
547 # Read positions
548 positions = self._get_data(self._positions_var, i,
549 consecutive_index)
551 # Determine cell size for non-periodic directions from shrink
552 # wrapped cell.
553 for dim in np.arange(3)[np.logical_not(pbc)]:
554 origin[dim] = positions[:, dim].min()
555 cell_lengths[dim] = positions[:, dim].max() - origin[dim]
557 # Construct cell shape from cell lengths and angles
558 cell = cellpar_to_cell(
559 list(cell_lengths) +
560 list(self.nc.variables[self._cell_angles_var][i])
561 )
563 # Compute momenta from velocities (if present)
564 momenta = self._get_data(self._velocities_var, i,
565 consecutive_index, exc=False)
566 if momenta is not None:
567 momenta *= self.masses.reshape(-1, 1)
569 info = {
570 name: np.array(self.nc.variables[name][i])
571 for name in self.extra_per_frame_atts
572 }
573 # Create atoms object
574 atoms = ase.Atoms(
575 positions=positions,
576 numbers=self.numbers,
577 cell=cell,
578 celldisp=origin,
579 momenta=momenta,
580 masses=self.masses,
581 pbc=pbc,
582 info=info
583 )
585 # Attach additional arrays found in the NetCDF file
586 for name in self.extra_per_frame_vars:
587 atoms.set_array(name, self._get_data(name, i,
588 consecutive_index))
589 for name in self.extra_per_file_vars:
590 atoms.set_array(name, self._get_data(name, i,
591 consecutive_index))
592 self._close()
593 return atoms
595 i = N + i
596 if i < 0 or i >= N:
597 self._close()
598 raise IndexError('Trajectory index out of range.')
599 return self[i]
601 def _len(self):
602 if self._frame_dim in self.nc.dimensions:
603 return int(self._get_variable(self._positions_var).shape[0])
604 else:
605 return 0
607 def __len__(self):
608 self._open()
609 n_frames = self._len()
610 self._close()
611 return n_frames
613 def pre_write_attach(self, function, interval=1, *args, **kwargs):
614 """
615 Attach a function to be called before writing begins.
617 function: The function or callable object to be called.
619 interval: How often the function is called. Default: every time (1).
621 All other arguments are stored, and passed to the function.
622 """
623 if not isinstance(function, collections.abc.Callable):
624 raise ValueError('Callback object must be callable.')
625 self.pre_observers.append((function, interval, args, kwargs))
627 def post_write_attach(self, function, interval=1, *args, **kwargs):
628 """
629 Attach a function to be called after writing ends.
631 function: The function or callable object to be called.
633 interval: How often the function is called. Default: every time (1).
635 All other arguments are stored, and passed to the function.
636 """
637 if not isinstance(function, collections.abc.Callable):
638 raise ValueError('Callback object must be callable.')
639 self.post_observers.append((function, interval, args, kwargs))
641 def _call_observers(self, obs):
642 """Call pre/post write observers."""
643 for function, interval, args, kwargs in obs:
644 if self.write_counter % interval == 0:
645 function(*args, **kwargs)
648def read_netcdftrajectory(filename, index=-1):
649 with NetCDFTrajectory(filename, mode='r') as traj:
650 return traj[index]
653def write_netcdftrajectory(filename, images):
654 if hasattr(images, 'get_positions'):
655 images = [images]
657 with NetCDFTrajectory(filename, mode='w') as traj:
658 for atoms in images:
659 traj.write(atoms)