Coverage for ase / io / zmatrix.py: 95.42%

131 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3import re 

4from collections import namedtuple 

5from numbers import Real 

6from string import digits 

7 

8import numpy as np 

9 

10from ase import Atoms 

11from ase.units import Angstrom, Bohr, nm 

12 

13# split on newlines or semicolons 

14_re_linesplit = re.compile(r'\n|;') 

15# split definitions on whitespace or on "=" (possibly also with whitespace) 

16_re_defs = re.compile(r'\s*=\s*|\s+') 

17 

18 

19_ZMatrixRow = namedtuple( 

20 '_ZMatrixRow', 'ind1 dist ind2 a_bend ind3 a_dihedral', 

21) 

22 

23 

24ThreeFloats = tuple[float, float, float] | np.ndarray 

25 

26 

27def require(condition): 

28 # (This is not good error handling, but it replaces assertions.) 

29 if not condition: 

30 raise RuntimeError('Internal requirement violated') 

31 

32 

33class _ZMatrixToAtoms: 

34 known_units = dict( 

35 distance={'angstrom': Angstrom, 'bohr': Bohr, 'au': Bohr, 'nm': nm}, 

36 angle={'radians': 1., 'degrees': np.pi / 180}, 

37 ) 

38 

39 def __init__( 

40 self, 

41 dconv: str | Real, 

42 aconv: str | Real, 

43 defs: dict[str, float] | str | list[str] | None = None 

44 ) -> None: 

45 self.dconv = self.get_units('distance', dconv) # type: float 

46 self.aconv = self.get_units('angle', aconv) # type: float 

47 self.set_defs(defs) 

48 self.name_to_index: dict[str, int] | None = {} 

49 self.symbols: list[str] = [] 

50 self.positions: list[ThreeFloats] = [] 

51 

52 @property 

53 def nrows(self): 

54 return len(self.symbols) 

55 

56 def get_units(self, kind: str, value: str | Real) -> float: 

57 if isinstance(value, Real): 

58 return float(value) 

59 out = self.known_units[kind].get(value.lower()) 

60 if out is None: 

61 raise ValueError("Unknown {} units: {}" 

62 .format(kind, value)) 

63 return out 

64 

65 def set_defs(self, defs: dict[str, float] | str | list[str] | None) -> None: 

66 self.defs = {} # type: dict[str, float] 

67 if defs is None: 

68 return 

69 

70 if isinstance(defs, dict): 

71 self.defs.update(**defs) 

72 return 

73 

74 if isinstance(defs, str): 

75 defs = _re_linesplit.split(defs.strip()) 

76 

77 for row in defs: 

78 key, val = _re_defs.split(row) 

79 self.defs[key] = self.get_var(val) 

80 

81 def get_var(self, val: str) -> float: 

82 try: 

83 return float(val) 

84 except ValueError as e: 

85 val_out = self.defs.get(val.lstrip('+-')) 

86 if val_out is None: 

87 raise ValueError('Invalid value encountered in Z-matrix: {}' 

88 .format(val)) from e 

89 return val_out * (-1 if val.startswith('-') else 1) 

90 

91 def get_index(self, name: str) -> int: 

92 """Find index for a given atom name""" 

93 try: 

94 return int(name) - 1 

95 except ValueError as e: 

96 if self.name_to_index is None or name not in self.name_to_index: 

97 raise ValueError('Failed to determine index for name "{}"' 

98 .format(name)) from e 

99 return self.name_to_index[name] 

100 

101 def set_index(self, name: str) -> None: 

102 """Assign index to a given atom name for name -> index lookup""" 

103 if self.name_to_index is None: 

104 return 

105 

106 if name in self.name_to_index: 

107 # "name" has been encountered before, so name_to_index is no 

108 # longer meaningful. Destroy the map. 

109 self.name_to_index = None 

110 return 

111 

112 self.name_to_index[name] = self.nrows 

113 

114 # Use typehint *indices: str from python3.11+ 

115 def validate_indices(self, *indices) -> None: 

116 """Raises an error if indices in a Z-matrix row are invalid.""" 

117 if any(np.array(indices) >= self.nrows): 

118 raise ValueError('An invalid Z-matrix was provided! Row {} refers ' 

119 'to atom indices {}, at least one of which ' 

120 "hasn't been defined yet!" 

121 .format(self.nrows, indices)) 

122 

123 if len(indices) != len(set(indices)): 

124 raise ValueError('An atom index has been used more than once a ' 

125 'row of the Z-matrix! Row numbers {}, ' 

126 'referred indices: {}' 

127 .format(self.nrows, indices)) 

128 

129 def parse_row(self, row: str) -> tuple[ 

130 str, _ZMatrixRow | ThreeFloats, 

131 ]: 

132 tokens = row.split() 

133 name = tokens[0] 

134 self.set_index(name) 

135 if len(tokens) == 1: 

136 require(self.nrows == 0) 

137 return name, np.zeros(3, dtype=float) 

138 

139 ind1 = self.get_index(tokens[1]) 

140 if ind1 == -1: 

141 require(len(tokens) == 5) 

142 return name, np.array(list(map(self.get_var, tokens[2:])), 

143 dtype=float) 

144 

145 dist = self.dconv * self.get_var(tokens[2]) 

146 

147 if len(tokens) == 3: 

148 require(self.nrows == 1) 

149 self.validate_indices(ind1) 

150 return name, np.array([dist, 0, 0], dtype=float) 

151 

152 ind2 = self.get_index(tokens[3]) 

153 a_bend = self.aconv * self.get_var(tokens[4]) 

154 

155 if len(tokens) == 5: 

156 require(self.nrows == 2) 

157 self.validate_indices(ind1, ind2) 

158 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, None, None) 

159 

160 ind3 = self.get_index(tokens[5]) 

161 a_dihedral = self.aconv * self.get_var(tokens[6]) 

162 self.validate_indices(ind1, ind2, ind3) 

163 return name, _ZMatrixRow(ind1, dist, ind2, a_bend, ind3, 

164 a_dihedral) 

165 

166 def add_atom(self, name: str, pos: ThreeFloats) -> None: 

167 """Sets the symbol and position of an atom.""" 

168 self.symbols.append( 

169 ''.join([c for c in name if c not in digits]).capitalize() 

170 ) 

171 self.positions.append(pos) 

172 

173 def add_row(self, row: str) -> None: 

174 name, zrow = self.parse_row(row) 

175 

176 if not isinstance(zrow, _ZMatrixRow): 

177 self.add_atom(name, zrow) 

178 return 

179 

180 if zrow.ind3 is None: 

181 # This is the third atom, so only a bond distance and an angle 

182 # have been provided. 

183 pos = self.positions[zrow.ind1].copy() 

184 pos[0] += zrow.dist * np.cos(zrow.a_bend) * (zrow.ind2 - zrow.ind1) 

185 pos[1] += zrow.dist * np.sin(zrow.a_bend) 

186 self.add_atom(name, pos) 

187 return 

188 

189 # ax1 is the dihedral axis, which is defined by the bond vector 

190 # between the two inner atoms in the dihedral, ind1 and ind2 

191 ax1 = self.positions[zrow.ind2] - self.positions[zrow.ind1] 

192 ax1 /= np.linalg.norm(ax1) 

193 

194 # ax2 lies within the 1-2-3 plane, and it is perpendicular 

195 # to the dihedral axis 

196 ax2 = self.positions[zrow.ind2] - self.positions[zrow.ind3] 

197 ax2 -= ax1 * (ax2 @ ax1) 

198 ax2 /= np.linalg.norm(ax2) 

199 

200 # ax3 is a vector that forms the appropriate dihedral angle, though 

201 # the bending angle is 90 degrees, rather than a_bend. It is formed 

202 # from a linear combination of ax2 and (ax2 x ax1) 

203 ax3 = (ax2 * np.cos(zrow.a_dihedral) 

204 + np.cross(ax2, ax1) * np.sin(zrow.a_dihedral)) 

205 

206 # The final position vector is a linear combination of ax1 and ax3. 

207 pos = ax1 * np.cos(zrow.a_bend) - ax3 * np.sin(zrow.a_bend) 

208 pos *= zrow.dist / np.linalg.norm(pos) 

209 pos += self.positions[zrow.ind1] 

210 self.add_atom(name, pos) 

211 

212 def to_atoms(self) -> Atoms: 

213 return Atoms(self.symbols, self.positions) 

214 

215 

216def parse_zmatrix( 

217 zmat: str | list[str], 

218 distance_units: str | Real = 'angstrom', 

219 angle_units: str | Real = 'degrees', 

220 defs: dict[str, float] | str | list[str] | None = None 

221) -> Atoms: 

222 """Converts a Z-matrix into an Atoms object. 

223 

224 Parameters 

225 ---------- 

226 

227 zmat: Iterable or str 

228 The Z-matrix to be parsed. Iteration over `zmat` should yield the rows 

229 of the Z-matrix. If `zmat` is a str, it will be automatically split 

230 into a list at newlines. 

231 distance_units: str or float, optional 

232 The units of distance in the provided Z-matrix. 

233 Defaults to Angstrom. 

234 angle_units: str or float, optional 

235 The units for angles in the provided Z-matrix. 

236 Defaults to degrees. 

237 defs: dict or str, optional 

238 If `zmat` contains symbols for bond distances, bending angles, and/or 

239 dihedral angles instead of numeric values, then the definition of 

240 those symbols should be passed to this function using this keyword 

241 argument. 

242 Note: The symbol definitions are typically printed adjacent to the 

243 Z-matrix itself, but this function will not automatically separate 

244 the symbol definitions from the Z-matrix. 

245 

246 Returns 

247 ------- 

248 

249 atoms: Atoms object 

250 """ 

251 zmatrix = _ZMatrixToAtoms(distance_units, angle_units, defs=defs) 

252 

253 # zmat should be a list containing the rows of the z-matrix. 

254 # for convenience, allow block strings and split at newlines. 

255 if isinstance(zmat, str): 

256 zmat = _re_linesplit.split(zmat.strip()) 

257 

258 for row in zmat: 

259 zmatrix.add_row(row) 

260 

261 return zmatrix.to_atoms()