This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 176fa068a5 [SYSTEMDS-3701] Add LSTM representation to scuro
176fa068a5 is described below
commit 176fa068a5b285f43b8df6d52acbf852519c939b
Author: Christina Dionysio <[email protected]>
AuthorDate: Mon Sep 2 10:51:20 2024 +0200
[SYSTEMDS-3701] Add LSTM representation to scuro
Closes #2086.
---
.../systemds/scuro/representations/__init__.py | 12 +++-
.../python/systemds/scuro/representations/lstm.py | 75 ++++++++++++++++++++++
2 files changed, 86 insertions(+), 1 deletion(-)
diff --git a/src/main/python/systemds/scuro/representations/__init__.py
b/src/main/python/systemds/scuro/representations/__init__.py
index 38df913019..9a2007319d 100644
--- a/src/main/python/systemds/scuro/representations/__init__.py
+++ b/src/main/python/systemds/scuro/representations/__init__.py
@@ -23,6 +23,16 @@ from systemds.scuro.representations.average import Average
from systemds.scuro.representations.concatenation import Concatenation
from systemds.scuro.representations.fusion import Fusion
from systemds.scuro.representations.unimodal import UnimodalRepresentation,
HDF5, NPY, Pickle, JSON
+from systemds.scuro.representations.lstm import LSTM
-__all__ = ["Representation", "Average", "Concatenation", "Fusion",
"UnimodalRepresentation", "HDF5", "NPY", "Pickle", "JSON"]
+__all__ = ["Representation",
+ "Average",
+ "Concatenation",
+ "Fusion",
+ "UnimodalRepresentation",
+ "HDF5",
+ "NPY",
+ "Pickle",
+ "JSON",
+ "LSTM"]
diff --git a/src/main/python/systemds/scuro/representations/lstm.py
b/src/main/python/systemds/scuro/representations/lstm.py
new file mode 100644
index 0000000000..a38ca1e577
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -0,0 +1,75 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import torch
+
+from torch import nn
+from typing import List
+
+import numpy as np
+
+from modality.modality import Modality
+from representations.fusion import Fusion
+
+
+class LSTM(Fusion):
+ def __init__(self, width=128, depth=1, dropout_rate=0.1):
+ """
+ Combines modalities using an LSTM
+ """
+ super().__init__('LSTM')
+ self.depth = depth
+ self.width = width
+ self.dropout_rate = dropout_rate
+ self.unimodal_embeddings = {}
+
+ def fuse(self, modalities: List[Modality], train_indices=None):
+ size = len(modalities[0].data)
+
+ result = np.zeros((size, 0))
+
+ for modality in modalities:
+ if modality.name in self.unimodal_embeddings.keys():
+ out = self.unimodal_embeddings.get(modality.name)
+ else:
+ out = self.run_lstm(modality.data)
+ self.unimodal_embeddings[modality.name] = out
+
+ result = np.concatenate([result, out], axis=-1)
+
+ return result
+
+ def run_lstm(self, data):
+ d = data.astype(np.float32)
+ dim = d.shape[-1]
+ d = torch.from_numpy(d)
+ dropout_layer = torch.nn.Dropout(self.dropout_rate)
+
+ for x in range(0, self.depth):
+ lstm_x = nn.LSTM(dim, self.width, batch_first=True,
bidirectional=True)
+ dim = 2 * self.width
+ d = lstm_x(d)[0]
+
+ out = dropout_layer(d)
+
+ if d.ndim > 2:
+ out = torch.flatten(out, 1)
+
+ return out.detach().numpy()