X-Git-Url: https://git.sthu.org/?p=libstick.git;a=blobdiff_plain;f=include%2Flibstick-0.1%2Fbooleanmatrix.h;h=594eb92c8a410bc749042afc541d572fea12994e;hp=d98e7b52cb5174795563ec0c785863dbd514305a;hb=54836e4fa90f52c9682c7918c47b28177e6f4170;hpb=c24df30e6826f1eca6444681c09723198bc675b7 diff --git a/include/libstick-0.1/booleanmatrix.h b/include/libstick-0.1/booleanmatrix.h index d98e7b5..594eb92 100644 --- a/include/libstick-0.1/booleanmatrix.h +++ b/include/libstick-0.1/booleanmatrix.h @@ -11,6 +11,7 @@ #include #include +#include "booleanvector.h" namespace libstick { @@ -26,7 +27,7 @@ class boolean_colmatrix_base { public: typedef IT index_type; - typedef std::vector column_type; + typedef sorted_boolean_vector column_type; typedef D derived; protected: @@ -60,8 +61,7 @@ class boolean_colmatrix_base { /** Get the matrix entry at row 'r' and column 'c'. */ bool get(index_type r, index_type c) const { assert(c < width()); - const column_type &col = get_column(c); - return binary_search(col.begin(), col.end(), r); + return get_column(c).get(r); } /** Set the matrix entry at row 'r' and column 'c'. */ @@ -93,7 +93,8 @@ class boolean_colmatrix_base { assert(c < width()); // Flip all entries that are set in 'col'. - for (typename column_type::const_iterator it = col.begin(); it != col.end(); ++it) + for (typename column_type::indexarray::const_iterator it = col.get_ones().begin(); + it != col.get_ones().end(); ++it) set(*it, c, !get(*it, c)); } @@ -116,7 +117,11 @@ class boolean_colmatrix_base { os << "{"; for (unsigned c=0; c < width(); ++c) { const column_type &col = get_column(c); - for (typename column_type::const_iterator it = col.begin(); it != col.end(); ++it) { + const typename column_type::indexarray &ones = col.get_ones(); + typename column_type::indexarray::iterator it = ones.begin(); + + for (typename column_type::indexarray::iterator it = ones.begin(); + it != ones.end(); ++it) { if (first) first = false; else @@ -131,32 +136,7 @@ class boolean_colmatrix_base { /** Set the matrix entry at row 'r' and column 'c'. */ void _set(index_type r, index_type c, bool value) { assert(c < width()); - - column_type &col = cols.at(c); - // Let us see where to insert the new element - typename column_type::iterator it = lower_bound(col.begin(), col.end(), r); - bool exists = (it != col.end() && *it == r); - assert(get(r,c) == exists); - - // Add 'r' to c-th column - if (value) { - // r is new, insert it - if (!exists) - col.insert(it, r); - assert(get(r,c)); - } - // Remove the element - else { - if (exists) - col.erase(it); - assert(!get(r,c)); - } - -#ifndef NDEBUG - // C++11 would have is_sorted - for (unsigned i=1; i < col.size(); i++) - assert(col[i-1] < col[i]); -#endif + cols.at(c).set(r, value); } protected: @@ -191,36 +171,13 @@ class boolean_colmatrix : public boolean_colmatrix_base row_type; + typedef sorted_boolean_vector row_type; typedef D derived; protected: @@ -250,8 +207,7 @@ class boolean_rowmatrix_base { /** Get the matrix entry at row 'r' and column 'c'. */ bool get(index_type r, index_type c) const { assert(r < height()); - const row_type &row = get_row(r); - return binary_search(row.begin(), row.end(), c); + return get_row(r).get(c); } /** Set the matrix entry at row 'r' and column 'c'. */ @@ -278,7 +234,8 @@ class boolean_rowmatrix_base { assert(r < height()); // Flip all entries that are set in 'row'. - for (typename row_type::const_iterator it = row.begin(); it != row.end(); ++it) + for (typename row_type::indexarray::const_iterator it = row.get_ones().begin(); + it != row.get_ones().end(); ++it) set(r, *it, !get(r, *it)); } @@ -296,32 +253,7 @@ class boolean_rowmatrix_base { /** Set the matrix entry at row 'r' and column 'c'. */ void _set(index_type r, index_type c, bool value) { assert(r < height()); - - row_type &row = rows.at(r); - // Let us see where to insert/remove the new element - typename row_type::iterator it = lower_bound(row.begin(), row.end(), c); - bool exists = (it != row.end() && *it == c); - assert(get(r,c) == exists); - - // Add 'r' to c-th column - if (value) { - // r is new, insert it - if (!exists) - row.insert(it, c); - assert(get(r,c)); - } - // Remove the element - else { - if (exists) - row.erase(it); - assert(!get(r,c)); - } - -#ifndef NDEBUG - // C++11 would have is_sorted - for (unsigned i=1; i < row.size(); i++) - assert(row[i-1] < row[i]); -#endif + rows.at(r).set(c, value); } protected: @@ -383,12 +315,12 @@ class boolean_colrowmatrix : public boolean_colmatrix_base -size_t count_set_intersection (InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2) +size_t count_set_intersection(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2) { size_t count = 0; @@ -486,7 +418,9 @@ void multiply_matrix(RT &result, const boolean_rowmatrix_base &a, const const typename boolean_rowmatrix_base::row_type &row = a.get_row(r); for (unsigned c=0; c < b.width(); ++c) { const typename boolean_colmatrix_base::column_type &col = b.get_column(c); - if (count_set_intersection(row.begin(), row.end(), col.begin(), col.end()) % 2 == 1) + if (count_set_intersection( + row.get_ones().begin(), row.get_ones().end(), + col.get_ones().begin(), col.get_ones().end()) % 2 == 1) result.set(r, c, true); } }