Coverage for ase / utils / structure_comparator.py: 94.90%
294 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-03-30 08:22 +0000
1"""Determine symmetry equivalence of two structures.
2Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012)."""
4from collections import Counter
5from itertools import combinations, filterfalse, product
7import numpy as np
8from scipy.spatial import cKDTree as KDTree
10from ase import Atom, Atoms
11from ase.build.tools import niggli_reduce
14def normalize(cell):
15 for i in range(3):
16 cell[i] /= np.linalg.norm(cell[i])
19class SpgLibNotFoundError(Exception):
20 """Raised if SPG lib is not found when needed."""
22 def __init__(self, msg):
23 super().__init__(msg)
26class SymmetryEquivalenceCheck:
27 """Compare two structures to determine if they are symmetry equivalent.
29 Based on the recipe from Comput. Phys. Commun. 183, 690-697 (2012).
31 Parameters
32 ----------
34 angle_tol: float
35 angle tolerance for the lattice vectors in degrees
37 ltol: float
38 relative tolerance for the length of the lattice vectors (per atom)
40 stol: float
41 position tolerance for the site comparison in units of
42 (V/N)^(1/3) (average length between atoms)
44 vol_tol: float
45 volume tolerance in angstrom cubed to compare the volumes of
46 the two structures
48 scale_volume: bool
49 if True the volumes of the two structures are scaled to be equal
51 to_primitive: bool
52 if True the structures are reduced to their primitive cells
53 note that this feature requires spglib to installed
55 Examples
56 --------
58 >>> from ase.build import bulk
59 >>> from ase.utils.structure_comparator import SymmetryEquivalenceCheck
60 >>> comp = SymmetryEquivalenceCheck()
62 Compare a cell with a rotated version
64 >>> a = bulk('Al', orthorhombic=True)
65 >>> b = a.copy()
66 >>> b.rotate(60, 'x', rotate_cell=True)
67 >>> comp.compare(a, b)
68 True
70 Transform to the primitive cell and then compare
72 >>> pa = bulk('Al')
73 >>> comp.compare(a, pa)
74 False
75 >>> comp = SymmetryEquivalenceCheck(to_primitive=True)
76 >>> comp.compare(a, pa)
77 True
79 Compare one structure with a list of other structures
81 >>> import numpy as np
82 >>> from ase import Atoms
83 >>> s1 = Atoms('H3', positions=[[0.5, 0.5, 0],
84 ... [0.5, 1.5, 0],
85 ... [1.5, 1.5, 0]],
86 ... cell=[2, 2, 2], pbc=True)
87 >>> comp = SymmetryEquivalenceCheck(stol=0.068)
88 >>> s2_list = []
89 >>> for d in np.linspace(0.1, 1.0, 5):
90 ... s2 = s1.copy()
91 ... s2.positions[0] += [d, 0, 0]
92 ... s2_list.append(s2)
93 >>> comp.compare(s1, s2_list[:-1])
94 False
95 >>> comp.compare(s1, s2_list)
96 True
98 """
100 def __init__(
101 self,
102 angle_tol=1.0,
103 ltol=0.05,
104 stol=0.05,
105 vol_tol=0.1,
106 scale_volume=False,
107 to_primitive=False,
108 ):
109 self.angle_tol = angle_tol * np.pi / 180.0 # convert to radians
110 self.scale_volume = scale_volume
111 self.stol = stol
112 self.ltol = ltol
113 self.vol_tol = vol_tol
114 self.position_tolerance = 0.0
115 self.to_primitive = to_primitive
117 # Variables to be used in the compare function
118 self.s1 = None
119 self.s2 = None
120 self.expanded_s1 = None
121 self.expanded_s2 = None
122 self.least_freq_element = None
124 def _niggli_reduce(self, atoms):
125 """Reduce to niggli cells.
127 Reduce the atoms to niggli cells, then rotates the niggli cells to
128 the so called "standard" orientation with one lattice vector along the
129 x-axis and a second vector in the xy plane.
130 """
131 niggli_reduce(atoms)
132 self._standarize_cell(atoms)
134 def _standarize_cell(self, atoms):
135 """Rotate the first vector such that it points along the x-axis.
136 Then rotate around the first vector so the second vector is in the
137 xy plane.
138 """
139 # Rotate first vector to x axis
140 cell = atoms.get_cell().T
141 total_rot_mat = np.eye(3)
142 v1 = cell[:, 0]
143 l1 = np.sqrt(v1[0] ** 2 + v1[2] ** 2)
144 angle = np.abs(np.arcsin(v1[2] / l1))
145 if v1[0] < 0.0 and v1[2] > 0.0:
146 angle = np.pi - angle
147 elif v1[0] < 0.0 and v1[2] < 0.0:
148 angle = np.pi + angle
149 elif v1[0] > 0.0 and v1[2] < 0.0:
150 angle = -angle
151 ca = np.cos(angle)
152 sa = np.sin(angle)
153 rotmat = np.array([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]])
154 total_rot_mat = rotmat.dot(total_rot_mat)
155 cell = rotmat.dot(cell)
157 v1 = cell[:, 0]
158 l1 = np.sqrt(v1[0] ** 2 + v1[1] ** 2)
159 angle = np.abs(np.arcsin(v1[1] / l1))
160 if v1[0] < 0.0 and v1[1] > 0.0:
161 angle = np.pi - angle
162 elif v1[0] < 0.0 and v1[1] < 0.0:
163 angle = np.pi + angle
164 elif v1[0] > 0.0 and v1[1] < 0.0:
165 angle = -angle
166 ca = np.cos(angle)
167 sa = np.sin(angle)
168 rotmat = np.array([[ca, sa, 0.0], [-sa, ca, 0.0], [0.0, 0.0, 1.0]])
169 total_rot_mat = rotmat.dot(total_rot_mat)
170 cell = rotmat.dot(cell)
172 # Rotate around x axis such that the second vector is in the xy plane
173 v2 = cell[:, 1]
174 l2 = np.sqrt(v2[1] ** 2 + v2[2] ** 2)
175 angle = np.abs(np.arcsin(v2[2] / l2))
176 if v2[1] < 0.0 and v2[2] > 0.0:
177 angle = np.pi - angle
178 elif v2[1] < 0.0 and v2[2] < 0.0:
179 angle = np.pi + angle
180 elif v2[1] > 0.0 and v2[2] < 0.0:
181 angle = -angle
182 ca = np.cos(angle)
183 sa = np.sin(angle)
184 rotmat = np.array([[1.0, 0.0, 0.0], [0.0, ca, sa], [0.0, -sa, ca]])
185 total_rot_mat = rotmat.dot(total_rot_mat)
186 cell = rotmat.dot(cell)
188 atoms.set_cell(cell.T)
189 atoms.set_positions(total_rot_mat.dot(atoms.get_positions().T).T)
190 atoms.wrap(pbc=[1, 1, 1])
191 return atoms
193 def _get_element_count(self, struct):
194 """Count the number of elements in each of the structures."""
195 return Counter(struct.numbers)
197 def _get_angles(self, cell):
198 """Get the internal angles of the unit cell."""
199 cell = cell.copy()
201 normalize(cell)
203 dot = cell.dot(cell.T)
205 # Extract only the relevant dot products
206 dot = [dot[0, 1], dot[0, 2], dot[1, 2]]
208 # Return angles
209 return np.arccos(dot)
211 def _has_same_elements(self):
212 """Check if two structures have same elements."""
213 elem1 = self._get_element_count(self.s1)
214 return elem1 == self._get_element_count(self.s2)
216 def _has_same_angles(self):
217 """Check that the Niggli unit vectors has the same internal angles."""
218 ang1 = np.sort(self._get_angles(self.s1.get_cell()))
219 ang2 = np.sort(self._get_angles(self.s2.get_cell()))
221 return np.allclose(ang1, ang2, rtol=0, atol=self.angle_tol)
223 def _has_same_volume(self):
224 vol1 = self.s1.get_volume()
225 vol2 = self.s2.get_volume()
226 return np.abs(vol1 - vol2) < self.vol_tol
228 def _scale_volumes(self):
229 """Scale the cell of s2 to have the same volume as s1."""
230 cell2 = self.s2.get_cell()
231 # Get the volumes
232 v2 = np.linalg.det(cell2)
233 v1 = np.linalg.det(self.s1.get_cell())
235 # Scale the cells
236 coordinate_scaling = (v1 / v2) ** (1.0 / 3.0)
237 cell2 *= coordinate_scaling
238 self.s2.set_cell(cell2, scale_atoms=True)
240 def compare(self, s1, s2):
241 """Compare the two structures.
243 Return *True* if the two structures are equivalent, *False* otherwise.
245 Parameters
246 ----------
248 s1: Atoms object.
249 Transformation matrices are calculated based on this structure.
251 s2: Atoms or list
252 s1 can be compared to one structure or many structures supplied in
253 a list. If s2 is a list it returns True if any structure in s2
254 matches s1, False otherwise.
255 """
256 if self.to_primitive:
257 s1 = self._reduce_to_primitive(s1)
258 self._set_least_frequent_element(s1)
259 self._least_frequent_element_to_origin(s1)
260 self.s1 = s1.copy()
261 vol = self.s1.get_volume()
262 self.expanded_s1 = None
263 s1_niggli_reduced = False
265 if isinstance(s2, Atoms):
266 # Just make it a list of length 1
267 s2 = [s2]
269 matrices = None
270 translations = None
271 transposed_matrices = None
272 for struct in s2:
273 self.s2 = struct.copy()
274 self.expanded_s2 = None
276 if self.to_primitive:
277 self.s2 = self._reduce_to_primitive(self.s2)
279 # Compare number of elements in structures
280 if len(self.s1) != len(self.s2):
281 continue
283 # Compare chemical formulae
284 if not self._has_same_elements():
285 continue
287 # Compare angles
288 if not s1_niggli_reduced:
289 self._niggli_reduce(self.s1)
290 self._niggli_reduce(self.s2)
291 if not self._has_same_angles():
292 continue
294 # Compare volumes
295 if self.scale_volume:
296 self._scale_volumes()
297 if not self._has_same_volume():
298 continue
300 if matrices is None:
301 matrices = self._get_rotation_reflection_matrices()
302 if matrices is None:
303 continue
305 if translations is None:
306 translations = self._get_least_frequent_positions(self.s1)
308 # After the candidate translation based on s1 has been computed
309 # we need potentially to swap s1 and s2 for robust comparison
310 self._least_frequent_element_to_origin(self.s2)
311 switch = self._switch_reference_struct()
312 if switch:
313 # Remember the matrices and translations used before
314 old_matrices = matrices
315 old_translations = translations
317 # If a s1 and s2 has been switched we need to use the
318 # transposed version of the matrices to map atoms the
319 # other way
320 if transposed_matrices is None:
321 transposed_matrices = np.transpose(matrices, axes=[0, 2, 1])
322 matrices = transposed_matrices
323 translations = self._get_least_frequent_positions(self.s1)
325 # Calculate tolerance on positions
326 self.position_tolerance = self.stol * (vol / len(self.s2)) ** (
327 1.0 / 3.0
328 )
330 if self._positions_match(matrices, translations):
331 return True
333 # Set the reference structure back to its original
334 self.s1 = s1.copy()
335 if switch:
336 self.expanded_s1 = self.expanded_s2
337 matrices = old_matrices
338 translations = old_translations
339 return False
341 def _set_least_frequent_element(self, atoms):
342 """Save the atomic number of the least frequent element."""
343 elem1 = self._get_element_count(atoms)
344 self.least_freq_element = elem1.most_common()[-1][0]
346 def _get_least_frequent_positions(self, atoms):
347 """Get the positions of the least frequent element in atoms."""
348 pos = atoms.get_positions(wrap=True)
349 return pos[atoms.numbers == self.least_freq_element]
351 def _get_only_least_frequent_of(self, struct):
352 """Get the atoms object with all other elements than the least frequent
353 one removed. Wrap the positions to get everything in the cell."""
354 pos = struct.get_positions(wrap=True)
356 indices = struct.numbers == self.least_freq_element
357 least_freq_struct = struct[indices]
358 least_freq_struct.set_positions(pos[indices])
360 return least_freq_struct
362 def _switch_reference_struct(self):
363 """There is an intrinsic assymetry in the system because
364 one of the atoms are being expanded, while the other is not.
365 This can cause the algorithm to return different result
366 depending on which structure is passed first.
367 We adopt the convention of using the atoms object
368 having the fewest atoms in its expanded cell as the
369 reference object.
370 We return True if a switch of structures has been performed."""
372 # First expand the cells
373 if self.expanded_s1 is None:
374 self.expanded_s1 = self._expand(self.s1)
375 if self.expanded_s2 is None:
376 self.expanded_s2 = self._expand(self.s2)
378 exp1 = self.expanded_s1
379 exp2 = self.expanded_s2
380 if len(exp1) < len(exp2):
381 # s1 should be the reference structure
382 # We have to swap s1 and s2
383 s1_temp = self.s1.copy()
384 self.s1 = self.s2
385 self.s2 = s1_temp
386 exp1_temp = self.expanded_s1.copy()
387 self.expanded_s1 = self.expanded_s2
388 self.expanded_s2 = exp1_temp
389 return True
390 return False
392 def _positions_match(self, rotation_reflection_matrices, translations):
393 """Check if the position and elements match.
395 Note that this function changes self.s1 and self.s2 to the rotation and
396 translation that matches best. Hence, it is crucial that this function
397 calls the element comparison, not the other way around.
398 """
399 pos1_ref = self.s1.get_positions(wrap=True)
401 # Get the expanded reference object
402 exp2 = self.expanded_s2
403 # Build a KD tree to enable fast look-up of nearest neighbours
404 tree = KDTree(exp2.get_positions())
405 for i in range(translations.shape[0]):
406 # Translate
407 pos1_trans = pos1_ref - translations[i]
408 for matrix in rotation_reflection_matrices:
409 # Rotate
410 pos1 = matrix.dot(pos1_trans.T).T
412 # Update the atoms positions
413 self.s1.set_positions(pos1)
414 self.s1.wrap(pbc=[1, 1, 1])
415 if self._elements_match(self.s1, exp2, tree):
416 return True
417 return False
419 def _expand(self, ref_atoms, tol=0.0001):
420 """If an atom is closer to a boundary than tol it is repeated at the
421 opposite boundaries.
423 This ensures that atoms having crossed the cell boundaries due to
424 numerical noise are properly detected.
426 The distance between a position and cell boundary is calculated as:
427 dot(position, (b_vec x c_vec) / (|b_vec| |c_vec|) ), where x is the
428 cross product.
429 """
430 syms = ref_atoms.get_chemical_symbols()
431 cell = ref_atoms.get_cell()
432 positions = ref_atoms.get_positions(wrap=True)
433 expanded_atoms = ref_atoms.copy()
435 # Calculate normal vectors to the unit cell faces
436 normal_vectors = np.array(
437 [
438 np.cross(cell[1, :], cell[2, :]),
439 np.cross(cell[0, :], cell[2, :]),
440 np.cross(cell[0, :], cell[1, :]),
441 ]
442 )
443 normalize(normal_vectors)
445 # Get the distance to the unit cell faces from each atomic position
446 pos2faces = np.abs(positions.dot(normal_vectors.T))
448 # And the opposite faces
449 pos2oppofaces = np.abs(
450 np.dot(positions - np.sum(cell, axis=0), normal_vectors.T)
451 )
453 for i, i2face in enumerate(pos2faces):
454 # Append indices for positions close to the other faces
455 # and convert to boolean array signifying if the position at
456 # index i is close to the faces bordering origo (0, 1, 2) or
457 # the opposite faces (3, 4, 5)
458 i_close2face = np.append(i2face, pos2oppofaces[i]) < tol
459 # For each position i.e. row it holds that
460 # 1 x True -> close to face -> 1 extra atom at opposite face
461 # 2 x True -> close to edge -> 3 extra atoms at opposite edges
462 # 3 x True -> close to corner -> 7 extra atoms opposite corners
463 # E.g. to add atoms at all corners we need to use the cell
464 # vectors: (a, b, c, a + b, a + c, b + c, a + b + c), we use
465 # itertools.combinations to get them all
466 for j in range(sum(i_close2face)):
467 for c in combinations(np.nonzero(i_close2face)[0], j + 1):
468 # Get the displacement vectors by adding the corresponding
469 # cell vectors, if the atom is close to an opposite face
470 # i.e. k > 2 subtract the cell vector
471 disp_vec = np.zeros(3)
472 for k in c:
473 disp_vec += cell[k % 3] * (int(k < 3) * 2 - 1)
474 pos = positions[i] + disp_vec
475 expanded_atoms.append(Atom(syms[i], position=pos))
476 return expanded_atoms
478 def _equal_elements_in_array(self, arr):
479 s = np.sort(arr)
480 return np.any(s[1:] == s[:-1])
482 def _elements_match(self, s1, s2, kdtree):
483 """Check if all the elements in s1 match corresponding position in s2
485 NOTE: The unit cells may be in different octants
486 Hence, try all cyclic permutations of x,y and z
487 """
488 pos1 = s1.get_positions()
489 for order in range(1): # Is the order still needed?
490 pos_order = [order, (order + 1) % 3, (order + 2) % 3]
491 pos = pos1[:, np.argsort(pos_order)]
492 dists, closest_in_s2 = kdtree.query(pos)
494 # Check if the elements are the same
495 if not np.all(s2.numbers[closest_in_s2] == s1.numbers):
496 return False
498 # Check if any distance is too large
499 if np.any(dists > self.position_tolerance):
500 return False
502 # Check for duplicates in what atom is closest
503 if self._equal_elements_in_array(closest_in_s2):
504 return False
506 return True
508 def _least_frequent_element_to_origin(self, atoms):
509 """Put one of the least frequent elements at the origin."""
510 least_freq_pos = self._get_least_frequent_positions(atoms)
511 cell_diag = np.sum(atoms.get_cell(), axis=0)
512 d = least_freq_pos[0] - 1e-6 * cell_diag
513 atoms.positions -= d
514 atoms.wrap(pbc=[1, 1, 1])
516 def _get_rotation_reflection_matrices(self):
517 """Compute candidates for the transformation matrix."""
518 atoms1_ref = self._get_only_least_frequent_of(self.s1)
519 cell = self.s1.get_cell().T
520 cell_diag = np.sum(cell, axis=1)
521 angle_tol = self.angle_tol
523 # Additional vector that is added to make sure that
524 # there always is an atom at the origin
525 delta_vec = 1e-6 * cell_diag
527 # Store three reference vectors and their lengths
528 ref_vec = self.s2.get_cell()
529 ref_vec_lengths = np.linalg.norm(ref_vec, axis=1)
531 # Compute ref vec angles
532 # ref_angles are arranged as [angle12, angle13, angle23]
533 ref_angles = np.array(self._get_angles(ref_vec))
534 large_angles = ref_angles > np.pi / 2.0
535 ref_angles[large_angles] = np.pi - ref_angles[large_angles]
537 # Translate by one cell diagonal so that a central cell is
538 # surrounded by cells in all directions
539 sc_atom_search = atoms1_ref * (3, 3, 3)
540 new_sc_pos = sc_atom_search.get_positions()
541 new_sc_pos -= new_sc_pos[0] + cell_diag - delta_vec
543 lengths = np.linalg.norm(new_sc_pos, axis=1)
545 candidate_indices = []
546 rtol = self.ltol / len(self.s1)
547 for k in range(3):
548 correct_lengths_mask = np.isclose(
549 lengths, ref_vec_lengths[k], rtol=rtol, atol=0
550 )
551 # The first vector is not interesting
552 correct_lengths_mask[0] = False
554 # If no trial vectors can be found (for any direction)
555 # then the candidates are different and we return None
556 if not np.any(correct_lengths_mask):
557 return None
559 candidate_indices.append(np.nonzero(correct_lengths_mask)[0])
561 # Now we calculate all relevant angles in one step. The relevant angles
562 # are the ones made by the current candidates. We will have to keep
563 # track of the indices in the angles matrix and the indices in the
564 # position and length arrays.
566 # Get all candidate indices (aci), only unique values
567 aci = np.sort(list(set().union(*candidate_indices)))
569 # Make a dictionary from original positions and lengths index to
570 # index in angle matrix
571 i2ang = dict(zip(aci, range(len(aci))))
573 # Calculate the dot product divided by the lengths:
574 # cos(angle) = dot(vec1, vec2) / |vec1| |vec2|
575 cosa = np.inner(new_sc_pos[aci], new_sc_pos[aci]) / np.outer(
576 lengths[aci], lengths[aci]
577 )
578 # Make sure the inverse cosine will work
579 cosa[cosa > 1] = 1
580 cosa[cosa < -1] = -1
581 angles = np.arccos(cosa)
582 # Do trick for enantiomorphic structures
583 angles[angles > np.pi / 2] = np.pi - angles[angles > np.pi / 2]
585 # Check which angles match the reference angles
586 # Test for all combinations on candidates. filterfalse makes sure
587 # that there are no duplicate candidates. product is the same as
588 # nested for loops.
589 refined_candidate_list = []
590 for p in filterfalse(
591 self._equal_elements_in_array, product(*candidate_indices)
592 ):
593 a = np.array(
594 [
595 angles[i2ang[p[0]], i2ang[p[1]]],
596 angles[i2ang[p[0]], i2ang[p[2]]],
597 angles[i2ang[p[1]], i2ang[p[2]]],
598 ]
599 )
601 if np.allclose(a, ref_angles, atol=angle_tol, rtol=0):
602 refined_candidate_list.append(new_sc_pos[np.array(p)].T)
604 # Get the rotation/reflection matrix [R] by:
605 # [R] = [V][T]^-1, where [V] is the reference vectors and
606 # [T] is the trial vectors
607 # XXX What do we know about the length/shape of refined_candidate_list?
608 if len(refined_candidate_list) == 0:
609 return None
610 else:
611 inverted_trial = np.linalg.inv(refined_candidate_list)
613 # Equivalent to np.matmul(ref_vec.T, inverted_trial)
614 candidate_trans_mat = np.dot(ref_vec.T, inverted_trial.T).T
615 return candidate_trans_mat
617 def _reduce_to_primitive(self, structure):
618 """Reduce the two structure to their primitive type"""
619 from ase.utils import spglib_new_errorhandling
621 try:
622 import spglib
623 except ImportError:
624 raise SpgLibNotFoundError('SpgLib is required if to_primitive=True')
625 cell = (structure.get_cell()).tolist()
626 pos = structure.get_scaled_positions().tolist()
627 numbers = structure.get_atomic_numbers()
629 cell, scaled_pos, numbers = spglib_new_errorhandling(
630 spglib.standardize_cell
631 )((cell, pos, numbers), to_primitive=True)
633 atoms = Atoms(
634 scaled_positions=scaled_pos, numbers=numbers, cell=cell, pbc=True
635 )
636 return atoms