This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 96add83 [SYSTEMDS-3223] Python functions with list arguments
96add83 is described below
commit 96add83f1d0aebfe1744a7c2fd014a2a603a9162
Author: baunsgaard <[email protected]>
AuthorDate: Fri Nov 19 16:44:35 2021 +0100
[SYSTEMDS-3223] Python functions with list arguments
This commit fixes a bug where the sourced functions would not
correctly build the scripts in cases where a list is an input to the
function defined in a sourced script.
Closes #1460
---
.../python/systemds/context/systemds_context.py | 7 ++-
src/main/python/systemds/operator/nodes/list.py | 5 +-
src/main/python/systemds/operator/nodes/scalar.py | 2 +-
src/main/python/systemds/script_building/script.py | 9 ++--
.../python/tests/source/source_with_list_input.dml | 29 ++++++++++
src/main/python/tests/source/test_source_list.py | 63 ++++++++++++++++++++++
6 files changed, 102 insertions(+), 13 deletions(-)
diff --git a/src/main/python/systemds/context/systemds_context.py
b/src/main/python/systemds/context/systemds_context.py
index 14f7260..d25805c 100644
--- a/src/main/python/systemds/context/systemds_context.py
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -64,8 +64,7 @@ class SystemDSContext(object):
if process.poll() is None:
self.__start_gateway(actual_port)
else:
- self.exception_and_close(
- "Java process stopped before gateway could connect")
+ self.exception_and_close("Java process stopped before gateway
could connect")
def get_stdout(self, lines: int = -1):
"""Getter for the stdout of the java subprocess
@@ -89,7 +88,7 @@ class SystemDSContext(object):
else:
return [self.__stderr.get() for x in range(lines)]
- def exception_and_close(self, exception_str: str, trace_back_limit: int =
None):
+ def exception_and_close(self, exception, trace_back_limit: int = None):
"""
Method for printing exception, printing stdout and error, while also
closing the context correctly.
@@ -104,7 +103,7 @@ class SystemDSContext(object):
if stdErr:
message += "standard error :\n" + "\n".join(stdErr)
message += "\n\n"
- message += exception_str
+ message += str(exception)
sys.tracebacklimit = trace_back_limit
self.close()
raise RuntimeError(message)
diff --git a/src/main/python/systemds/operator/nodes/list.py
b/src/main/python/systemds/operator/nodes/list.py
index 6ad69ca..578535e 100644
--- a/src/main/python/systemds/operator/nodes/list.py
+++ b/src/main/python/systemds/operator/nodes/list.py
@@ -76,9 +76,8 @@ class List(OperationNode):
def code_line(self, var_name: str, unnamed_input_vars: Sequence[str],
named_input_vars: Dict[str, str]) -> str:
- inputs_comma_sep = create_params_string(
- unnamed_input_vars, named_input_vars)
- return f'{var_name}={self.operation}({inputs_comma_sep});'
+ code_line = super().code_line(var_name, unnamed_input_vars,
named_input_vars)
+ return code_line
def compute(self, verbose: bool = False, lineage: bool = False) ->
np.array:
return super().compute(verbose, lineage)
diff --git a/src/main/python/systemds/operator/nodes/scalar.py
b/src/main/python/systemds/operator/nodes/scalar.py
index 511ad34..a4d6292 100644
--- a/src/main/python/systemds/operator/nodes/scalar.py
+++ b/src/main/python/systemds/operator/nodes/scalar.py
@@ -57,7 +57,7 @@ class Scalar(OperationNode):
else:
return super().code_line(var_name, unnamed_input_vars,
named_input_vars)
- def compute(self, verbose: bool = False, lineage: bool = False) ->
Union[np.array]:
+ def compute(self, verbose: bool = False, lineage: bool = False):
return super().compute(verbose, lineage)
def _parse_output_result_variables(self, result_variables):
diff --git a/src/main/python/systemds/script_building/script.py
b/src/main/python/systemds/script_building/script.py
index 06753ab..05f1cfe 100644
--- a/src/main/python/systemds/script_building/script.py
+++ b/src/main/python/systemds/script_building/script.py
@@ -200,15 +200,13 @@ class DMLScript:
# for each node do the dfs operation and save the variable names in
`input_var_names`
# get variable names of unnamed parameters
- unnamed_input_vars = [self._dfs_dag_nodes(
- input_node) for input_node in dag_node.unnamed_input_nodes]
+ unnamed_input_vars = []
+ for un_node in dag_node.unnamed_input_nodes:
+ unnamed_input_vars.append(self._dfs_dag_nodes(un_node))
named_input_vars = {}
for name, input_node in dag_node.named_input_nodes.items():
named_input_vars[name] = self._dfs_dag_nodes(input_node)
- if isinstance(input_node, DAGNode) and input_node._output_type ==
OutputType.LIST:
- dag_node.dml_name = named_input_vars[name] + name
- return dag_node.dml_name
# check if the node gets a name after multireturns
# If it has, great, return that name
@@ -222,6 +220,7 @@ class DMLScript:
code_line = dag_node.code_line(
dag_node.dml_name, unnamed_input_vars, named_input_vars)
+
self.add_code(code_line)
return dag_node.dml_name
diff --git a/src/main/python/tests/source/source_with_list_input.dml
b/src/main/python/tests/source/source_with_list_input.dml
new file mode 100644
index 0000000..2e5a415
--- /dev/null
+++ b/src/main/python/tests/source/source_with_list_input.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+func = function(list[unknown] a) return (matrix[double] b){
+ b = as.matrix(a[1])
+}
+
+func2 = function(list[unknown] a) return (matrix[double] b, matrix[double] c){
+ b = as.matrix(a[1])
+ c = as.matrix(a[2])
+}
diff --git a/src/main/python/tests/source/test_source_list.py
b/src/main/python/tests/source/test_source_list.py
new file mode 100644
index 0000000..d4ab391
--- /dev/null
+++ b/src/main/python/tests/source/test_source_list.py
@@ -0,0 +1,63 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+
+import numpy as np
+from systemds.context import SystemDSContext
+from systemds.operator.algorithm.builtin.scale import scale
+
+
+class TestSource_01(unittest.TestCase):
+
+ sds: SystemDSContext = None
+ source_path: str = "./tests/source/source_with_list_input.dml"
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sds = SystemDSContext()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sds.close()
+
+ def test_single_return(self):
+ arr = self.sds.array(self.sds.full((10, 10), 4))
+ c = self.sds.source(self.source_path, "test").func(arr)
+ res = c.sum().compute()
+ self.assertTrue(res == 10*10*4)
+
+ def test_input_multireturn(self):
+ m = self.sds.full((10, 10), 2)
+ [a, b, c] = scale(m, True, True)
+ arr = self.sds.array(a, b, c)
+ c = self.sds.source(self.source_path, "test").func(arr)
+ res = c.sum().compute(verbose=True)
+ self.assertTrue(res == 0)
+
+ # [SYSTEMDS-3224] https://issues.apache.org/jira/browse/SYSTEMDS-3224
+ # def test_multi_return(self):
+ # arr = self.sds.array(
+ # self.sds.full((10, 10), 4),
+ # self.sds.full((3, 3), 5))
+ # [b, c] = self.sds.source(self.source_path, "test", True).func2(arr)
+ # res = c.sum().compute()
+ # self.assertTrue(res == 10*10*4)