This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 96add83  [SYSTEMDS-3223] Python functions with list arguments
96add83 is described below

commit 96add83f1d0aebfe1744a7c2fd014a2a603a9162
Author: baunsgaard <[email protected]>
AuthorDate: Fri Nov 19 16:44:35 2021 +0100

    [SYSTEMDS-3223] Python functions with list arguments
    
    This commit fixes a bug where the sourced functions would not
    correctly build the scripts in cases where a list is an input to the
    function defined in a sourced script.
    
    Closes #1460
---
 .../python/systemds/context/systemds_context.py    |  7 ++-
 src/main/python/systemds/operator/nodes/list.py    |  5 +-
 src/main/python/systemds/operator/nodes/scalar.py  |  2 +-
 src/main/python/systemds/script_building/script.py |  9 ++--
 .../python/tests/source/source_with_list_input.dml | 29 ++++++++++
 src/main/python/tests/source/test_source_list.py   | 63 ++++++++++++++++++++++
 6 files changed, 102 insertions(+), 13 deletions(-)

diff --git a/src/main/python/systemds/context/systemds_context.py 
b/src/main/python/systemds/context/systemds_context.py
index 14f7260..d25805c 100644
--- a/src/main/python/systemds/context/systemds_context.py
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -64,8 +64,7 @@ class SystemDSContext(object):
         if process.poll() is None:
             self.__start_gateway(actual_port)
         else:
-            self.exception_and_close(
-                "Java process stopped before gateway could connect")
+            self.exception_and_close("Java process stopped before gateway 
could connect")
 
     def get_stdout(self, lines: int = -1):
         """Getter for the stdout of the java subprocess
@@ -89,7 +88,7 @@ class SystemDSContext(object):
         else:
             return [self.__stderr.get() for x in range(lines)]
 
-    def exception_and_close(self, exception_str: str, trace_back_limit: int = 
None):
+    def exception_and_close(self, exception, trace_back_limit: int = None):
         """
         Method for printing exception, printing stdout and error, while also 
closing the context correctly.
 
@@ -104,7 +103,7 @@ class SystemDSContext(object):
         if stdErr:
             message += "standard error  :\n" + "\n".join(stdErr)
         message += "\n\n"
-        message += exception_str
+        message += str(exception)
         sys.tracebacklimit = trace_back_limit
         self.close()
         raise RuntimeError(message)
diff --git a/src/main/python/systemds/operator/nodes/list.py 
b/src/main/python/systemds/operator/nodes/list.py
index 6ad69ca..578535e 100644
--- a/src/main/python/systemds/operator/nodes/list.py
+++ b/src/main/python/systemds/operator/nodes/list.py
@@ -76,9 +76,8 @@ class List(OperationNode):
 
     def code_line(self, var_name: str, unnamed_input_vars: Sequence[str],
                   named_input_vars: Dict[str, str]) -> str:
-        inputs_comma_sep = create_params_string(
-            unnamed_input_vars, named_input_vars)
-        return f'{var_name}={self.operation}({inputs_comma_sep});'
+        code_line = super().code_line(var_name, unnamed_input_vars, 
named_input_vars)
+        return code_line
 
     def compute(self, verbose: bool = False, lineage: bool = False) -> 
np.array:
         return super().compute(verbose, lineage)
diff --git a/src/main/python/systemds/operator/nodes/scalar.py 
b/src/main/python/systemds/operator/nodes/scalar.py
index 511ad34..a4d6292 100644
--- a/src/main/python/systemds/operator/nodes/scalar.py
+++ b/src/main/python/systemds/operator/nodes/scalar.py
@@ -57,7 +57,7 @@ class Scalar(OperationNode):
         else:
             return super().code_line(var_name, unnamed_input_vars, 
named_input_vars)
 
-    def compute(self, verbose: bool = False, lineage: bool = False) -> 
Union[np.array]:
+    def compute(self, verbose: bool = False, lineage: bool = False):
         return super().compute(verbose, lineage)
 
     def _parse_output_result_variables(self, result_variables):
diff --git a/src/main/python/systemds/script_building/script.py 
b/src/main/python/systemds/script_building/script.py
index 06753ab..05f1cfe 100644
--- a/src/main/python/systemds/script_building/script.py
+++ b/src/main/python/systemds/script_building/script.py
@@ -200,15 +200,13 @@ class DMLScript:
         # for each node do the dfs operation and save the variable names in 
`input_var_names`
         # get variable names of unnamed parameters
 
-        unnamed_input_vars = [self._dfs_dag_nodes(
-            input_node) for input_node in dag_node.unnamed_input_nodes]
+        unnamed_input_vars = []
+        for un_node in dag_node.unnamed_input_nodes:
+            unnamed_input_vars.append(self._dfs_dag_nodes(un_node))
 
         named_input_vars = {}
         for name, input_node in dag_node.named_input_nodes.items():
             named_input_vars[name] = self._dfs_dag_nodes(input_node)
-            if isinstance(input_node, DAGNode) and input_node._output_type == 
OutputType.LIST:
-                dag_node.dml_name = named_input_vars[name] + name
-                return dag_node.dml_name
 
         # check if the node gets a name after multireturns
         # If it has, great, return that name
@@ -222,6 +220,7 @@ class DMLScript:
 
         code_line = dag_node.code_line(
             dag_node.dml_name, unnamed_input_vars, named_input_vars)
+
         self.add_code(code_line)
         return dag_node.dml_name
 
diff --git a/src/main/python/tests/source/source_with_list_input.dml 
b/src/main/python/tests/source/source_with_list_input.dml
new file mode 100644
index 0000000..2e5a415
--- /dev/null
+++ b/src/main/python/tests/source/source_with_list_input.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+func = function(list[unknown] a) return (matrix[double] b){
+    b = as.matrix(a[1])
+}
+
+func2 = function(list[unknown] a) return (matrix[double] b, matrix[double] c){
+    b = as.matrix(a[1])
+    c = as.matrix(a[2])
+}
diff --git a/src/main/python/tests/source/test_source_list.py 
b/src/main/python/tests/source/test_source_list.py
new file mode 100644
index 0000000..d4ab391
--- /dev/null
+++ b/src/main/python/tests/source/test_source_list.py
@@ -0,0 +1,63 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+
+import numpy as np
+from systemds.context import SystemDSContext
+from systemds.operator.algorithm.builtin.scale import scale
+
+
+class TestSource_01(unittest.TestCase):
+
+    sds: SystemDSContext = None
+    source_path: str = "./tests/source/source_with_list_input.dml"
+
+    @classmethod
+    def setUpClass(cls):
+        cls.sds = SystemDSContext()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.sds.close()
+
+    def test_single_return(self):
+        arr = self.sds.array(self.sds.full((10, 10), 4))
+        c = self.sds.source(self.source_path, "test").func(arr)
+        res = c.sum().compute()
+        self.assertTrue(res == 10*10*4)
+
+    def test_input_multireturn(self):
+        m = self.sds.full((10, 10), 2)
+        [a, b, c] = scale(m, True, True)
+        arr = self.sds.array(a, b, c)
+        c = self.sds.source(self.source_path, "test").func(arr)
+        res = c.sum().compute(verbose=True)
+        self.assertTrue(res == 0)
+
+    # [SYSTEMDS-3224] https://issues.apache.org/jira/browse/SYSTEMDS-3224
+    # def test_multi_return(self):
+    #     arr = self.sds.array(
+    #         self.sds.full((10, 10), 4),
+    #         self.sds.full((3, 3), 5))
+    #     [b, c] = self.sds.source(self.source_path, "test", True).func2(arr)
+    #     res = c.sum().compute()
+    #     self.assertTrue(res == 10*10*4)

Reply via email to