Coverage for ase / io / zmatrix.py: 95.42%
131 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
3import re
4from collections import namedtuple
5from numbers import Real
6from string import digits
8import numpy as np
10from ase import Atoms
11from ase.units import Angstrom, Bohr, nm
13# split on newlines or semicolons
14_re_linesplit = re.compile(r'\n|;')
15# split definitions on whitespace or on "=" (possibly also with whitespace)
16_re_defs = re.compile(r'\s*=\s*|\s+')
19_ZMatrixRow = namedtuple(
20 '_ZMatrixRow', 'ind1 dist ind2 a_bend ind3 a_dihedral',
21)
24ThreeFloats = tuple[float, float, float] | np.ndarray
27def require(condition):
28 # (This is not good error handling, but it replaces assertions.)
29 if not condition:
30 raise RuntimeError('Internal requirement violated')
33class _ZMatrixToAtoms:
34 known_units = dict(
35 distance={'angstrom': Angstrom, 'bohr': Bohr, 'au': Bohr, 'nm': nm},
36 angle={'radians': 1., 'degrees': np.pi / 180},
37 )
39 def __init__(
40 self,
41 dconv: str | Real,
42 aconv: str | Real,
43 defs: dict[str, float] | str | list[str] | None = None
44 ) -> None:
45 self.dconv = self.get_units('distance', dconv) # type: float
46 self.aconv = self.get_units('angle', aconv) # type: float
47 self.set_defs(defs)
48 self.name_to_index: dict[str, int] | None = {}
49 self.symbols: list[str] = []
50 self.positions: list[ThreeFloats] = []
52 @property
53 def nrows(self):
54 return len(self.symbols)
56 def get_units(self, kind: str, value: str | Real) -> float:
57 if isinstance(value, Real):
58 return float(value)
59 out = self.known_units[kind].get(value.lower())
60 if out is None:
61 raise ValueError("Unknown {} units: {}"
62 .format(kind, value))
63 return out
65 def set_defs(self, defs: dict[str, float] | str | list[str] | None) -> None:
66 self.defs = {} # type: dict[str, float]
67 if defs is None:
68 return
70 if isinstance(defs, dict):
71 self.defs.update(**defs)
72 return
74 if isinstance(defs, str):
75 defs = _re_linesplit.split(defs.strip())
77 for row in defs:
78 key, val = _re_defs.split(row)
79 self.defs[key] = self.get_var(val)
81 def get_var(self, val: str) -> float:
82 try:
83 return float(val)
84 except ValueError as e:
85 val_out = self.defs.get(val.lstrip('+-'))
86 if val_out is None:
87 raise ValueError('Invalid value encountered in Z-matrix: {}'
88 .format(val)) from e
89 return val_out * (-1 if val.startswith('-') else 1)
91 def get_index(self, name: str) -> int:
92 """Find index for a given atom name"""
93 try:
94 return int(name) - 1
95 except ValueError as e:
96 if self.name_to_index is None or name not in self.name_to_index:
97 raise ValueError('Failed to determine index for name "{}"'
98 .format(name)) from e
99 return self.name_to_index[name]
101 def set_index(self, name: str) -> None:
102 """Assign index to a given atom name for name -> index lookup"""
103 if self.name_to_index is None:
104 return
106 if name in self.name_to_index:
107 # "name" has been encountered before, so name_to_index is no
108 # longer meaningful. Destroy the map.
109 self.name_to_index = None
110 return
112 self.name_to_index[name] = self.nrows
114 # Use typehint *indices: str from python3.11+
115 def validate_indices(self, *indices) -> None:
116 """Raises an error if indices in a Z-matrix row are invalid."""
117 if any(np.array(indices) >= self.nrows):
118 raise ValueError('An invalid Z-matrix was provided! Row {} refers '
119 'to atom indices {}, at least one of which '
120 "hasn't been defined yet!"
121 .format(self.nrows, indices))
123 if len(indices) != len(set(indices)):
124 raise ValueError('An atom index has been used more than once a '
125 'row of the Z-matrix! Row numbers {}, '
126 'referred indices: {}'
127 .format(self.nrows, indices))
129 def parse_row(self, row: str) -> tuple[
130 str, _ZMatrixRow | ThreeFloats,
131 ]:
132 tokens = row.split()
133 name = tokens[0]
134 self.set_index(name)
135 if len(tokens) == 1:
136 require(self.nrows == 0)
137 return name, np.zeros(3, dtype=float)
139 ind1 = self.get_index(tokens[1])
140 if ind1 == -1:
141 require(len(tokens) == 5)
142 return name, np.array(list(map(self.get_var, tokens[2:])),
143 dtype=float)
145 dist = self.dconv * self.get_var(tokens[2])
147 if len(tokens) == 3:
148 require(self.nrows == 1)
149 self.validate_indices(ind1)
150 return name, np.array([dist, 0, 0], dtype=float)
152 ind2 = self.get_index(tokens[3])
153 a_bend = self.aconv * self.get_var(tokens[4])
155 if len(tokens) == 5:
156 require(self.nrows == 2)
157 self.validate_indices(ind1, ind2)
158 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, None, None)
160 ind3 = self.get_index(tokens[5])
161 a_dihedral = self.aconv * self.get_var(tokens[6])
162 self.validate_indices(ind1, ind2, ind3)
163 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, ind3,
164 a_dihedral)
166 def add_atom(self, name: str, pos: ThreeFloats) -> None:
167 """Sets the symbol and position of an atom."""
168 self.symbols.append(
169 ''.join([c for c in name if c not in digits]).capitalize()
170 )
171 self.positions.append(pos)
173 def add_row(self, row: str) -> None:
174 name, zrow = self.parse_row(row)
176 if not isinstance(zrow, _ZMatrixRow):
177 self.add_atom(name, zrow)
178 return
180 if zrow.ind3 is None:
181 # This is the third atom, so only a bond distance and an angle
182 # have been provided.
183 pos = self.positions[zrow.ind1].copy()
184 pos[0] += zrow.dist * np.cos(zrow.a_bend) * (zrow.ind2 - zrow.ind1)
185 pos[1] += zrow.dist * np.sin(zrow.a_bend)
186 self.add_atom(name, pos)
187 return
189 # ax1 is the dihedral axis, which is defined by the bond vector
190 # between the two inner atoms in the dihedral, ind1 and ind2
191 ax1 = self.positions[zrow.ind2] - self.positions[zrow.ind1]
192 ax1 /= np.linalg.norm(ax1)
194 # ax2 lies within the 1-2-3 plane, and it is perpendicular
195 # to the dihedral axis
196 ax2 = self.positions[zrow.ind2] - self.positions[zrow.ind3]
197 ax2 -= ax1 * (ax2 @ ax1)
198 ax2 /= np.linalg.norm(ax2)
200 # ax3 is a vector that forms the appropriate dihedral angle, though
201 # the bending angle is 90 degrees, rather than a_bend. It is formed
202 # from a linear combination of ax2 and (ax2 x ax1)
203 ax3 = (ax2 * np.cos(zrow.a_dihedral)
204 + np.cross(ax2, ax1) * np.sin(zrow.a_dihedral))
206 # The final position vector is a linear combination of ax1 and ax3.
207 pos = ax1 * np.cos(zrow.a_bend) - ax3 * np.sin(zrow.a_bend)
208 pos *= zrow.dist / np.linalg.norm(pos)
209 pos += self.positions[zrow.ind1]
210 self.add_atom(name, pos)
212 def to_atoms(self) -> Atoms:
213 return Atoms(self.symbols, self.positions)
216def parse_zmatrix(
217 zmat: str | list[str],
218 distance_units: str | Real = 'angstrom',
219 angle_units: str | Real = 'degrees',
220 defs: dict[str, float] | str | list[str] | None = None
221) -> Atoms:
222 """Converts a Z-matrix into an Atoms object.
224 Parameters
225 ----------
227 zmat: Iterable or str
228 The Z-matrix to be parsed. Iteration over `zmat` should yield the rows
229 of the Z-matrix. If `zmat` is a str, it will be automatically split
230 into a list at newlines.
231 distance_units: str or float, optional
232 The units of distance in the provided Z-matrix.
233 Defaults to Angstrom.
234 angle_units: str or float, optional
235 The units for angles in the provided Z-matrix.
236 Defaults to degrees.
237 defs: dict or str, optional
238 If `zmat` contains symbols for bond distances, bending angles, and/or
239 dihedral angles instead of numeric values, then the definition of
240 those symbols should be passed to this function using this keyword
241 argument.
242 Note: The symbol definitions are typically printed adjacent to the
243 Z-matrix itself, but this function will not automatically separate
244 the symbol definitions from the Z-matrix.
246 Returns
247 -------
249 atoms: Atoms object
250 """
251 zmatrix = _ZMatrixToAtoms(distance_units, angle_units, defs=defs)
253 # zmat should be a list containing the rows of the z-matrix.
254 # for convenience, allow block strings and split at newlines.
255 if isinstance(zmat, str):
256 zmat = _re_linesplit.split(zmat.strip())
258 for row in zmat:
259 zmatrix.add_row(row)
261 return zmatrix.to_atoms()