This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new f7b5f88 [SYSTEMDS-2821] Python Stability
f7b5f88 is described below
commit f7b5f8811ad0c264f6ee1c7ecb9d16f4315698ca
Author: baunsgaard <[email protected]>
AuthorDate: Mon Feb 1 11:58:55 2021 +0100
[SYSTEMDS-2821] Python Stability
There have been some startup issues in the python API, where some tests
would not properly connect to the JVM.
This task addresses this by introducing a retry startup of the context
in case of failures.
---
src/main/java/org/apache/sysds/conf/DMLConfig.java | 2 +-
.../python/systemds/context/systemds_context.py | 146 ++++++++++++---------
.../python/systemds/operator/operation_node.py | 62 +++++----
src/main/python/systemds/script_building/script.py | 58 +++++---
4 files changed, 158 insertions(+), 110 deletions(-)
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index fdedc14..bc84318 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -71,7 +71,7 @@ public class DMLConfig
public static final String COMPRESSED_LOSSY =
"sysds.compressed.lossy";
public static final String COMPRESSED_VALID_COMPRESSIONS =
"sysds.compressed.valid.compressions";
public static final String COMPRESSED_OVERLAPPING =
"sysds.compressed.overlapping"; // true, false
- public static final String COMPRESSED_SAMPLING_RATIO =
"sysds.compressed.sampling.ratio"; // 0.1
+ public static final String COMPRESSED_SAMPLING_RATIO =
"sysds.compressed.sampling.ratio";
public static final String COMPRESSED_COCODE =
"sysds.compressed.cocode"; // COST
public static final String COMPRESSED_TRANSPOSE =
"sysds.compressed.transpose"; // true, false, auto.
public static final String NATIVE_BLAS = "sysds.native.blas";
diff --git a/src/main/python/systemds/context/systemds_context.py
b/src/main/python/systemds/context/systemds_context.py
index 5160c2b..3e29f62 100644
--- a/src/main/python/systemds/context/systemds_context.py
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -23,6 +23,7 @@ __all__ = ["SystemDSContext"]
import copy
import os
+import socket
import time
from glob import glob
from queue import Empty, Queue
@@ -34,10 +35,10 @@ from typing import Dict, Iterable, Sequence, Tuple, Union
import numpy as np
from py4j.java_gateway import GatewayParameters, JavaGateway
from py4j.protocol import Py4JNetworkError
-from systemds.utils.consts import VALID_INPUT_TYPES
-from systemds.utils.helpers import get_module_dir
from systemds.operator import OperationNode
from systemds.script_building import OutputType
+from systemds.utils.consts import VALID_INPUT_TYPES
+from systemds.utils.helpers import get_module_dir
class SystemDSContext(object):
@@ -55,67 +56,12 @@ class SystemDSContext(object):
Standard out and standard error form the JVM is also handled in this
class, filling up Queues,
that can be read from to get the printed statements from the JVM.
"""
-
- root = os.environ.get("SYSTEMDS_ROOT")
- if root == None:
- # If there is no systemds install default to use the PIP packaged
java files.
- root = os.path.join(get_module_dir(), "systemds-java")
-
- # nt means its Windows
- cp_separator = ";" if os.name == "nt" else ":"
-
- if os.environ.get("SYSTEMDS_ROOT") != None:
- lib_cp = os.path.join(root, "target", "lib", "*")
- systemds_cp = os.path.join(root, "target", "SystemDS.jar")
- classpath = cp_separator.join([lib_cp, systemds_cp])
-
- command = ["java", "-cp", classpath]
- files = glob(os.path.join(root, "conf", "log4j*.properties"))
- if len(files) > 1:
- print(
- "WARNING: Multiple logging files found selecting: " +
files[0])
- if len(files) == 0:
- print("WARNING: No log4j file found at: "
- + os.path.join(root, "conf")
- + " therefore using default settings")
- else:
- command.append("-Dlog4j.configuration=file:" + files[0])
- else:
- lib_cp = os.path.join(root, "lib", "*")
- command = ["java", "-cp", lib_cp]
-
- command.append("org.apache.sysds.api.PythonDMLScript")
-
+ command = self.__build_startup_command()
# TODO add an argument parser here
-
- # Find a random port, and hope that no other process
- # steals it while we wait for the JVM to startup
port = self.__get_open_port()
command.append(str(port))
- process = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE)
- first_stdout = process.stdout.readline()
-
- if(not b"GatewayServer Started" in first_stdout):
- stderr = process.stderr.readline().decode("utf-8")
- if(len(stderr) > 1):
- raise Exception(
- "Exception in startup of GatewayServer: " + stderr)
- outputs = []
- outputs.append(first_stdout.decode("utf-8"))
- max_tries = 10
- for i in range(max_tries):
- next_line = process.stdout.readline()
- if(b"GatewayServer Started" in next_line):
- print("WARNING: Stdout corrupted by prints: " +
str(outputs))
- print("Startup success")
- break
- else:
- outputs.append(next_line)
-
- if (i == max_tries-1):
- raise Exception("Error in startup of systemDS gateway
process: \n gateway StdOut: " + str(
- outputs) + " \n gateway StdErr" +
process.stderr.readline().decode("utf-8"))
+ process = self.__try_startup(command)
# Handle Std out from the subprocess.
self.__stdout = Queue()
@@ -166,7 +112,79 @@ class SystemDSContext(object):
print("exception")
print(e)
self.close()
- exit()
+
+
+ def __try_startup(self, command, rep = 0):
+ try:
+ process = Popen(command, stdout=PIPE, stdin=PIPE, stderr=PIPE)
+ self.__verify_startup(process)
+ return process
+ except Exception as e:
+ if rep > 3:
+ raise Exception("Failed to start SystemDS context with " + rep
+ " repeated tries")
+ else:
+ ret += 1
+ print("Failed to startup JVM process, retrying: " + rep)
+ sleep(rep) # Sleeping increasingly long time, maybe this helps.
+ return self.__try_startup()
+
+ def __verify_startup(self, process):
+ first_stdout = process.stdout.readline()
+ if(not b"GatewayServer Started" in first_stdout):
+ stderr = process.stderr.readline().decode("utf-8")
+ if(len(stderr) > 1):
+ raise Exception(
+ "Exception in startup of GatewayServer: " + stderr)
+ outputs = []
+ outputs.append(first_stdout.decode("utf-8"))
+ max_tries = 10
+ for i in range(max_tries):
+ next_line = process.stdout.readline()
+ if(b"GatewayServer Started" in next_line):
+ print("WARNING: Stdout corrupted by prints: " +
str(outputs))
+ print("Startup success")
+ break
+ else:
+ outputs.append(next_line)
+
+ if (i == max_tries-1):
+ raise Exception("Error in startup of systemDS gateway
process: \n gateway StdOut: " + str(
+ outputs) + " \n gateway StdErr" +
process.stderr.readline().decode("utf-8"))
+
+ def __build_startup_command(self):
+
+ command = ["java", "-cp"]
+ root = os.environ.get("SYSTEMDS_ROOT")
+ if root == None:
+ # If there is no systemds install default to use the PIP packaged
java files.
+ root = os.path.join(get_module_dir(), "systemds-java")
+
+ # nt means its Windows
+ cp_separator = ";" if os.name == "nt" else ":"
+
+ if os.environ.get("SYSTEMDS_ROOT") != None:
+ lib_cp = os.path.join(root, "target", "lib", "*")
+ systemds_cp = os.path.join(root, "target", "SystemDS.jar")
+ classpath = cp_separator.join([lib_cp, systemds_cp])
+
+ command.append(classpath)
+ files = glob(os.path.join(root, "conf", "log4j*.properties"))
+ if len(files) > 1:
+ print(
+ "WARNING: Multiple logging files found selecting: " +
files[0])
+ if len(files) == 0:
+ print("WARNING: No log4j file found at: "
+ + os.path.join(root, "conf")
+ + " therefore using default settings")
+ else:
+ command.append("-Dlog4j.configuration=file:" + files[0])
+ else:
+ lib_cp = os.path.join(root, "lib", "*")
+ command.append(lib_cp)
+
+ command.append("org.apache.sysds.api.PythonDMLScript")
+
+ return command
def __enter__(self):
return self
@@ -191,11 +209,11 @@ class SystemDSContext(object):
queue.put(line.decode("utf-8").strip())
def __get_open_port(self):
- """Get a random available port."""
- # TODO Verify that it is not taking some critical ports change to
select a good port range.
- # TODO If it tries to select a port already in use, find another.
+ """Get a random available port.
+ and hope that no other process steals it while we wait for the JVM to
startup
+ """
#
https://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python
- import socket
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.listen(1)
diff --git a/src/main/python/systemds/operator/operation_node.py
b/src/main/python/systemds/operator/operation_node.py
index 40e4e68..e2ac681 100644
--- a/src/main/python/systemds/operator/operation_node.py
+++ b/src/main/python/systemds/operator/operation_node.py
@@ -95,29 +95,12 @@ class OperationNode(DAGNode):
print(self._script.dml_script)
if lineage:
- result_variables, self._lineage_trace = self._script.execute(
- lineage)
+ result_variables, self._lineage_trace =
self._script.execute_with_lineage()
else:
- result_variables = self._script.execute(lineage)
-
- if self.output_type == OutputType.DOUBLE:
- self._result_var = result_variables.getDouble(
- self._script.out_var_name[0])
- elif self.output_type == OutputType.MATRIX:
- self._result_var =
matrix_block_to_numpy(self.sds_context.java_gateway.jvm,
-
result_variables.getMatrixBlock(self._script.out_var_name[0]))
- elif self.output_type == OutputType.LIST:
- self._result_var = []
- for idx, v in enumerate(self._script.out_var_name):
- if(self._output_types == None):
-
self._result_var.append(matrix_block_to_numpy(self.sds_context.java_gateway.jvm,
-
result_variables.getMatrixBlock(v)))
- elif(self._output_types[idx] == OutputType.MATRIX):
-
self._result_var.append(matrix_block_to_numpy(self.sds_context.java_gateway.jvm,
-
result_variables.getMatrixBlock(v)))
- else:
- self._result_var.append(result_variables.getDouble(
- self._script.out_var_name[idx]))
+ result_variables = self._script.execute()
+
+ self._result_var =
self.__parse_output_result_variables(result_variables)
+
if verbose:
for x in self.sds_context.get_stdout():
print(x)
@@ -129,6 +112,31 @@ class OperationNode(DAGNode):
else:
return self._result_var
+ def __parse_output_result_variables(self, result_variables):
+ if self.output_type == OutputType.DOUBLE:
+ return self.__parse_output_result_double(result_variables,
self._script.out_var_name[0])
+ elif self.output_type == OutputType.MATRIX:
+ return self.__parse_output_result_matrix(result_variables,
self._script.out_var_name[0])
+ elif self.output_type == OutputType.LIST:
+ return self.__parse_output_result_list(result_variables)
+
+ def __parse_output_result_double(self, result_variables, var_name):
+ return result_variables.getDouble(var_name)
+
+ def __parse_output_result_matrix(self, result_variables, var_name):
+ return matrix_block_to_numpy(self.sds_context.java_gateway.jvm,
+
result_variables.getMatrixBlock(var_name))
+
+ def __parse_output_result_list(self, result_variables):
+ result_var = []
+ for idx, v in enumerate(self._script.out_var_name):
+ if(self._output_types == None or self._output_types[idx] ==
OutputType.MATRIX):
+
result_var.append(self.__parse_output_result_matrix(result_variables,v))
+ else:
+ result_var.append(result_variables.getDouble(
+ self._script.out_var_name[idx]))
+ return result_var
+
def get_lineage_trace(self) -> str:
"""Get the lineage trace for this node.
@@ -501,8 +509,9 @@ class OperationNode(DAGNode):
other._check_matrix_op()
if self.shape[1] != other.shape[1]:
- raise ValueError("The input matrices to rbind does not have the
same number of columns")
-
+ raise ValueError(
+ "The input matrices to rbind does not have the same number of
columns")
+
return OperationNode(self.sds_context, 'rbind', [self, other],
shape=(self.shape[0] + other.shape[0], self.shape[1]))
def cbind(self, other) -> 'OperationNode':
@@ -516,6 +525,7 @@ class OperationNode(DAGNode):
other._check_matrix_op()
if self.shape[0] != other.shape[0]:
- raise ValueError("The input matrices to cbind does not have the
same number of columns")
-
+ raise ValueError(
+ "The input matrices to cbind does not have the same number of
columns")
+
return OperationNode(self.sds_context, 'cbind', [self, other],
shape=(self.shape[0], self.shape[1] + other.shape[1]))
diff --git a/src/main/python/systemds/script_building/script.py
b/src/main/python/systemds/script_building/script.py
index 1f9d5de..7084e1e 100644
--- a/src/main/python/systemds/script_building/script.py
+++ b/src/main/python/systemds/script_building/script.py
@@ -19,7 +19,7 @@
#
# -------------------------------------------------------------
-from typing import Any, Collection, KeysView, Tuple, Union, Optional, Dict,
TYPE_CHECKING
+from typing import Any, Collection, KeysView, Tuple, Union, Optional, Dict,
TYPE_CHECKING, List
from py4j.java_collections import JavaArray
from py4j.java_gateway import JavaObject, JavaGateway
@@ -44,7 +44,7 @@ class DMLScript:
dml_script: str
inputs: Dict[str, DAGNode]
prepared_script: Optional[Any]
- out_var_name: str
+ out_var_name: List[str]
_variable_counter: int
def __init__(self, context: 'SystemDSContext') -> None:
@@ -70,7 +70,7 @@ class DMLScript:
"""
self.inputs[var_name] = input_var
- def execute(self, lineage: bool = False) -> Union[JavaObject,
Tuple[JavaObject, str]]:
+ def execute(self) -> JavaObject:
"""If not already created, create a preparedScript from our DMLCode,
pass python local data to our prepared
script, then execute our script and return the resultVariables
@@ -78,27 +78,29 @@ class DMLScript:
"""
# we could use the gateway directly, non defined functions will be
automatically
# sent to the entry_point, but this is safer
- gateway = self.sds_context.java_gateway
- entry_point = gateway.entry_point
- if self.prepared_script is None:
- input_names = self.inputs.keys()
- connection = entry_point.getConnection()
- self.prepared_script = connection.prepareScript(
- self.dml_script,
- _list_to_java_array(gateway, input_names),
- _list_to_java_array(gateway, self.out_var_name))
- for (name, input_node) in self.inputs.items():
- input_node.pass_python_data_to_prepared_script(
- self.sds_context, name, self.prepared_script)
- if lineage:
- connection.setLineage(True)
try:
+ self.__prepare_script()
ret = self.prepared_script.executeScript()
+ return ret
except Exception as e:
self.sds_context.exception_and_close(e)
+ return None
+
+ def execute_with_lineage(self) -> Tuple[JavaObject, str]:
+ """If not already created, create a preparedScript from our DMLCode,
pass python local data to our prepared
+ script, then execute our script and return the resultVariables
+
+ :return: resultVariables of our execution and the string lineage trace
+ """
+ # we could use the gateway directly, non defined functions will be
automatically
+ # sent to the entry_point, but this is safer
+ try:
+ connection = self.__prepare_script()
+ connection.setLineage(True)
+ ret = self.prepared_script.executeScript()
+
- if lineage:
if len(self.out_var_name) == 1:
return ret,
self.prepared_script.getLineageTrace(self.out_var_name[0])
else:
@@ -106,8 +108,26 @@ class DMLScript:
for output in self.out_var_name:
traces.append(self.prepared_script.getLineageTrace(output))
return ret, traces
+
+ except Exception as e:
+ self.sds_context.exception_and_close(e)
+ return None, None
+
+ def __prepare_script(self):
+ gateway = self.sds_context.java_gateway
+ entry_point = gateway.entry_point
+ if self.prepared_script is None:
+ input_names = self.inputs.keys()
+ connection = entry_point.getConnection()
+ self.prepared_script = connection.prepareScript(
+ self.dml_script,
+ _list_to_java_array(gateway, input_names),
+ _list_to_java_array(gateway, self.out_var_name))
+ for (name, input_node) in self.inputs.items():
+ input_node.pass_python_data_to_prepared_script(
+ self.sds_context, name, self.prepared_script)
+ return connection
- return ret
def get_lineage(self) -> str:
gateway = self.sds_context.java_gateway