This is an automated email from the ASF dual-hosted git repository.

holden pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 92530c7  [SPARK-9792] Make DenseMatrix equality semantical
92530c7 is described below

commit 92530c7db1e5ec2827e4211a2290555c3959397a
Author: Giovanni Lanzani <giova...@lanzani.nl>
AuthorDate: Mon Apr 1 09:30:33 2019 -0700

    [SPARK-9792] Make DenseMatrix equality semantical
    
    Before, you could have this code
    
    ```
    A = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
    B = DenseMatrix(2, 2, [2, 0, 0, 0])
    
    B == A  # False
    A == B  # True
    ```
    
    The second would be `True` as `SparseMatrix` already checks for semantic
    equality. This commit changes `DenseMatrix` so that equality is
    semantical as well.
    
    ## What changes were proposed in this pull request?
    
    Better semantic equality for DenseMatrix
    
    ## How was this patch tested?
    
    Unit tests were added, plus manual testing. Note that the code falls back 
to the old behavior when `other` is not a SparseMatrix.
    
    Closes #17968 from gglanzani/SPARK-9792.
    
    Authored-by: Giovanni Lanzani <giova...@lanzani.nl>
    Signed-off-by: Holden Karau <hol...@pigscanfly.ca>
---
 python/pyspark/ml/linalg/__init__.py      | 8 ++++----
 python/pyspark/ml/tests/test_linalg.py    | 6 ++++++
 python/pyspark/mllib/linalg/__init__.py   | 8 ++++----
 python/pyspark/mllib/tests/test_linalg.py | 6 ++++++
 4 files changed, 20 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/ml/linalg/__init__.py 
b/python/pyspark/ml/linalg/__init__.py
index 9da9836..f99161c 100644
--- a/python/pyspark/ml/linalg/__init__.py
+++ b/python/pyspark/ml/linalg/__init__.py
@@ -980,14 +980,14 @@ class DenseMatrix(Matrix):
             return self.values[i + j * self.numRows]
 
     def __eq__(self, other):
-        if (not isinstance(other, DenseMatrix) or
-                self.numRows != other.numRows or
-                self.numCols != other.numCols):
+        if (self.numRows != other.numRows or self.numCols != other.numCols):
             return False
+        if isinstance(other, SparseMatrix):
+            return np.all(self.toArray() == other.toArray())
 
         self_values = np.ravel(self.toArray(), order='F')
         other_values = np.ravel(other.toArray(), order='F')
-        return all(self_values == other_values)
+        return np.all(self_values == other_values)
 
 
 class SparseMatrix(Matrix):
diff --git a/python/pyspark/ml/tests/test_linalg.py 
b/python/pyspark/ml/tests/test_linalg.py
index 995bc35..0c25e2b 100644
--- a/python/pyspark/ml/tests/test_linalg.py
+++ b/python/pyspark/ml/tests/test_linalg.py
@@ -112,11 +112,17 @@ class VectorTests(MLlibTestCase):
         v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
         v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
         v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+        dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
+        sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
         self.assertEqual(v1, v2)
         self.assertEqual(v1, v3)
         self.assertFalse(v2 == v4)
         self.assertFalse(v1 == v5)
         self.assertFalse(v1 == v6)
+        # this is done as Dense and Sparse matrices can be semantically
+        # equal while still implementing a different __eq__ method
+        self.assertEqual(dm1, sm1)
+        self.assertEqual(sm1, dm1)
 
     def test_equals(self):
         indices = [1, 2, 4]
diff --git a/python/pyspark/mllib/linalg/__init__.py 
b/python/pyspark/mllib/linalg/__init__.py
index 94a3e2a..df411d7 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -1135,14 +1135,14 @@ class DenseMatrix(Matrix):
             return self.values[i + j * self.numRows]
 
     def __eq__(self, other):
-        if (not isinstance(other, DenseMatrix) or
-                self.numRows != other.numRows or
-                self.numCols != other.numCols):
+        if (self.numRows != other.numRows or self.numCols != other.numCols):
             return False
+        if isinstance(other, SparseMatrix):
+            return np.all(self.toArray() == other.toArray())
 
         self_values = np.ravel(self.toArray(), order='F')
         other_values = np.ravel(other.toArray(), order='F')
-        return all(self_values == other_values)
+        return np.all(self_values == other_values)
 
 
 class SparseMatrix(Matrix):
diff --git a/python/pyspark/mllib/tests/test_linalg.py 
b/python/pyspark/mllib/tests/test_linalg.py
index f26e28d..703aed2 100644
--- a/python/pyspark/mllib/tests/test_linalg.py
+++ b/python/pyspark/mllib/tests/test_linalg.py
@@ -115,11 +115,17 @@ class VectorTests(MLlibTestCase):
         v4 = SparseVector(6, [(1, 1.0), (3, 5.5)])
         v5 = DenseVector([0.0, 1.0, 0.0, 2.5])
         v6 = SparseVector(4, [(1, 1.0), (3, 2.5)])
+        dm1 = DenseMatrix(2, 2, [2, 0, 0, 0])
+        sm1 = SparseMatrix(2, 2, [0, 2, 3], [0], [2])
         self.assertEqual(v1, v2)
         self.assertEqual(v1, v3)
         self.assertFalse(v2 == v4)
         self.assertFalse(v1 == v5)
         self.assertFalse(v1 == v6)
+        # this is done as Dense and Sparse matrices can be semantically
+        # equal while still implementing a different __eq__ method
+        self.assertEqual(dm1, sm1)
+        self.assertEqual(sm1, dm1)
 
     def test_equals(self):
         indices = [1, 2, 4]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to