Coverage for ase / geometry / dimensionality / rank_determination.py: 99.15%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-30 08:22 +0000

1# fmt: off 

2 

3""" 

4Implements the Rank Determination Algorithm (RDA) 

5 

6Method is described in: 

7Definition of a scoring parameter to identify low-dimensional materials 

8components 

9P.M. Larsen, M. Pandey, M. Strange, and K. W. Jacobsen 

10Phys. Rev. Materials 3 034003, 2019 

11https://doi.org/10.1103/PhysRevMaterials.3.034003 

12""" 

13from collections import defaultdict 

14 

15import numpy as np 

16 

17from ase.geometry.dimensionality.disjoint_set import DisjointSet 

18 

19# Numpy has a large overhead for lots of small vectors. The cross product is 

20# particularly bad. Pure python is a lot faster. 

21 

22 

23def dot_product(A, B): 

24 return sum(a * b for a, b in zip(A, B)) 

25 

26 

27def cross_product(a, b): 

28 return [a[i] * b[j] - a[j] * b[i] for i, j in [(1, 2), (2, 0), (0, 1)]] 

29 

30 

31def subtract(A, B): 

32 return [a - b for a, b in zip(A, B)] 

33 

34 

35def rank_increase(a, b): 

36 if len(a) == 0: 

37 return True 

38 elif len(a) == 1: 

39 return a[0] != b 

40 elif len(a) == 4: 

41 return False 

42 

43 L = a + [b] 

44 w = cross_product(subtract(L[1], L[0]), subtract(L[2], L[0])) 

45 if len(a) == 2: 

46 return any(w) 

47 elif len(a) == 3: 

48 return dot_product(w, subtract(L[3], L[0])) != 0 

49 else: 

50 raise Exception("This shouldn't be possible.") 

51 

52 

53def bfs(adjacency, start): 

54 """Traverse the component graph using BFS. 

55 

56 The graph is traversed until the matrix rank of the subspace spanned by 

57 the visited components no longer increases. 

58 """ 

59 visited = set() 

60 cvisited = defaultdict(list) 

61 queue = [(start, (0, 0, 0))] 

62 while queue: 

63 vertex = queue.pop(0) 

64 if vertex in visited: 

65 continue 

66 

67 visited.add(vertex) 

68 c, p = vertex 

69 if not rank_increase(cvisited[c], p): 

70 continue 

71 

72 cvisited[c].append(p) 

73 

74 for nc, offset in adjacency[c]: 

75 

76 nbrpos = (p[0] + offset[0], p[1] + offset[1], p[2] + offset[2]) 

77 nbrnode = (nc, nbrpos) 

78 if nbrnode in visited: 

79 continue 

80 

81 if rank_increase(cvisited[nc], nbrpos): 

82 queue.append(nbrnode) 

83 

84 return visited, len(cvisited[start]) - 1 

85 

86 

87def traverse_component_graphs(adjacency): 

88 vertices = adjacency.keys() 

89 all_visited = {} 

90 ranks = {} 

91 for v in vertices: 

92 visited, rank = bfs(adjacency, v) 

93 all_visited[v] = visited 

94 ranks[v] = rank 

95 

96 return all_visited, ranks 

97 

98 

99def build_adjacency_list(parents, bonds): 

100 graph = np.unique(parents) 

101 adjacency = {e: set() for e in graph} 

102 for (i, j, offset) in bonds: 

103 component_a = parents[i] 

104 component_b = parents[j] 

105 adjacency[component_a].add((component_b, offset)) 

106 return adjacency 

107 

108 

109def get_dimensionality_histogram(ranks, roots): 

110 h = [0, 0, 0, 0] 

111 for e in roots: 

112 h[ranks[e]] += 1 

113 return tuple(h) 

114 

115 

116def merge_mutual_visits(all_visited, ranks, graph): 

117 """Find components with mutual visits and merge them.""" 

118 merged = False 

119 common = defaultdict(list) 

120 for b, visited in all_visited.items(): 

121 for offset in visited: 

122 for a in common[offset]: 

123 assert ranks[a] == ranks[b] 

124 merged |= graph.union(a, b) 

125 common[offset].append(b) 

126 

127 if not merged: 

128 return merged, all_visited, ranks 

129 

130 merged_visits = defaultdict(set) 

131 merged_ranks = {} 

132 parents = graph.find_all() 

133 for k, v in all_visited.items(): 

134 key = parents[k] 

135 merged_visits[key].update(v) 

136 merged_ranks[key] = ranks[key] 

137 return merged, merged_visits, merged_ranks 

138 

139 

140class RDA: 

141 

142 def __init__(self, num_atoms): 

143 """ 

144 Initializes the RDA class. 

145 

146 A disjoint set is used to maintain the component graph. 

147 

148 Parameters 

149 ---------- 

150 

151 num_atoms: int The number of atoms in the unit cell. 

152 """ 

153 self.bonds = [] 

154 self.graph = DisjointSet(num_atoms) 

155 self.adjacency = None 

156 self.hcached = None 

157 self.components_cached = None 

158 self.cdim_cached = None 

159 

160 def insert_bond(self, i, j, offset): 

161 """ 

162 Adds a bond to the list of graph edges. 

163 

164 Graph components are merged if the bond does not cross a cell boundary. 

165 Bonds which cross cell boundaries can inappropriately connect 

166 components which are not connected in the infinite crystal. This is 

167 tested during graph traversal. 

168 

169 Parameters 

170 ---------- 

171 

172 i: int The index of the first atom. 

173 n: int The index of the second atom. 

174 offset: tuple The cell offset of the second atom. 

175 """ 

176 roffset = tuple(-np.array(offset)) 

177 

178 if offset == (0, 0, 0): # only want bonds in aperiodic unit cell 

179 self.graph.union(i, j) 

180 else: 

181 self.bonds += [(i, j, offset)] 

182 self.bonds += [(j, i, roffset)] 

183 

184 def check(self): 

185 """ 

186 Determines the dimensionality histogram. 

187 

188 The component graph is traversed (using BFS) until the matrix rank 

189 of the subspace spanned by the visited components no longer increases. 

190 

191 Returns 

192 ------- 

193 hist : tuple Dimensionality histogram. 

194 """ 

195 adjacency = build_adjacency_list(self.graph.find_all(), 

196 self.bonds) 

197 if adjacency == self.adjacency: 

198 return self.hcached 

199 

200 self.adjacency = adjacency 

201 self.all_visited, self.ranks = traverse_component_graphs(adjacency) 

202 res = merge_mutual_visits(self.all_visited, self.ranks, self.graph) 

203 _, self.all_visited, self.ranks = res 

204 

205 self.roots = np.unique(self.graph.find_all()) 

206 h = get_dimensionality_histogram(self.ranks, self.roots) 

207 self.hcached = h 

208 return h 

209 

210 def get_components(self): 

211 """ 

212 Determines the dimensionality and constituent atoms of each component. 

213 

214 Returns 

215 ------- 

216 components: array The component ID of every atom 

217 """ 

218 component_dim = {e: self.ranks[e] for e in self.roots} 

219 relabelled_components = self.graph.find_all(relabel=True) 

220 relabelled_dim = { 

221 relabelled_components[k]: v for k, v in component_dim.items() 

222 } 

223 self.cdim_cached = relabelled_dim 

224 self.components_cached = relabelled_components 

225 

226 return relabelled_components, relabelled_dim