This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 176fa068a5 [SYSTEMDS-3701] Add LSTM representation to scuro
176fa068a5 is described below

commit 176fa068a5b285f43b8df6d52acbf852519c939b
Author: Christina Dionysio <diony...@tu-berlin.de>
AuthorDate: Mon Sep 2 10:51:20 2024 +0200

    [SYSTEMDS-3701] Add LSTM representation to scuro
    
    Closes #2086.
---
 .../systemds/scuro/representations/__init__.py     | 12 +++-
 .../python/systemds/scuro/representations/lstm.py  | 75 ++++++++++++++++++++++
 2 files changed, 86 insertions(+), 1 deletion(-)

diff --git a/src/main/python/systemds/scuro/representations/__init__.py 
b/src/main/python/systemds/scuro/representations/__init__.py
index 38df913019..9a2007319d 100644
--- a/src/main/python/systemds/scuro/representations/__init__.py
+++ b/src/main/python/systemds/scuro/representations/__init__.py
@@ -23,6 +23,16 @@ from systemds.scuro.representations.average import Average
 from systemds.scuro.representations.concatenation import Concatenation
 from systemds.scuro.representations.fusion import Fusion
 from systemds.scuro.representations.unimodal import UnimodalRepresentation, 
HDF5, NPY, Pickle, JSON
+from systemds.scuro.representations.lstm import LSTM
 
 
-__all__ = ["Representation", "Average", "Concatenation", "Fusion", 
"UnimodalRepresentation", "HDF5", "NPY", "Pickle", "JSON"]
+__all__ = ["Representation",
+           "Average",
+           "Concatenation",
+           "Fusion",
+           "UnimodalRepresentation",
+           "HDF5",
+           "NPY",
+           "Pickle",
+           "JSON",
+           "LSTM"]
diff --git a/src/main/python/systemds/scuro/representations/lstm.py 
b/src/main/python/systemds/scuro/representations/lstm.py
new file mode 100644
index 0000000000..a38ca1e577
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -0,0 +1,75 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import torch
+
+from torch import nn
+from typing import List
+
+import numpy as np
+
+from modality.modality import Modality
+from representations.fusion import Fusion
+
+
+class LSTM(Fusion):
+    def __init__(self, width=128, depth=1, dropout_rate=0.1):
+        """
+        Combines modalities using an LSTM
+        """
+        super().__init__('LSTM')
+        self.depth = depth
+        self.width = width
+        self.dropout_rate = dropout_rate
+        self.unimodal_embeddings = {}
+
+    def fuse(self, modalities: List[Modality], train_indices=None):
+        size = len(modalities[0].data)
+
+        result = np.zeros((size, 0))
+
+        for modality in modalities:
+            if modality.name in self.unimodal_embeddings.keys():
+                out = self.unimodal_embeddings.get(modality.name)
+            else:
+                out = self.run_lstm(modality.data)
+                self.unimodal_embeddings[modality.name] = out
+
+            result = np.concatenate([result, out], axis=-1)
+
+        return result
+
+    def run_lstm(self, data):
+        d = data.astype(np.float32)
+        dim = d.shape[-1]
+        d = torch.from_numpy(d)
+        dropout_layer = torch.nn.Dropout(self.dropout_rate)
+
+        for x in range(0, self.depth):
+            lstm_x = nn.LSTM(dim, self.width, batch_first=True, 
bidirectional=True)
+            dim = 2 * self.width
+            d = lstm_x(d)[0]
+
+        out = dropout_layer(d)
+
+        if d.ndim > 2:
+            out = torch.flatten(out, 1)
+
+        return out.detach().numpy()

Reply via email to