projects
/
persistence.git
/ blobdiff
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
union_find: Do not cummulate weight
[persistence.git]
/
union_find.py
diff --git
a/union_find.py
b/union_find.py
index 320bba116ac0c1dd1dc890b80dcd104bf63067ae..0cd86044527a6d0d4be02c011ec08c17d700707d 100644
(file)
--- a/
union_find.py
+++ b/
union_find.py
@@
-29,11
+29,20
@@
class UnionFind:
self.weights = {}
self.parents = {}
self.weights = {}
self.parents = {}
+ def add(self, object, weight):
+ if object not in self.parents:
+ self.parents[object] = object
+ self.weights[object] = weight
+
+ def __contains__(self, object):
+ return object in self.parents
+
def __getitem__(self, object):
"""Find and return the name of the set containing the object."""
# check for previously unknown object
if object not in self.parents:
def __getitem__(self, object):
"""Find and return the name of the set containing the object."""
# check for previously unknown object
if object not in self.parents:
+ assert(False)
self.parents[object] = object
self.weights[object] = 1
return object
self.parents[object] = object
self.weights[object] = 1
return object
@@
-62,5
+71,4
@@
class UnionFind:
heaviest = max([(self.weights[r], r) for r in roots])[1]
for r in roots:
if r != heaviest:
heaviest = max([(self.weights[r], r) for r in roots])[1]
for r in roots:
if r != heaviest:
- self.weights[heaviest] += self.weights[r]
self.parents[r] = heaviest
self.parents[r] = heaviest