Coverage for /builds/ase/ase/ase/io/zmatrix.py: 95.45%
132 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import re
4from collections import namedtuple
5from numbers import Real
6from string import digits
7from typing import Dict, List, Optional, Tuple, Union
9import numpy as np
11from ase import Atoms
12from ase.units import Angstrom, Bohr, nm
14# split on newlines or semicolons
15_re_linesplit = re.compile(r'\n|;')
16# split definitions on whitespace or on "=" (possibly also with whitespace)
17_re_defs = re.compile(r'\s*=\s*|\s+')
20_ZMatrixRow = namedtuple(
21 '_ZMatrixRow', 'ind1 dist ind2 a_bend ind3 a_dihedral',
22)
25ThreeFloats = Union[Tuple[float, float, float], np.ndarray]
28def require(condition):
29 # (This is not good error handling, but it replaces assertions.)
30 if not condition:
31 raise RuntimeError('Internal requirement violated')
34class _ZMatrixToAtoms:
35 known_units = dict(
36 distance={'angstrom': Angstrom, 'bohr': Bohr, 'au': Bohr, 'nm': nm},
37 angle={'radians': 1., 'degrees': np.pi / 180},
38 )
40 def __init__(self, dconv: Union[str, Real], aconv: Union[str, Real],
41 defs: Optional[Union[Dict[str, float],
42 str, List[str]]] = None) -> None:
43 self.dconv = self.get_units('distance', dconv) # type: float
44 self.aconv = self.get_units('angle', aconv) # type: float
45 self.set_defs(defs)
46 self.name_to_index: Optional[Dict[str, int]] = {}
47 self.symbols: List[str] = []
48 self.positions: List[ThreeFloats] = []
50 @property
51 def nrows(self):
52 return len(self.symbols)
54 def get_units(self, kind: str, value: Union[str, Real]) -> float:
55 if isinstance(value, Real):
56 return float(value)
57 out = self.known_units[kind].get(value.lower())
58 if out is None:
59 raise ValueError("Unknown {} units: {}"
60 .format(kind, value))
61 return out
63 def set_defs(self, defs: Union[Dict[str, float], str,
64 List[str], None]) -> None:
65 self.defs = {} # type: Dict[str, float]
66 if defs is None:
67 return
69 if isinstance(defs, dict):
70 self.defs.update(**defs)
71 return
73 if isinstance(defs, str):
74 defs = _re_linesplit.split(defs.strip())
76 for row in defs:
77 key, val = _re_defs.split(row)
78 self.defs[key] = self.get_var(val)
80 def get_var(self, val: str) -> float:
81 try:
82 return float(val)
83 except ValueError as e:
84 val_out = self.defs.get(val.lstrip('+-'))
85 if val_out is None:
86 raise ValueError('Invalid value encountered in Z-matrix: {}'
87 .format(val)) from e
88 return val_out * (-1 if val.startswith('-') else 1)
90 def get_index(self, name: str) -> int:
91 """Find index for a given atom name"""
92 try:
93 return int(name) - 1
94 except ValueError as e:
95 if self.name_to_index is None or name not in self.name_to_index:
96 raise ValueError('Failed to determine index for name "{}"'
97 .format(name)) from e
98 return self.name_to_index[name]
100 def set_index(self, name: str) -> None:
101 """Assign index to a given atom name for name -> index lookup"""
102 if self.name_to_index is None:
103 return
105 if name in self.name_to_index:
106 # "name" has been encountered before, so name_to_index is no
107 # longer meaningful. Destroy the map.
108 self.name_to_index = None
109 return
111 self.name_to_index[name] = self.nrows
113 # Use typehint *indices: str from python3.11+
114 def validate_indices(self, *indices) -> None:
115 """Raises an error if indices in a Z-matrix row are invalid."""
116 if any(np.array(indices) >= self.nrows):
117 raise ValueError('An invalid Z-matrix was provided! Row {} refers '
118 'to atom indices {}, at least one of which '
119 "hasn't been defined yet!"
120 .format(self.nrows, indices))
122 if len(indices) != len(set(indices)):
123 raise ValueError('An atom index has been used more than once a '
124 'row of the Z-matrix! Row numbers {}, '
125 'referred indices: {}'
126 .format(self.nrows, indices))
128 def parse_row(self, row: str) -> Tuple[
129 str, Union[_ZMatrixRow, ThreeFloats],
130 ]:
131 tokens = row.split()
132 name = tokens[0]
133 self.set_index(name)
134 if len(tokens) == 1:
135 require(self.nrows == 0)
136 return name, np.zeros(3, dtype=float)
138 ind1 = self.get_index(tokens[1])
139 if ind1 == -1:
140 require(len(tokens) == 5)
141 return name, np.array(list(map(self.get_var, tokens[2:])),
142 dtype=float)
144 dist = self.dconv * self.get_var(tokens[2])
146 if len(tokens) == 3:
147 require(self.nrows == 1)
148 self.validate_indices(ind1)
149 return name, np.array([dist, 0, 0], dtype=float)
151 ind2 = self.get_index(tokens[3])
152 a_bend = self.aconv * self.get_var(tokens[4])
154 if len(tokens) == 5:
155 require(self.nrows == 2)
156 self.validate_indices(ind1, ind2)
157 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, None, None)
159 ind3 = self.get_index(tokens[5])
160 a_dihedral = self.aconv * self.get_var(tokens[6])
161 self.validate_indices(ind1, ind2, ind3)
162 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, ind3,
163 a_dihedral)
165 def add_atom(self, name: str, pos: ThreeFloats) -> None:
166 """Sets the symbol and position of an atom."""
167 self.symbols.append(
168 ''.join([c for c in name if c not in digits]).capitalize()
169 )
170 self.positions.append(pos)
172 def add_row(self, row: str) -> None:
173 name, zrow = self.parse_row(row)
175 if not isinstance(zrow, _ZMatrixRow):
176 self.add_atom(name, zrow)
177 return
179 if zrow.ind3 is None:
180 # This is the third atom, so only a bond distance and an angle
181 # have been provided.
182 pos = self.positions[zrow.ind1].copy()
183 pos[0] += zrow.dist * np.cos(zrow.a_bend) * (zrow.ind2 - zrow.ind1)
184 pos[1] += zrow.dist * np.sin(zrow.a_bend)
185 self.add_atom(name, pos)
186 return
188 # ax1 is the dihedral axis, which is defined by the bond vector
189 # between the two inner atoms in the dihedral, ind1 and ind2
190 ax1 = self.positions[zrow.ind2] - self.positions[zrow.ind1]
191 ax1 /= np.linalg.norm(ax1)
193 # ax2 lies within the 1-2-3 plane, and it is perpendicular
194 # to the dihedral axis
195 ax2 = self.positions[zrow.ind2] - self.positions[zrow.ind3]
196 ax2 -= ax1 * (ax2 @ ax1)
197 ax2 /= np.linalg.norm(ax2)
199 # ax3 is a vector that forms the appropriate dihedral angle, though
200 # the bending angle is 90 degrees, rather than a_bend. It is formed
201 # from a linear combination of ax2 and (ax2 x ax1)
202 ax3 = (ax2 * np.cos(zrow.a_dihedral)
203 + np.cross(ax2, ax1) * np.sin(zrow.a_dihedral))
205 # The final position vector is a linear combination of ax1 and ax3.
206 pos = ax1 * np.cos(zrow.a_bend) - ax3 * np.sin(zrow.a_bend)
207 pos *= zrow.dist / np.linalg.norm(pos)
208 pos += self.positions[zrow.ind1]
209 self.add_atom(name, pos)
211 def to_atoms(self) -> Atoms:
212 return Atoms(self.symbols, self.positions)
215def parse_zmatrix(zmat: Union[str, List[str]],
216 distance_units: Union[str, Real] = 'angstrom',
217 angle_units: Union[str, Real] = 'degrees',
218 defs: Optional[Union[Dict[str, float], str,
219 List[str]]] = None) -> Atoms:
220 """Converts a Z-matrix into an Atoms object.
222 Parameters:
224 zmat: Iterable or str
225 The Z-matrix to be parsed. Iteration over `zmat` should yield the rows
226 of the Z-matrix. If `zmat` is a str, it will be automatically split
227 into a list at newlines.
228 distance_units: str or float, optional
229 The units of distance in the provided Z-matrix.
230 Defaults to Angstrom.
231 angle_units: str or float, optional
232 The units for angles in the provided Z-matrix.
233 Defaults to degrees.
234 defs: dict or str, optional
235 If `zmat` contains symbols for bond distances, bending angles, and/or
236 dihedral angles instead of numeric values, then the definition of
237 those symbols should be passed to this function using this keyword
238 argument.
239 Note: The symbol definitions are typically printed adjacent to the
240 Z-matrix itself, but this function will not automatically separate
241 the symbol definitions from the Z-matrix.
243 Returns:
245 atoms: Atoms object
246 """
247 zmatrix = _ZMatrixToAtoms(distance_units, angle_units, defs=defs)
249 # zmat should be a list containing the rows of the z-matrix.
250 # for convenience, allow block strings and split at newlines.
251 if isinstance(zmat, str):
252 zmat = _re_linesplit.split(zmat.strip())
254 for row in zmat:
255 zmatrix.add_row(row)
257 return zmatrix.to_atoms()