Use boolean_vector in boolean_matrix
[libstick.git] / include / libstick-0.1 / booleanmatrix.h
index 1fb4dd28f6d98e02d70cabb6372e62ddf6d9180a..594eb92c8a410bc749042afc541d572fea12994e 100644 (file)
 #include <ostream>
 #include <iterator>
 
+#include "booleanvector.h"
 
 namespace libstick {
 
 
+template<class T>
+std::ostream& operator<<(std::ostream &, const std::vector<T> &);
+
+
 /** The base class of boolean_colmatrix and boolean_colrowmatrix which implements
  * the common logic of both. */
 template<class IT, class D>
@@ -22,7 +27,7 @@ class boolean_colmatrix_base {
 
     public:
         typedef IT index_type;
-        typedef std::vector<index_type> column_type;
+        typedef sorted_boolean_vector<IT> column_type;
         typedef D derived;
 
     protected:
@@ -32,16 +37,31 @@ class boolean_colmatrix_base {
         }
 
     public:
-        /** Get height resp. width of the matrix. */
+        /** A casting constructor for any colmatrix with same entries type. */
+        template<class D2>
+        boolean_colmatrix_base(const boolean_colmatrix_base<IT, D2> &mat) :
+            cols(mat.get_columns()) {
+        }
+
+        /** Get width of the matrix. */
         size_t width() const {
             return cols.size();
         }
 
+        /** Get height of the matrix, i.e. maximum row-index + 1 among all
+         * columns. */
+        size_t height() const {
+            IT h = 0;
+            for (unsigned c=0; c < width(); ++c)
+                if (cols[c].size() > 0)
+                    h = std::max(h, cols[c].back());
+            return h+1;
+        }
+
         /** Get the matrix entry at row 'r' and column 'c'. */
         bool get(index_type r, index_type c) const {
             assert(c < width());
-            const column_type &col = get_column(c);
-            return binary_search(col.begin(), col.end(), r);
+            return get_column(c).get(r);
         }
 
         /** Set the matrix entry at row 'r' and column 'c'. */
@@ -49,7 +69,8 @@ class boolean_colmatrix_base {
             get_derived()->_set(r, c, value);
         }
 
-        /** For each of the 'count'-many (row-index, column-pair) pair in 'indices', set the specific value. */
+        /** For each of the 'count'-many (row-index, column-pair) pair in
+         * 'indices', set the specific value. */
         void set_all(index_type indices[][2], size_t count, bool value) {
             for (unsigned i=0; i < count; ++i)
                 set(indices[i][0], indices[i][1], value);
@@ -61,23 +82,31 @@ class boolean_colmatrix_base {
             return cols[c];
         }
 
+        /** Get all columns */
+        const std::vector<column_type>& get_columns() const {
+            return cols;
+        }
+
         /** Add the column-vector 'col' to the c-th column. Note that 'col'
          * actually contains the list of row-indices that are 1. */
         void add_column(index_type c, const column_type &col) {
             assert(c < width());
 
             // Flip all entries that are set in 'col'.
-            for (typename column_type::const_iterator it = col.begin(); it != col.end(); ++it)
+            for (typename column_type::indexarray::const_iterator it = col.get_ones().begin();
+                    it != col.get_ones().end(); ++it)
                 set(*it, c, !get(*it, c));
         }
 
         /** Two matrices are equal iff they have the same entries */
-        bool operator==(const boolean_colmatrix_base<IT, D> &m) const {
-            return cols == m.cols;
+        template<class D2>
+        bool operator==(const boolean_colmatrix_base<IT, D2> &m) const {
+            return cols == m.get_columns();
         }
 
         /** Two matrices are equal iff they have the same entries */
-        bool operator!=(const boolean_colmatrix_base<IT, D> &m) const {
+        template<class D2>
+        bool operator!=(const boolean_colmatrix_base<IT, D2> &m) const {
             return !(*this == m);
         }
 
@@ -88,7 +117,11 @@ class boolean_colmatrix_base {
             os << "{";
             for (unsigned c=0; c < width(); ++c) {
                 const column_type &col = get_column(c);
-                for (typename column_type::const_iterator it = col.begin(); it != col.end(); ++it) {
+                const typename column_type::indexarray &ones = col.get_ones();
+                typename column_type::indexarray::iterator it = ones.begin();
+
+                for (typename column_type::indexarray::iterator it = ones.begin();
+                        it != ones.end(); ++it) {
                     if (first)
                         first = false;
                     else
@@ -103,35 +136,10 @@ class boolean_colmatrix_base {
         /** Set the matrix entry at row 'r' and column 'c'. */
         void _set(index_type r, index_type c, bool value) {
             assert(c < width());
-
-            column_type &col = cols.at(c);
-            // Let us see where to insert the new element
-            typename column_type::iterator it = lower_bound(col.begin(), col.end(), r);
-            bool exists = (it != col.end() && *it == r);
-            assert(get(r,c) == exists);
-
-            // Add 'r' to c-th column
-            if (value) {
-                // r is new, insert it
-                if (!exists)
-                    col.insert(it, r);
-                assert(get(r,c));
-            }
-            // Remove the element
-            else {
-                if (exists)
-                    col.erase(it);
-                assert(!get(r,c));
-            }
-
-#ifndef NDEBUG
-            // C++11 would have is_sorted
-            for (unsigned i=1; i < col.size(); i++)
-                assert(col[i-1] < col[i]);
-#endif
+            cols.at(c).set(r, value);
         }
 
-    private:
+    protected:
         /** The matrix is the set of columns. */
         std::vector<column_type> cols;
 
@@ -151,6 +159,7 @@ class boolean_colmatrix : public boolean_colmatrix_base<IT, boolean_colmatrix<IT
     public:
         typedef IT index_type;
         typedef boolean_colmatrix_base<IT, boolean_colmatrix<IT> > base;
+        typedef typename base::column_type column_type;
 
         /** Create a matrix with 'width' columns, initalized with zero entries. */
         boolean_colmatrix(size_t columns) :
@@ -161,6 +170,14 @@ class boolean_colmatrix : public boolean_colmatrix_base<IT, boolean_colmatrix<IT
         void _set(index_type r, index_type c, bool value) {
             base::_set(r, c, value);
         }
+
+        /** A faster implementation of boolean_colmatrix_base::add_column(). */
+        void add_column(index_type c, const column_type &col) {
+            assert(c < base::width());
+            base::cols[c].add(col);
+        }
+
+    private:
 };
 
 
@@ -171,7 +188,7 @@ class boolean_rowmatrix_base {
 
     public:
         typedef IT index_type;
-        typedef std::vector<index_type> row_type;
+        typedef sorted_boolean_vector<IT> row_type;
         typedef D derived;
 
     protected:
@@ -182,7 +199,7 @@ class boolean_rowmatrix_base {
         }
 
     public:
-        /** Get height resp. width of the matrix. */
+        /** Get height of the matrix. */
         size_t height() const {
             return rows.size();
         }
@@ -190,8 +207,7 @@ class boolean_rowmatrix_base {
         /** Get the matrix entry at row 'r' and column 'c'. */
         bool get(index_type r, index_type c) const {
             assert(r < height());
-            const row_type &row = get_row(r);
-            return binary_search(row.begin(), row.end(), c);
+            return get_row(r).get(c);
         }
 
         /** Set the matrix entry at row 'r' and column 'c'. */
@@ -199,7 +215,8 @@ class boolean_rowmatrix_base {
             get_derived()->_set(r, c, value);
         }
 
-        /** For each of the 'count'-many (row-index, column-pair) pair in 'indices', set the specific value. */
+        /** For each of the 'count'-many (row-index, column-pair) pair in
+         * 'indices', set the specific value. */
         void set_all(index_type indices[][2], size_t count, bool value) {
             for (unsigned i=0; i < count; ++i)
                 set(indices[i][0], indices[i][1], value);
@@ -217,7 +234,8 @@ class boolean_rowmatrix_base {
             assert(r < height());
 
             // Flip all entries that are set in 'row'.
-            for (typename row_type::const_iterator it = row.begin(); it != row.end(); ++it)
+            for (typename row_type::indexarray::const_iterator it = row.get_ones().begin();
+                    it != row.get_ones().end(); ++it)
                 set(r, *it, !get(r, *it));
         }
 
@@ -235,35 +253,10 @@ class boolean_rowmatrix_base {
         /** Set the matrix entry at row 'r' and column 'c'. */
         void _set(index_type r, index_type c, bool value) {
             assert(r < height());
-
-            row_type &row = rows.at(r);
-            // Let us see where to insert/remove the new element
-            typename row_type::iterator it = lower_bound(row.begin(), row.end(), c);
-            bool exists = (it != row.end() && *it == c);
-            assert(get(r,c) == exists);
-
-            // Add 'r' to c-th column
-            if (value) {
-                // r is new, insert it
-                if (!exists)
-                    row.insert(it, c);
-                assert(get(r,c));
-            }
-            // Remove the element
-            else {
-                if (exists)
-                    row.erase(it);
-                assert(!get(r,c));
-            }
-
-#ifndef NDEBUG
-            // C++11 would have is_sorted
-            for (unsigned i=1; i < row.size(); i++)
-                assert(row[i-1] < row[i]);
-#endif
+            rows.at(r).set(c, value);
         }
 
-    private:
+    protected:
         derived* get_derived() {
             return static_cast<derived*>(this);
         }
@@ -314,6 +307,23 @@ class boolean_colrowmatrix : public boolean_colmatrix_base<IT, boolean_colrowmat
             rowbase(size) {
         }
 
+        /** Casting a colmatrix into a colrow matrix */
+        template<class D>
+        boolean_colrowmatrix(const boolean_colmatrix_base<IT, D>& mat) :
+            colbase(std::max(mat.width(), mat.height())),
+            rowbase(std::max(mat.width(), mat.height())) {
+            for (unsigned c=0; c < mat.width(); ++c) {
+                const typename colbase::column_type &col = mat.get_column(c);
+                for (unsigned i=0; i < col.size(); ++i)
+                    set(col.get_ones()[i], c, true);
+            }
+            for (unsigned r=0; r < size(); ++r) {
+                const typename rowbase::row_type &row = rowbase::get_row(r);
+                for (unsigned i=0; i < row.size(); ++i)
+                    assert(get(r, row.get_ones()[i]) == true);
+            }
+        }
+
         /** Override implementation. */
         void _set(index_type r, index_type c, bool value) {
             colbase::_set(r, c, value);
@@ -335,7 +345,8 @@ class boolean_colrowmatrix : public boolean_colmatrix_base<IT, boolean_colrowmat
             colbase::set(r, c, value);
         }
 
-        /** For each of the 'count'-many (row-index, column-pair) pair in 'indices', set the specific value. */
+        /** For each of the 'count'-many (row-index, column-pair) pair in
+         * 'indices', set the specific value. */
         void set_all(index_type indices[][2], size_t count, bool value) {
             colbase::set_all(indices, count, value);
         }
@@ -352,6 +363,12 @@ class boolean_colrowmatrix : public boolean_colmatrix_base<IT, boolean_colrowmat
             return colbase::operator==(m);
         }
 
+        /** Two matrices are equal iff they have the same entries */
+        template<class D>
+        bool operator==(const boolean_colmatrix_base<IT, D> &m) const {
+            return colbase::operator==(m);
+        }
+
         /** Two matrices are equal iff they have the same entries */
         bool operator!=(const boolean_colrowmatrix<IT> &m) const {
             return !(*this == m);
@@ -370,12 +387,12 @@ class boolean_colrowmatrix : public boolean_colmatrix_base<IT, boolean_colrowmat
 /** Counts the number of common elements in two sorted vectors. Equal to
  * counting the number of elements given by std::set_intersection. */
 template <class InputIterator1, class InputIterator2>
-size_t count_set_intersection (InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2)
+size_t count_set_intersection(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2 last2)
 {
     size_t count = 0;
 
     // As long as we did not either end, look for common elements
-    while (first1!=last1 && first2!=last2)
+    while (first1 != last1 && first2 != last2)
     {
         if (*first1 < *first2)
             ++first1;
@@ -391,7 +408,8 @@ size_t count_set_intersection (InputIterator1 first1, InputIterator1 last1, Inpu
     return count;
 }
 
-/** Multiply a*b and save the product in 'result'. It is assumed that 'result' is intially empty and has appropriate size. */
+/** Multiply a*b and save the product in 'result'. It is assumed that 'result'
+ * is intially empty and has appropriate size. */
 template<class IT, class D1, class D2, class RT>
 void multiply_matrix(RT &result, const boolean_rowmatrix_base<IT, D1> &a, const boolean_colmatrix_base<IT, D2> &b) {
     assert(a.height() == b.width());
@@ -400,7 +418,9 @@ void multiply_matrix(RT &result, const boolean_rowmatrix_base<IT, D1> &a, const
         const typename boolean_rowmatrix_base<IT, D1>::row_type &row = a.get_row(r);
         for (unsigned c=0; c < b.width(); ++c) {
             const typename boolean_colmatrix_base<IT, D2>::column_type &col = b.get_column(c);
-            if (count_set_intersection(row.begin(), row.end(), col.begin(), col.end()) % 2 == 1)
+            if (count_set_intersection(
+                        row.get_ones().begin(), row.get_ones().end(),
+                        col.get_ones().begin(), col.get_ones().end()) % 2 == 1)
                 result.set(r, c, true);
         }
     }
@@ -524,6 +544,21 @@ std::ostream& operator<<(std::ostream &os, boolean_colmatrix_base<IT, D> &mat) {
     return os << m;
 }
 
+template<class T>
+std::ostream& operator<<(std::ostream& os, const std::vector<T> &vec) {
+    os << "[";
+
+    typename std::vector<T>::const_iterator it = vec.begin();
+    while ( it != vec.end()) {
+        os << *it;
+        if (++it != vec.end())
+            os << " ";
+    }
+
+    os << "]";
+    return os;
+}
+
 }
 
 #endif