ipynb: Update
[persistence.git] / union_find.py
1 """UnionFind.py
2
3 Union-find data structure. Based on Josiah Carlson's code,
4 http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/215912
5 with significant additional changes by D. Eppstein.
6 """
7
8
9 class UnionFind:
10
11 """Union-find data structure.
12
13 Each unionFind instance X maintains a family of disjoint sets of
14 hashable objects, supporting the following two methods:
15
16 - X[item] returns a name for the set containing the given item.
17 Each set is named by an arbitrarily-chosen one of its members; as
18 long as the set remains unchanged it will keep the same name. If
19 the item is not yet part of a set in X, a new singleton set is
20 created for it.
21
22 - X.union(item1, item2, ...) merges the sets containing each item
23 into a single larger set. If any item is not yet part of a set
24 in X, it is added to X as one of the members of the merged set.
25 """
26
27 def __init__(self):
28 """Create a new empty union-find structure."""
29 self.weights = {}
30 self.parents = {}
31
32 def add(self, object, weight):
33 if object not in self.parents:
34 self.parents[object] = object
35 self.weights[object] = weight
36
37 def __contains__(self, object):
38 return object in self.parents
39
40 def __getitem__(self, object):
41 """Find and return the name of the set containing the object."""
42
43 # check for previously unknown object
44 if object not in self.parents:
45 assert(False)
46 self.parents[object] = object
47 self.weights[object] = 1
48 return object
49
50 # find path of objects leading to the root
51 path = [object]
52 root = self.parents[object]
53 while root != path[-1]:
54 path.append(root)
55 root = self.parents[root]
56
57 # compress the path and return
58 for ancestor in path:
59 self.parents[ancestor] = root
60 return root
61
62 def __iter__(self):
63 """Iterate through all items ever found or unioned by this structure.
64
65 """
66 return iter(self.parents)
67
68 def union(self, *objects):
69 """Find the sets containing the objects and merge them all."""
70 roots = [self[x] for x in objects]
71 heaviest = max([(self.weights[r], r) for r in roots])[1]
72 for r in roots:
73 if r != heaviest:
74 self.parents[r] = heaviest