Coverage for /builds/ase/ase/ase/spacegroup/utils.py: 90.16%
61 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3from typing import List
5import numpy as np
7from ase import Atoms
9from .spacegroup import _SPACEGROUP, Spacegroup
11__all__ = ('get_basis', )
14def _has_spglib() -> bool:
15 """Check if spglib is available"""
16 try:
17 import spglib
18 assert spglib # silence flakes
19 except ImportError:
20 return False
21 return True
24def _get_basis_ase(atoms: Atoms,
25 spacegroup: _SPACEGROUP,
26 tol: float = 1e-5) -> np.ndarray:
27 """Recursively get a reduced basis, by removing equivalent sites.
28 Uses the first index as a basis, then removes all equivalent sites,
29 uses the next index which hasn't been placed into a basis, etc.
31 :param atoms: Atoms object to get basis from.
32 :param spacegroup: ``int``, ``str``, or
33 :class:`ase.spacegroup.Spacegroup` object.
34 :param tol: ``float``, numeric tolerance for positional comparisons
35 Default: ``1e-5``
36 """
37 scaled_positions = atoms.get_scaled_positions()
38 spacegroup = Spacegroup(spacegroup)
40 def scaled_in_sites(scaled_pos: np.ndarray, sites: np.ndarray):
41 """Check if a scaled position is in a site"""
42 for site in sites:
43 if np.allclose(site, scaled_pos, atol=tol):
44 return True
45 return False
47 def _get_basis(scaled_positions: np.ndarray,
48 spacegroup: Spacegroup,
49 all_basis=None) -> np.ndarray:
50 """Main recursive function to be executed"""
51 if all_basis is None:
52 # Initialization, first iteration
53 all_basis = []
54 if len(scaled_positions) == 0:
55 # End termination
56 return np.array(all_basis)
58 basis = scaled_positions[0]
59 all_basis.append(basis.tolist()) # Add the site as a basis
61 # Get equivalent sites
62 sites, _ = spacegroup.equivalent_sites(basis)
64 # Remove equivalent
65 new_scaled = np.array(
66 [sc for sc in scaled_positions if not scaled_in_sites(sc, sites)])
67 # We should always have at least popped off the site itself
68 assert len(new_scaled) < len(scaled_positions)
70 return _get_basis(new_scaled, spacegroup, all_basis=all_basis)
72 return _get_basis(scaled_positions, spacegroup)
75def _get_basis_spglib(atoms: Atoms, tol: float = 1e-5) -> np.ndarray:
76 """Get a reduced basis using spglib. This requires having the
77 spglib package installed.
79 :param atoms: Atoms, atoms object to get basis from
80 :param tol: ``float``, numeric tolerance for positional comparisons
81 Default: ``1e-5``
82 """
83 if not _has_spglib():
84 # Give a reasonable alternative solution to this function.
85 raise ImportError(
86 'This function requires spglib. Use "get_basis" and specify '
87 'the spacegroup instead, or install spglib.')
89 scaled_positions = atoms.get_scaled_positions()
90 reduced_indices = _get_reduced_indices(atoms, tol=tol)
91 return scaled_positions[reduced_indices]
94def _can_use_spglib(spacegroup: _SPACEGROUP = None) -> bool:
95 """Helper dispatch function, for deciding if the spglib implementation
96 can be used"""
97 if not _has_spglib():
98 # Spglib not installed
99 return False
100 if spacegroup is not None:
101 # Currently, passing an explicit space group is not supported
102 # in spglib implementation
103 return False
104 return True
107# Dispatcher function for chosing get_basis implementation.
108def get_basis(atoms: Atoms,
109 spacegroup: _SPACEGROUP = None,
110 method: str = 'auto',
111 tol: float = 1e-5) -> np.ndarray:
112 """Function for determining a reduced basis of an atoms object.
113 Can use either an ASE native algorithm or an spglib based one.
114 The native ASE version requires specifying a space group,
115 while the (current) spglib version cannot.
116 The default behavior is to automatically determine which implementation
117 to use, based on the the ``spacegroup`` parameter,
118 and whether spglib is installed.
120 :param atoms: ase Atoms object to get basis from
121 :param spacegroup: Optional, ``int``, ``str``
122 or :class:`ase.spacegroup.Spacegroup` object.
123 If unspecified, the spacegroup can be inferred using spglib,
124 if spglib is installed, and ``method`` is set to either
125 ``'spglib'`` or ``'auto'``.
126 Inferring the spacegroup requires spglib.
127 :param method: ``str``, one of: ``'auto'`` | ``'ase'`` | ``'spglib'``.
128 Selection of which implementation to use.
129 It is recommended to use ``'auto'``, which is also the default.
130 :param tol: ``float``, numeric tolerance for positional comparisons
131 Default: ``1e-5``
132 """
133 ALLOWED_METHODS = ('auto', 'ase', 'spglib')
135 if method not in ALLOWED_METHODS:
136 raise ValueError('Expected one of {} methods, got {}'.format(
137 ALLOWED_METHODS, method))
139 if method == 'auto':
140 # Figure out which implementation we want to use automatically
141 # Essentially figure out if we can use the spglib version or not
142 use_spglib = _can_use_spglib(spacegroup=spacegroup)
143 else:
144 # User told us which implementation they wanted
145 use_spglib = method == 'spglib'
147 if use_spglib:
148 # Use the spglib implementation
149 # Note, we do not pass the spacegroup, as the function cannot handle
150 # an explicit space group right now. This may change in the future.
151 return _get_basis_spglib(atoms, tol=tol)
152 else:
153 # Use the ASE native non-spglib version, since a specific
154 # space group is requested
155 if spacegroup is None:
156 # We have reached this point either because spglib is not installed,
157 # or ASE was explicitly required
158 raise ValueError(
159 'A space group must be specified for the native ASE '
160 'implementation. Try using the spglib version instead, '
161 'or explicitly specifying a space group.')
162 return _get_basis_ase(atoms, spacegroup, tol=tol)
165def _get_reduced_indices(atoms: Atoms, tol: float = 1e-5) -> List[int]:
166 """Get a list of the reduced atomic indices using spglib.
167 Note: Does no checks to see if spglib is installed.
169 :param atoms: ase Atoms object to reduce
170 :param tol: ``float``, numeric tolerance for positional comparisons
171 """
172 from ase.spacegroup.symmetrize import spglib_get_symmetry_dataset
174 # Create input for spglib
175 spglib_cell = (atoms.get_cell(), atoms.get_scaled_positions(),
176 atoms.numbers)
177 symmetry_data = spglib_get_symmetry_dataset(spglib_cell,
178 symprec=tol)
179 return list(set(symmetry_data.equivalent_atoms))