This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 176fa068a5 [SYSTEMDS-3701] Add LSTM representation to scuro 176fa068a5 is described below commit 176fa068a5b285f43b8df6d52acbf852519c939b Author: Christina Dionysio <diony...@tu-berlin.de> AuthorDate: Mon Sep 2 10:51:20 2024 +0200 [SYSTEMDS-3701] Add LSTM representation to scuro Closes #2086. --- .../systemds/scuro/representations/__init__.py | 12 +++- .../python/systemds/scuro/representations/lstm.py | 75 ++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) diff --git a/src/main/python/systemds/scuro/representations/__init__.py b/src/main/python/systemds/scuro/representations/__init__.py index 38df913019..9a2007319d 100644 --- a/src/main/python/systemds/scuro/representations/__init__.py +++ b/src/main/python/systemds/scuro/representations/__init__.py @@ -23,6 +23,16 @@ from systemds.scuro.representations.average import Average from systemds.scuro.representations.concatenation import Concatenation from systemds.scuro.representations.fusion import Fusion from systemds.scuro.representations.unimodal import UnimodalRepresentation, HDF5, NPY, Pickle, JSON +from systemds.scuro.representations.lstm import LSTM -__all__ = ["Representation", "Average", "Concatenation", "Fusion", "UnimodalRepresentation", "HDF5", "NPY", "Pickle", "JSON"] +__all__ = ["Representation", + "Average", + "Concatenation", + "Fusion", + "UnimodalRepresentation", + "HDF5", + "NPY", + "Pickle", + "JSON", + "LSTM"] diff --git a/src/main/python/systemds/scuro/representations/lstm.py b/src/main/python/systemds/scuro/representations/lstm.py new file mode 100644 index 0000000000..a38ca1e577 --- /dev/null +++ b/src/main/python/systemds/scuro/representations/lstm.py @@ -0,0 +1,75 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- +import torch + +from torch import nn +from typing import List + +import numpy as np + +from modality.modality import Modality +from representations.fusion import Fusion + + +class LSTM(Fusion): + def __init__(self, width=128, depth=1, dropout_rate=0.1): + """ + Combines modalities using an LSTM + """ + super().__init__('LSTM') + self.depth = depth + self.width = width + self.dropout_rate = dropout_rate + self.unimodal_embeddings = {} + + def fuse(self, modalities: List[Modality], train_indices=None): + size = len(modalities[0].data) + + result = np.zeros((size, 0)) + + for modality in modalities: + if modality.name in self.unimodal_embeddings.keys(): + out = self.unimodal_embeddings.get(modality.name) + else: + out = self.run_lstm(modality.data) + self.unimodal_embeddings[modality.name] = out + + result = np.concatenate([result, out], axis=-1) + + return result + + def run_lstm(self, data): + d = data.astype(np.float32) + dim = d.shape[-1] + d = torch.from_numpy(d) + dropout_layer = torch.nn.Dropout(self.dropout_rate) + + for x in range(0, self.depth): + lstm_x = nn.LSTM(dim, self.width, batch_first=True, bidirectional=True) + dim = 2 * self.width + d = lstm_x(d)[0] + + out = dropout_layer(d) + + if d.ndim > 2: + out = torch.flatten(out, 1) + + return out.detach().numpy()