projects
/
persistence.git
/ commitdiff
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
4e47641
)
union_find: Do not cummulate weight
author
Stefan Huber <shuber@sthu.org>
Wed, 23 May 2018 17:03:51 +0000
(19:03 +0200)
committer
Stefan Huber <shuber@sthu.org>
Wed, 23 May 2018 17:03:51 +0000
(19:03 +0200)
union_find.py
patch
|
blob
|
history
diff --git
a/union_find.py
b/union_find.py
index 320bba116ac0c1dd1dc890b80dcd104bf63067ae..0cd86044527a6d0d4be02c011ec08c17d700707d 100644
(file)
--- a/
union_find.py
+++ b/
union_find.py
@@
-29,11
+29,20
@@
class UnionFind:
self.weights = {}
self.parents = {}
self.weights = {}
self.parents = {}
+ def add(self, object, weight):
+ if object not in self.parents:
+ self.parents[object] = object
+ self.weights[object] = weight
+
+ def __contains__(self, object):
+ return object in self.parents
+
def __getitem__(self, object):
"""Find and return the name of the set containing the object."""
# check for previously unknown object
if object not in self.parents:
def __getitem__(self, object):
"""Find and return the name of the set containing the object."""
# check for previously unknown object
if object not in self.parents:
+ assert(False)
self.parents[object] = object
self.weights[object] = 1
return object
self.parents[object] = object
self.weights[object] = 1
return object
@@
-62,5
+71,4
@@
class UnionFind:
heaviest = max([(self.weights[r], r) for r in roots])[1]
for r in roots:
if r != heaviest:
heaviest = max([(self.weights[r], r) for r in roots])[1]
for r in roots:
if r != heaviest:
- self.weights[heaviest] += self.weights[r]
self.parents[r] = heaviest
self.parents[r] = heaviest