Coverage for /builds/ase/ase/ase/geometry/dimensionality/disjoint_set.py: 100.00%
41 statements
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2025-08-02 00:12 +0000
1# fmt: off
3import numpy as np
6class DisjointSet:
8 def __init__(self, n):
9 self.sizes = np.ones(n, dtype=int)
10 self.parents = np.arange(n)
11 self.nc = n
13 def _compress(self):
14 a = self.parents
15 b = a[a]
16 while (a != b).any():
17 a = b
18 b = a[a]
19 self.parents = a
21 def union(self, a, b):
22 a = self.find(a)
23 b = self.find(b)
24 if a == b:
25 return False
27 sizes = self.sizes
28 parents = self.parents
29 if sizes[a] < sizes[b]:
30 parents[a] = b
31 sizes[b] += sizes[a]
32 else:
33 parents[b] = a
34 sizes[a] += sizes[b]
36 self.nc -= 1
37 return True
39 def find(self, index):
40 parents = self.parents
41 parent = parents[index]
42 while parent != parents[parent]:
43 parent = parents[parent]
44 parents[index] = parent
45 return parent
47 def find_all(self, relabel=False):
48 self._compress()
49 if not relabel:
50 return self.parents
52 # order elements by frequency
53 _unique, inverse, counts = np.unique(self.parents,
54 return_inverse=True,
55 return_counts=True)
56 indices = np.argsort(counts, kind='merge')[::-1]
57 return np.argsort(indices)[inverse]