This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new d6fe87e6cc [SystemDS-3750] Python API Builtin prod
d6fe87e6cc is described below

commit d6fe87e6ccbcb363eca92ac2f52d3aea25fe6559
Author: e-strauss <92718421+e-stra...@users.noreply.github.com>
AuthorDate: Wed Sep 4 19:26:18 2024 +0200

    [SystemDS-3750] Python API Builtin prod
    
    This commit also contains missing error case tests
    for mean, var, min, and max aggregate operations.
    
    Closes #2095
---
 src/main/python/systemds/operator/nodes/matrix.py | 16 +++++++++
 src/main/python/tests/matrix/test_aggregations.py | 42 +++++++++++++++++++++++
 2 files changed, 58 insertions(+)

diff --git a/src/main/python/systemds/operator/nodes/matrix.py 
b/src/main/python/systemds/operator/nodes/matrix.py
index fafb815ca4..d8132829ae 100644
--- a/src/main/python/systemds/operator/nodes/matrix.py
+++ b/src/main/python/systemds/operator/nodes/matrix.py
@@ -238,6 +238,22 @@ class Matrix(OperationNode):
             f"Axis has to be either 0, 1 or None, for column, row or complete 
{self.operation}"
         )
 
+    def prod(self, axis: int = None) -> "OperationNode":
+        """Calculate product of cells in matrix.
+
+        :param axis: can be 0 or 1 to do either row or column sums
+        :return: `Matrix` representing operation
+        """
+        if axis == 0:
+            return Matrix(self.sds_context, "colProds", [self])
+        elif axis == 1:
+            return Matrix(self.sds_context, "rowProds", [self])
+        elif axis is None:
+            return Scalar(self.sds_context, "prod", [self])
+        raise ValueError(
+            f"Axis has to be either 0, 1 or None, for column, row or complete 
{self.operation}"
+        )
+
     def mean(self, axis: int = None) -> "OperationNode":
         """Calculate mean of matrix.
 
diff --git a/src/main/python/tests/matrix/test_aggregations.py 
b/src/main/python/tests/matrix/test_aggregations.py
index d02d5dfb3e..597bcfc9f5 100644
--- a/src/main/python/tests/matrix/test_aggregations.py
+++ b/src/main/python/tests/matrix/test_aggregations.py
@@ -61,6 +61,32 @@ class TestMatrixAggFn(unittest.TestCase):
             )
         )
 
+    def test_sum4(self):
+        with self.assertRaises(ValueError):
+            self.sds.from_numpy(m1).sum(2)
+
+    def test_prod1(self):
+        self.assertTrue(
+            np.allclose(self.sds.from_numpy(m1).prod().compute(), np.prod(m1))
+        )
+
+    def test_prod2(self):
+        self.assertTrue(
+            np.allclose(self.sds.from_numpy(m1).prod(0).compute(), np.prod(m1, 
0))
+        )
+
+    def test_prod3(self):
+        self.assertTrue(
+            np.allclose(
+                self.sds.from_numpy(m1).prod(axis=1).compute(),
+                np.prod(m1, 1).reshape(dim, 1),
+            )
+        )
+
+    def test_prod4(self):
+        with self.assertRaises(ValueError):
+            self.sds.from_numpy(m1).prod(2)
+
     def test_mean1(self):
         self.assertTrue(
             np.allclose(self.sds.from_numpy(m1).mean().compute(), m1.mean())
@@ -79,6 +105,10 @@ class TestMatrixAggFn(unittest.TestCase):
             )
         )
 
+    def test_mean4(self):
+        with self.assertRaises(ValueError):
+            self.sds.from_numpy(m1).mean(2)
+
     def test_full(self):
         self.assertTrue(
             np.allclose(self.sds.full((2, 3), 10.1).compute(), np.full((2, 3), 
10.1))
@@ -109,6 +139,10 @@ class TestMatrixAggFn(unittest.TestCase):
             )
         )
 
+    def test_var4(self):
+        with self.assertRaises(ValueError):
+            self.sds.from_numpy(m1).var(2)
+
     def test_min1(self):
         self.assertTrue(np.allclose(self.sds.from_numpy(m1).min().compute(), 
m1.min()))
 
@@ -125,6 +159,10 @@ class TestMatrixAggFn(unittest.TestCase):
             )
         )
 
+    def test_min4(self):
+        with self.assertRaises(ValueError):
+            self.sds.from_numpy(m1).min(2)
+
     def test_max1(self):
         self.assertTrue(np.allclose(self.sds.from_numpy(m1).max().compute(), 
m1.max()))
 
@@ -141,6 +179,10 @@ class TestMatrixAggFn(unittest.TestCase):
             )
         )
 
+    def test_max4(self):
+        with self.assertRaises(ValueError):
+            self.sds.from_numpy(m1).max(2)
+
     def test_trace1(self):
         self.assertTrue(
             np.allclose(self.sds.from_numpy(m1).trace().compute(), m1.trace())

Reply via email to