This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new b492ac4  [SYSTEMDS-2711] Fix Python One hot encode
b492ac4 is described below

commit b492ac4fc18092f0f591a83a50a3d7e0b46fb1b1
Author: baunsgaard <[email protected]>
AuthorDate: Mon Nov 2 09:08:33 2020 +0100

    [SYSTEMDS-2711] Fix Python One hot encode
    
    This commit fixes one hot encode in python API, where previously
    the One hot encode only allowed vector input. This in turn produce
    errors when the input was column vectors, that are encoded as a two
    dimensional data structure.
---
 src/main/python/systemds/operator/operation_node.py |  4 ++--
 src/main/python/tests/matrix/test_to_one_hot.py     | 15 +++++++++++++++
 2 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/src/main/python/systemds/operator/operation_node.py 
b/src/main/python/systemds/operator/operation_node.py
index 96ac5cf..40e4e68 100644
--- a/src/main/python/systemds/operator/operation_node.py
+++ b/src/main/python/systemds/operator/operation_node.py
@@ -480,9 +480,9 @@ class OperationNode(DAGNode):
         """
 
         self._check_matrix_op()
-        if len(self.shape) != 1:
+        if len(self.shape) == 2 and self.shape[1] != 1:
             raise ValueError(
-                "Only Matrixes  with a single column or row is valid in One 
Hot, " + str(self.shape) + " is invalid")
+                "Only Matrixes with a single column is valid in One Hot, " + 
str(self.shape) + " is invalid")
 
         if num_classes < 2:
             raise ValueError("Number of classes should be larger than 1")
diff --git a/src/main/python/tests/matrix/test_to_one_hot.py 
b/src/main/python/tests/matrix/test_to_one_hot.py
index e660905..95d5bc6 100644
--- a/src/main/python/tests/matrix/test_to_one_hot.py
+++ b/src/main/python/tests/matrix/test_to_one_hot.py
@@ -74,6 +74,21 @@ class TestMatrixOneHot(unittest.TestCase):
     #     with self.assertRaises(ValueError) as context:
     #         res = Matrix(self.sds, m1).to_one_hot(2).compute()
 
+    def test_one_hot_matrix_1(self):
+        m1 = np.array([[1],[2],[3]])
+        res = Matrix(self.sds, m1).to_one_hot(3).compute()
+        self.assertTrue((res == [[1,0,0], [0,1,0], [0,0,1]]).all())
+    
+    def test_one_hot_matrix_2(self):
+        m1 = np.array([[1],[3],[3]])
+        res = Matrix(self.sds, m1).to_one_hot(3).compute()
+        self.assertTrue((res == [[1,0,0], [0,0,1], [0,0,1]]).all())
+
+    def test_one_hot_matrix_3(self):
+        m1 = np.array([[1],[2],[1]])
+        res = Matrix(self.sds, m1).to_one_hot(2).compute()
+        self.assertTrue((res == [[1,0], [0,1], [1,0]]).all())
+
     def test_neg_one_hot_numClasses(self):
         m1 = np.array([1])
         with self.assertRaises(ValueError) as context:

Reply via email to