Coverage for /builds/ase/ase/ase/geometry/dimensionality/rank_determination.py: 99.15%
118 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3"""
4Implements the Rank Determination Algorithm (RDA)
6Method is described in:
7Definition of a scoring parameter to identify low-dimensional materials
8components
9P.M. Larsen, M. Pandey, M. Strange, and K. W. Jacobsen
10Phys. Rev. Materials 3 034003, 2019
11https://doi.org/10.1103/PhysRevMaterials.3.034003
12"""
13from collections import defaultdict
15import numpy as np
17from ase.geometry.dimensionality.disjoint_set import DisjointSet
19# Numpy has a large overhead for lots of small vectors. The cross product is
20# particularly bad. Pure python is a lot faster.
23def dot_product(A, B):
24 return sum(a * b for a, b in zip(A, B))
27def cross_product(a, b):
28 return [a[i] * b[j] - a[j] * b[i] for i, j in [(1, 2), (2, 0), (0, 1)]]
31def subtract(A, B):
32 return [a - b for a, b in zip(A, B)]
35def rank_increase(a, b):
36 if len(a) == 0:
37 return True
38 elif len(a) == 1:
39 return a[0] != b
40 elif len(a) == 4:
41 return False
43 L = a + [b]
44 w = cross_product(subtract(L[1], L[0]), subtract(L[2], L[0]))
45 if len(a) == 2:
46 return any(w)
47 elif len(a) == 3:
48 return dot_product(w, subtract(L[3], L[0])) != 0
49 else:
50 raise Exception("This shouldn't be possible.")
53def bfs(adjacency, start):
54 """Traverse the component graph using BFS.
56 The graph is traversed until the matrix rank of the subspace spanned by
57 the visited components no longer increases.
58 """
59 visited = set()
60 cvisited = defaultdict(list)
61 queue = [(start, (0, 0, 0))]
62 while queue:
63 vertex = queue.pop(0)
64 if vertex in visited:
65 continue
67 visited.add(vertex)
68 c, p = vertex
69 if not rank_increase(cvisited[c], p):
70 continue
72 cvisited[c].append(p)
74 for nc, offset in adjacency[c]:
76 nbrpos = (p[0] + offset[0], p[1] + offset[1], p[2] + offset[2])
77 nbrnode = (nc, nbrpos)
78 if nbrnode in visited:
79 continue
81 if rank_increase(cvisited[nc], nbrpos):
82 queue.append(nbrnode)
84 return visited, len(cvisited[start]) - 1
87def traverse_component_graphs(adjacency):
88 vertices = adjacency.keys()
89 all_visited = {}
90 ranks = {}
91 for v in vertices:
92 visited, rank = bfs(adjacency, v)
93 all_visited[v] = visited
94 ranks[v] = rank
96 return all_visited, ranks
99def build_adjacency_list(parents, bonds):
100 graph = np.unique(parents)
101 adjacency = {e: set() for e in graph}
102 for (i, j, offset) in bonds:
103 component_a = parents[i]
104 component_b = parents[j]
105 adjacency[component_a].add((component_b, offset))
106 return adjacency
109def get_dimensionality_histogram(ranks, roots):
110 h = [0, 0, 0, 0]
111 for e in roots:
112 h[ranks[e]] += 1
113 return tuple(h)
116def merge_mutual_visits(all_visited, ranks, graph):
117 """Find components with mutual visits and merge them."""
118 merged = False
119 common = defaultdict(list)
120 for b, visited in all_visited.items():
121 for offset in visited:
122 for a in common[offset]:
123 assert ranks[a] == ranks[b]
124 merged |= graph.union(a, b)
125 common[offset].append(b)
127 if not merged:
128 return merged, all_visited, ranks
130 merged_visits = defaultdict(set)
131 merged_ranks = {}
132 parents = graph.find_all()
133 for k, v in all_visited.items():
134 key = parents[k]
135 merged_visits[key].update(v)
136 merged_ranks[key] = ranks[key]
137 return merged, merged_visits, merged_ranks
140class RDA:
142 def __init__(self, num_atoms):
143 """
144 Initializes the RDA class.
146 A disjoint set is used to maintain the component graph.
148 Parameters:
150 num_atoms: int The number of atoms in the unit cell.
151 """
152 self.bonds = []
153 self.graph = DisjointSet(num_atoms)
154 self.adjacency = None
155 self.hcached = None
156 self.components_cached = None
157 self.cdim_cached = None
159 def insert_bond(self, i, j, offset):
160 """
161 Adds a bond to the list of graph edges.
163 Graph components are merged if the bond does not cross a cell boundary.
164 Bonds which cross cell boundaries can inappropriately connect
165 components which are not connected in the infinite crystal. This is
166 tested during graph traversal.
168 Parameters:
170 i: int The index of the first atom.
171 n: int The index of the second atom.
172 offset: tuple The cell offset of the second atom.
173 """
174 roffset = tuple(-np.array(offset))
176 if offset == (0, 0, 0): # only want bonds in aperiodic unit cell
177 self.graph.union(i, j)
178 else:
179 self.bonds += [(i, j, offset)]
180 self.bonds += [(j, i, roffset)]
182 def check(self):
183 """
184 Determines the dimensionality histogram.
186 The component graph is traversed (using BFS) until the matrix rank
187 of the subspace spanned by the visited components no longer increases.
189 Returns:
190 hist : tuple Dimensionality histogram.
191 """
192 adjacency = build_adjacency_list(self.graph.find_all(),
193 self.bonds)
194 if adjacency == self.adjacency:
195 return self.hcached
197 self.adjacency = adjacency
198 self.all_visited, self.ranks = traverse_component_graphs(adjacency)
199 res = merge_mutual_visits(self.all_visited, self.ranks, self.graph)
200 _, self.all_visited, self.ranks = res
202 self.roots = np.unique(self.graph.find_all())
203 h = get_dimensionality_histogram(self.ranks, self.roots)
204 self.hcached = h
205 return h
207 def get_components(self):
208 """
209 Determines the dimensionality and constituent atoms of each component.
211 Returns:
212 components: array The component ID of every atom
213 """
214 component_dim = {e: self.ranks[e] for e in self.roots}
215 relabelled_components = self.graph.find_all(relabel=True)
216 relabelled_dim = {
217 relabelled_components[k]: v for k, v in component_dim.items()
218 }
219 self.cdim_cached = relabelled_dim
220 self.components_cached = relabelled_components
222 return relabelled_components, relabelled_dim