Add union_find from PADS
[persistence.git] / union_find.py
1 """UnionFind.py
2
3 Union-find data structure. Based on Josiah Carlson's code,
4 http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912
5 with significant additional changes by D. Eppstein.
6 """
7
8
9 class UnionFind:
10
11 """Union-find data structure.
12
13 Each unionFind instance X maintains a family of disjoint sets of
14 hashable objects, supporting the following two methods:
15
16 - X[item] returns a name for the set containing the given item.
17 Each set is named by an arbitrarily-chosen one of its members; as
18 long as the set remains unchanged it will keep the same name. If
19 the item is not yet part of a set in X, a new singleton set is
20 created for it.
21
22 - X.union(item1, item2, ...) merges the sets containing each item
23 into a single larger set. If any item is not yet part of a set
24 in X, it is added to X as one of the members of the merged set.
25 """
26
27 def __init__(self):
28 """Create a new empty union-find structure."""
29 self.weights = {}
30 self.parents = {}
31
32 def __getitem__(self, object):
33 """Find and return the name of the set containing the object."""
34
35 # check for previously unknown object
36 if object not in self.parents:
37 self.parents[object] = object
38 self.weights[object] = 1
39 return object
40
41 # find path of objects leading to the root
42 path = [object]
43 root = self.parents[object]
44 while root != path[-1]:
45 path.append(root)
46 root = self.parents[root]
47
48 # compress the path and return
49 for ancestor in path:
50 self.parents[ancestor] = root
51 return root
52
53 def __iter__(self):
54 """Iterate through all items ever found or unioned by this structure.
55
56 """
57 return iter(self.parents)
58
59 def union(self, *objects):
60 """Find the sets containing the objects and merge them all."""
61 roots = [self[x] for x in objects]
62 heaviest = max([(self.weights[r], r) for r in roots])[1]
63 for r in roots:
64 if r != heaviest:
65 self.weights[heaviest] += self.weights[r]
66 self.parents[r] = heaviest