Coverage for ase / spacegroup / utils.py: 90.00%
60 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1# fmt: off
4import numpy as np
6from ase import Atoms
8from .spacegroup import _SPACEGROUP, Spacegroup
10__all__ = ('get_basis', )
13def _has_spglib() -> bool:
14 """Check if spglib is available"""
15 try:
16 import spglib
17 assert spglib # silence flakes
18 except ImportError:
19 return False
20 return True
23def _get_basis_ase(atoms: Atoms,
24 spacegroup: _SPACEGROUP,
25 tol: float = 1e-5) -> np.ndarray:
26 """Recursively get a reduced basis, by removing equivalent sites.
27 Uses the first index as a basis, then removes all equivalent sites,
28 uses the next index which hasn't been placed into a basis, etc.
30 :param atoms: Atoms object to get basis from.
31 :param spacegroup: ``int``, ``str``, or
32 :class:`ase.spacegroup.Spacegroup` object.
33 :param tol: ``float``, numeric tolerance for positional comparisons
34 Default: ``1e-5``
35 """
36 scaled_positions = atoms.get_scaled_positions()
37 spacegroup = Spacegroup(spacegroup)
39 def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray):
40 """Check if a scaled position is in a site"""
41 for site in sites:
42 if np.allclose(site, scaled_pos, atol=tol):
43 return True
44 return False
46 def _get_basis(scaled_positions: np.ndarray,
47 spacegroup: Spacegroup,
48 all_basis=None) -> np.ndarray:
49 """Main recursive function to be executed"""
50 if all_basis is None:
51 # Initialization, first iteration
52 all_basis = []
53 if len(scaled_positions) == 0:
54 # End termination
55 return np.array(all_basis)
57 basis = scaled_positions[0]
58 all_basis.append(basis.tolist()) # Add the site as a basis
60 # Get equivalent sites
61 sites, _ = spacegroup.equivalent_sites(basis)
63 # Remove equivalent
64 new_scaled = np.array(
65 [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)])
66 # We should always have at least popped off the site itself
67 assert len(new_scaled) < len(scaled_positions)
69 return _get_basis(new_scaled, spacegroup, all_basis=all_basis)
71 return _get_basis(scaled_positions, spacegroup)
74def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray:
75 """Get a reduced basis using spglib. This requires having the
76 spglib package installed.
78 :param atoms: Atoms, atoms object to get basis from
79 :param tol: ``float``, numeric tolerance for positional comparisons
80 Default: ``1e-5``
81 """
82 if not _has_spglib():
83 # Give a reasonable alternative solution to this function.
84 raise ImportError(
85 'This function requires spglib. Use "get_basis" and specify '
86 'the spacegroup instead, or install spglib.')
88 scaled_positions = atoms.get_scaled_positions()
89 reduced_indices = _get_reduced_indices(atoms, tol=tol)
90 return scaled_positions[reduced_indices]
93def _can_use_spglib(spacegroup: _SPACEGROUP | None = None) -> bool:
94 """Helper dispatch function, for deciding if the spglib implementation
95 can be used"""
96 if not _has_spglib():
97 # Spglib not installed
98 return False
99 if spacegroup is not None:
100 # Currently, passing an explicit space group is not supported
101 # in spglib implementation
102 return False
103 return True
106# Dispatcher function for chosing get_basis implementation.
107def get_basis(atoms: Atoms,
108 spacegroup: _SPACEGROUP | None = None,
109 method: str = 'auto',
110 tol: float = 1e-5) -> np.ndarray:
111 """Function for determining a reduced basis of an atoms object.
112 Can use either an ASE native algorithm or an spglib based one.
113 The native ASE version requires specifying a space group,
114 while the (current) spglib version cannot.
115 The default behavior is to automatically determine which implementation
116 to use, based on the the ``spacegroup`` parameter,
117 and whether spglib is installed.
119 :param atoms: ase Atoms object to get basis from
120 :param spacegroup: Optional, ``int``, ``str``
121 or :class:`ase.spacegroup.Spacegroup` object.
122 If unspecified, the spacegroup can be inferred using spglib,
123 if spglib is installed, and ``method`` is set to either
124 ``'spglib'`` or ``'auto'``.
125 Inferring the spacegroup requires spglib.
126 :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``.
127 Selection of which implementation to use.
128 It is recommended to use ``'auto'``, which is also the default.
129 :param tol: ``float``, numeric tolerance for positional comparisons
130 Default: ``1e-5``
131 """
132 ALLOWED_METHODS = ('auto', 'ase', 'spglib')
134 if method not in ALLOWED_METHODS:
135 raise ValueError('Expected one of {} methods, got {}'.format(
136 ALLOWED_METHODS, method))
138 if method == 'auto':
139 # Figure out which implementation we want to use automatically
140 # Essentially figure out if we can use the spglib version or not
141 use_spglib = _can_use_spglib(spacegroup=spacegroup)
142 else:
143 # User told us which implementation they wanted
144 use_spglib = method == 'spglib'
146 if use_spglib:
147 # Use the spglib implementation
148 # Note, we do not pass the spacegroup, as the function cannot handle
149 # an explicit space group right now. This may change in the future.
150 return _get_basis_spglib(atoms, tol=tol)
151 else:
152 # Use the ASE native non-spglib version, since a specific
153 # space group is requested
154 if spacegroup is None:
155 # We have reached this point either because spglib is not installed,
156 # or ASE was explicitly required
157 raise ValueError(
158 'A space group must be specified for the native ASE '
159 'implementation. Try using the spglib version instead, '
160 'or explicitly specifying a space group.')
161 return _get_basis_ase(atoms, spacegroup, tol=tol)
164def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> list[int]:
165 """Get a list of the reduced atomic indices using spglib.
166 Note: Does no checks to see if spglib is installed.
168 :param atoms: ase Atoms object to reduce
169 :param tol: ``float``, numeric tolerance for positional comparisons
170 """
171 from ase.spacegroup.symmetrize import spglib_get_symmetry_dataset
173 # Create input for spglib
174 spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(),
175 atoms.numbers)
176 symmetry_data = spglib_get_symmetry_dataset(spglib_cell,
177 symprec=tol)
178 return list(set(symmetry_data.equivalent_atoms))