Coverage for /builds/ase/ase/ase/utils/structure_comparator.py: 94.88%
293 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1"""Determine symmetry equivalence of two structures.
2Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012)."""
4from collections import Counter
5from itertools import combinations, filterfalse, product
7import numpy as np
8from scipy.spatial import cKDTree as KDTree
10from ase import Atom, Atoms
11from ase.build.tools import niggli_reduce
14def normalize(cell):
15 for i in range(3):
16 cell[i] /= np.linalg.norm(cell[i])
19class SpgLibNotFoundError(Exception):
20 """Raised if SPG lib is not found when needed."""
22 def __init__(self, msg):
23 super().__init__(msg)
26class SymmetryEquivalenceCheck:
27 """Compare two structures to determine if they are symmetry equivalent.
29 Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012).
31 Parameters:
33 angle_tol: float
34 angle tolerance for the lattice vectors in degrees
36 ltol: float
37 relative tolerance for the length of the lattice vectors (per atom)
39 stol: float
40 position tolerance for the site comparison in units of
41 (V/N)^(1/3) (average length between atoms)
43 vol_tol: float
44 volume tolerance in angstrom cubed to compare the volumes of
45 the two structures
47 scale_volume: bool
48 if True the volumes of the two structures are scaled to be equal
50 to_primitive: bool
51 if True the structures are reduced to their primitive cells
52 note that this feature requires spglib to installed
54 Examples:
56 >>> from ase.build import bulk
57 >>> from ase.utils.structure_comparator import SymmetryEquivalenceCheck
58 >>> comp = SymmetryEquivalenceCheck()
60 Compare a cell with a rotated version
62 >>> a = bulk('Al', orthorhombic=True)
63 >>> b = a.copy()
64 >>> b.rotate(60, 'x', rotate_cell=True)
65 >>> comp.compare(a, b)
66 True
68 Transform to the primitive cell and then compare
70 >>> pa = bulk('Al')
71 >>> comp.compare(a, pa)
72 False
73 >>> comp = SymmetryEquivalenceCheck(to_primitive=True)
74 >>> comp.compare(a, pa)
75 True
77 Compare one structure with a list of other structures
79 >>> import numpy as np
80 >>> from ase import Atoms
81 >>> s1 = Atoms('H3', positions=[[0.5, 0.5, 0],
82 ... [0.5, 1.5, 0],
83 ... [1.5, 1.5, 0]],
84 ... cell=[2, 2, 2], pbc=True)
85 >>> comp = SymmetryEquivalenceCheck(stol=0.068)
86 >>> s2_list = []
87 >>> for d in np.linspace(0.1, 1.0, 5):
88 ... s2 = s1.copy()
89 ... s2.positions[0] += [d, 0, 0]
90 ... s2_list.append(s2)
91 >>> comp.compare(s1, s2_list[:-1])
92 False
93 >>> comp.compare(s1, s2_list)
94 True
96 """
98 def __init__(
99 self,
100 angle_tol=1.0,
101 ltol=0.05,
102 stol=0.05,
103 vol_tol=0.1,
104 scale_volume=False,
105 to_primitive=False,
106 ):
107 self.angle_tol = angle_tol * np.pi / 180.0 # convert to radians
108 self.scale_volume = scale_volume
109 self.stol = stol
110 self.ltol = ltol
111 self.vol_tol = vol_tol
112 self.position_tolerance = 0.0
113 self.to_primitive = to_primitive
115 # Variables to be used in the compare function
116 self.s1 = None
117 self.s2 = None
118 self.expanded_s1 = None
119 self.expanded_s2 = None
120 self.least_freq_element = None
122 def _niggli_reduce(self, atoms):
123 """Reduce to niggli cells.
125 Reduce the atoms to niggli cells, then rotates the niggli cells to
126 the so called "standard" orientation with one lattice vector along the
127 x-axis and a second vector in the xy plane.
128 """
129 niggli_reduce(atoms)
130 self._standarize_cell(atoms)
132 def _standarize_cell(self, atoms):
133 """Rotate the first vector such that it points along the x-axis.
134 Then rotate around the first vector so the second vector is in the
135 xy plane.
136 """
137 # Rotate first vector to x axis
138 cell = atoms.get_cell().T
139 total_rot_mat = np.eye(3)
140 v1 = cell[:, 0]
141 l1 = np.sqrt(v1[0] ** 2 + v1[2] ** 2)
142 angle = np.abs(np.arcsin(v1[2] / l1))
143 if v1[0] < 0.0 and v1[2] > 0.0:
144 angle = np.pi - angle
145 elif v1[0] < 0.0 and v1[2] < 0.0:
146 angle = np.pi + angle
147 elif v1[0] > 0.0 and v1[2] < 0.0:
148 angle = -angle
149 ca = np.cos(angle)
150 sa = np.sin(angle)
151 rotmat = np.array([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]])
152 total_rot_mat = rotmat.dot(total_rot_mat)
153 cell = rotmat.dot(cell)
155 v1 = cell[:, 0]
156 l1 = np.sqrt(v1[0] ** 2 + v1[1] ** 2)
157 angle = np.abs(np.arcsin(v1[1] / l1))
158 if v1[0] < 0.0 and v1[1] > 0.0:
159 angle = np.pi - angle
160 elif v1[0] < 0.0 and v1[1] < 0.0:
161 angle = np.pi + angle
162 elif v1[0] > 0.0 and v1[1] < 0.0:
163 angle = -angle
164 ca = np.cos(angle)
165 sa = np.sin(angle)
166 rotmat = np.array([[ca, sa, 0.0], [-sa, ca, 0.0], [0.0, 0.0, 1.0]])
167 total_rot_mat = rotmat.dot(total_rot_mat)
168 cell = rotmat.dot(cell)
170 # Rotate around x axis such that the second vector is in the xy plane
171 v2 = cell[:, 1]
172 l2 = np.sqrt(v2[1] ** 2 + v2[2] ** 2)
173 angle = np.abs(np.arcsin(v2[2] / l2))
174 if v2[1] < 0.0 and v2[2] > 0.0:
175 angle = np.pi - angle
176 elif v2[1] < 0.0 and v2[2] < 0.0:
177 angle = np.pi + angle
178 elif v2[1] > 0.0 and v2[2] < 0.0:
179 angle = -angle
180 ca = np.cos(angle)
181 sa = np.sin(angle)
182 rotmat = np.array([[1.0, 0.0, 0.0], [0.0, ca, sa], [0.0, -sa, ca]])
183 total_rot_mat = rotmat.dot(total_rot_mat)
184 cell = rotmat.dot(cell)
186 atoms.set_cell(cell.T)
187 atoms.set_positions(total_rot_mat.dot(atoms.get_positions().T).T)
188 atoms.wrap(pbc=[1, 1, 1])
189 return atoms
191 def _get_element_count(self, struct):
192 """Count the number of elements in each of the structures."""
193 return Counter(struct.numbers)
195 def _get_angles(self, cell):
196 """Get the internal angles of the unit cell."""
197 cell = cell.copy()
199 normalize(cell)
201 dot = cell.dot(cell.T)
203 # Extract only the relevant dot products
204 dot = [dot[0, 1], dot[0, 2], dot[1, 2]]
206 # Return angles
207 return np.arccos(dot)
209 def _has_same_elements(self):
210 """Check if two structures have same elements."""
211 elem1 = self._get_element_count(self.s1)
212 return elem1 == self._get_element_count(self.s2)
214 def _has_same_angles(self):
215 """Check that the Niggli unit vectors has the same internal angles."""
216 ang1 = np.sort(self._get_angles(self.s1.get_cell()))
217 ang2 = np.sort(self._get_angles(self.s2.get_cell()))
219 return np.allclose(ang1, ang2, rtol=0, atol=self.angle_tol)
221 def _has_same_volume(self):
222 vol1 = self.s1.get_volume()
223 vol2 = self.s2.get_volume()
224 return np.abs(vol1 - vol2) < self.vol_tol
226 def _scale_volumes(self):
227 """Scale the cell of s2 to have the same volume as s1."""
228 cell2 = self.s2.get_cell()
229 # Get the volumes
230 v2 = np.linalg.det(cell2)
231 v1 = np.linalg.det(self.s1.get_cell())
233 # Scale the cells
234 coordinate_scaling = (v1 / v2) ** (1.0 / 3.0)
235 cell2 *= coordinate_scaling
236 self.s2.set_cell(cell2, scale_atoms=True)
238 def compare(self, s1, s2):
239 """Compare the two structures.
241 Return *True* if the two structures are equivalent, *False* otherwise.
243 Parameters:
245 s1: Atoms object.
246 Transformation matrices are calculated based on this structure.
248 s2: Atoms or list
249 s1 can be compared to one structure or many structures supplied in
250 a list. If s2 is a list it returns True if any structure in s2
251 matches s1, False otherwise.
252 """
253 if self.to_primitive:
254 s1 = self._reduce_to_primitive(s1)
255 self._set_least_frequent_element(s1)
256 self._least_frequent_element_to_origin(s1)
257 self.s1 = s1.copy()
258 vol = self.s1.get_volume()
259 self.expanded_s1 = None
260 s1_niggli_reduced = False
262 if isinstance(s2, Atoms):
263 # Just make it a list of length 1
264 s2 = [s2]
266 matrices = None
267 translations = None
268 transposed_matrices = None
269 for struct in s2:
270 self.s2 = struct.copy()
271 self.expanded_s2 = None
273 if self.to_primitive:
274 self.s2 = self._reduce_to_primitive(self.s2)
276 # Compare number of elements in structures
277 if len(self.s1) != len(self.s2):
278 continue
280 # Compare chemical formulae
281 if not self._has_same_elements():
282 continue
284 # Compare angles
285 if not s1_niggli_reduced:
286 self._niggli_reduce(self.s1)
287 self._niggli_reduce(self.s2)
288 if not self._has_same_angles():
289 continue
291 # Compare volumes
292 if self.scale_volume:
293 self._scale_volumes()
294 if not self._has_same_volume():
295 continue
297 if matrices is None:
298 matrices = self._get_rotation_reflection_matrices()
299 if matrices is None:
300 continue
302 if translations is None:
303 translations = self._get_least_frequent_positions(self.s1)
305 # After the candidate translation based on s1 has been computed
306 # we need potentially to swap s1 and s2 for robust comparison
307 self._least_frequent_element_to_origin(self.s2)
308 switch = self._switch_reference_struct()
309 if switch:
310 # Remember the matrices and translations used before
311 old_matrices = matrices
312 old_translations = translations
314 # If a s1 and s2 has been switched we need to use the
315 # transposed version of the matrices to map atoms the
316 # other way
317 if transposed_matrices is None:
318 transposed_matrices = np.transpose(matrices, axes=[0, 2, 1])
319 matrices = transposed_matrices
320 translations = self._get_least_frequent_positions(self.s1)
322 # Calculate tolerance on positions
323 self.position_tolerance = self.stol * (vol / len(self.s2)) ** (
324 1.0 / 3.0
325 )
327 if self._positions_match(matrices, translations):
328 return True
330 # Set the reference structure back to its original
331 self.s1 = s1.copy()
332 if switch:
333 self.expanded_s1 = self.expanded_s2
334 matrices = old_matrices
335 translations = old_translations
336 return False
338 def _set_least_frequent_element(self, atoms):
339 """Save the atomic number of the least frequent element."""
340 elem1 = self._get_element_count(atoms)
341 self.least_freq_element = elem1.most_common()[-1][0]
343 def _get_least_frequent_positions(self, atoms):
344 """Get the positions of the least frequent element in atoms."""
345 pos = atoms.get_positions(wrap=True)
346 return pos[atoms.numbers == self.least_freq_element]
348 def _get_only_least_frequent_of(self, struct):
349 """Get the atoms object with all other elements than the least frequent
350 one removed. Wrap the positions to get everything in the cell."""
351 pos = struct.get_positions(wrap=True)
353 indices = struct.numbers == self.least_freq_element
354 least_freq_struct = struct[indices]
355 least_freq_struct.set_positions(pos[indices])
357 return least_freq_struct
359 def _switch_reference_struct(self):
360 """There is an intrinsic assymetry in the system because
361 one of the atoms are being expanded, while the other is not.
362 This can cause the algorithm to return different result
363 depending on which structure is passed first.
364 We adopt the convention of using the atoms object
365 having the fewest atoms in its expanded cell as the
366 reference object.
367 We return True if a switch of structures has been performed."""
369 # First expand the cells
370 if self.expanded_s1 is None:
371 self.expanded_s1 = self._expand(self.s1)
372 if self.expanded_s2 is None:
373 self.expanded_s2 = self._expand(self.s2)
375 exp1 = self.expanded_s1
376 exp2 = self.expanded_s2
377 if len(exp1) < len(exp2):
378 # s1 should be the reference structure
379 # We have to swap s1 and s2
380 s1_temp = self.s1.copy()
381 self.s1 = self.s2
382 self.s2 = s1_temp
383 exp1_temp = self.expanded_s1.copy()
384 self.expanded_s1 = self.expanded_s2
385 self.expanded_s2 = exp1_temp
386 return True
387 return False
389 def _positions_match(self, rotation_reflection_matrices, translations):
390 """Check if the position and elements match.
392 Note that this function changes self.s1 and self.s2 to the rotation and
393 translation that matches best. Hence, it is crucial that this function
394 calls the element comparison, not the other way around.
395 """
396 pos1_ref = self.s1.get_positions(wrap=True)
398 # Get the expanded reference object
399 exp2 = self.expanded_s2
400 # Build a KD tree to enable fast look-up of nearest neighbours
401 tree = KDTree(exp2.get_positions())
402 for i in range(translations.shape[0]):
403 # Translate
404 pos1_trans = pos1_ref - translations[i]
405 for matrix in rotation_reflection_matrices:
406 # Rotate
407 pos1 = matrix.dot(pos1_trans.T).T
409 # Update the atoms positions
410 self.s1.set_positions(pos1)
411 self.s1.wrap(pbc=[1, 1, 1])
412 if self._elements_match(self.s1, exp2, tree):
413 return True
414 return False
416 def _expand(self, ref_atoms, tol=0.0001):
417 """If an atom is closer to a boundary than tol it is repeated at the
418 opposite boundaries.
420 This ensures that atoms having crossed the cell boundaries due to
421 numerical noise are properly detected.
423 The distance between a position and cell boundary is calculated as:
424 dot(position, (b_vec x c_vec) / (|b_vec| |c_vec|) ), where x is the
425 cross product.
426 """
427 syms = ref_atoms.get_chemical_symbols()
428 cell = ref_atoms.get_cell()
429 positions = ref_atoms.get_positions(wrap=True)
430 expanded_atoms = ref_atoms.copy()
432 # Calculate normal vectors to the unit cell faces
433 normal_vectors = np.array(
434 [
435 np.cross(cell[1, :], cell[2, :]),
436 np.cross(cell[0, :], cell[2, :]),
437 np.cross(cell[0, :], cell[1, :]),
438 ]
439 )
440 normalize(normal_vectors)
442 # Get the distance to the unit cell faces from each atomic position
443 pos2faces = np.abs(positions.dot(normal_vectors.T))
445 # And the opposite faces
446 pos2oppofaces = np.abs(
447 np.dot(positions - np.sum(cell, axis=0), normal_vectors.T)
448 )
450 for i, i2face in enumerate(pos2faces):
451 # Append indices for positions close to the other faces
452 # and convert to boolean array signifying if the position at
453 # index i is close to the faces bordering origo (0, 1, 2) or
454 # the opposite faces (3, 4, 5)
455 i_close2face = np.append(i2face, pos2oppofaces[i]) < tol
456 # For each position i.e. row it holds that
457 # 1 x True -> close to face -> 1 extra atom at opposite face
458 # 2 x True -> close to edge -> 3 extra atoms at opposite edges
459 # 3 x True -> close to corner -> 7 extra atoms opposite corners
460 # E.g. to add atoms at all corners we need to use the cell
461 # vectors: (a, b, c, a + b, a + c, b + c, a + b + c), we use
462 # itertools.combinations to get them all
463 for j in range(sum(i_close2face)):
464 for c in combinations(np.nonzero(i_close2face)[0], j + 1):
465 # Get the displacement vectors by adding the corresponding
466 # cell vectors, if the atom is close to an opposite face
467 # i.e. k > 2 subtract the cell vector
468 disp_vec = np.zeros(3)
469 for k in c:
470 disp_vec += cell[k % 3] * (int(k < 3) * 2 - 1)
471 pos = positions[i] + disp_vec
472 expanded_atoms.append(Atom(syms[i], position=pos))
473 return expanded_atoms
475 def _equal_elements_in_array(self, arr):
476 s = np.sort(arr)
477 return np.any(s[1:] == s[:-1])
479 def _elements_match(self, s1, s2, kdtree):
480 """Check if all the elements in s1 match corresponding position in s2
482 NOTE: The unit cells may be in different octants
483 Hence, try all cyclic permutations of x,y and z
484 """
485 pos1 = s1.get_positions()
486 for order in range(1): # Is the order still needed?
487 pos_order = [order, (order + 1) % 3, (order + 2) % 3]
488 pos = pos1[:, np.argsort(pos_order)]
489 dists, closest_in_s2 = kdtree.query(pos)
491 # Check if the elements are the same
492 if not np.all(s2.numbers[closest_in_s2] == s1.numbers):
493 return False
495 # Check if any distance is too large
496 if np.any(dists > self.position_tolerance):
497 return False
499 # Check for duplicates in what atom is closest
500 if self._equal_elements_in_array(closest_in_s2):
501 return False
503 return True
505 def _least_frequent_element_to_origin(self, atoms):
506 """Put one of the least frequent elements at the origin."""
507 least_freq_pos = self._get_least_frequent_positions(atoms)
508 cell_diag = np.sum(atoms.get_cell(), axis=0)
509 d = least_freq_pos[0] - 1e-6 * cell_diag
510 atoms.positions -= d
511 atoms.wrap(pbc=[1, 1, 1])
513 def _get_rotation_reflection_matrices(self):
514 """Compute candidates for the transformation matrix."""
515 atoms1_ref = self._get_only_least_frequent_of(self.s1)
516 cell = self.s1.get_cell().T
517 cell_diag = np.sum(cell, axis=1)
518 angle_tol = self.angle_tol
520 # Additional vector that is added to make sure that
521 # there always is an atom at the origin
522 delta_vec = 1e-6 * cell_diag
524 # Store three reference vectors and their lengths
525 ref_vec = self.s2.get_cell()
526 ref_vec_lengths = np.linalg.norm(ref_vec, axis=1)
528 # Compute ref vec angles
529 # ref_angles are arranged as [angle12, angle13, angle23]
530 ref_angles = np.array(self._get_angles(ref_vec))
531 large_angles = ref_angles > np.pi / 2.0
532 ref_angles[large_angles] = np.pi - ref_angles[large_angles]
534 # Translate by one cell diagonal so that a central cell is
535 # surrounded by cells in all directions
536 sc_atom_search = atoms1_ref * (3, 3, 3)
537 new_sc_pos = sc_atom_search.get_positions()
538 new_sc_pos -= new_sc_pos[0] + cell_diag - delta_vec
540 lengths = np.linalg.norm(new_sc_pos, axis=1)
542 candidate_indices = []
543 rtol = self.ltol / len(self.s1)
544 for k in range(3):
545 correct_lengths_mask = np.isclose(
546 lengths, ref_vec_lengths[k], rtol=rtol, atol=0
547 )
548 # The first vector is not interesting
549 correct_lengths_mask[0] = False
551 # If no trial vectors can be found (for any direction)
552 # then the candidates are different and we return None
553 if not np.any(correct_lengths_mask):
554 return None
556 candidate_indices.append(np.nonzero(correct_lengths_mask)[0])
558 # Now we calculate all relevant angles in one step. The relevant angles
559 # are the ones made by the current candidates. We will have to keep
560 # track of the indices in the angles matrix and the indices in the
561 # position and length arrays.
563 # Get all candidate indices (aci), only unique values
564 aci = np.sort(list(set().union(*candidate_indices)))
566 # Make a dictionary from original positions and lengths index to
567 # index in angle matrix
568 i2ang = dict(zip(aci, range(len(aci))))
570 # Calculate the dot product divided by the lengths:
571 # cos(angle) = dot(vec1, vec2) / |vec1| |vec2|
572 cosa = np.inner(new_sc_pos[aci], new_sc_pos[aci]) / np.outer(
573 lengths[aci], lengths[aci]
574 )
575 # Make sure the inverse cosine will work
576 cosa[cosa > 1] = 1
577 cosa[cosa < -1] = -1
578 angles = np.arccos(cosa)
579 # Do trick for enantiomorphic structures
580 angles[angles > np.pi / 2] = np.pi - angles[angles > np.pi / 2]
582 # Check which angles match the reference angles
583 # Test for all combinations on candidates. filterfalse makes sure
584 # that there are no duplicate candidates. product is the same as
585 # nested for loops.
586 refined_candidate_list = []
587 for p in filterfalse(
588 self._equal_elements_in_array, product(*candidate_indices)
589 ):
590 a = np.array(
591 [
592 angles[i2ang[p[0]], i2ang[p[1]]],
593 angles[i2ang[p[0]], i2ang[p[2]]],
594 angles[i2ang[p[1]], i2ang[p[2]]],
595 ]
596 )
598 if np.allclose(a, ref_angles, atol=angle_tol, rtol=0):
599 refined_candidate_list.append(new_sc_pos[np.array(p)].T)
601 # Get the rotation/reflection matrix [R] by:
602 # [R] = [V][T]^-1, where [V] is the reference vectors and
603 # [T] is the trial vectors
604 # XXX What do we know about the length/shape of refined_candidate_list?
605 if len(refined_candidate_list) == 0:
606 return None
607 else:
608 inverted_trial = np.linalg.inv(refined_candidate_list)
610 # Equivalent to np.matmul(ref_vec.T, inverted_trial)
611 candidate_trans_mat = np.dot(ref_vec.T, inverted_trial.T).T
612 return candidate_trans_mat
614 def _reduce_to_primitive(self, structure):
615 """Reduce the two structure to their primitive type"""
616 try:
617 import spglib
618 except ImportError:
619 raise SpgLibNotFoundError('SpgLib is required if to_primitive=True')
620 cell = (structure.get_cell()).tolist()
621 pos = structure.get_scaled_positions().tolist()
622 numbers = structure.get_atomic_numbers()
624 cell, scaled_pos, numbers = spglib.standardize_cell(
625 (cell, pos, numbers), to_primitive=True
626 )
628 atoms = Atoms(
629 scaled_positions=scaled_pos, numbers=numbers, cell=cell, pbc=True
630 )
631 return atoms