Coverage for /builds/ase/ase/ase/spacegroup/utils.py: 90.16%

61 statements  

« prev     ^ index     » next       coverage.py v7.5.3, created at 2025-08-02 00:12 +0000

1# fmt: off 

2 

3from typing import List 

4 

5import numpy as np 

6 

7from ase import Atoms 

8 

9from .spacegroup import _SPACEGROUP, Spacegroup 

10 

11__all__ = ('get_basis', ) 

12 

13 

14def _has_spglib() -> bool: 

15 """Check if spglib is available""" 

16 try: 

17 import spglib 

18 assert spglib # silence flakes 

19 except ImportError: 

20 return False 

21 return True 

22 

23 

24def _get_basis_ase(atoms: Atoms, 

25 spacegroup: _SPACEGROUP, 

26 tol: float = 1e-5) -> np.ndarray: 

27 """Recursively get a reduced basis, by removing equivalent sites. 

28 Uses the first index as a basis, then removes all equivalent sites, 

29 uses the next index which hasn't been placed into a basis, etc. 

30 

31 :param atoms: Atoms object to get basis from. 

32 :param spacegroup: ``int``, ``str``, or 

33 :class:`ase.spacegroup.Spacegroup` object. 

34 :param tol: ``float``, numeric tolerance for positional comparisons 

35 Default: ``1e-5`` 

36 """ 

37 scaled_positions = atoms.get_scaled_positions() 

38 spacegroup = Spacegroup(spacegroup) 

39 

40 def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray): 

41 """Check if a scaled position is in a site""" 

42 for site in sites: 

43 if np.allclose(site, scaled_pos, atol=tol): 

44 return True 

45 return False 

46 

47 def _get_basis(scaled_positions: np.ndarray, 

48 spacegroup: Spacegroup, 

49 all_basis=None) -> np.ndarray: 

50 """Main recursive function to be executed""" 

51 if all_basis is None: 

52 # Initialization, first iteration 

53 all_basis = [] 

54 if len(scaled_positions) == 0: 

55 # End termination 

56 return np.array(all_basis) 

57 

58 basis = scaled_positions[0] 

59 all_basis.append(basis.tolist()) # Add the site as a basis 

60 

61 # Get equivalent sites 

62 sites, _ = spacegroup.equivalent_sites(basis) 

63 

64 # Remove equivalent 

65 new_scaled = np.array( 

66 [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)]) 

67 # We should always have at least popped off the site itself 

68 assert len(new_scaled) < len(scaled_positions) 

69 

70 return _get_basis(new_scaled, spacegroup, all_basis=all_basis) 

71 

72 return _get_basis(scaled_positions, spacegroup) 

73 

74 

75def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray: 

76 """Get a reduced basis using spglib. This requires having the 

77 spglib package installed. 

78 

79 :param atoms: Atoms, atoms object to get basis from 

80 :param tol: ``float``, numeric tolerance for positional comparisons 

81 Default: ``1e-5`` 

82 """ 

83 if not _has_spglib(): 

84 # Give a reasonable alternative solution to this function. 

85 raise ImportError( 

86 'This function requires spglib. Use "get_basis" and specify ' 

87 'the spacegroup instead, or install spglib.') 

88 

89 scaled_positions = atoms.get_scaled_positions() 

90 reduced_indices = _get_reduced_indices(atoms, tol=tol) 

91 return scaled_positions[reduced_indices] 

92 

93 

94def _can_use_spglib(spacegroup: _SPACEGROUP = None) -> bool: 

95 """Helper dispatch function, for deciding if the spglib implementation 

96 can be used""" 

97 if not _has_spglib(): 

98 # Spglib not installed 

99 return False 

100 if spacegroup is not None: 

101 # Currently, passing an explicit space group is not supported 

102 # in spglib implementation 

103 return False 

104 return True 

105 

106 

107# Dispatcher function for chosing get_basis implementation. 

108def get_basis(atoms: Atoms, 

109 spacegroup: _SPACEGROUP = None, 

110 method: str = 'auto', 

111 tol: float = 1e-5) -> np.ndarray: 

112 """Function for determining a reduced basis of an atoms object. 

113 Can use either an ASE native algorithm or an spglib based one. 

114 The native ASE version requires specifying a space group, 

115 while the (current) spglib version cannot. 

116 The default behavior is to automatically determine which implementation 

117 to use, based on the the ``spacegroup`` parameter, 

118 and whether spglib is installed. 

119 

120 :param atoms: ase Atoms object to get basis from 

121 :param spacegroup: Optional, ``int``, ``str`` 

122 or :class:`ase.spacegroup.Spacegroup` object. 

123 If unspecified, the spacegroup can be inferred using spglib, 

124 if spglib is installed, and ``method`` is set to either 

125 ``'spglib'`` or ``'auto'``. 

126 Inferring the spacegroup requires spglib. 

127 :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``. 

128 Selection of which implementation to use. 

129 It is recommended to use ``'auto'``, which is also the default. 

130 :param tol: ``float``, numeric tolerance for positional comparisons 

131 Default: ``1e-5`` 

132 """ 

133 ALLOWED_METHODS = ('auto', 'ase', 'spglib') 

134 

135 if method not in ALLOWED_METHODS: 

136 raise ValueError('Expected one of {} methods, got {}'.format( 

137 ALLOWED_METHODS, method)) 

138 

139 if method == 'auto': 

140 # Figure out which implementation we want to use automatically 

141 # Essentially figure out if we can use the spglib version or not 

142 use_spglib = _can_use_spglib(spacegroup=spacegroup) 

143 else: 

144 # User told us which implementation they wanted 

145 use_spglib = method == 'spglib' 

146 

147 if use_spglib: 

148 # Use the spglib implementation 

149 # Note, we do not pass the spacegroup, as the function cannot handle 

150 # an explicit space group right now. This may change in the future. 

151 return _get_basis_spglib(atoms, tol=tol) 

152 else: 

153 # Use the ASE native non-spglib version, since a specific 

154 # space group is requested 

155 if spacegroup is None: 

156 # We have reached this point either because spglib is not installed, 

157 # or ASE was explicitly required 

158 raise ValueError( 

159 'A space group must be specified for the native ASE ' 

160 'implementation. Try using the spglib version instead, ' 

161 'or explicitly specifying a space group.') 

162 return _get_basis_ase(atoms, spacegroup, tol=tol) 

163 

164 

165def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> List[int]: 

166 """Get a list of the reduced atomic indices using spglib. 

167 Note: Does no checks to see if spglib is installed. 

168 

169 :param atoms: ase Atoms object to reduce 

170 :param tol: ``float``, numeric tolerance for positional comparisons 

171 """ 

172 from ase.spacegroup.symmetrize import spglib_get_symmetry_dataset 

173 

174 # Create input for spglib 

175 spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(), 

176 atoms.numbers) 

177 symmetry_data = spglib_get_symmetry_dataset(spglib_cell, 

178 symprec=tol) 

179 return list(set(symmetry_data.equivalent_atoms))