This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 0793a18  [SYSTEMDS-314] New Python SystemDS context manager
0793a18 is described below

commit 0793a183cf874faaf1f7d143d6b4e64b48e35db9
Author: Kevin Innerebner <kevin.innereb...@yahoo.com>
AuthorDate: Fri Apr 10 17:53:34 2020 +0200

    [SYSTEMDS-314] New Python SystemDS context manager
    
    Closes #874.
---
 docs/Tasks.txt                                     |   1 +
 src/main/python/docs/source/matrix.rst             |  42 +++++-
 src/main/python/docs/source/simple_examples.rst    |  67 +++++----
 src/main/python/systemds/__init__.py               |   2 +-
 src/main/python/systemds/{ => context}/__init__.py |   4 +-
 .../python/systemds/context/systemds_context.py    | 149 +++++++++++++++++++++
 src/main/python/systemds/matrix/matrix.py          |  44 +++---
 src/main/python/systemds/matrix/operation_node.py  | 124 +++++++++--------
 src/main/python/systemds/script_building/dag.py    |  36 +++--
 src/main/python/systemds/script_building/script.py |  37 ++---
 .../systemds/{__init__.py => utils/consts.py}      |   9 +-
 src/main/python/systemds/utils/converters.py       |   6 +-
 src/main/python/systemds/utils/helpers.py          |  46 +------
 src/main/python/tests/test_l2svm.py                |   9 +-
 src/main/python/tests/test_l2svm_lineage.py        |  20 +--
 src/main/python/tests/test_lineagetrace.py         |  23 ++--
 src/main/python/tests/test_matrix_aggregations.py  |  30 ++---
 src/main/python/tests/test_matrix_binary_op.py     |  32 ++---
 18 files changed, 446 insertions(+), 235 deletions(-)

diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index d19672f..d1168e6 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -231,6 +231,7 @@ SYSTEMDS-310 Python Bindings
  * 311 Initial Python Binding for federated execution                 OK
  * 312 Python 3.6 compatibility                                       OK
  * 313 Python Documentation upload via Github Actions                 OK
+ * 314 Python SystemDS context manager                                OK
 
 SYSTEMDS-320 Merge SystemDS into Apache SystemML                      OK
  * 321 Merge histories of SystemDS and SystemML                       OK
diff --git a/src/main/python/docs/source/matrix.rst 
b/src/main/python/docs/source/matrix.rst
index f2f2fdc..dd88c7c 100644
--- a/src/main/python/docs/source/matrix.rst
+++ b/src/main/python/docs/source/matrix.rst
@@ -23,6 +23,39 @@
 Matrix API
 ==========
 
+SystemDSContext
+---------------
+
+All operations using SystemDS need a java instance running.
+The connection is ensured by an ``SystemDSContext`` object.
+An ``SystemDSContext`` object can be created using:
+
+.. code_block:: python
+  sysds = SystemDSContext()
+
+When the calculations are finished the context has to be closed again:
+
+.. code_block:: python
+  sysds.close()
+
+Since it is annoying that it is always necessary to close the context, 
``SystemDSContext``
+implements the python context management protocol, which supports the 
following syntax:
+
+.. code_block:: python
+  with SystemDSContext() as sds:
+    # do something with sds which is an SystemDSContext
+    pass
+
+This will automatically close the ``SystemDSContext`` once the with-block is 
left.
+
+.. note::
+
+  Creating a context is an expensive procedure, because a sub-process starting 
a JVM might have to start, therefore
+  try to do this only once for your program, or always leave at least one 
context open.
+
+.. autoclass:: systemds.context.SystemDSContext
+  :members:
+
 OperationNode
 -------------
 
@@ -49,13 +82,12 @@ Matrix
 ------
 
 A ``Matrix`` is represented either by an ``OperationNode``, or the derived 
class ``Matrix``.
-An Matrix can recognized it by checking the ``output_type`` of the object.
+An Matrix can be recognized it by checking the ``output_type`` of the object.
 
-Matrices are the most fundamental objects we operate on.
-If one generate the matrix in SystemDS directly via a function call,
-it can be used in an function which will generate an ``OperationNode`` e.g. 
``federated``, ``full``, ``seq``.
+Matrices are the most fundamental objects SystemDS operates on.
 
-If we want to work on an numpy array we need to use the class ``Matrix``.
+Although it is possible to generate matrices with the function calls or object 
construction specified below,
+the recommended way is to use the methods defined on ``SystemDSContext``.
 
 .. autoclass:: systemds.matrix.Matrix
     :members:
diff --git a/src/main/python/docs/source/simple_examples.rst 
b/src/main/python/docs/source/simple_examples.rst
index b9c35c3..2175fd4 100644
--- a/src/main/python/docs/source/simple_examples.rst
+++ b/src/main/python/docs/source/simple_examples.rst
@@ -27,18 +27,24 @@ Let's take a look at some code examples.
 Matrix Operations
 -----------------
 
-Making use of SystemDS, let us multiply an Matrix with an scalar::
-
-  # Import full
-  from systemds.matrix import full
-  # Full generates a matrix completely filled with one number.
-  # Generate a 5x10 matrix filled with 4.2
-  m = full((5, 10), 4.20)
-  # multiply with scala. Nothing is executed yet!
-  m_res = m * 3.1
-  # Do the calculation in SystemDS by calling compute().
-  # The returned value is an numpy array that can be directly printed.
-  print(m_res.compute())
+Making use of SystemDS, let us multiply an Matrix with an scalar:
+
+.. code-block:: python
+
+  # Import SystemDSContext
+  from systemds.context import SystemDSContext
+  # Create a context and if necessary (no SystemDS py4j instance running)
+  # it starts a subprocess which does the execution in SystemDS
+  with SystemDSContext() as sds:
+      # Full generates a matrix completely filled with one number.
+      # Generate a 5x10 matrix filled with 4.2
+      m = sds.full((5, 10), 4.20)
+      # multiply with scalar. Nothing is executed yet!
+      m_res = m * 3.1
+      # Do the calculation in SystemDS by calling compute().
+      # The returned value is an numpy array that can be directly printed.
+      print(m_res.compute())
+  # context will automatically be closed and process stopped
 
 As output we get::
 
@@ -50,10 +56,14 @@ As output we get::
 
 The Python SystemDS package is compatible with numpy arrays.
 Let us do a quick element-wise matrix multiplication of numpy arrays with 
SystemDS.
-Remember to first start up a new terminal::
+Remember to first start up a new terminal:
+
+.. code-block:: python
 
   import numpy as np  # import numpy
-  from systemds.matrix import Matrix  # import Matrix class
+
+  # Import SystemDSContext
+  from systemds.context import SystemDSContext
 
   # create a random array
   m1 = np.array(np.random.randint(100, size=5 * 5) + 1.01, dtype=np.double)
@@ -62,22 +72,26 @@ Remember to first start up a new terminal::
   m2 = np.array(np.random.randint(5, size=5 * 5) + 1, dtype=np.double)
   m2.shape = (5, 5)
 
-  # element-wise matrix multiplication, note that nothing is executed yet!
-  m_res = Matrix(m1) * Matrix(m2)
-  # lets do the actual computation in SystemDS! We get an numpy array as a 
result
-  m_res_np = m_res.compute()
-  print(m_res_np)
+  # Create a context
+  with SystemDSContext() as sds:
+      # element-wise matrix multiplication, note that nothing is executed yet!
+      m_res = sds.matrix(m1) * sds.matrix(m2)
+      # lets do the actual computation in SystemDS! The result is an numpy 
array
+      m_res_np = m_res.compute()
+      print(m_res_np)
 
 More complex operations
 -----------------------
 
-SystemDS provides algorithm level functions as buildin functions to simplify 
development.
-One example of this is l2SVM.
-high level functions for Data-Scientists, lets take a look at l2svm::
+SystemDS provides algorithm level functions as built-in functions to simplify 
development.
+One example of this is l2SVM, a high level functions for Data-Scientists. 
Let's take a look at l2svm:
+
+.. code-block:: python
 
   # Import numpy and SystemDS matrix
   import numpy as np
-  from systemds.matrix import Matrix
+  from systemds.context import SystemDSContext
+
   # Set a seed
   np.random.seed(0)
   # Generate random features and labels in numpy
@@ -85,13 +99,16 @@ high level functions for Data-Scientists, lets take a look 
at l2svm::
   features = np.array(np.random.randint(100, size=10 * 10) + 1.01, 
dtype=np.double)
   features.shape = (10, 10)
   labels = np.zeros((10, 1))
+
   # l2svm labels can only be 0 or 1
   for i in range(10):
       if np.random.random() > 0.5:
           labels[i][0] = 1
+
   # compute our model
-  model = Matrix(features).l2svm(Matrix(labels)).compute()
-  print(model)
+  with SystemDSContext() as sds:
+      model = sds.matrix(features).l2svm(sds.matrix(labels)).compute()
+      print(model)
 
 The output should be similar to::
 
diff --git a/src/main/python/systemds/__init__.py 
b/src/main/python/systemds/__init__.py
index ed5f2bd..e51fbf8 100644
--- a/src/main/python/systemds/__init__.py
+++ b/src/main/python/systemds/__init__.py
@@ -19,4 +19,4 @@
 #
 #-------------------------------------------------------------
 
-__all__ = ['matrix']
+__all__ = ['context', 'matrix']
diff --git a/src/main/python/systemds/__init__.py 
b/src/main/python/systemds/context/__init__.py
similarity index 93%
copy from src/main/python/systemds/__init__.py
copy to src/main/python/systemds/context/__init__.py
index ed5f2bd..4d80c0b 100644
--- a/src/main/python/systemds/__init__.py
+++ b/src/main/python/systemds/context/__init__.py
@@ -19,4 +19,6 @@
 #
 #-------------------------------------------------------------
 
-__all__ = ['matrix']
+from .systemds_context import *
+
+__all__ = systemds_context.__all__
diff --git a/src/main/python/systemds/context/systemds_context.py 
b/src/main/python/systemds/context/systemds_context.py
new file mode 100644
index 0000000..d5bdeb8
--- /dev/null
+++ b/src/main/python/systemds/context/systemds_context.py
@@ -0,0 +1,149 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+__all__ = ["SystemDSContext"]
+
+import os
+import subprocess
+import threading
+from typing import Optional, Sequence, Union, Dict, Tuple, Iterable
+
+import numpy as np
+from py4j.java_gateway import JavaGateway
+from py4j.protocol import Py4JNetworkError
+
+from systemds.matrix import full, seq, federated, Matrix, OperationNode
+from systemds.utils.helpers import get_module_dir
+from systemds.utils.consts import VALID_INPUT_TYPES
+
+PROCESS_LOCK: threading.Lock = threading.Lock()
+PROCESS: Optional[subprocess.Popen] = None
+ACTIVE_PROCESS_CONNECTIONS: int = 0
+
+
+class SystemDSContext(object):
+    """A context with a connection to the java instance with which SystemDS 
operations are executed.
+    If necessary this class might also start a java process which is used for 
the SystemDS operations,
+    before connecting."""
+    _java_gateway: Optional[JavaGateway]
+
+    def __init__(self):
+        global PROCESS_LOCK
+        global PROCESS
+        global ACTIVE_PROCESS_CONNECTIONS
+        # make sure that a process is only started if necessary and no other 
thread
+        # is killing the process while the connection is established
+        PROCESS_LOCK.acquire()
+        try:
+            # attempt connection to manually started java instance
+            self._java_gateway = JavaGateway(eager_load=True)
+        except Py4JNetworkError:
+            # if no java instance is running start it
+            systemds_java_path = os.path.join(get_module_dir(), 
'systemds-java')
+            cp_separator = ':'
+            if os.name == 'nt':  # nt means its Windows
+                cp_separator = ';'
+            lib_cp = os.path.join(systemds_java_path, 'lib', '*')
+            systemds_cp = os.path.join(systemds_java_path, '*')
+            classpath = cp_separator.join([lib_cp, systemds_cp])
+            process = subprocess.Popen(['java', '-cp', classpath, 
'org.apache.sysds.pythonapi.PythonDMLScript'],
+                                       stdout=subprocess.PIPE, 
stdin=subprocess.PIPE)
+            process.stdout.readline()  # wait for 'Gateway Server Started\n' 
written by server
+            assert process.poll() is None, "Could not start JMLC server"
+            self._java_gateway = JavaGateway()
+            PROCESS = process
+        if PROCESS is not None:
+            ACTIVE_PROCESS_CONNECTIONS += 1
+        PROCESS_LOCK.release()
+
+    @property
+    def java_gateway(self):
+        return self._java_gateway
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+        # no errors to handle to allow continuation
+        return None
+
+    def close(self):
+        """Close the connection to the java process and do necessary 
cleanup."""
+
+        global PROCESS_LOCK
+        global PROCESS
+        global ACTIVE_PROCESS_CONNECTIONS
+        self._java_gateway.shutdown()
+        PROCESS_LOCK.acquire()
+        # check if no other thread is connected to the process, if it was 
started as a subprocess
+        if PROCESS is not None and ACTIVE_PROCESS_CONNECTIONS == 1:
+            # stop the process by sending a new line (it will shutdown on its 
own)
+            PROCESS.communicate(input=b'\n')
+            PROCESS.wait()
+            PROCESS = None
+            ACTIVE_PROCESS_CONNECTIONS = 0
+        PROCESS_LOCK.release()
+
+    def matrix(self, mat: Union[np.array, os.PathLike], *args: 
Sequence[VALID_INPUT_TYPES],
+               **kwargs: Dict[str, VALID_INPUT_TYPES]) -> 'Matrix':
+        """ Create matrix.
+
+        :param mat: Matrix given by numpy array or path to matrix file
+        :param args: additional arguments
+        :param kwargs: additional named arguments
+        :return: the OperationNode representing this operation
+        """
+        return Matrix(self, mat, *args, **kwargs)
+
+    def federated(self, addresses: Iterable[str], ranges: 
Iterable[Tuple[Iterable[int], Iterable[int]]],
+                  *args: Sequence[VALID_INPUT_TYPES], **kwargs: Dict[str, 
VALID_INPUT_TYPES]) -> 'OperationNode':
+        """Create federated matrix object.
+    
+        :param addresses: addresses of the federated workers
+        :param ranges: for each federated worker a pair of begin and end index 
of their held matrix
+        :param args: unnamed params
+        :param kwargs: named params
+        :return: the OperationNode representing this operation
+        """
+        return federated(self, addresses, ranges, *args, **kwargs)
+
+    def full(self, shape: Tuple[int, int], value: Union[float, int]) -> 
'OperationNode':
+        """Generates a matrix completely filled with a value
+
+        :param shape: shape (rows and cols) of the matrix
+        :param value: the value to fill all cells with
+        :return: the OperationNode representing this operation
+        """
+        return full(self, shape, value)
+
+    def seq(self, start: Union[float, int], stop: Union[float, int] = None,
+            step: Union[float, int] = 1) -> 'OperationNode':
+        """Create a single column vector with values from `start` to `stop` 
and an increment of `step`.
+        If no stop is defined and only one parameter is given, then start will 
be 0 and the parameter will be
+        interpreted as stop.
+
+        :param start: the starting value
+        :param stop: the maximum value
+        :param step: the step size
+        :return: the OperationNode representing this operation
+        """
+        return seq(self, start, stop, step)
diff --git a/src/main/python/systemds/matrix/matrix.py 
b/src/main/python/systemds/matrix/matrix.py
index 35aae8e..1509dbe 100644
--- a/src/main/python/systemds/matrix/matrix.py
+++ b/src/main/python/systemds/matrix/matrix.py
@@ -22,22 +22,29 @@
 __all__ = ['Matrix', 'federated', 'full', 'seq']
 
 import os
-from typing import Union, Optional, Iterable, Dict, Tuple, Sequence
+from typing import Union, Optional, Iterable, Dict, Tuple, Sequence, 
TYPE_CHECKING
 
 import numpy as np
 from py4j.java_gateway import JVMView, JavaObject
 
 from systemds.utils.converters import numpy_to_matrix_block
-from systemds.script_building.dag import VALID_INPUT_TYPES
 from systemds.matrix.operation_node import OperationNode
 
+from systemds.utils.consts import VALID_INPUT_TYPES
+
+if TYPE_CHECKING:
+    # to avoid cyclic dependencies during runtime
+    from systemds.context import SystemDSContext
 
 # TODO maybe instead of having a new class we could have a function `matrix` 
instead, adding behaviour to
 #  `OperationNode` would be necessary
+
+
 class Matrix(OperationNode):
-    np_array: Optional[np.array]
+    _np_array: Optional[np.array]
 
-    def __init__(self, mat: Union[np.array, os.PathLike], *args: 
Sequence[VALID_INPUT_TYPES],
+    def __init__(self, sds_context: 'SystemDSContext', mat: Union[np.array, 
os.PathLike],
+                 *args: Sequence[VALID_INPUT_TYPES],
                  **kwargs: Dict[str, VALID_INPUT_TYPES]) -> None:
         """Generate DAGNode representing matrix with data either given by a 
numpy array, which will be sent to SystemDS
         on need, or a path pointing to a matrix.
@@ -49,19 +56,19 @@ class Matrix(OperationNode):
         if isinstance(mat, str):
             unnamed_params = [f'\'{mat}\'']
             named_params = {}
-            self.np_array = None
+            self._np_array = None
         else:
             unnamed_params = ['\'./tmp/{file_name}\'']  # TODO better 
alternative than format string?
             named_params = {'rows': -1, 'cols': -1}
-            self.np_array = mat
+            self._np_array = mat
         unnamed_params.extend(args)
         named_params.update(kwargs)
-        super().__init__('read', unnamed_params, named_params, 
is_python_local_data=self._is_numpy())
+        super().__init__(sds_context, 'read', unnamed_params, named_params, 
is_python_local_data=self._is_numpy())
 
     def pass_python_data_to_prepared_script(self, jvm: JVMView, var_name: str, 
prepared_script: JavaObject) -> None:
         assert self.is_python_local_data, 'Can only pass data to prepared 
script if it is python local!'
         if self._is_numpy():
-            prepared_script.setMatrix(var_name, numpy_to_matrix_block(jvm, 
self.np_array), True)  # True for reuse
+            prepared_script.setMatrix(var_name, numpy_to_matrix_block(jvm, 
self._np_array), True)  # True for reuse
 
     def code_line(self, var_name: str, unnamed_input_vars: Sequence[str],
                   named_input_vars: Dict[str, str]) -> str:
@@ -74,18 +81,20 @@ class Matrix(OperationNode):
         if self._is_numpy():
             if verbose:
                 print('[Numpy Array - No Compilation necessary]')
-            return self.np_array
+            return self._np_array
         else:
             return super().compute(verbose, lineage)
 
     def _is_numpy(self) -> bool:
-        return self.np_array is not None
+        return self._np_array is not None
 
 
-def federated(addresses: Iterable[str], ranges: Iterable[Tuple[Iterable[int], 
Iterable[int]]], *args,
+def federated(sds_context: 'SystemDSContext', addresses: Iterable[str],
+              ranges: Iterable[Tuple[Iterable[int], Iterable[int]]], *args,
               **kwargs) -> OperationNode:
     """Create federated matrix object.
 
+    :param sds_context: the SystemDS context
     :param addresses: addresses of the federated workers
     :param ranges: for each federated worker a pair of begin and end index of 
their held matrix
     :param args: unnamed params
@@ -99,26 +108,29 @@ def federated(addresses: Iterable[str], ranges: 
Iterable[Tuple[Iterable[int], It
     ranges_str += ')'
     named_params = {'addresses': addresses_str, 'ranges': ranges_str}
     named_params.update(kwargs)
-    return OperationNode('federated', args, named_params)
+    return OperationNode(sds_context, 'federated', args, named_params)
 
 
-def full(shape: Tuple[int, int], value: Union[float, int]) -> OperationNode:
+def full(sds_context: 'SystemDSContext', shape: Tuple[int, int], value: 
Union[float, int]) -> OperationNode:
     """Generates a matrix completely filled with a value
 
+    :param sds_context: SystemDS context
     :param shape: shape (rows and cols) of the matrix TODO tensor
     :param value: the value to fill all cells with
     :return: the OperationNode representing this operation
     """
     unnamed_input_nodes = [value]
     named_input_nodes = {'rows': shape[0], 'cols': shape[1]}
-    return OperationNode('matrix', unnamed_input_nodes, named_input_nodes)
+    return OperationNode(sds_context, 'matrix', unnamed_input_nodes, 
named_input_nodes)
 
 
-def seq(start: Union[float, int], stop: Union[float, int] = None, step: 
Union[float, int] = 1) -> OperationNode:
+def seq(sds_context: 'SystemDSContext', start: Union[float, int], stop: 
Union[float, int] = None,
+        step: Union[float, int] = 1) -> OperationNode:
     """Create a single column vector with values from `start` to `stop` and an 
increment of `step`.
     If no stop is defined and only one parameter is given, then start will be 
0 and the parameter will be interpreted as
     stop.
 
+    :param sds_context: SystemDS context
     :param start: the starting value
     :param stop: the maximum value
     :param step: the step size
@@ -128,4 +140,4 @@ def seq(start: Union[float, int], stop: Union[float, int] = 
None, step: Union[fl
         stop = start
         start = 0
     unnamed_input_nodes = [start, stop, step]
-    return OperationNode('seq', unnamed_input_nodes)
+    return OperationNode(sds_context, 'seq', unnamed_input_nodes)
diff --git a/src/main/python/systemds/matrix/operation_node.py 
b/src/main/python/systemds/matrix/operation_node.py
index 5aa5efb..9a7eff1 100644
--- a/src/main/python/systemds/matrix/operation_node.py
+++ b/src/main/python/systemds/matrix/operation_node.py
@@ -19,86 +19,92 @@
 #
 #-------------------------------------------------------------
 
+from typing import Union, Optional, Iterable, Dict, Sequence, Tuple, 
TYPE_CHECKING
+
 import numpy as np
 from py4j.java_gateway import JVMView, JavaObject
-from typing import Union, Optional, Iterable, Dict, Sequence, Tuple
 
-from systemds.utils.helpers import get_gateway, create_params_string
+from systemds.utils.consts import VALID_INPUT_TYPES, BINARY_OPERATIONS, 
VALID_ARITHMETIC_TYPES
+from systemds.utils.helpers import create_params_string
 from systemds.utils.converters import matrix_block_to_numpy
 from systemds.script_building.script import DMLScript
-from systemds.script_building.dag import OutputType, DAGNode, VALID_INPUT_TYPES
+from systemds.script_building.dag import OutputType, DAGNode
 
-BINARY_OPERATIONS = ['+', '-', '/', '//', '*', '<', '<=', '>', '>=', '==', 
'!=']
-# TODO add numpy array
-VALID_ARITHMETIC_TYPES = Union[DAGNode, int, float]
+if TYPE_CHECKING:
+    # to avoid cyclic dependencies during runtime
+    from systemds.context import SystemDSContext
 
 __all__ = ["OperationNode"]
 
 
 class OperationNode(DAGNode):
-    result_var: Optional[Union[float, np.array]]
-    lineage_trace: str
-    script: Optional[DMLScript]
+    """A Node representing an operation in SystemDS"""
+    _result_var: Optional[Union[float, np.array]]
+    _lineage_trace: Optional[str]
+    _script: Optional[DMLScript]
 
-    def __init__(self, operation: str, unnamed_input_nodes: 
Iterable[VALID_INPUT_TYPES] = None,
+    def __init__(self, sds_context: 'SystemDSContext', operation: str,
+                 unnamed_input_nodes: Iterable[VALID_INPUT_TYPES] = None,
                  named_input_nodes: Dict[str, VALID_INPUT_TYPES] = None,
                  output_type: OutputType = OutputType.MATRIX, 
is_python_local_data: bool = False):
         """
         Create general `OperationNode`
 
+        :param sds_context: The SystemDS context for performing the operations
         :param operation: The name of the DML function to execute
         :param unnamed_input_nodes: inputs identified by their position, not 
name
         :param named_input_nodes: inputs with their respective parameter name
         :param output_type: type of the output in DML (double, matrix etc.)
         :param is_python_local_data: if the data is local in python e.g. numpy 
arrays
         """
+        self.sds_context = sds_context
         if unnamed_input_nodes is None:
             unnamed_input_nodes = []
         if named_input_nodes is None:
             named_input_nodes = {}
         self.operation = operation
-        self.unnamed_input_nodes = unnamed_input_nodes
-        self.named_input_nodes = named_input_nodes
+        self._unnamed_input_nodes = unnamed_input_nodes
+        self._named_input_nodes = named_input_nodes
         self.output_type = output_type
-        self.is_python_local_data = is_python_local_data
-        self.result_var = None
-        self.lineage_trace = None
-        self.script = None
-
-    def compute(self, verbose: bool = False, lineage: bool = False) -> 
Union[float, np.array, Tuple[Union[float, np.array], str]]:
-        if self.result_var is None:
-            self.script = DMLScript()
-            self.script.build_code(self)
+        self._is_python_local_data = is_python_local_data
+        self._result_var = None
+        self._lineage_trace = None
+        self._script = None
+
+    def compute(self, verbose: bool = False, lineage: bool = False) -> \
+            Union[float, np.array, Tuple[Union[float, np.array], str]]:
+        if self._result_var is None or self._lineage_trace is None:
+            self._script = DMLScript(self.sds_context)
+            self._script.build_code(self)
             if lineage:
-                result_variables, ltrace = self.script.execute(lineage)
+                result_variables, self._lineage_trace = 
self._script.execute(lineage)
             else:
-                result_variables = self.script.execute(lineage)
+                result_variables = self._script.execute(lineage)
             if self.output_type == OutputType.DOUBLE:
-                self.result_var = 
result_variables.getDouble(self.script.out_var_name)
+                self._result_var = 
result_variables.getDouble(self._script.out_var_name)
             elif self.output_type == OutputType.MATRIX:
-                self.result_var = matrix_block_to_numpy(get_gateway().jvm,
-                                                        
result_variables.getMatrixBlock(self.script.out_var_name))
+                self._result_var = 
matrix_block_to_numpy(self.sds_context._java_gateway.jvm,
+                                                         
result_variables.getMatrixBlock(self._script.out_var_name))
         if verbose:
-            print(self.script.dml_script)
+            print(self._script.dml_script)
             # TODO further info
 
         if lineage:
-            return self.result_var, ltrace
+            return self._result_var, self._lineage_trace
         else:
-            return self.result_var
+            return self._result_var
 
-    def getLineageTrace(self) -> str:
+    def get_lineage_trace(self) -> str:
         """Get the lineage trace for this node.
 
         :return: Lineage trace
         """
-        if self.lineage_trace is None:
-            self.script = DMLScript()
-            self.script.build_code(self)
-            self.lineage_trace = self.script.getlineage()
+        if self._lineage_trace is None:
+            self._script = DMLScript(self.sds_context)
+            self._script.build_code(self)
+            self._lineage_trace = self._script.get_lineage()
 
-        return self.lineage_trace
-        
+        return self._lineage_trace
 
     def code_line(self, var_name: str, unnamed_input_vars: Sequence[str],
                   named_input_vars: Dict[str, str]) -> str:
@@ -121,37 +127,37 @@ class OperationNode(DAGNode):
         assert self.output_type == OutputType.MATRIX, f'{self.operation} only 
supported for matrices'
 
     def __add__(self, other: VALID_ARITHMETIC_TYPES):
-        return OperationNode('+', [self, other])
+        return OperationNode(self.sds_context, '+', [self, other])
 
     def __sub__(self, other: VALID_ARITHMETIC_TYPES):
-        return OperationNode('-', [self, other])
+        return OperationNode(self.sds_context, '-', [self, other])
 
     def __mul__(self, other: VALID_ARITHMETIC_TYPES):
-        return OperationNode('*', [self, other])
+        return OperationNode(self.sds_context, '*', [self, other])
 
     def __truediv__(self, other: VALID_ARITHMETIC_TYPES):
-        return OperationNode('/', [self, other])
+        return OperationNode(self.sds_context, '/', [self, other])
 
     def __floordiv__(self, other: VALID_ARITHMETIC_TYPES):
-        return OperationNode('//', [self, other])
+        return OperationNode(self.sds_context, '//', [self, other])
 
     def __lt__(self, other) -> 'OperationNode':
-        return OperationNode('<', [self, other])
+        return OperationNode(self.sds_context, '<', [self, other])
 
     def __le__(self, other):
-        return OperationNode('<=', [self, other])
+        return OperationNode(self.sds_context, '<=', [self, other])
 
     def __gt__(self, other):
-        return OperationNode('>', [self, other])
+        return OperationNode(self.sds_context, '>', [self, other])
 
     def __ge__(self, other):
-        return OperationNode('>=', [self, other])
+        return OperationNode(self.sds_context, '>=', [self, other])
 
     def __eq__(self, other):
-        return OperationNode('==', [self, other])
+        return OperationNode(self.sds_context, '==', [self, other])
 
     def __ne__(self, other):
-        return OperationNode('!=', [self, other])
+        return OperationNode(self.sds_context, '!=', [self, other])
 
     def l2svm(self, labels: DAGNode, **kwargs) -> 'OperationNode':
         """Perform l2svm on matrix with labels given.
@@ -161,7 +167,7 @@ class OperationNode(DAGNode):
         self._check_matrix_op()
         params_dict = {'X': self, 'Y': labels}
         params_dict.update(kwargs)
-        return OperationNode('l2svm', named_input_nodes=params_dict)
+        return OperationNode(self.sds_context, 'l2svm', 
named_input_nodes=params_dict)
 
     def sum(self, axis: int = None) -> 'OperationNode':
         """Calculate sum of matrix.
@@ -171,11 +177,11 @@ class OperationNode(DAGNode):
         """
         self._check_matrix_op()
         if axis == 0:
-            return OperationNode('colSums', [self])
+            return OperationNode(self.sds_context, 'colSums', [self])
         elif axis == 1:
-            return OperationNode('rowSums', [self])
+            return OperationNode(self.sds_context, 'rowSums', [self])
         elif axis is None:
-            return OperationNode('sum', [self], output_type=OutputType.DOUBLE)
+            return OperationNode(self.sds_context, 'sum', [self], 
output_type=OutputType.DOUBLE)
         raise ValueError(f"Axis has to be either 0, 1 or None, for column, row 
or complete {self.operation}")
 
     def mean(self, axis: int = None) -> 'OperationNode':
@@ -186,11 +192,11 @@ class OperationNode(DAGNode):
         """
         self._check_matrix_op()
         if axis == 0:
-            return OperationNode('colMeans', [self])
+            return OperationNode(self.sds_context, 'colMeans', [self])
         elif axis == 1:
-            return OperationNode('rowMeans', [self])
+            return OperationNode(self.sds_context, 'rowMeans', [self])
         elif axis is None:
-            return OperationNode('mean', [self], output_type=OutputType.DOUBLE)
+            return OperationNode(self.sds_context, 'mean', [self], 
output_type=OutputType.DOUBLE)
         raise ValueError(f"Axis has to be either 0, 1 or None, for column, row 
or complete {self.operation}")
 
     def var(self, axis: int = None) -> 'OperationNode':
@@ -201,11 +207,11 @@ class OperationNode(DAGNode):
         """
         self._check_matrix_op()
         if axis == 0:
-            return OperationNode('colVars', [self])
+            return OperationNode(self.sds_context, 'colVars', [self])
         elif axis == 1:
-            return OperationNode('rowVars', [self])
+            return OperationNode(self.sds_context, 'rowVars', [self])
         elif axis is None:
-            return OperationNode('var', [self], output_type=OutputType.DOUBLE)
+            return OperationNode(self.sds_context, 'var', [self], 
output_type=OutputType.DOUBLE)
         raise ValueError(f"Axis has to be either 0, 1 or None, for column, row 
or complete {self.operation}")
 
     def abs(self) -> 'OperationNode':
@@ -213,7 +219,7 @@ class OperationNode(DAGNode):
 
         :return: `OperationNode` representing operation
         """
-        return OperationNode('abs', [self])
+        return OperationNode(self.sds_context, 'abs', [self])
 
     def moment(self, moment, weights: DAGNode = None) -> 'OperationNode':
         # TODO write tests
@@ -222,4 +228,4 @@ class OperationNode(DAGNode):
         if weights is not None:
             unnamed_inputs.append(weights)
         unnamed_inputs.append(moment)
-        return OperationNode('moment', unnamed_inputs, 
output_type=OutputType.DOUBLE)
+        return OperationNode(self.sds_context, 'moment', unnamed_inputs, 
output_type=OutputType.DOUBLE)
diff --git a/src/main/python/systemds/script_building/dag.py 
b/src/main/python/systemds/script_building/dag.py
index b2ac8a7..146db0c 100644
--- a/src/main/python/systemds/script_building/dag.py
+++ b/src/main/python/systemds/script_building/dag.py
@@ -20,11 +20,15 @@
 #-------------------------------------------------------------
 
 from enum import Enum, auto
-from typing import Any, Dict, Union, Sequence
+from typing import Any, Dict, Union, Sequence, TYPE_CHECKING
 from abc import ABC
 
 from py4j.java_gateway import JavaObject, JVMView
 
+if TYPE_CHECKING:
+    # to avoid cyclic dependencies during runtime
+    from systemds.context import SystemDSContext
+
 
 class OutputType(Enum):
     MATRIX = auto()
@@ -33,24 +37,27 @@ class OutputType(Enum):
 
 class DAGNode(ABC):
     """A Node in the directed-acyclic-graph (DAG) defining all operations."""
-    unnamed_input_nodes: Sequence[Union['DAGNode', str, int, float, bool]]
-    named_input_nodes: Dict[str, Union['DAGNode', str, int, float, bool]]
-    output_type: OutputType
-    is_python_local_data: bool
+    sds_context: 'SystemDSContext'
+    _unnamed_input_nodes: Sequence[Union['DAGNode', str, int, float, bool]]
+    _named_input_nodes: Dict[str, Union['DAGNode', str, int, float, bool]]
+    _output_type: OutputType
+    _is_python_local_data: bool
 
     def compute(self, verbose: bool = False, lineage: bool = False) -> Any:
         """Get result of this operation. Builds the dml script and executes it 
in SystemDS, before this method is called
         all operations are only building the DAG without actually executing 
(lazy evaluation).
 
         :param verbose: Can be activated to print additional information such 
as created DML-Script
-        :lineage: Can be activated to print lineage trace till this node
+        :param lineage: Can be activated to print lineage trace till this node
         :return: the output as an python builtin data type or numpy array
         """
         raise NotImplementedError
 
-    def getLineageTrace(self) -> str:
-        """Get lineage trace of this operation. This executes the dml script 
but unlike compute, doesn't store the results
-        """
+    def get_lineage_trace(self) -> str:
+        """Get lineage trace of this operation. This executes the dml script 
but unlike compute,
+        doesn't store the results"""
+        # TODO why do we not want to store the results? The execution script 
will should stay the same
+        #  therefore we could cache the result.
         raise NotImplementedError
 
     def code_line(self, var_name: str, unnamed_input_vars: Sequence[str], 
named_input_vars: Dict[str, str]) -> str:
@@ -72,5 +79,14 @@ class DAGNode(ABC):
         """
         raise NotImplementedError
 
+    @property
+    def unnamed_input_nodes(self):
+        return self._unnamed_input_nodes
+
+    @property
+    def named_input_nodes(self):
+        return self._named_input_nodes
 
-VALID_INPUT_TYPES = Union[DAGNode, str, int, float, bool]
+    @property
+    def is_python_local_data(self):
+        return self._is_python_local_data
diff --git a/src/main/python/systemds/script_building/script.py 
b/src/main/python/systemds/script_building/script.py
index 940cc5b..30857c3 100644
--- a/src/main/python/systemds/script_building/script.py
+++ b/src/main/python/systemds/script_building/script.py
@@ -19,14 +19,17 @@
 #
 #-------------------------------------------------------------
 
-from typing import Any, Dict, Optional, Collection, KeysView, Union, Tuple
+from typing import Any, Collection, KeysView, Tuple, Union, Optional, Dict, 
TYPE_CHECKING
 
 from py4j.java_collections import JavaArray
-from py4j.java_gateway import JavaObject
+from py4j.java_gateway import JavaObject, JavaGateway
 
-from systemds.utils.helpers import get_gateway
-from systemds.script_building.dag import DAGNode, VALID_INPUT_TYPES
-from typing import Union, Optional, Dict
+from systemds.script_building.dag import DAGNode
+from systemds.utils.consts import VALID_INPUT_TYPES
+
+if TYPE_CHECKING:
+    # to avoid cyclic dependencies during runtime
+    from systemds.context import SystemDSContext
 
 
 class DMLScript:
@@ -39,13 +42,15 @@ class DMLScript:
 
     TODO rerun with different inputs without recompilation
     """
+    sds_context: 'SystemDSContext'
     dml_script: str
     inputs: Dict[str, DAGNode]
     prepared_script: Optional[Any]
     out_var_name: str
     _variable_counter: int
 
-    def __init__(self) -> None:
+    def __init__(self, context: 'SystemDSContext') -> None:
+        self.sds_context = context
         self.dml_script = ''
         self.inputs = {}
         self.prepared_script = None
@@ -75,14 +80,14 @@ class DMLScript:
         """
         # we could use the gateway directly, non defined functions will be 
automatically
         # sent to the entry_point, but this is safer
-        gateway = get_gateway()
+        gateway = self.sds_context.java_gateway
         entry_point = gateway.entry_point
         if self.prepared_script is None:
             input_names = self.inputs.keys()
             connection = entry_point.getConnection()
             self.prepared_script = connection.prepareScript(self.dml_script,
-                                                            
_list_to_java_array(input_names),
-                                                            
_list_to_java_array([self.out_var_name]))
+                                                            
_list_to_java_array(gateway, input_names),
+                                                            
_list_to_java_array(gateway, [self.out_var_name]))
             for (name, input_node) in self.inputs.items():
                 input_node.pass_python_data_to_prepared_script(gateway.jvm, 
name, self.prepared_script)
 
@@ -95,24 +100,23 @@ class DMLScript:
 
         return ret
 
-    def getlineage(self) -> str:
-        gateway = get_gateway()
+    def get_lineage(self) -> str:
+        gateway = self.sds_context.java_gateway
         entry_point = gateway.entry_point
         if self.prepared_script is None:
             input_names = self.inputs.keys()
             connection = entry_point.getConnection()
             self.prepared_script = connection.prepareScript(self.dml_script,
-                                                            
_list_to_java_array(input_names),
-                                                            
_list_to_java_array([self.out_var_name]))
+                                                            
_list_to_java_array(gateway, input_names),
+                                                            
_list_to_java_array(gateway, [self.out_var_name]))
             for (name, input_node) in self.inputs.items():
                 input_node.pass_python_data_to_prepared_script(gateway.jvm, 
name, self.prepared_script)
 
             connection.setLineage(True)
 
-        ret = self.prepared_script.executeScript()
+        self.prepared_script.executeScript()
         lineage = self.prepared_script.getLineageTrace(self.out_var_name)
         return lineage
-        
 
     def build_code(self, dag_root: DAGNode) -> None:
         """Builds code from our DAG
@@ -156,13 +160,12 @@ class DMLScript:
 
 
 # Helper Functions
-def _list_to_java_array(py_list: Union[Collection[str], KeysView[str]]) -> 
JavaArray:
+def _list_to_java_array(gateway: JavaGateway, py_list: Union[Collection[str], 
KeysView[str]]) -> JavaArray:
     """Convert python collection to java array.
 
     :param py_list: python collection
     :return: java array
     """
-    gateway = get_gateway()
     array = gateway.new_array(gateway.jvm.java.lang.String, len(py_list))
     for (i, e) in enumerate(py_list):
         array[i] = e
diff --git a/src/main/python/systemds/__init__.py 
b/src/main/python/systemds/utils/consts.py
similarity index 74%
copy from src/main/python/systemds/__init__.py
copy to src/main/python/systemds/utils/consts.py
index ed5f2bd..34506e2 100644
--- a/src/main/python/systemds/__init__.py
+++ b/src/main/python/systemds/utils/consts.py
@@ -17,6 +17,11 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-#-------------------------------------------------------------
+from typing import Union
+
+MODULE_NAME = 'systemds'
+VALID_INPUT_TYPES = Union['DAGNode', str, int, float, bool]
+BINARY_OPERATIONS = ['+', '-', '/', '//', '*', '<', '<=', '>', '>=', '==', 
'!=']
+# TODO add numpy array and implement for numpy array
+VALID_ARITHMETIC_TYPES = Union['DAGNode', int, float]
 
-__all__ = ['matrix']
diff --git a/src/main/python/systemds/utils/converters.py 
b/src/main/python/systemds/utils/converters.py
index 37c9c20..bbbba78 100644
--- a/src/main/python/systemds/utils/converters.py
+++ b/src/main/python/systemds/utils/converters.py
@@ -44,7 +44,7 @@ def numpy_to_matrix_block(jvm: JVMView, np_arr: np.array):
 
 
 def matrix_block_to_numpy(jvm: JVMView, mb: JavaObject):
-    numRows = mb.getNumRows()
-    numCols = mb.getNumColumns()
+    num_ros = mb.getNumRows()
+    num_cols = mb.getNumColumns()
     buf = 
jvm.org.apache.sysds.runtime.compress.utils.Py4jConverterUtils.convertMBtoPy4JDenseArr(mb)
-    return np.frombuffer(buf, count=numRows * numCols, 
dtype=np.float64).reshape((numRows, numCols))
+    return np.frombuffer(buf, count=num_ros * num_cols, 
dtype=np.float64).reshape((num_ros, num_cols))
diff --git a/src/main/python/systemds/utils/helpers.py 
b/src/main/python/systemds/utils/helpers.py
index 32d0499..7aae17e 100644
--- a/src/main/python/systemds/utils/helpers.py
+++ b/src/main/python/systemds/utils/helpers.py
@@ -20,53 +20,11 @@
 #-------------------------------------------------------------
 
 import os
-import subprocess
 from itertools import chain
 from typing import Iterable, Dict
 from importlib.util import find_spec
 
-from py4j.java_gateway import JavaGateway
-from py4j.protocol import Py4JNetworkError
-
-JAVA_GATEWAY = None
-MODULE_NAME = 'systemds'
-PROC = None
-
-
-def get_gateway() -> JavaGateway:
-    """
-    Gives the gateway with which we can communicate with the SystemDS instance 
running a
-    JMLC (Java Machine Learning Compactor) API.
-
-    :return: the java gateway object
-    """
-    global JAVA_GATEWAY
-    global PROC
-    if JAVA_GATEWAY is None:
-        try:
-            JAVA_GATEWAY = JavaGateway(eager_load=True)
-        except Py4JNetworkError:  # if no java instance is running start it
-            systemds_java_path = os.path.join(_get_module_dir(), 
'systemds-java')
-            cp_separator = ':'
-            if os.name == 'nt':  # nt means its Windows
-                cp_separator = ';'
-            lib_cp = os.path.join(systemds_java_path, 'lib', '*')
-            systemds_cp = os.path.join(systemds_java_path, '*')
-            classpath = cp_separator.join([lib_cp, systemds_cp])
-            process = subprocess.Popen(['java', '-cp', classpath, 
'org.apache.sysds.pythonapi.PythonDMLScript'],
-                                       stdout=subprocess.PIPE, 
stdin=subprocess.PIPE)
-            print(process.stdout.readline())  # wait for 'Gateway Server 
Started\n' written by server
-            assert process.poll() is None, "Could not start JMLC server"
-            JAVA_GATEWAY = JavaGateway()
-            PROC = process
-    return JAVA_GATEWAY
-
-
-def shutdown():
-    global JAVA_GATEWAY
-    global PROC
-    JAVA_GATEWAY.shutdown()
-    PROC.communicate(input=b'\n')
+from systemds.utils.consts import MODULE_NAME
 
 
 def create_params_string(unnamed_parameters: Iterable[str], named_parameters: 
Dict[str, str]) -> str:
@@ -82,7 +40,7 @@ def create_params_string(unnamed_parameters: Iterable[str], 
named_parameters: Di
     return ','.join(chain(unnamed_parameters, named_input_strs))
 
 
-def _get_module_dir() -> os.PathLike:
+def get_module_dir() -> os.PathLike:
     """
     Gives the path to our module
 
diff --git a/src/main/python/tests/test_l2svm.py 
b/src/main/python/tests/test_l2svm.py
index 5287381..1412d05 100644
--- a/src/main/python/tests/test_l2svm.py
+++ b/src/main/python/tests/test_l2svm.py
@@ -30,8 +30,11 @@ import numpy as np
 
 path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
 sys.path.insert(0, path)
+
+from systemds.context import SystemDSContext
 from systemds.matrix import Matrix
-from systemds.utils import helpers
+
+sds = SystemDSContext()
 
 class TestL2svm(unittest.TestCase):
 
@@ -69,9 +72,9 @@ def generate_matrices_for_l2svm(dims: int, seed: int = 1234) 
-> Tuple[Matrix, Ma
     for i in range(dims):
         if np.random.random() > 0.5:
             m2[i][0] = 1
-    return Matrix(m1), Matrix(m2)
+    return sds.matrix(m1), sds.matrix(m2)
 
 
 if __name__ == "__main__":
     unittest.main(exit=False)
-    helpers.shutdown()
\ No newline at end of file
+    sds.close()
diff --git a/src/main/python/tests/test_l2svm_lineage.py 
b/src/main/python/tests/test_l2svm_lineage.py
index c8cf66e..9bd535e 100644
--- a/src/main/python/tests/test_l2svm_lineage.py
+++ b/src/main/python/tests/test_l2svm_lineage.py
@@ -30,7 +30,10 @@ import numpy as np
 path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
 sys.path.insert(0, path)
 from systemds.matrix import Matrix
-from systemds.utils import helpers
+from systemds.context import SystemDSContext
+
+sds = SystemDSContext()
+
 
 class TestAPI(unittest.TestCase):
 
@@ -46,8 +49,8 @@ class TestAPI(unittest.TestCase):
 
     def test_getl2svm_lineage(self):
         features, labels = generate_matrices_for_l2svm(10, seed=1304)
-        #get the lineage trace
-        lt = features.l2svm(labels).getLineageTrace()
+        # get the lineage trace
+        lt = features.l2svm(labels).get_lineage_trace()
         with open(os.path.join("tests", "lt_l2svm.txt"), "r") as file:
             data = file.read()
         file.close()
@@ -55,8 +58,8 @@ class TestAPI(unittest.TestCase):
 
     def test_getl2svm_lineage2(self):
         features, labels = generate_matrices_for_l2svm(10, seed=1304)
-        #get the lineage trace
-        model, lt = features.l2svm(labels).compute(lineage = True)
+        # get the lineage trace
+        model, lt = features.l2svm(labels).compute(lineage=True)
         with open(os.path.join("tests", "lt_l2svm.txt"), "r") as file:
             data = file.read()
         file.close()
@@ -71,14 +74,15 @@ def generate_matrices_for_l2svm(dims: int, seed: int = 
1234) -> Tuple[Matrix, Ma
     for i in range(dims):
         if np.random.random() > 0.5:
             m2[i][0] = 1
-    return Matrix(m1), Matrix(m2)
+    return sds.matrix(m1), sds.matrix(m2)
+
 
 def reVars(s: str) -> str:
     s = re.sub(r'\b_mVar\d*\b', '', s)
     s = re.sub(r'\b_Var\d*\b', '', s)
     return s
-    
+
 
 if __name__ == "__main__":
     unittest.main(exit=False)
-    helpers.shutdown()
+    sds.close()
diff --git a/src/main/python/tests/test_lineagetrace.py 
b/src/main/python/tests/test_lineagetrace.py
index 371f424..d2064bd 100644
--- a/src/main/python/tests/test_lineagetrace.py
+++ b/src/main/python/tests/test_lineagetrace.py
@@ -27,8 +27,10 @@ import re
 
 path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
 sys.path.insert(0, path)
-from systemds.matrix import Matrix, full, seq
-from systemds.utils import helpers
+from systemds.context import SystemDSContext
+
+sds = SystemDSContext()
+
 
 class TestLineageTrace(unittest.TestCase):
 
@@ -42,19 +44,19 @@ class TestLineageTrace(unittest.TestCase):
                                 message="unclosed",
                                 category=ResourceWarning)
 
-    def test_compare_trace1(self): #test getLineageTrace() on an intermediate
-        m = full((5, 10), 4.20)
+    def test_compare_trace1(self):  # test getLineageTrace() on an intermediate
+        m = sds.full((5, 10), 4.20)
         m_res = m * 3.1
-        m_sum = m_res.sum()       
+        m_sum = m_res.sum()
         with open(os.path.join("tests", "lt.txt"), "r") as file:
             data = file.read()
         file.close()
-        self.assertEqual(reVars(m_res.getLineageTrace()), reVars(data))
+        self.assertEqual(reVars(m_res.get_lineage_trace()), reVars(data))
 
-    def test_compare_trace2(self): #test (lineage=True) as an argument to 
compute
-        m = full((5, 10), 4.20)
+    def test_compare_trace2(self):  # test (lineage=True) as an argument to 
compute
+        m = sds.full((5, 10), 4.20)
         m_res = m * 3.1
-        sum, lt = m_res.sum().compute(lineage = True)       
+        sum, lt = m_res.sum().compute(lineage=True)
         lt = re.sub(r'\b_mVar\d*\b', '', lt)
         with open(os.path.join("tests", "lt2.txt"), "r") as file:
             data = file.read()
@@ -67,6 +69,7 @@ def reVars(s: str) -> str:
     s = re.sub(r'\b_Var\d*\b', '', s)
     return s
 
+
 if __name__ == "__main__":
     unittest.main(exit=False)
-    helpers.shutdown()
+    sds.close()
diff --git a/src/main/python/tests/test_matrix_aggregations.py 
b/src/main/python/tests/test_matrix_aggregations.py
index d17a634..ba9cb4b 100644
--- a/src/main/python/tests/test_matrix_aggregations.py
+++ b/src/main/python/tests/test_matrix_aggregations.py
@@ -28,9 +28,7 @@ import numpy as np
 
 path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
 sys.path.insert(0, path)
-
-from systemds.matrix import Matrix, full, seq
-from systemds.utils import helpers
+from systemds.context import SystemDSContext
 
 dim = 5
 m1 = np.array(np.random.randint(100, size=dim * dim) + 1.01, dtype=np.double)
@@ -38,6 +36,8 @@ m1.shape = (dim, dim)
 m2 = np.array(np.random.randint(5, size=dim * dim) + 1, dtype=np.double)
 m2.shape = (dim, dim)
 
+sds = SystemDSContext()
+
 
 class TestMatrixAggFn(unittest.TestCase):
 
@@ -52,39 +52,39 @@ class TestMatrixAggFn(unittest.TestCase):
                                 category=ResourceWarning)
 
     def test_sum1(self):
-        self.assertTrue(np.allclose(Matrix(m1).sum().compute(), m1.sum()))
+        self.assertTrue(np.allclose(sds.matrix(m1).sum().compute(), m1.sum()))
 
     def test_sum2(self):
-        self.assertTrue(np.allclose(Matrix(m1).sum(axis=0).compute(), 
m1.sum(axis=0)))
+        self.assertTrue(np.allclose(sds.matrix(m1).sum(axis=0).compute(), 
m1.sum(axis=0)))
 
     def test_sum3(self):
-        self.assertTrue(np.allclose(Matrix(m1).sum(axis=1).compute(), 
m1.sum(axis=1).reshape(dim, 1)))
+        self.assertTrue(np.allclose(sds.matrix(m1).sum(axis=1).compute(), 
m1.sum(axis=1).reshape(dim, 1)))
 
     def test_mean1(self):
-        self.assertTrue(np.allclose(Matrix(m1).mean().compute(), m1.mean()))
+        self.assertTrue(np.allclose(sds.matrix(m1).mean().compute(), 
m1.mean()))
 
     def test_mean2(self):
-        self.assertTrue(np.allclose(Matrix(m1).mean(axis=0).compute(), 
m1.mean(axis=0)))
+        self.assertTrue(np.allclose(sds.matrix(m1).mean(axis=0).compute(), 
m1.mean(axis=0)))
 
     def test_mean3(self):
-        self.assertTrue(np.allclose(Matrix(m1).mean(axis=1).compute(), 
m1.mean(axis=1).reshape(dim, 1)))
+        self.assertTrue(np.allclose(sds.matrix(m1).mean(axis=1).compute(), 
m1.mean(axis=1).reshape(dim, 1)))
 
     def test_full(self):
-        self.assertTrue(np.allclose(full((2, 3), 10.1).compute(), np.full((2, 
3), 10.1)))
+        self.assertTrue(np.allclose(sds.full((2, 3), 10.1).compute(), 
np.full((2, 3), 10.1)))
 
     def test_seq(self):
-        self.assertTrue(np.allclose(seq(3).compute(), np.arange(4).reshape(4, 
1)))
+        self.assertTrue(np.allclose(sds.seq(3).compute(), 
np.arange(4).reshape(4, 1)))
 
     def test_var1(self):
-        self.assertTrue(np.allclose(Matrix(m1).var().compute(), 
m1.var(ddof=1)))
+        self.assertTrue(np.allclose(sds.matrix(m1).var().compute(), 
m1.var(ddof=1)))
 
     def test_var2(self):
-        self.assertTrue(np.allclose(Matrix(m1).var(axis=0).compute(), 
m1.var(axis=0, ddof=1)))
+        self.assertTrue(np.allclose(sds.matrix(m1).var(axis=0).compute(), 
m1.var(axis=0, ddof=1)))
 
     def test_var3(self):
-        self.assertTrue(np.allclose(Matrix(m1).var(axis=1).compute(), 
m1.var(axis=1, ddof=1).reshape(dim, 1)))
+        self.assertTrue(np.allclose(sds.matrix(m1).var(axis=1).compute(), 
m1.var(axis=1, ddof=1).reshape(dim, 1)))
 
 
 if __name__ == "__main__":
     unittest.main(exit=False)
-    helpers.shutdown()
+    sds.close()
diff --git a/src/main/python/tests/test_matrix_binary_op.py 
b/src/main/python/tests/test_matrix_binary_op.py
index b5403eb..95dac02 100644
--- a/src/main/python/tests/test_matrix_binary_op.py
+++ b/src/main/python/tests/test_matrix_binary_op.py
@@ -29,8 +29,7 @@ import numpy as np
 path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")
 sys.path.insert(0, path)
 
-from systemds.matrix import Matrix
-from systemds.utils import helpers
+from systemds.context import SystemDSContext
 
 dim = 5
 m1 = np.array(np.random.randint(100, size=dim * dim) + 1.01, dtype=np.double)
@@ -39,6 +38,7 @@ m2 = np.array(np.random.randint(5, size=dim * dim) + 1, 
dtype=np.double)
 m2.shape = (dim, dim)
 s = 3.02
 
+sds = SystemDSContext()
 
 class TestBinaryOp(unittest.TestCase):
 
@@ -53,51 +53,51 @@ class TestBinaryOp(unittest.TestCase):
                                 category=ResourceWarning)
 
     def test_plus(self):
-        self.assertTrue(np.allclose((Matrix(m1) + Matrix(m2)).compute(), m1 + 
m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) + 
sds.matrix(m2)).compute(), m1 + m2))
 
     def test_minus(self):
-        self.assertTrue(np.allclose((Matrix(m1) - Matrix(m2)).compute(), m1 - 
m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) - 
sds.matrix(m2)).compute(), m1 - m2))
 
     def test_mul(self):
-        self.assertTrue(np.allclose((Matrix(m1) * Matrix(m2)).compute(), m1 * 
m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) * 
sds.matrix(m2)).compute(), m1 * m2))
 
     def test_div(self):
-        self.assertTrue(np.allclose((Matrix(m1) / Matrix(m2)).compute(), m1 / 
m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) / 
sds.matrix(m2)).compute(), m1 / m2))
 
     # TODO arithmetic with numpy rhs
 
     # TODO arithmetic with numpy lhs
 
     def test_plus3(self):
-        self.assertTrue(np.allclose((Matrix(m1) + s).compute(), m1 + s))
+        self.assertTrue(np.allclose((sds.matrix(m1) + s).compute(), m1 + s))
 
     def test_minus3(self):
-        self.assertTrue(np.allclose((Matrix(m1) - s).compute(), m1 - s))
+        self.assertTrue(np.allclose((sds.matrix(m1) - s).compute(), m1 - s))
 
     def test_mul3(self):
-        self.assertTrue(np.allclose((Matrix(m1) * s).compute(), m1 * s))
+        self.assertTrue(np.allclose((sds.matrix(m1) * s).compute(), m1 * s))
 
     def test_div3(self):
-        self.assertTrue(np.allclose((Matrix(m1) / s).compute(), m1 / s))
+        self.assertTrue(np.allclose((sds.matrix(m1) / s).compute(), m1 / s))
 
     # TODO arithmetic with scala lhs
 
     def test_lt(self):
-        self.assertTrue(np.allclose((Matrix(m1) < Matrix(m2)).compute(), m1 < 
m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) < 
sds.matrix(m2)).compute(), m1 < m2))
 
     def test_gt(self):
-        self.assertTrue(np.allclose((Matrix(m1) > Matrix(m2)).compute(), m1 > 
m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) > 
sds.matrix(m2)).compute(), m1 > m2))
 
     def test_le(self):
-        self.assertTrue(np.allclose((Matrix(m1) <= Matrix(m2)).compute(), m1 
<= m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) <= 
sds.matrix(m2)).compute(), m1 <= m2))
 
     def test_ge(self):
-        self.assertTrue(np.allclose((Matrix(m1) >= Matrix(m2)).compute(), m1 
>= m2))
+        self.assertTrue(np.allclose((sds.matrix(m1) >= 
sds.matrix(m2)).compute(), m1 >= m2))
 
     def test_abs(self):
-        self.assertTrue(np.allclose(Matrix(m1).abs().compute(), np.abs(m1)))
+        self.assertTrue(np.allclose(sds.matrix(m1).abs().compute(), 
np.abs(m1)))
 
 
 if __name__ == "__main__":
     unittest.main(exit=False)
-    helpers.shutdown()
+    sds.close()

Reply via email to