Coverage for ase / spacegroup / utils.py: 90.00%

60 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3 

4import numpy as np 

5 

6from ase import Atoms 

7 

8from .spacegroup import _SPACEGROUP, Spacegroup 

9 

10__all__ = ('get_basis', ) 

11 

12 

13def _has_spglib() -> bool: 

14 """Check if spglib is available""" 

15 try: 

16 import spglib 

17 assert spglib # silence flakes 

18 except ImportError: 

19 return False 

20 return True 

21 

22 

23def _get_basis_ase(atoms: Atoms, 

24 spacegroup: _SPACEGROUP, 

25 tol: float = 1e-5) -> np.ndarray: 

26 """Recursively get a reduced basis, by removing equivalent sites. 

27 Uses the first index as a basis, then removes all equivalent sites, 

28 uses the next index which hasn't been placed into a basis, etc. 

29 

30 :param atoms: Atoms object to get basis from. 

31 :param spacegroup: ``int``, ``str``, or 

32 :class:`ase.spacegroup.Spacegroup` object. 

33 :param tol: ``float``, numeric tolerance for positional comparisons 

34 Default: ``1e-5`` 

35 """ 

36 scaled_positions = atoms.get_scaled_positions() 

37 spacegroup = Spacegroup(spacegroup) 

38 

39 def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray): 

40 """Check if a scaled position is in a site""" 

41 for site in sites: 

42 if np.allclose(site, scaled_pos, atol=tol): 

43 return True 

44 return False 

45 

46 def _get_basis(scaled_positions: np.ndarray, 

47 spacegroup: Spacegroup, 

48 all_basis=None) -> np.ndarray: 

49 """Main recursive function to be executed""" 

50 if all_basis is None: 

51 # Initialization, first iteration 

52 all_basis = [] 

53 if len(scaled_positions) == 0: 

54 # End termination 

55 return np.array(all_basis) 

56 

57 basis = scaled_positions[0] 

58 all_basis.append(basis.tolist()) # Add the site as a basis 

59 

60 # Get equivalent sites 

61 sites, _ = spacegroup.equivalent_sites(basis) 

62 

63 # Remove equivalent 

64 new_scaled = np.array( 

65 [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)]) 

66 # We should always have at least popped off the site itself 

67 assert len(new_scaled) < len(scaled_positions) 

68 

69 return _get_basis(new_scaled, spacegroup, all_basis=all_basis) 

70 

71 return _get_basis(scaled_positions, spacegroup) 

72 

73 

74def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray: 

75 """Get a reduced basis using spglib. This requires having the 

76 spglib package installed. 

77 

78 :param atoms: Atoms, atoms object to get basis from 

79 :param tol: ``float``, numeric tolerance for positional comparisons 

80 Default: ``1e-5`` 

81 """ 

82 if not _has_spglib(): 

83 # Give a reasonable alternative solution to this function. 

84 raise ImportError( 

85 'This function requires spglib. Use "get_basis" and specify ' 

86 'the spacegroup instead, or install spglib.') 

87 

88 scaled_positions = atoms.get_scaled_positions() 

89 reduced_indices = _get_reduced_indices(atoms, tol=tol) 

90 return scaled_positions[reduced_indices] 

91 

92 

93def _can_use_spglib(spacegroup: _SPACEGROUP | None = None) -> bool: 

94 """Helper dispatch function, for deciding if the spglib implementation 

95 can be used""" 

96 if not _has_spglib(): 

97 # Spglib not installed 

98 return False 

99 if spacegroup is not None: 

100 # Currently, passing an explicit space group is not supported 

101 # in spglib implementation 

102 return False 

103 return True 

104 

105 

106# Dispatcher function for chosing get_basis implementation. 

107def get_basis(atoms: Atoms, 

108 spacegroup: _SPACEGROUP | None = None, 

109 method: str = 'auto', 

110 tol: float = 1e-5) -> np.ndarray: 

111 """Function for determining a reduced basis of an atoms object. 

112 Can use either an ASE native algorithm or an spglib based one. 

113 The native ASE version requires specifying a space group, 

114 while the (current) spglib version cannot. 

115 The default behavior is to automatically determine which implementation 

116 to use, based on the the ``spacegroup`` parameter, 

117 and whether spglib is installed. 

118 

119 :param atoms: ase Atoms object to get basis from 

120 :param spacegroup: Optional, ``int``, ``str`` 

121 or :class:`ase.spacegroup.Spacegroup` object. 

122 If unspecified, the spacegroup can be inferred using spglib, 

123 if spglib is installed, and ``method`` is set to either 

124 ``'spglib'`` or ``'auto'``. 

125 Inferring the spacegroup requires spglib. 

126 :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``. 

127 Selection of which implementation to use. 

128 It is recommended to use ``'auto'``, which is also the default. 

129 :param tol: ``float``, numeric tolerance for positional comparisons 

130 Default: ``1e-5`` 

131 """ 

132 ALLOWED_METHODS = ('auto', 'ase', 'spglib') 

133 

134 if method not in ALLOWED_METHODS: 

135 raise ValueError('Expected one of {} methods, got {}'.format( 

136 ALLOWED_METHODS, method)) 

137 

138 if method == 'auto': 

139 # Figure out which implementation we want to use automatically 

140 # Essentially figure out if we can use the spglib version or not 

141 use_spglib = _can_use_spglib(spacegroup=spacegroup) 

142 else: 

143 # User told us which implementation they wanted 

144 use_spglib = method == 'spglib' 

145 

146 if use_spglib: 

147 # Use the spglib implementation 

148 # Note, we do not pass the spacegroup, as the function cannot handle 

149 # an explicit space group right now. This may change in the future. 

150 return _get_basis_spglib(atoms, tol=tol) 

151 else: 

152 # Use the ASE native non-spglib version, since a specific 

153 # space group is requested 

154 if spacegroup is None: 

155 # We have reached this point either because spglib is not installed, 

156 # or ASE was explicitly required 

157 raise ValueError( 

158 'A space group must be specified for the native ASE ' 

159 'implementation. Try using the spglib version instead, ' 

160 'or explicitly specifying a space group.') 

161 return _get_basis_ase(atoms, spacegroup, tol=tol) 

162 

163 

164def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> list[int]: 

165 """Get a list of the reduced atomic indices using spglib. 

166 Note: Does no checks to see if spglib is installed. 

167 

168 :param atoms: ase Atoms object to reduce 

169 :param tol: ``float``, numeric tolerance for positional comparisons 

170 """ 

171 from ase.spacegroup.symmetrize import spglib_get_symmetry_dataset 

172 

173 # Create input for spglib 

174 spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(), 

175 atoms.numbers) 

176 symmetry_data = spglib_get_symmetry_dataset(spglib_cell, 

177 symprec=tol) 

178 return list(set(symmetry_data.equivalent_atoms))