X-Git-Url: https://git.sthu.org/?p=persistence.git;a=blobdiff_plain;f=union_find.py;h=0cd86044527a6d0d4be02c011ec08c17d700707d;hp=320bba116ac0c1dd1dc890b80dcd104bf63067ae;hb=0693252f1fbdd4b2fc3db010b25305cd55eb1406;hpb=4e476411154b9776e355406ca4911b8b7c448228 diff --git a/union_find.py b/union_find.py index 320bba1..0cd8604 100644 --- a/union_find.py +++ b/union_find.py @@ -29,11 +29,20 @@ class UnionFind: self.weights = {} self.parents = {} + def add(self, object, weight): + if object not in self.parents: + self.parents[object] = object + self.weights[object] = weight + + def __contains__(self, object): + return object in self.parents + def __getitem__(self, object): """Find and return the name of the set containing the object.""" # check for previously unknown object if object not in self.parents: + assert(False) self.parents[object] = object self.weights[object] = 1 return object @@ -62,5 +71,4 @@ class UnionFind: heaviest = max([(self.weights[r], r) for r in roots])[1] for r in roots: if r != heaviest: - self.weights[heaviest] += self.weights[r] self.parents[r] = heaviest