This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8e1e53bacb [SYSTEMDS-3694] Python NN Sequence and layer interface
8e1e53bacb is described below

commit 8e1e53bacbf444625272462c90e1e9c2bfaf206a
Author: Nakroma <tarackobar...@gmail.com>
AuthorDate: Tue May 28 10:40:44 2024 +0200

    [SYSTEMDS-3694] Python NN Sequence and layer interface
    
    This commit:
    - Adds a Layer interface for the Python API.
    - Affine and ReLU classes are changed to extend this interface.
    - Fixes fixes some small formatting issues in the modified classes.
    - Adds a Sequential primitive to the nn Python API.
      It is able to combine multiple nn layers into one sequential module.
    - fix in the python MultiReturn so outputs of the instance can be properly 
accessed.
    - Adds the backwards pass to the Sequential primitives.
    - Variations to Sequential testing involving MultiReturns.
    - Test if the input gradient is set correctly on the backwards pass
      and Fixes a bug where this was not the case on the affine layer.
    - Testing to verify that the layer gets updated correctly during forward 
and backward pass.
    
    AMLS project SoSe'24
    
    Closes #2025
---
 src/main/python/systemds/operator/nn/affine.py     |  34 +--
 .../systemds/operator/nn/{relu.py => layer.py}     |  63 ++---
 src/main/python/systemds/operator/nn/relu.py       |  24 +-
 src/main/python/systemds/operator/nn/sequential.py |  97 +++++++
 .../python/systemds/operator/nodes/multi_return.py |   2 +-
 src/main/python/tests/nn/test_affine.py            |   6 +-
 src/main/python/tests/nn/test_layer.py             |  80 ++++++
 src/main/python/tests/nn/test_sequential.py        | 304 +++++++++++++++++++++
 8 files changed, 535 insertions(+), 75 deletions(-)

diff --git a/src/main/python/systemds/operator/nn/affine.py 
b/src/main/python/systemds/operator/nn/affine.py
index 44c67d1eda..35935871aa 100644
--- a/src/main/python/systemds/operator/nn/affine.py
+++ b/src/main/python/systemds/operator/nn/affine.py
@@ -18,21 +18,15 @@
 # under the License.
 #
 # -------------------------------------------------------------
-import os
-
 from systemds.context import SystemDSContext
-from systemds.operator import Matrix, Source, MultiReturn
-from systemds.utils.helpers import get_path_to_script_layers
+from systemds.operator import Matrix, MultiReturn
+from systemds.operator.nn.layer import Layer
 
 
-class Affine:
-    _source: Source = None
+class Affine(Layer):
     weight: Matrix
     bias: Matrix
 
-    def __new__(cls, *args, **kwargs):
-        return super().__new__(cls)
-
     def __init__(self, sds_context: SystemDSContext, d, m, seed=-1):
         """
         sds_context: The systemdsContext to construct the layer inside of
@@ -40,11 +34,8 @@ class Affine:
         m: The number of neurons that are contained in the layer, 
             and the number of features output
         """
-        Affine._create_source(sds_context)
-
-        # bypassing overload limitation in python
-        self.forward = self._instance_forward
-        self.backward = self._instance_backward
+        super().__init__(sds_context, 'affine.dml')
+        self._X = None
 
         # init weight and bias
         self.weight = Matrix(sds_context, '')
@@ -64,7 +55,7 @@ class Affine:
         b: The bias added in the output.
         return out: An output matrix.
         """
-        Affine._create_source(X.sds_context)
+        Affine._create_source(X.sds_context, "affine.dml")
         return Affine._source.forward(X, W, b)
 
     @staticmethod
@@ -77,7 +68,7 @@ class Affine:
         return dX, dW, db: The gradients of: input X, weights and bias.
         """
         sds = X.sds_context
-        Affine._create_source(sds)
+        Affine._create_source(sds, "affine.dml")
         params_dict = {'dout': dout, 'X': X, 'W': W, 'b': b}
         dX = Matrix(sds, '')
         dW = Matrix(sds, '')
@@ -104,11 +95,6 @@ class Affine:
         X: The input to this layer.
         return dX, dW,db: gradient of input, weights and bias, respectively
         """
-        return Affine.backward(dout, X, self.weight, self.bias)
-
-    @staticmethod
-    def _create_source(sds: SystemDSContext):
-        if Affine._source is None or Affine._source.sds_context != sds:
-            path = get_path_to_script_layers()
-            path = os.path.join(path, "affine.dml")
-            Affine._source = sds.source(path, "affine")
+        gradients = Affine.backward(dout, X, self.weight, self.bias)
+        self._X = gradients[0]
+        return gradients
diff --git a/src/main/python/systemds/operator/nn/relu.py 
b/src/main/python/systemds/operator/nn/layer.py
similarity index 50%
copy from src/main/python/systemds/operator/nn/relu.py
copy to src/main/python/systemds/operator/nn/layer.py
index 99833e6d86..255fa2d4d1 100644
--- a/src/main/python/systemds/operator/nn/relu.py
+++ b/src/main/python/systemds/operator/nn/layer.py
@@ -18,51 +18,52 @@
 # under the License.
 #
 # -------------------------------------------------------------
-import os.path
+import os
 
 from systemds.context import SystemDSContext
-from systemds.operator import Matrix, Source
+from systemds.operator import Source
 from systemds.utils.helpers import get_path_to_script_layers
 
 
-class ReLU:
+class Layer:
+    """
+    Interface for neural network layers
+    """
+
     _source: Source = None
 
-    def __init__(self, sds: SystemDSContext):
-        ReLU._create_source(sds)
+    def __init__(self, sds_context: SystemDSContext = None, dml_script: str = 
None):
+        if sds_context is not None and dml_script is not None:
+            self.__class__._create_source(sds_context, dml_script)
+
+        # bypassing overload limitation in python
         self.forward = self._instance_forward
         self.backward = self._instance_backward
 
-    @staticmethod
-    def forward(X: Matrix):
-        """
-        X: input matrix
-        return out: output matrix
-        """
-        ReLU._create_source(X.sds_context)
-        return ReLU._source.forward(X)
-
-    @staticmethod
-    def backward(dout: Matrix, X: Matrix):
+    @classmethod
+    def _create_source(cls, sds_context: SystemDSContext, dml_script: str):
         """
-        dout: gradient of output, passed from the upstream
-        X: input matrix
-        return dX: gradient of input
+        Create SystemDS source
+        :param sds_context: SystemDS context
+        :param dml_script: DML script inside /scripts/nn/layers/
+        :return:
         """
-        ReLU._create_source(dout.sds_context)
-        return ReLU._source.backward(dout, X)
+        if cls._source is None or cls._source.sds_context != sds_context:
+            script_path = get_path_to_script_layers()
+            path = os.path.join(script_path, dml_script)
+            name = dml_script.split(".")[0]
+            cls._source = sds_context.source(path, name)
 
-    def _instance_forward(self, X: Matrix):
-        self._X = X
-        return ReLU.forward(X)
+    def _instance_forward(self, *args):
+        raise NotImplementedError
 
-    def _instance_backward(self, dout: Matrix, X: Matrix):
-        return ReLU.backward(dout, X)
+    def _instance_backward(self, *args):
+        raise NotImplementedError
 
     @staticmethod
-    def _create_source(sds: SystemDSContext):
-        if ReLU._source is None or ReLU._source.sds_context != sds:
-            path = get_path_to_script_layers()
-            path = os.path.join(path, "relu.dml")
-            ReLU._source = sds.source(path, "relu")
+    def forward(*args):
+        raise NotImplementedError
 
+    @staticmethod
+    def backward(*args):
+        raise NotImplementedError
diff --git a/src/main/python/systemds/operator/nn/relu.py 
b/src/main/python/systemds/operator/nn/relu.py
index 99833e6d86..e124e350d9 100644
--- a/src/main/python/systemds/operator/nn/relu.py
+++ b/src/main/python/systemds/operator/nn/relu.py
@@ -18,20 +18,16 @@
 # under the License.
 #
 # -------------------------------------------------------------
-import os.path
-
 from systemds.context import SystemDSContext
 from systemds.operator import Matrix, Source
-from systemds.utils.helpers import get_path_to_script_layers
+from systemds.operator.nn.layer import Layer
 
 
-class ReLU:
+class ReLU(Layer):
     _source: Source = None
 
-    def __init__(self, sds: SystemDSContext):
-        ReLU._create_source(sds)
-        self.forward = self._instance_forward
-        self.backward = self._instance_backward
+    def __init__(self, sds_context: SystemDSContext):
+        super().__init__(sds_context, "relu.dml")
 
     @staticmethod
     def forward(X: Matrix):
@@ -39,7 +35,7 @@ class ReLU:
         X: input matrix
         return out: output matrix
         """
-        ReLU._create_source(X.sds_context)
+        ReLU._create_source(X.sds_context, "relu.dml")
         return ReLU._source.forward(X)
 
     @staticmethod
@@ -49,7 +45,7 @@ class ReLU:
         X: input matrix
         return dX: gradient of input
         """
-        ReLU._create_source(dout.sds_context)
+        ReLU._create_source(dout.sds_context, "relu.dml")
         return ReLU._source.backward(dout, X)
 
     def _instance_forward(self, X: Matrix):
@@ -58,11 +54,3 @@ class ReLU:
 
     def _instance_backward(self, dout: Matrix, X: Matrix):
         return ReLU.backward(dout, X)
-
-    @staticmethod
-    def _create_source(sds: SystemDSContext):
-        if ReLU._source is None or ReLU._source.sds_context != sds:
-            path = get_path_to_script_layers()
-            path = os.path.join(path, "relu.dml")
-            ReLU._source = sds.source(path, "relu")
-
diff --git a/src/main/python/systemds/operator/nn/sequential.py 
b/src/main/python/systemds/operator/nn/sequential.py
new file mode 100644
index 0000000000..bf54ff5a65
--- /dev/null
+++ b/src/main/python/systemds/operator/nn/sequential.py
@@ -0,0 +1,97 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from systemds.operator import MultiReturn
+from systemds.operator.nn.layer import Layer
+
+
+class Sequential(Layer):
+    def __init__(self, *args):
+        super().__init__()
+
+        self.layers = []
+        if len(args) == 1 and isinstance(args[0], list):
+            self.layers = args[0]
+        else:
+            self.layers = list(args)
+
+    def __len__(self):
+        return len(self.layers)
+
+    def __getitem__(self, idx):
+        return self.layers[idx]
+
+    def __setitem__(self, idx, value):
+        self.layers[idx] = value
+
+    def __delitem__(self, idx):
+        del self.layers[idx]
+
+    def __iter__(self):
+        return iter(self.layers)
+
+    def __reversed__(self):
+        return reversed(self.layers)
+
+    def push(self, layer: Layer):
+        """
+        Add layer
+        :param layer: Layer
+        :return:
+        """
+        self.layers.append(layer)
+
+    def pop(self):
+        """
+        Remove last layer
+        :return: Layer
+        """
+        return self.layers.pop()
+
+    def _instance_forward(self, X):
+        """
+        Forward pass
+        :param X: Input matrix
+        :return: output matrix
+        """
+        out = X
+        for layer in self:
+            out = layer.forward(out)
+
+            # if MultiReturn, take only output matrix
+            if isinstance(out, MultiReturn):
+                out = out[0]
+        return out
+
+    def _instance_backward(self, dout, X):
+        """
+        Backward pass
+        :param dout: gradient of output, passed from the upstream
+        :param X: input matrix
+        :return: output matrix
+        """
+        dx = dout
+        for layer in reversed(self):
+            dx = layer.backward(dx, X)
+
+            # if MultiReturn, take only gradient of input
+            if isinstance(dx, MultiReturn):
+                dx = dx[0]
+        return dx
diff --git a/src/main/python/systemds/operator/nodes/multi_return.py 
b/src/main/python/systemds/operator/nodes/multi_return.py
index cb6b923d2c..e2fa09b3db 100644
--- a/src/main/python/systemds/operator/nodes/multi_return.py
+++ b/src/main/python/systemds/operator/nodes/multi_return.py
@@ -47,7 +47,7 @@ class MultiReturn(OperationNode):
                          named_input_nodes, OutputType.MULTI_RETURN, False)
 
     def __getitem__(self, key):
-        self._outputs[key]
+        return self._outputs[key]
 
     def code_line(self, var_name: str, unnamed_input_vars: Sequence[str],
                   named_input_vars: Dict[str, str]) -> str:
diff --git a/src/main/python/tests/nn/test_affine.py 
b/src/main/python/tests/nn/test_affine.py
index 955945b29c..a7de2c383d 100644
--- a/src/main/python/tests/nn/test_affine.py
+++ b/src/main/python/tests/nn/test_affine.py
@@ -77,6 +77,7 @@ class TestAffine(unittest.TestCase):
         out = affine.forward(Xm).compute()
         self.assertEqual(len(out), 5)
         self.assertEqual(len(out[0]), 6)
+        assert_almost_equal(affine._X.compute(), Xm.compute())
 
         # test static method
         out = Affine.forward(Xm, Wm, bm).compute()
@@ -91,10 +92,13 @@ class TestAffine(unittest.TestCase):
 
         # test class method
         affine = Affine(self.sds, dim, m, 10)
-        [dx, dw, db] = affine.backward(doutm, Xm).compute()
+        gradients = affine.backward(doutm, Xm)
+        intermediate = affine._X.compute()
+        [dx, dw, db] = gradients.compute()
         assert len(dx) == 5 and len(dx[0]) == 6
         assert len(dw) == 6 and len(dx[0]) == 6
         assert len(db) == 1 and len(db[0]) == 6
+        assert_almost_equal(intermediate, dx)
 
         # test static method
         res = Affine.backward(doutm, Xm, Wm, bm).compute()
diff --git a/src/main/python/tests/nn/test_layer.py 
b/src/main/python/tests/nn/test_layer.py
new file mode 100644
index 0000000000..0b6a0eb2e1
--- /dev/null
+++ b/src/main/python/tests/nn/test_layer.py
@@ -0,0 +1,80 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+
+from systemds.context import SystemDSContext
+from systemds.operator.nn.layer import Layer
+
+
+class TestLayer(unittest.TestCase):
+    sds: SystemDSContext = None
+
+    @classmethod
+    def setUpClass(cls):
+        cls.sds = SystemDSContext()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.sds.close()
+
+    def test_init(self):
+        """
+        Test that the source is created correctly from dml_script param when 
layer is initialized
+        """
+        _ = Layer(self.sds, "relu.dml")
+        self.assertIsNotNone(Layer._source)
+        self.assertTrue(Layer._source.operation.endswith('relu.dml"'))
+        self.assertEqual(Layer._source._Source__name, "relu")
+
+    def test_notimplemented(self):
+        """
+        Test that NotImplementedError is raised
+        """
+
+        class TestLayerImpl(Layer):
+            pass
+
+        layer = TestLayerImpl(self.sds, "relu.dml")
+        with self.assertRaises(NotImplementedError):
+            layer.forward(None)
+        with self.assertRaises(NotImplementedError):
+            layer.backward(None)
+        with self.assertRaises(NotImplementedError):
+            TestLayerImpl.forward(None)
+        with self.assertRaises(NotImplementedError):
+            TestLayerImpl.backward(None)
+
+    def test_class_source_assignments(self):
+        """
+        Test that the source is not shared between interface and 
implementation class
+        """
+
+        class TestLayerImpl(Layer):
+            @classmethod
+            def _create_source(cls, sds_context: SystemDSContext, dml_script: 
str):
+                cls._source = "test"
+
+        _ = Layer(self.sds, "relu.dml")
+        _ = TestLayerImpl(self.sds, "relu.dml")
+
+        self.assertNotEqual(Layer._source, "test")
+        self.assertEqual(TestLayerImpl._source, "test")
diff --git a/src/main/python/tests/nn/test_sequential.py 
b/src/main/python/tests/nn/test_sequential.py
new file mode 100644
index 0000000000..a7a361e40f
--- /dev/null
+++ b/src/main/python/tests/nn/test_sequential.py
@@ -0,0 +1,304 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+
+import numpy as np
+from numpy.testing import assert_almost_equal
+
+from systemds.operator.nn.affine import Affine
+from systemds.operator.nn.relu import ReLU
+from systemds.operator.nn.sequential import Sequential
+from systemds.operator import Matrix, MultiReturn
+from systemds.operator.nn.layer import Layer
+from systemds.context import SystemDSContext
+
+
+class TestLayerImpl(Layer):
+    def __init__(self, test_id):
+        super().__init__()
+        self.test_id = test_id
+
+    def _instance_forward(self, X: Matrix):
+        return X + self.test_id
+
+    def _instance_backward(self, dout: Matrix, X: Matrix):
+        return dout - self.test_id
+
+
+class MultiReturnImpl(Layer):
+    def __init__(self, sds):
+        super().__init__()
+        self.sds = sds
+
+    def _instance_forward(self, X: Matrix):
+        return MultiReturn(self.sds, "test.dml", output_nodes=[X, 
'some_random_return'])
+
+    def _instance_backward(self, dout: Matrix, X: Matrix):
+        return MultiReturn(self.sds, "test.dml", output_nodes=[dout, X, 
'some_random_return'])
+
+
+class TestSequential(unittest.TestCase):
+    sds: SystemDSContext = None
+
+    @classmethod
+    def setUpClass(cls):
+        cls.sds = SystemDSContext()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.sds.close()
+
+    def test_init_with_multiple_args(self):
+        """
+        Test that Sequential is correctly initialized if multiple layers are 
passed as arguments
+        """
+        model = Sequential(TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3))
+        self.assertEqual(len(model.layers), 3)
+        self.assertEqual(model.layers[0].test_id, 1)
+        self.assertEqual(model.layers[1].test_id, 2)
+        self.assertEqual(model.layers[2].test_id, 3)
+
+    def test_init_with_list(self):
+        """
+        Test that Sequential is correctly initialized if list of layers is 
passed as argument
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        self.assertEqual(len(model.layers), 3)
+        self.assertEqual(model.layers[0].test_id, 1)
+        self.assertEqual(model.layers[1].test_id, 2)
+        self.assertEqual(model.layers[2].test_id, 3)
+
+    def test_len(self):
+        """
+        Test that len() returns the number of layers
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        self.assertEqual(len(model), 3)
+
+    def test_getitem(self):
+        """
+        Test that Sequential[index] returns the layer at the given index
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        self.assertEqual(model[1].test_id, 2)
+
+    def test_setitem(self):
+        """
+        Test that Sequential[index] = layer sets the layer at the given index
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        model[1] = TestLayerImpl(4)
+        self.assertEqual(model[1].test_id, 4)
+
+    def test_delitem(self):
+        """
+        Test that del Sequential[index] removes the layer at the given index
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        del model[1]
+        self.assertEqual(len(model.layers), 2)
+        self.assertEqual(model[1].test_id, 3)
+
+    def test_iter(self):
+        """
+        Test that iter() returns an iterator over the layers
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        for i, layer in enumerate(model):
+            self.assertEqual(layer.test_id, i + 1)
+
+    def test_push(self):
+        """
+        Test that push() adds a layer
+        """
+        model = Sequential()
+        model.push(TestLayerImpl(1))
+        self.assertEqual(len(model.layers), 1)
+        self.assertEqual(model.layers[0].test_id, 1)
+
+    def test_pop(self):
+        """
+        Test that pop() removes the last layer
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        layer = model.pop()
+        self.assertEqual(len(model.layers), 2)
+        self.assertEqual(layer.test_id, 3)
+
+    def test_reversed(self):
+        """
+        Test that reversed() returns an iterator over the layers in reverse 
order
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        for i, layer in enumerate(reversed(model)):
+            self.assertEqual(layer.test_id, 3 - i)
+
+    def test_forward(self):
+        """
+        Test that forward() calls forward() on all layers
+        """
+        model = Sequential([TestLayerImpl(1), TestLayerImpl(2), 
TestLayerImpl(3)])
+        in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]]))
+        out_matrix = model.forward(in_matrix).compute()
+        self.assertEqual(out_matrix.tolist(), [[7, 8], [9, 10]])
+
+    def test_forward_actual_layers(self):
+        """
+        Test forward() with actual layers
+        """
+        params = [
+            np.array([[0.5, -0.5], [-0.5, 0.5]]),
+            np.array([[0.1, -0.1]]),
+            np.array([[0.4, -0.4], [-0.4, 0.4]]),
+            np.array([[0.2, -0.2]]),
+            np.array([[0.3, -0.3], [-0.3, 0.3]]),
+            np.array([[0.3, -0.3]]),
+        ]
+
+        model = Sequential(
+            [
+                Affine(self.sds, 2, 2),
+                ReLU(self.sds),
+                Affine(self.sds, 2, 2),
+                ReLU(self.sds),
+                Affine(self.sds, 2, 2),
+            ]
+        )
+
+        for i, layer in enumerate(model):
+            if isinstance(layer, Affine):
+                layer.weight = self.sds.from_numpy(params[i])
+                layer.bias = self.sds.from_numpy(params[i + 1])
+
+        in_matrix = self.sds.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]]))
+        out_matrix = model.forward(in_matrix).compute()
+        expected = np.array([[0.3120, -0.3120], [0.3120, -0.3120]])
+        assert_almost_equal(out_matrix, expected)
+
+    def test_backward_actual_layers(self):
+        """
+        Test backward() with actual layers
+        """
+        params = [
+            np.array([[0.5, -0.5], [-0.5, 0.5]]),
+            np.array([[0.1, -0.1]]),
+            np.array([[0.4, -0.4], [-0.4, 0.4]]),
+            np.array([[0.2, -0.2]]),
+            np.array([[0.3, -0.3], [-0.3, 0.3]]),
+            np.array([[0.3, -0.3]]),
+        ]
+
+        model = Sequential(
+            [
+                Affine(self.sds, 2, 2),
+                ReLU(self.sds),
+                Affine(self.sds, 2, 2),
+                ReLU(self.sds),
+                Affine(self.sds, 2, 2),
+            ]
+        )
+
+        for i, layer in enumerate(model):
+            if isinstance(layer, Affine):
+                layer.weight = self.sds.from_numpy(params[i])
+                layer.bias = self.sds.from_numpy(params[i + 1])
+
+        in_matrix = self.sds.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]]))
+        out_matrix = model.forward(in_matrix)
+        gradient = model.backward(out_matrix, in_matrix).compute()
+
+        # Test returned gradient
+        expected = np.array([[0.14976, -0.14976], [0.14976, -0.14976]])
+        assert_almost_equal(gradient, expected)
+
+        # Test if layers have been updated correctly
+        expected_gradients = [
+            np.array([[0.14976, -0.14976], [0.14976, -0.14976]]),
+            np.array([[0.14976, -0.14976], [0.14976, -0.14976]]),
+            np.array([[0.1872, -0.1872], [0.1872, -0.1872]]),
+        ]
+        for i, layer in enumerate(model):
+            if isinstance(layer, Affine):
+                assert_almost_equal(layer._X.compute(), 
expected_gradients[int(i / 2)])
+
+    def test_multireturn_forward_pass(self):
+        """
+        Test that forward() handles MultiReturn correctly
+        """
+        model = Sequential(MultiReturnImpl(self.sds), TestLayerImpl(1))
+        in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]]))
+        out_matrix = model.forward(in_matrix).compute()
+        self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]])
+
+    def test_multireturn_backward_pass(self):
+        """
+        Test that backward() handles MultiReturn correctly
+        """
+        model = Sequential(TestLayerImpl(1), MultiReturnImpl(self.sds))
+        in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]]))
+        out_matrix = self.sds.from_numpy(np.array([[2, 3], [4, 5]]))
+        gradient = model.backward(out_matrix, in_matrix).compute()
+        self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]])
+
+    def test_multireturn_variation_multiple(self):
+        """
+        Test that multiple MultiReturn after each other are handled correctly
+        """
+        model = Sequential(MultiReturnImpl(self.sds), 
MultiReturnImpl(self.sds))
+        in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]]))
+        out_matrix = model.forward(in_matrix).compute()
+        self.assertEqual(out_matrix.tolist(), [[1, 2], [3, 4]])
+        gradient = model.backward(self.sds.from_numpy(out_matrix), 
in_matrix).compute()
+        self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]])
+
+    def test_multireturn_variation_single_to_multiple(self):
+        """
+        Test that a single return into multiple MultiReturn are handled 
correctly
+        """
+        model = Sequential(TestLayerImpl(1), MultiReturnImpl(self.sds), 
MultiReturnImpl(self.sds))
+        in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]]))
+        out_matrix = model.forward(in_matrix).compute()
+        self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]])
+        gradient = model.backward(self.sds.from_numpy(out_matrix), 
in_matrix).compute()
+        self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]])
+
+    def test_multireturn_variation_multiple_to_single(self):
+        """
+        Test that multiple MultiReturn into a single return are handled 
correctly
+        """
+        model = Sequential(MultiReturnImpl(self.sds), 
MultiReturnImpl(self.sds), TestLayerImpl(1))
+        in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]]))
+        out_matrix = model.forward(in_matrix).compute()
+        self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]])
+        gradient = model.backward(self.sds.from_numpy(out_matrix), 
in_matrix).compute()
+        self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]])
+
+    def test_multireturn_variation_sandwich(self):
+        """
+        Test that a single return between two MultiReturn are handled correctly
+        """
+        model = Sequential(MultiReturnImpl(self.sds), TestLayerImpl(1), 
MultiReturnImpl(self.sds))
+        in_matrix = self.sds.from_numpy(np.array([[1, 2], [3, 4]]))
+        out_matrix = model.forward(in_matrix).compute()
+        self.assertEqual(out_matrix.tolist(), [[2, 3], [4, 5]])
+        gradient = model.backward(self.sds.from_numpy(out_matrix), 
in_matrix).compute()
+        self.assertEqual(gradient.tolist(), [[1, 2], [3, 4]])

Reply via email to