Coverage for ase / spacegroup / symmetrize.py: 92.06%

126 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-04 10:20 +0000

1# fmt: off 

2 

3""" 

4Provides utility functions for FixSymmetry class 

5""" 

6from collections.abc import MutableMapping 

7from typing import Optional 

8 

9import numpy as np 

10 

11from ase.utils import atoms_to_spglib_cell, spglib_new_errorhandling 

12 

13__all__ = ['refine_symmetry', 'check_symmetry'] 

14 

15 

16def spglib_get_symmetry_dataset(*args, **kwargs): 

17 """Temporary compatibility adapter around spglib dataset. 

18 

19 Return an object that allows attribute-based access 

20 in line with recent spglib. This allows ASE code to not care about 

21 older spglib versions. 

22 """ 

23 import spglib 

24 

25 dataset = spglib_new_errorhandling(spglib.get_symmetry_dataset)( 

26 *args, **kwargs 

27 ) 

28 if dataset is None: 

29 return None 

30 if isinstance(dataset, dict): # spglib < 2.5.0 

31 return SpglibDatasetWrapper(dataset) 

32 return dataset # spglib >= 2.5.0 

33 

34 

35class SpglibDatasetWrapper(MutableMapping): 

36 # Spglib 2.5.0 returns SpglibDataset with deprecated __getitem__. 

37 # Spglib 2.4.0 and earlier return dict. 

38 # 

39 # We use this object to wrap dictionaries such that both types of access 

40 # work correctly. 

41 def __init__(self, spglib_dct): 

42 self._spglib_dct = spglib_dct 

43 

44 def __getattr__(self, attr): 

45 return self[attr] 

46 

47 def __getitem__(self, key): 

48 return self._spglib_dct[key] 

49 

50 def __len__(self): 

51 return len(self._spglib_dct) 

52 

53 def __iter__(self): 

54 return iter(self._spglib_dct) 

55 

56 def __setitem__(self, key, value): 

57 self._spglib_dct[key] = value 

58 

59 def __delitem__(self, item): 

60 del self._spglib_dct[item] 

61 

62 

63def print_symmetry(symprec, dataset): 

64 print("ase.spacegroup.symmetrize: prec", symprec, 

65 "got symmetry group number", dataset.number, 

66 ", international (Hermann-Mauguin)", dataset.international, 

67 ", Hall ", dataset.hall) 

68 

69 

70def refine_symmetry(atoms, symprec=0.01, verbose=False): 

71 """ 

72 Refine symmetry of an Atoms object 

73 

74 Parameters 

75 ---------- 

76 atoms - input Atoms object 

77 symprec - symmetry precicion 

78 verbose - if True, print out symmetry information before and after 

79 

80 Returns 

81 ------- 

82 

83 spglib dataset 

84 

85 """ 

86 _check_and_symmetrize_cell(atoms, symprec=symprec, verbose=verbose) 

87 _check_and_symmetrize_positions(atoms, symprec=symprec, verbose=verbose) 

88 return check_symmetry(atoms, symprec=1e-4, verbose=verbose) 

89 

90 

91class IntermediateDatasetError(Exception): 

92 """The symmetry dataset in `_check_and_symmetrize_positions` can be at odds 

93 with the original symmetry dataset in `_check_and_symmetrize_cell`. 

94 This implies a faulty partial symmetrization if not handled by exception.""" 

95 

96 

97def get_symmetrized_atoms(atoms, 

98 symprec: float = 0.01, 

99 final_symprec: Optional[float] = None): 

100 """Get new Atoms object with refined symmetries. 

101 

102 Checks internal consistency of the found symmetries. 

103 

104 Parameters 

105 ---------- 

106 atoms : Atoms 

107 Input atoms object. 

108 symprec : float 

109 Symmetry precision used to identify symmetries with spglib. 

110 final_symprec : float 

111 Symmetry precision used for testing the symmetrization. 

112 

113 Returns 

114 ------- 

115 symatoms : Atoms 

116 New atoms object symmetrized according to the input symprec. 

117 """ 

118 atoms = atoms.copy() 

119 original_dataset = _check_and_symmetrize_cell(atoms, symprec=symprec) 

120 intermediate_dataset = _check_and_symmetrize_positions( 

121 atoms, symprec=symprec) 

122 if intermediate_dataset.number != original_dataset.number: 

123 raise IntermediateDatasetError() 

124 final_symprec = final_symprec or symprec 

125 final_dataset = check_symmetry(atoms, symprec=final_symprec) 

126 assert final_dataset.number == original_dataset.number 

127 return atoms, final_dataset 

128 

129 

130def _check_and_symmetrize_cell(atoms, **kwargs): 

131 dataset = check_symmetry(atoms, **kwargs) 

132 _symmetrize_cell(atoms, dataset) 

133 return dataset 

134 

135 

136def _symmetrize_cell(atoms, dataset): 

137 # set actual cell to symmetrized cell vectors by copying 

138 # transformed and rotated standard cell 

139 std_cell = dataset.std_lattice 

140 trans_std_cell = dataset.transformation_matrix.T @ std_cell 

141 rot_trans_std_cell = trans_std_cell @ dataset.std_rotation_matrix 

142 atoms.set_cell(rot_trans_std_cell, True) 

143 

144 

145def _check_and_symmetrize_positions(atoms, *, symprec, **kwargs): 

146 import spglib 

147 dataset = check_symmetry(atoms, symprec=symprec, **kwargs) 

148 # here we are assuming that primitive vectors returned by find_primitive 

149 # are compatible with std_lattice returned by get_symmetry_dataset 

150 

151 res = spglib_new_errorhandling(spglib.find_primitive)( 

152 atoms_to_spglib_cell(atoms), symprec=symprec) 

153 _symmetrize_positions(atoms, dataset, res) 

154 return dataset 

155 

156 

157def _symmetrize_positions(atoms, dataset, primitive_spglib_cell): 

158 prim_cell, _prim_scaled_pos, _prim_types = primitive_spglib_cell 

159 

160 # calculate offset between standard cell and actual cell 

161 std_cell = dataset.std_lattice 

162 rot_std_cell = std_cell @ dataset.std_rotation_matrix 

163 rot_std_pos = dataset.std_positions @ rot_std_cell 

164 pos = atoms.get_positions() 

165 dp0 = (pos[list(dataset.mapping_to_primitive).index(0)] - rot_std_pos[ 

166 list(dataset.std_mapping_to_primitive).index(0)]) 

167 

168 # create aligned set of standard cell positions to figure out mapping 

169 rot_prim_cell = prim_cell @ dataset.std_rotation_matrix 

170 inv_rot_prim_cell = np.linalg.inv(rot_prim_cell) 

171 aligned_std_pos = rot_std_pos + dp0 

172 

173 # find ideal positions from position of corresponding std cell atom + 

174 # integer_vec . primitive cell vectors 

175 mapping_to_primitive = list(dataset.mapping_to_primitive) 

176 std_mapping_to_primitive = list(dataset.std_mapping_to_primitive) 

177 pos = atoms.get_positions() 

178 for i_at in range(len(atoms)): 

179 std_i_at = std_mapping_to_primitive.index(mapping_to_primitive[i_at]) 

180 dp = aligned_std_pos[std_i_at] - pos[i_at] 

181 dp_s = dp @ inv_rot_prim_cell 

182 pos[i_at] = (aligned_std_pos[std_i_at] - np.round(dp_s) @ rot_prim_cell) 

183 atoms.set_positions(pos) 

184 

185 

186def check_symmetry(atoms, symprec=1.0e-6, verbose=False): 

187 """ 

188 Check symmetry of `atoms` with precision `symprec` using `spglib` 

189 

190 Prints a summary and returns result of `spglib.get_symmetry_dataset()` 

191 """ 

192 dataset = spglib_get_symmetry_dataset(atoms_to_spglib_cell(atoms), 

193 symprec=symprec) 

194 if verbose: 

195 print_symmetry(symprec, dataset) 

196 return dataset 

197 

198 

199def is_subgroup(sup_data, sub_data, tol=1e-10): 

200 """ 

201 Test if spglib dataset `sub_data` is a subgroup of dataset `sup_data` 

202 """ 

203 for rot1, trns1 in zip(sub_data.rotations, sub_data.translations): 

204 for rot2, trns2 in zip(sup_data.rotations, sup_data.translations): 

205 if np.all(rot1 == rot2) and np.linalg.norm(trns1 - trns2) < tol: 

206 break 

207 else: 

208 return False 

209 return True 

210 

211 

212def prep_symmetry(atoms, symprec=1.0e-6, verbose=False): 

213 """ 

214 Prepare `at` for symmetry-preserving minimisation at precision `symprec` 

215 

216 Returns a tuple `(rotations, translations, symm_map)` 

217 """ 

218 dataset = spglib_get_symmetry_dataset(atoms_to_spglib_cell(atoms), 

219 symprec=symprec) 

220 if verbose: 

221 print_symmetry(symprec, dataset) 

222 rotations = dataset.rotations.copy() 

223 translations = dataset.translations.copy() 

224 symm_map = [] 

225 scaled_pos = atoms.get_scaled_positions() 

226 for (rot, trans) in zip(rotations, translations): 

227 this_op_map = [-1] * len(atoms) 

228 for i_at in range(len(atoms)): 

229 new_p = rot @ scaled_pos[i_at, :] + trans 

230 dp = scaled_pos - new_p 

231 dp -= np.round(dp) 

232 i_at_map = np.argmin(np.linalg.norm(dp, axis=1)) 

233 this_op_map[i_at] = i_at_map 

234 symm_map.append(this_op_map) 

235 return (rotations, translations, symm_map) 

236 

237 

238def symmetrize_rank1(lattice, inv_lattice, forces, rot, trans, symm_map): 

239 """ 

240 Return symmetrized forces 

241 

242 lattice vectors expected as row vectors (same as ASE get_cell() convention), 

243 inv_lattice is its matrix inverse (reciprocal().T) 

244 """ 

245 scaled_symmetrized_forces_T = np.zeros(forces.T.shape) 

246 

247 scaled_forces_T = np.dot(inv_lattice.T, forces.T) 

248 for (r, t, this_op_map) in zip(rot, trans, symm_map): 

249 transformed_forces_T = np.dot(r, scaled_forces_T) 

250 scaled_symmetrized_forces_T[:, this_op_map] += transformed_forces_T 

251 scaled_symmetrized_forces_T /= len(rot) 

252 symmetrized_forces = (lattice.T @ scaled_symmetrized_forces_T).T 

253 

254 return symmetrized_forces 

255 

256 

257def symmetrize_rank2(lattice, lattice_inv, stress_3_3, rot): 

258 """ 

259 Return symmetrized stress 

260 

261 lattice vectors expected as row vectors (same as ASE get_cell() convention), 

262 inv_lattice is its matrix inverse (reciprocal().T) 

263 """ 

264 scaled_stress = np.dot(np.dot(lattice, stress_3_3), lattice.T) 

265 

266 symmetrized_scaled_stress = np.zeros((3, 3)) 

267 for r in rot: 

268 symmetrized_scaled_stress += np.dot(np.dot(r.T, scaled_stress), r) 

269 symmetrized_scaled_stress /= len(rot) 

270 

271 sym = np.dot(np.dot(lattice_inv, symmetrized_scaled_stress), lattice_inv.T) 

272 return sym