This is an automated email from the ASF dual-hosted git repository.

cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new abf179a77d [SYSTEMDS-3913] Combined unimodal representation
abf179a77d is described below

commit abf179a77d05d3a69d454de613477bfe6104e318
Author: Christina Dionysio <[email protected]>
AuthorDate: Tue Sep 2 11:02:52 2025 +0200

    [SYSTEMDS-3913] Combined unimodal representation
    
    This patch adds a combined unimodal representation feature to the unimodal 
representation optimizer. As an example multiple audio representations can be 
combined into one single representation using either Concatenation, 
Hadamard-product, or Addition. To ensure this feature works as intended, the 
unimodal optimization test was adapted.
---
 src/main/python/systemds/scuro/__init__.py         |  14 +
 .../systemds/scuro/dataloader/audio_loader.py      |   1 -
 .../scuro/drsearch/multimodal_optimizer.py         |   6 +-
 .../systemds/scuro/drsearch/operator_registry.py   |  13 +
 src/main/python/systemds/scuro/drsearch/task.py    |  73 ++++-
 .../systemds/scuro/drsearch/unimodal_optimizer.py  | 104 ++++--
 .../python/systemds/scuro/modality/modality.py     |   2 +-
 .../python/systemds/scuro/modality/transformed.py  |  23 +-
 .../systemds/scuro/modality/unimodal_modality.py   |  60 +++-
 .../representations/aggregated_representation.py   |   5 +-
 .../representations/covarep_audio_features.py      | 156 +++++++++
 .../systemds/scuro/representations/fusion.py       |  21 ++
 .../scuro/representations/mel_spectrogram.py       |   5 +-
 .../python/systemds/scuro/representations/mfcc.py  |   2 +-
 .../representations/multimodal_attention_fusion.py | 365 +++++++++++++++++++++
 .../scuro/representations/representation.py        |   1 +
 .../systemds/scuro/representations/resnet.py       |  20 +-
 .../systemds/scuro/representations/spectrogram.py  |   4 +-
 .../representations/swin_video_transformer.py      |  38 +--
 .../systemds/scuro/representations/unimodal.py     |   5 +-
 .../systemds/scuro/representations/word2vec.py     |   2 +-
 .../systemds/scuro/utils/static_variables.py       |   3 +-
 src/main/python/tests/scuro/data_generator.py      |   8 +-
 .../python/tests/scuro/test_operator_registry.py   |  41 ++-
 .../python/tests/scuro/test_unimodal_optimizer.py  |  11 +-
 .../tests/scuro/test_unimodal_representations.py   |  99 ++++--
 26 files changed, 919 insertions(+), 163 deletions(-)

diff --git a/src/main/python/systemds/scuro/__init__.py 
b/src/main/python/systemds/scuro/__init__.py
index ae9aed44c0..b2a5e9df37 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -38,6 +38,9 @@ from systemds.scuro.representations.glove import GloVe
 from systemds.scuro.representations.lstm import LSTM
 from systemds.scuro.representations.max import RowMax
 from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.multimodal_attention_fusion import (
+    AttentionFusion,
+)
 from systemds.scuro.representations.mfcc import MFCC
 from systemds.scuro.representations.hadamard import Hadamard
 from systemds.scuro.representations.optical_flow import OpticalFlow
@@ -73,6 +76,12 @@ from systemds.scuro.drsearch.representation_cache import 
RepresentationCache
 from systemds.scuro.drsearch.unimodal_representation_optimizer import (
     UnimodalRepresentationOptimizer,
 )
+from systemds.scuro.representations.covarep_audio_features import (
+    RMSE,
+    Spectral,
+    ZeroCrossing,
+    Pitch,
+)
 from systemds.scuro.drsearch.multimodal_optimizer import MultimodalOptimizer
 from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
 
@@ -131,4 +140,9 @@ __all__ = [
     "UnimodalRepresentationOptimizer",
     "UnimodalOptimizer",
     "MultimodalOptimizer",
+    "ZeroCrossing",
+    "Pitch",
+    "RMSE",
+    "Spectral",
+    "AttentionFusion",
 ]
diff --git a/src/main/python/systemds/scuro/dataloader/audio_loader.py 
b/src/main/python/systemds/scuro/dataloader/audio_loader.py
index 1197617673..d8080c607d 100644
--- a/src/main/python/systemds/scuro/dataloader/audio_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/audio_loader.py
@@ -19,7 +19,6 @@
 #
 # -------------------------------------------------------------
 from typing import List, Optional, Union
-
 import librosa
 import numpy as np
 
diff --git a/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
index 2da8e7ae19..ac4365ed5c 100644
--- a/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
@@ -98,6 +98,9 @@ class MultimodalOptimizer:
                             ],
                         )
 
+    # TODO: check if order matters for reused reps - only compute once - check 
in cache
+    # TODO: parallelize - whenever an item of len 0 comes along give it to a 
new thread - merge results
+    # TODO: change the algorithm so that one representation is used until 
there is no more representations to add - saves a lot of memory
     def optimize_intermodal_representations(self, task):
         modality_combos = []
         n = len(self.k_best_cache[task.model.name])
@@ -122,8 +125,7 @@ class MultimodalOptimizer:
         reuse_fused_representations = False
         for i, modality_combo in enumerate(modality_combos):
             # clear reuse cache
-            if i % 5 == 0:
-                reuse_cache = self.prune_cache(modality_combos[i:], 
reuse_cache)
+            reuse_cache = self.prune_cache(modality_combos[i:], reuse_cache)
 
             if i != 0:
                 reuse_fused_representations = self.is_prefix_match(
diff --git a/src/main/python/systemds/scuro/drsearch/operator_registry.py 
b/src/main/python/systemds/scuro/drsearch/operator_registry.py
index 3909b51ff9..699dcad857 100644
--- a/src/main/python/systemds/scuro/drsearch/operator_registry.py
+++ b/src/main/python/systemds/scuro/drsearch/operator_registry.py
@@ -43,6 +43,12 @@ class Registry:
                 cls._representations[m_type] = []
         return cls._instance
 
+    def set_fusion_operators(self, fusion_operators):
+        if isinstance(fusion_operators, list):
+            self._context_operators = fusion_operators
+        else:
+            self._fusion_operators = [fusion_operators]
+
     def add_representation(
         self, representation: Representation, modality: ModalityType
     ):
@@ -57,6 +63,13 @@ class Registry:
     def get_representations(self, modality: ModalityType):
         return self._representations[modality]
 
+    def get_not_self_contained_representations(self, modality: ModalityType):
+        reps = []
+        for rep in self.get_representations(modality):
+            if not rep().self_contained:
+                reps.append(rep)
+        return reps
+
     def get_context_operators(self):
         # TODO: return modality specific context operations
         return self._context_operators
diff --git a/src/main/python/systemds/scuro/drsearch/task.py 
b/src/main/python/systemds/scuro/drsearch/task.py
index 7e05a489e4..d08844c7bb 100644
--- a/src/main/python/systemds/scuro/drsearch/task.py
+++ b/src/main/python/systemds/scuro/drsearch/task.py
@@ -19,8 +19,10 @@
 #
 # -------------------------------------------------------------
 import time
-from typing import List
+from typing import List, Union
 
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.representation import Representation
 from systemds.scuro.models.model import Model
 import numpy as np
 from sklearn.model_selection import KFold
@@ -57,6 +59,8 @@ class Task:
         self.inference_time = []
         self.training_time = []
         self.expected_dim = 1
+        self.train_scores = []
+        self.val_scores = []
 
     def get_train_test_split(self, data):
         X_train = [data[i] for i in self.train_indices]
@@ -73,28 +77,69 @@ class Task:
          :param data: The aligned data used in the prediction process
          :return: the validation accuracy
         """
+        self._reset_params()
+        skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
+
+        fold = 0
+        X, y, _, _ = self.get_train_test_split(data)
+
+        for train, test in skf.split(X, y):
+            train_X = np.array(X)[train]
+            train_y = np.array(y)[train]
+            test_X = np.array(X)[test]
+            test_y = np.array(y)[test]
+            self._run_fold(train_X, train_y, test_X, test_y)
+            fold += 1
+
+        if self.measure_performance:
+            self.inference_time = np.mean(self.inference_time)
+            self.training_time = np.mean(self.training_time)
+
+        return [np.mean(self.train_scores), np.mean(self.val_scores)]
+
+    def _reset_params(self):
         self.inference_time = []
         self.training_time = []
+        self.train_scores = []
+        self.val_scores = []
+
+    def _run_fold(self, train_X, train_y, test_X, test_y):
+        train_start = time.time()
+        train_score = self.model.fit(train_X, train_y, test_X, test_y)
+        train_end = time.time()
+        self.training_time.append(train_end - train_start)
+        self.train_scores.append(train_score)
+        test_start = time.time()
+        test_score = self.model.test(np.array(test_X), test_y)
+        test_end = time.time()
+        self.inference_time.append(test_end - test_start)
+        self.val_scores.append(test_score)
+
+    def create_representation_and_run(
+        self,
+        representation: Representation,
+        modalities: Union[List[Modality], Modality],
+    ):
+        self._reset_params()
         skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
-        train_scores = []
-        test_scores = []
+
         fold = 0
-        X, y, X_test, y_test = self.get_train_test_split(data)
+        X, y, _, _ = self.get_train_test_split(data)
 
         for train, test in skf.split(X, y):
             train_X = np.array(X)[train]
             train_y = np.array(y)[train]
-            train_start = time.time()
-            train_score = self.model.fit(train_X, train_y, X_test, y_test)
-            train_end = time.time()
-            self.training_time.append(train_end - train_start)
-            train_scores.append(train_score)
-            test_start = time.time()
-            test_score = self.model.test(np.array(X_test), y_test)
-            test_end = time.time()
-            self.inference_time.append(test_end - test_start)
-            test_scores.append(test_score)
+            test_X = s.transform(np.array(X)[test])
+            test_y = np.array(y)[test]
+
+            if isinstance(modalities, Modality):
+                rep = modality.apply_representation(representation())
+            else:
+                representation().transform(
+                    train_X, train_y
+                )  # TODO: think about a way how to handle masks
 
+            self._run_fold(train_X, train_y, test_X, test_y)
             fold += 1
 
         if self.measure_performance:
diff --git a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
index 030f04aa43..b84d86d94d 100644
--- a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
@@ -20,6 +20,7 @@
 # -------------------------------------------------------------
 import pickle
 import time
+import copy
 from concurrent.futures import ProcessPoolExecutor, as_completed
 from dataclasses import dataclass, field, asdict
 
@@ -28,11 +29,17 @@ from typing import Union
 
 import numpy as np
 from systemds.scuro.representations.window_aggregation import WindowAggregation
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.hadamard import Hadamard
+from systemds.scuro.representations.sum import Sum
 
 from systemds.scuro.representations.aggregated_representation import (
     AggregatedRepresentation,
 )
-from systemds.scuro import ModalityType, Aggregation
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.aggregate import Aggregation
 from systemds.scuro.drsearch.operator_registry import Registry
 from systemds.scuro.utils.schema_helpers import get_shape
 
@@ -84,7 +91,6 @@ class UnimodalOptimizer:
     def optimize(self):
         for modality in self.modalities:
             local_result = self._process_modality(modality, False)
-            # self._merge_results(local_result)
 
     def _process_modality(self, modality, parallel):
         if parallel:
@@ -95,43 +101,59 @@ class UnimodalOptimizer:
             local_results = self.operator_performance
 
         context_operators = self.operator_registry.get_context_operators()
-
-        for context_operator in context_operators:
-            context_representation = None
-            if (
-                modality.modality_type != ModalityType.TEXT
-                and modality.modality_type != ModalityType.VIDEO
-            ):
-                con_op = context_operator()
-                context_representation = modality.context(con_op)
-                self._evaluate_local(context_representation, [con_op], 
local_results)
-
-            modality_specific_operators = 
self.operator_registry.get_representations(
+        not_self_contained_reps = (
+            self.operator_registry.get_not_self_contained_representations(
                 modality.modality_type
             )
-            for modality_specific_operator in modality_specific_operators:
-                mod_context = None
-                mod_op = modality_specific_operator()
-                if context_representation is not None:
-                    mod_context = 
context_representation.apply_representation(mod_op)
-                    self._evaluate_local(mod_context, [con_op, mod_op], 
local_results)
-
-                mod = modality.apply_representation(mod_op)
-                self._evaluate_local(mod, [mod_op], local_results)
-
-                for context_operator_after in context_operators:
-                    con_op_after = context_operator_after()
-                    if mod_context is not None:
-                        mod_context = mod_context.context(con_op_after)
-                        self._evaluate_local(
-                            mod_context, [con_op, mod_op, con_op_after], 
local_results
-                        )
-
-                    mod = mod.context(con_op_after)
-                    self._evaluate_local(mod, [mod_op, con_op_after], 
local_results)
+        )
+        modality_specific_operators = 
self.operator_registry.get_representations(
+            modality.modality_type
+        )
+        for modality_specific_operator in modality_specific_operators:
+            mod_op = modality_specific_operator()
+
+            mod = modality.apply_representation(mod_op)
+            self._evaluate_local(mod, [mod_op], local_results)
+
+            if not mod_op.self_contained:
+                self._combine_non_self_contained_representations(
+                    modality, mod, not_self_contained_reps, local_results
+                )
+
+            for context_operator_after in context_operators:
+                con_op_after = context_operator_after()
+                mod = mod.context(con_op_after)
+                self._evaluate_local(mod, [mod_op, con_op_after], 
local_results)
 
             return local_results
 
+    def _combine_non_self_contained_representations(
+        self,
+        modality: Modality,
+        representation: TransformedModality,
+        other_representations,
+        local_results,
+    ):
+        combined = representation
+        context_operators = self.operator_registry.get_context_operators()
+        used_representations = representation.transformation
+        for other_representation in other_representations:
+            used_representations.append(other_representation())
+            for combination in [Concatenation(), Hadamard(), Sum()]:
+                combined = combined.combine(
+                    modality.apply_representation(other_representation()), 
combination
+                )
+                self._evaluate_local(
+                    combined, used_representations, local_results, combination
+                )
+
+                for context_op in context_operators:
+                    con_op = context_op()
+                    mod = combined.context(con_op)
+                    c_t = copy.deepcopy(used_representations)
+                    c_t.append(con_op)
+                    self._evaluate_local(mod, c_t, local_results, combination)
+
     def _merge_results(self, local_results):
         """Merge local results into the main results"""
         for modality_id in local_results.results:
@@ -145,7 +167,9 @@ class UnimodalOptimizer:
                 for key, value in 
local_results.cache[modality][task_name].items():
                     self.operator_performance.cache[modality][task_name][key] 
= value
 
-    def _evaluate_local(self, modality, representations, local_results):
+    def _evaluate_local(
+        self, modality, representations, local_results, combination=None
+    ):
         if self._tasks_require_same_dims:
             if self.expected_dimensions == 1 and get_shape(modality.metadata) 
> 1:
                 # for aggregation in Aggregation().get_aggregation_functions():
@@ -165,6 +189,7 @@ class UnimodalOptimizer:
                         modality,
                         task.model.name,
                         end - start,
+                        combination,
                     )
             else:
                 modality.pad()
@@ -178,6 +203,7 @@ class UnimodalOptimizer:
                         modality,
                         task.model.name,
                         end - start,
+                        combination,
                     )
         else:
             for task in self.tasks:
@@ -198,6 +224,7 @@ class UnimodalOptimizer:
                         modality,
                         task.model.name,
                         end - start,
+                        combination,
                     )
                 else:
                     # modality.pad()
@@ -210,6 +237,7 @@ class UnimodalOptimizer:
                         modality,
                         task.model.name,
                         end - start,
+                        combination,
                     )
 
 
@@ -228,7 +256,9 @@ class UnimodalResults:
                 self.cache[modality][task_name] = {}
                 self.results[modality][task_name] = []
 
-    def add_result(self, scores, representations, modality, task_name, 
task_time):
+    def add_result(
+        self, scores, representations, modality, task_name, task_time, 
combination
+    ):
         parameters = []
         representation_names = []
 
@@ -256,6 +286,7 @@ class UnimodalResults:
             val_score=scores[1],
             representation_time=modality.transform_time,
             task_time=task_time,
+            combination=combination.name if combination else "",
         )
         self.results[modality.modality_id][task_name].append(entry)
         self.cache[modality.modality_id][task_name][
@@ -302,3 +333,4 @@ class ResultEntry:
     train_score: float
     representation_time: float
     task_time: float
+    combination: str
diff --git a/src/main/python/systemds/scuro/modality/modality.py 
b/src/main/python/systemds/scuro/modality/modality.py
index 94e745b2cc..f1b00fefcf 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -168,6 +168,6 @@ class Modality:
                 != 
list(other_modality.metadata.values())[i]["data_layout"]["shape"]
             ):
                 aligned = False
-                continue
+                break
 
         return aligned
diff --git a/src/main/python/systemds/scuro/modality/transformed.py 
b/src/main/python/systemds/scuro/modality/transformed.py
index 6523e9502f..9481937e2c 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -32,7 +32,9 @@ import copy
 
 class TransformedModality(Modality):
 
-    def __init__(self, modality, transformation, new_modality_type=None):
+    def __init__(
+        self, modality, transformation, new_modality_type=None, 
self_contained=True
+    ):
         """
         Parent class of the different Modalities (unimodal & multimodal)
         :param modality_type: Type of the original modality(ies)
@@ -46,8 +48,18 @@ class TransformedModality(Modality):
             new_modality_type, modality.modality_id, metadata, 
modality.data_type
         )
         self.transformation = None
+        self.self_contained = (
+            self_contained and transformation.self_contained
+            if isinstance(transformation, TransformedModality)
+            else True
+        )
         self.add_transformation(transformation, modality)
 
+        if modality.__class__.__name__ == "UnimodalModality":
+            for k, v in self.metadata.items():
+                if "attention_masks" in v:
+                    del self.metadata[k]["attention_masks"]
+
     def add_transformation(self, transformation, modality):
         if (
             transformation.__class__.__bases__[0].__name__ == "Fusion"
@@ -89,14 +101,18 @@ class TransformedModality(Modality):
 
     def window_aggregation(self, windowSize, aggregation):
         w = WindowAggregation(windowSize, aggregation)
-        transformed_modality = TransformedModality(self, w)
+        transformed_modality = TransformedModality(
+            self, w, self_contained=self.self_contained
+        )
         start = time.time()
         transformed_modality.data = w.execute(self)
         transformed_modality.transform_time = time.time() - start
         return transformed_modality
 
     def context(self, context_operator):
-        transformed_modality = TransformedModality(self, context_operator)
+        transformed_modality = TransformedModality(
+            self, context_operator, self_contained=self.self_contained
+        )
         start = time.time()
         transformed_modality.data = context_operator.execute(self)
         transformed_modality.transform_time = time.time() - start
@@ -107,6 +123,7 @@ class TransformedModality(Modality):
         new_modality = representation.transform(self)
         new_modality.update_metadata()
         new_modality.transform_time = time.time() - start
+        new_modality.self_contained = representation.self_contained
         return new_modality
 
     def combine(self, other: Union[Modality, List[Modality]], fusion_method):
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py 
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 94d1fa057d..dd1674ea85 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -110,6 +110,9 @@ class UnimodalModality(Modality):
             self,
             representation,
         )
+
+        pad_dim_one = False
+
         new_modality.data = []
         start = time.time()
         original_lengths = []
@@ -131,26 +134,39 @@ class UnimodalModality(Modality):
                 "attention_masks" in entry for entry in 
new_modality.metadata.values()
             ):
                 for d in new_modality.data:
-                    original_lengths.append(d.shape[0])
+                    if d.shape[0] == 1 and d.ndim == 2:
+                        pad_dim_one = True
+                        original_lengths.append(d.shape[1])
+                    else:
+                        original_lengths.append(d.shape[0])
+
+        new_modality.data = self.l2_normalize_features(new_modality.data)
 
         if len(original_lengths) > 0 and min(original_lengths) < 
max(original_lengths):
             target_length = max(original_lengths)
             padded_embeddings = []
             for embeddings in new_modality.data:
-                current_length = embeddings.shape[0]
+                current_length = (
+                    embeddings.shape[0] if not pad_dim_one else 
embeddings.shape[1]
+                )
                 if current_length < target_length:
                     padding_needed = target_length - current_length
-
-                    padded = np.pad(
-                        embeddings,
-                        pad_width=(
-                            (0, padding_needed),
-                            (0, 0),
-                        ),
-                        mode="constant",
-                        constant_values=0,
-                    )
-                    padded_embeddings.append(padded)
+                    if pad_dim_one:
+                        padding = np.zeros((embeddings.shape[0], 
padding_needed))
+                        padded_embeddings.append(
+                            np.concatenate((embeddings, padding), axis=1)
+                        )
+                    else:
+                        padded = np.pad(
+                            embeddings,
+                            pad_width=(
+                                (0, padding_needed),
+                                (0, 0),
+                            ),
+                            mode="constant",
+                            constant_values=0,
+                        )
+                        padded_embeddings.append(padded)
                 else:
                     padded_embeddings.append(embeddings)
 
@@ -164,4 +180,22 @@ class UnimodalModality(Modality):
             new_modality.data = padded_embeddings
         new_modality.update_metadata()
         new_modality.transform_time = time.time() - start
+        new_modality.self_contained = representation.self_contained
         return new_modality
+
+    def l2_normalize_features(self, feature_list):
+        normalized_features = []
+        for feature in feature_list:
+            original_shape = feature.shape
+            flattened = feature.flatten()
+
+            norm = np.linalg.norm(flattened)
+            if norm > 0:
+                normalized_flat = flattened / norm
+                normalized_feature = normalized_flat.reshape(original_shape)
+            else:
+                normalized_feature = feature
+
+            normalized_features.append(normalized_feature)
+
+        return normalized_features
diff --git 
a/src/main/python/systemds/scuro/representations/aggregated_representation.py 
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
index 9412c5be00..9119070a02 100644
--- 
a/src/main/python/systemds/scuro/representations/aggregated_representation.py
+++ 
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
@@ -26,8 +26,11 @@ class AggregatedRepresentation(Representation):
     def __init__(self, aggregation):
         super().__init__("AggregatedRepresentation", aggregation.parameters)
         self.aggregation = aggregation
+        self.self_contained = True
 
     def transform(self, modality):
-        aggregated_modality = TransformedModality(modality, self)
+        aggregated_modality = TransformedModality(
+            modality, self, self_contained=modality.self_contained
+        )
         aggregated_modality.data = self.aggregation.execute(modality)
         return aggregated_modality
diff --git 
a/src/main/python/systemds/scuro/representations/covarep_audio_features.py 
b/src/main/python/systemds/scuro/representations/covarep_audio_features.py
new file mode 100644
index 0000000000..3b4398cb11
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/covarep_audio_features.py
@@ -0,0 +1,156 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import librosa
+import numpy as np
+
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.modality.transformed import TransformedModality
+
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+
+@register_representation(ModalityType.AUDIO)
+class Spectral(UnimodalRepresentation):
+    def __init__(self, hop_length=512):
+        parameters = {
+            "hop_length": [256, 512, 1024, 2048],
+        }  # TODO
+        super().__init__("Spectral", ModalityType.EMBEDDING, parameters, False)
+        self.hop_length = hop_length
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(
+            modality, self, self.output_modality_type
+        )
+        result = []
+        for i, y in enumerate(modality.data):
+            sr = list(modality.metadata.values())[i]["frequency"]
+
+            spectral_centroid = librosa.feature.spectral_centroid(
+                y=y, sr=sr, hop_length=self.hop_length
+            )
+            spectral_bandwidth = librosa.feature.spectral_bandwidth(
+                y=y, sr=sr, hop_length=self.hop_length
+            )
+            spectral_rolloff = librosa.feature.spectral_rolloff(
+                y=y, sr=sr, hop_length=self.hop_length
+            )
+            spectral_flatness = librosa.feature.spectral_flatness(
+                y=y, hop_length=self.hop_length
+            )
+
+            features = np.vstack(
+                [
+                    spectral_centroid,
+                    spectral_bandwidth,
+                    spectral_rolloff,
+                    spectral_flatness,
+                ]
+            )
+
+            result.append(features.T)
+
+        transformed_modality.data = result
+
+        return transformed_modality
+
+
+@register_representation(ModalityType.AUDIO)
+class ZeroCrossing(UnimodalRepresentation):
+    def __init__(self, hop_length=512):
+        parameters = {
+            "hop_length": [256, 512, 1024, 2048],
+        }  # TODO
+        super().__init__("ZeroCrossing", ModalityType.EMBEDDING, parameters, 
False)
+        self.hop_length = hop_length
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(
+            modality, self, self.output_modality_type
+        )
+        result = []
+        for i, y in enumerate(modality.data):
+            zero_crossing_rate = librosa.feature.zero_crossing_rate(
+                y, hop_length=self.hop_length
+            )
+
+            result.append(zero_crossing_rate)
+
+        transformed_modality.data = result
+
+        return transformed_modality
+
+
+@register_representation(ModalityType.AUDIO)
+class RMSE(UnimodalRepresentation):
+    def __init__(self, frame_length=1024, hop_length=512):
+        parameters = {
+            "frame_length": [1024, 2048, 4096],
+            "hop_length": [256, 512, 1024, 2048],
+        }  # TODO
+        super().__init__("RMSE", ModalityType.EMBEDDING, parameters, False)
+        self.hop_length = hop_length
+        self.frame_length = frame_length
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(
+            modality, self, self.output_modality_type
+        )
+        result = []
+        for i, y in enumerate(modality.data):
+            rmse = librosa.feature.rms(
+                y=y, frame_length=self.frame_length, hop_length=self.hop_length
+            )
+            result.append(rmse)
+
+        transformed_modality.data = result
+
+        return transformed_modality
+
+
+@register_representation(ModalityType.AUDIO)
+class Pitch(UnimodalRepresentation):
+    def __init__(self, hop_length=512):
+        parameters = {
+            "hop_length": [256, 512, 1024, 2048],
+        }  # TODO
+        super().__init__("Pitch", ModalityType.EMBEDDING, parameters, False)
+        self.hop_length = hop_length
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(
+            modality, self, self.output_modality_type
+        )
+        result = []
+        for i, y in enumerate(modality.data):
+            sr = list(modality.metadata.values())[i]["frequency"]
+
+            pitches, magnitudes = librosa.piptrack(
+                y=y, sr=sr, hop_length=self.hop_length
+            )
+            pitch = pitches[magnitudes.argmax(axis=0), 
np.arange(magnitudes.shape[1])]
+
+            result.append(pitch[np.newaxis, :])
+
+        transformed_modality.data = result
+
+        return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/fusion.py 
b/src/main/python/systemds/scuro/representations/fusion.py
index 4b746eee21..61988abba2 100644
--- a/src/main/python/systemds/scuro/representations/fusion.py
+++ b/src/main/python/systemds/scuro/representations/fusion.py
@@ -38,6 +38,8 @@ class Fusion(Representation):
         self.associative = False
         self.commutative = False
         self.needs_alignment = False
+        self.needs_training = False
+        self.needs_instance_alignment = False
 
     def transform(self, modalities: List[Modality]):
         """
@@ -58,8 +60,27 @@ class Fusion(Representation):
             max_len = self.get_max_embedding_size(mods)
             for modality in mods:
                 modality.pad(max_len=max_len)
+
         return self.execute(mods)
 
+    def transform_with_training(
+        self, modalities: List[Modality], train_indices, labels
+    ):
+        # if self.needs_instance_alignment:
+        #     max_len = self.get_max_embedding_size(modalities)
+        #     for modality in modalities:
+        #         modality.pad(max_len=max_len)
+
+        self.execute(
+            [np.array(modality.data)[train_indices] for modality in 
modalities],
+            labels[train_indices],
+        )
+
+    def transform_data(self, modalities: List[Modality], val_indices):
+        return self.apply_representation(
+            [np.array(modality.data)[val_indices] for modality in modalities]
+        )
+
     def execute(self, modalities: List[Modality]):
         raise f"Not implemented for Fusion: {self.name}"
 
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py 
b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
index 8e897542b0..dca1b0eec8 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
@@ -36,7 +36,7 @@ class MelSpectrogram(UnimodalRepresentation):
             "hop_length": [256, 512, 1024, 2048],
             "n_fft": [1024, 2048, 4096],
         }
-        super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
+        super().__init__("MelSpectrogram", ModalityType.TIMESERIES, 
parameters, False)
         self.n_mels = n_mels
         self.hop_length = hop_length
         self.n_fft = n_fft
@@ -56,9 +56,8 @@ class MelSpectrogram(UnimodalRepresentation):
                 hop_length=self.hop_length,
                 n_fft=self.n_fft,
             ).astype(modality.data_type)
-            S_dB = librosa.power_to_db(S, ref=np.max)
 
-            result.append(S_dB.T)
+            result.append(S.T)
 
         transformed_modality.data = result
         return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/mfcc.py 
b/src/main/python/systemds/scuro/representations/mfcc.py
index 00f735a756..c942f3076e 100644
--- a/src/main/python/systemds/scuro/representations/mfcc.py
+++ b/src/main/python/systemds/scuro/representations/mfcc.py
@@ -37,7 +37,7 @@ class MFCC(UnimodalRepresentation):
             "hop_length": [256, 512, 1024, 2048],
             "n_mels": [20, 32, 64, 128],
         }  # TODO
-        super().__init__("MFCC", ModalityType.TIMESERIES, parameters)
+        super().__init__("MFCC", ModalityType.TIMESERIES, parameters, False)
         self.n_mfcc = n_mfcc
         self.dct_type = dct_type
         self.n_mels = n_mels
diff --git 
a/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py 
b/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
new file mode 100644
index 0000000000..7928b9988b
--- /dev/null
+++ 
b/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
@@ -0,0 +1,365 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List, Dict, Optional
+import numpy as np
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.utils.static_variables import get_device
+
+
+@register_fusion_operator()
+class AttentionFusion(Fusion):
+    def __init__(
+        self,
+        hidden_dim=256,
+        num_heads=8,
+        dropout=0.1,
+        fusion_strategy="attention",
+        batch_size=32,
+        num_epochs=50,
+    ):
+        self.encoder = None
+        params = {
+            "hidden_dim": [128, 256, 512],
+            "num_heads": [1, 4, 8],
+            "dropout": [0.1, 0.2, 0.3],
+            "fusion_strategy": ["mean", "max", "attention", "cls"],
+            "batch_size": [32, 64, 128],
+            "num_epochs": [50, 70, 100, 150],
+        }
+        super().__init__("AttentionFusion", params)
+        self.hidden_dim = hidden_dim
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.fusion_strategy = fusion_strategy
+        self.batch_size = batch_size
+        self.needs_training = True
+        self.needs_instance_alignment = True
+        self.num_epochs = num_epochs
+
+    def execute(
+        self,
+        data: List[np.ndarray],
+        labels: np.ndarray,
+    ):
+        input_dimension = {}
+        inputs = {}
+        max_sequence_length = 0
+        masks = {}
+        for i, modality in enumerate(data):
+            modality_name = "modality_" + str(i)
+            shape = modality.shape
+            max_sequence_length = max(max_sequence_length, shape[1])
+            input_dimension[modality_name] = shape[2] if len(shape) > 2 else 
shape[1]
+            inputs[modality_name] = torch.from_numpy(np.stack(modality)).to(
+                get_device()
+            )
+
+            # attention_masks_list = [
+            #     entry["attention_masks"]
+            #     for entry in modality.metadata.values()
+            #     if "attention_masks" in entry
+            # ]
+            attention_masks_list = None
+            if attention_masks_list:
+                masks[modality_name] = (
+                    
torch.tensor(np.array(attention_masks_list)).bool().to(get_device())
+                )
+            else:
+                masks[modality_name] = None
+
+        self.encoder = MultiModalAttentionFusion(
+            input_dimension,
+            self.hidden_dim,
+            self.num_heads,
+            self.dropout,
+            max_sequence_length,
+            self.fusion_strategy,
+        )
+
+        head = FusedClassificationHead(
+            fused_dim=self.hidden_dim, num_classes=len(np.unique(labels))
+        )
+        criterion = nn.CrossEntropyLoss()
+        optimizer = torch.optim.Adam(
+            list(self.encoder.parameters()) + list(head.parameters()), lr=0.001
+        )
+        labels = torch.from_numpy(labels).to(get_device())
+
+        for epoch in range(self.num_epochs):
+            total_loss = 0
+            total_accuracy = 0
+            for batch_idx in range(0, len(data), self.batch_size):
+                batched_input = {}
+                for modality, modality_data in inputs.items():
+                    batched_input[modality] = modality_data[
+                        batch_idx : batch_idx + self.batch_size
+                    ]
+                loss, predictions = self.train_encoder_step(
+                    head,
+                    inputs,
+                    labels[batch_idx : batch_idx + self.batch_size],
+                    criterion,
+                    optimizer,
+                )
+                total_loss += loss
+                total_accuracy += predictions
+
+            if epoch % 50 == 0 or epoch == self.num_epochs - 1:
+                print(
+                    f"Epoch {epoch}, Loss: {total_loss:.4f}, accuracy: 
{total_accuracy/len(data):.4f}"
+                )
+
+    # Training step (encoder + classification head)
+    def train_encoder_step(self, head, inputs, labels, criterion, optimizer):
+        self.encoder.train()
+        head.train()
+        optimizer.zero_grad()
+        output = self.encoder(inputs)
+        logits = head(output["fused"])
+        loss = criterion(logits, labels)
+        loss.backward()
+        optimizer.step()
+        _, predicted = torch.max(logits.data, 1)
+        return loss.item(), (predicted == labels).sum().item()
+
+    def apply_representation(self, modalities):
+        inputs = {}
+        for i, modality in enumerate(modalities):
+            modality_name = "modality_" + str(i)
+            inputs[modality_name] = torch.from_numpy(np.stack(modality)).to(
+                get_device()
+            )
+        self.encoder.eval()
+        with torch.no_grad():
+            output = self.encoder(inputs)
+        return output["fused"].cpu().numpy()
+
+
+class FusedClassificationHead(nn.Module):
+    """
+    Simple classification head for supervision during training.
+    """
+
+    def __init__(self, fused_dim, num_classes=2):
+        super(FusedClassificationHead, self).__init__()
+        self.head = nn.Sequential(
+            nn.Linear(fused_dim, fused_dim // 2),
+            nn.ReLU(),
+            nn.Linear(fused_dim // 2, num_classes),
+        ).to(get_device())
+
+    def forward(self, fused):
+        return self.head(fused)
+
+
+class MultiModalAttentionFusion(nn.Module):
+    def __init__(
+        self,
+        modality_dims: Dict[str, int],
+        hidden_dim: int,
+        num_heads: int,
+        dropout: float,
+        max_seq_len: int,
+        pooling_strategy: str,
+    ):
+        super().__init__()
+
+        self.modality_dims = modality_dims
+        self.hidden_dim = hidden_dim
+        self.num_heads = num_heads
+        self.pooling_strategy = pooling_strategy
+        self.max_seq_len = max_seq_len
+
+        # Project each modality to the same hidden dimension
+        self.modality_projections = nn.ModuleDict(
+            {
+                modality: nn.Linear(dim, hidden_dim).to(get_device())
+                for modality, dim in modality_dims.items()
+            }
+        )
+
+        # Positional encoding for sequence modalities
+        self.positional_encoding = nn.Parameter(
+            torch.randn(max_seq_len, hidden_dim) * 0.1
+        ).to(get_device())
+
+        # Cross-modal attention
+        self.cross_attention = nn.MultiheadAttention(
+            embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, 
batch_first=True
+        ).to(get_device())
+
+        # Self-attention within each modality
+        self.self_attention = nn.MultiheadAttention(
+            embed_dim=hidden_dim, num_heads=num_heads, dropout=dropout, 
batch_first=True
+        ).to(get_device())
+
+        # Attention-based pooling for sequences
+        if pooling_strategy == "attention":
+            self.pooling_attention = nn.Sequential(
+                nn.Linear(hidden_dim, hidden_dim // 2),
+                nn.Tanh(),
+                nn.Linear(hidden_dim // 2, 1),
+            ).to(get_device())
+
+        # Modality-level attention for final fusion
+        self.modality_attention = nn.Sequential(
+            nn.Linear(hidden_dim, hidden_dim // 2),
+            nn.ReLU(),
+            nn.Linear(hidden_dim // 2, 1),
+        ).to(get_device())
+
+        # Layer normalization
+        self.layer_norm = nn.LayerNorm(hidden_dim).to(get_device())
+        self.dropout = nn.Dropout(dropout).to(get_device())
+
+        # Final projection
+        self.final_projection = nn.Linear(hidden_dim, 
hidden_dim).to(get_device())
+
+    def _handle_input_format(self, modality_tensor: torch.Tensor) -> 
torch.Tensor:
+        if len(modality_tensor.shape) == 2:
+            modality_tensor = modality_tensor.unsqueeze(1)
+        elif len(modality_tensor.shape) == 3:
+            pass
+        else:
+            raise ValueError(
+                f"Input tensor must be 2D or 3D, got 
{len(modality_tensor.shape)}D"
+            )
+
+        if modality_tensor.dtype != torch.float:
+            modality_tensor = modality_tensor.float()
+
+        return modality_tensor
+
+    def _pool_sequence(
+        self, sequence: torch.Tensor, mask: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
+        if self.pooling_strategy == "mean":
+            if mask is not None:
+                # Masked mean pooling
+                masked_seq = sequence * mask.unsqueeze(-1)
+                return masked_seq.sum(dim=1) / mask.sum(dim=1, 
keepdim=True).clamp(
+                    min=1
+                )
+            else:
+                return sequence.mean(dim=1)
+
+        elif self.pooling_strategy == "max":
+            if mask is not None:
+                # Set masked positions to large negative value before max 
pooling
+                masked_seq = sequence.masked_fill(~mask.unsqueeze(-1), 
float("-inf"))
+                return masked_seq.max(dim=1)[0]
+            else:
+                return sequence.max(dim=1)[0]
+
+        elif self.pooling_strategy == "cls":
+            # Use the first token (assuming it's a CLS token)
+            return sequence[:, 0, :]
+
+        elif self.pooling_strategy == "attention":
+            # Attention-based pooling
+            attention_scores = self.pooling_attention(sequence).squeeze(
+                -1
+            )  # (batch, seq)
+
+            if mask is not None:
+                attention_scores = attention_scores.masked_fill(~mask, 
float("-inf"))
+
+            attention_weights = F.softmax(attention_scores, dim=1)  # (batch, 
seq)
+            return (sequence * attention_weights.unsqueeze(-1)).sum(
+                dim=1
+            )  # (batch, hidden)
+
+        else:
+            raise ValueError(f"Unknown pooling strategy: 
{self.pooling_strategy}")
+
+    def forward(
+        self,
+        modality_inputs: Dict[str, torch.Tensor],
+        modality_masks: Optional[Dict[str, torch.Tensor]] = None,
+    ) -> Dict[str, torch.Tensor]:
+        modality_embeddings = {}
+
+        for modality, input_tensor in modality_inputs.items():
+            normalized_input = self._handle_input_format(input_tensor)
+            seq_len = normalized_input.size(1)
+
+            projected = self.modality_projections[modality](normalized_input)
+
+            if seq_len > 1:
+                pos_encoding = self.positional_encoding[:seq_len].unsqueeze(0)
+                projected = projected + pos_encoding
+
+            if seq_len > 1:
+                mask = modality_masks.get(modality) if modality_masks else None
+
+                attended, _ = self.self_attention(
+                    query=projected,
+                    key=projected,
+                    value=projected,
+                    key_padding_mask=~mask if mask is not None else None,
+                )
+
+                projected = self.layer_norm(projected + self.dropout(attended))
+
+                pooled = self._pool_sequence(projected, mask)
+            else:
+                pooled = projected.squeeze(1)
+
+            modality_embeddings[modality] = pooled
+
+        if len(modality_embeddings) > 1:
+            modality_stack = torch.stack(list(modality_embeddings.values()), 
dim=1)
+
+            cross_attended, cross_attention_weights = self.cross_attention(
+                query=modality_stack, key=modality_stack, value=modality_stack
+            )
+
+            cross_attended = self.layer_norm(
+                modality_stack + self.dropout(cross_attended)
+            )
+
+            updated_embeddings = {
+                modality: cross_attended[:, i, :]
+                for i, modality in enumerate(modality_embeddings.keys())
+            }
+            modality_embeddings = updated_embeddings
+
+        modality_stack = torch.stack(list(modality_embeddings.values()), dim=1)
+        modality_scores = self.modality_attention(modality_stack).squeeze(-1)
+        modality_weights = F.softmax(modality_scores, dim=1)
+
+        fused_representation = (modality_stack * 
modality_weights.unsqueeze(-1)).sum(
+            dim=1
+        )
+
+        output = self.final_projection(fused_representation)
+
+        return {
+            "fused": output,
+            "modality_embeddings": modality_embeddings,
+            "attention_weights": modality_weights,
+        }
diff --git a/src/main/python/systemds/scuro/representations/representation.py 
b/src/main/python/systemds/scuro/representations/representation.py
index 6137baf46d..144b88f34c 100644
--- a/src/main/python/systemds/scuro/representations/representation.py
+++ b/src/main/python/systemds/scuro/representations/representation.py
@@ -25,6 +25,7 @@ class Representation:
     def __init__(self, name, parameters):
         self.name = name
         self._parameters = parameters
+        self.self_contained = True
 
     @property
     def parameters(self):
diff --git a/src/main/python/systemds/scuro/representations/resnet.py 
b/src/main/python/systemds/scuro/representations/resnet.py
index bdfbfb17fc..f961cb4588 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/resnet.py
@@ -29,13 +29,7 @@ import torch
 import torchvision.models as models
 import numpy as np
 from systemds.scuro.modality.type import ModalityType
-
-if torch.backends.mps.is_available():
-    DEVICE = torch.device("mps")
-elif torch.cuda.is_available():
-    DEVICE = torch.device("cuda")
-else:
-    DEVICE = torch.device("cpu")
+from systemds.scuro.utils.static_variables import get_device
 
 
 @register_representation(
@@ -72,33 +66,33 @@ class ResNet(UnimodalRepresentation):
         if model_name == "ResNet18":
             self.model = (
                 models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
-                .to(DEVICE)
+                .to(get_device())
                 .to(self.data_type)
             )
 
         elif model_name == "ResNet34":
             self.model = 
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
-                DEVICE
+                get_device()
             )
             self.model = self.model.to(self.data_type)
         elif model_name == "ResNet50":
             self.model = (
                 models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
-                .to(DEVICE)
+                .to(get_device())
                 .to(self.data_type)
             )
 
         elif model_name == "ResNet101":
             self.model = (
                 models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
-                .to(DEVICE)
+                .to(get_device())
                 .to(self.data_type)
             )
 
         elif model_name == "ResNet152":
             self.model = (
                 models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
-                .to(DEVICE)
+                .to(get_device())
                 .to(self.data_type)
             )
 
@@ -129,7 +123,7 @@ class ResNet(UnimodalRepresentation):
         if next(self.model.parameters()).dtype != self.data_type:
             self.model = self.model.to(self.data_type)
 
-        dataset = CustomDataset(modality.data, self.data_type, DEVICE)
+        dataset = CustomDataset(modality.data, self.data_type, get_device())
         embeddings = {}
 
         res5c_output = None
diff --git a/src/main/python/systemds/scuro/representations/spectrogram.py 
b/src/main/python/systemds/scuro/representations/spectrogram.py
index 8daa9abb01..51b69d7d87 100644
--- a/src/main/python/systemds/scuro/representations/spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/spectrogram.py
@@ -32,7 +32,7 @@ from systemds.scuro.drsearch.operator_registry import 
register_representation
 class Spectrogram(UnimodalRepresentation):
     def __init__(self, hop_length=512, n_fft=2048):
         parameters = {"hop_length": [256, 512, 1024, 2048], "n_fft": [1024, 
2048, 4096]}
-        super().__init__("Spectrogram", ModalityType.TIMESERIES, parameters)
+        super().__init__("Spectrogram", ModalityType.TIMESERIES, parameters, 
False)
         self.hop_length = hop_length
         self.n_fft = n_fft
 
@@ -48,7 +48,7 @@ class Spectrogram(UnimodalRepresentation):
             ).astype(modality.data_type)
             S_dB = librosa.amplitude_to_db(np.abs(spectrogram))
 
-            result.append(S_dB.T.reshape(-1))
+            result.append(S_dB.T)
 
         transformed_modality.data = result
         return transformed_modality
diff --git 
a/src/main/python/systemds/scuro/representations/swin_video_transformer.py 
b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
index 19b2fd05c4..c0b7ab38ab 100644
--- a/src/main/python/systemds/scuro/representations/swin_video_transformer.py
+++ b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
@@ -18,7 +18,7 @@
 # under the License.
 #
 # -------------------------------------------------------------
-# from torchvision.models.video.swin_transformer import swin3d_t
+from torchvision.models.video.swin_transformer import swin3d_t
 
 from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
@@ -31,13 +31,7 @@ from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.drsearch.operator_registry import register_representation
 
 from systemds.scuro.utils.torch_dataset import CustomDataset
-
-if torch.backends.mps.is_available():
-    DEVICE = torch.device("mps")
-# elif torch.cuda.is_available():
-#     DEVICE = torch.device("cuda")
-else:
-    DEVICE = torch.device("cpu")
+from systemds.scuro.utils.static_variables import get_device
 
 
 # @register_representation([ModalityType.VIDEO])
@@ -55,16 +49,17 @@ class SwinVideoTransformer(UnimodalRepresentation):
                 "avgpool",
             ],
         }
+        self.data_type = torch.float
         super().__init__("SwinVideoTransformer", ModalityType.TIMESERIES, 
parameters)
         self.layer_name = layer_name
-        # self.model = 
swin3d_t(weights=models.video.Swin3D_T_Weights).to(DEVICE)
+        self.model = 
swin3d_t(weights=models.video.Swin3D_T_Weights.KINETICS400_V1).to(
+            get_device()
+        )
         self.model.eval()
         for param in self.model.parameters():
             param.requires_grad = False
 
     def transform(self, modality):
-        # model = swin3d_t(weights=models.video.Swin3D_T_Weights)
-
         embeddings = {}
         swin_output = None
 
@@ -82,11 +77,11 @@ class SwinVideoTransformer(UnimodalRepresentation):
                 if name == self.layer_name:
                     layer.register_forward_hook(get_features(name))
                     break
-        dataset = CustomDataset(modality.data)
+        dataset = CustomDataset(modality.data, self.data_type, get_device())
 
-        for instance in dataset:
-            video_id = instance["id"]
-            frames = instance["data"].to(DEVICE)
+        for instance in torch.utils.data.DataLoader(dataset):
+            video_id = instance["id"][0]
+            frames = instance["data"][0]
             embeddings[video_id] = []
 
             frames = frames.unsqueeze(0).permute(0, 2, 1, 3, 4)
@@ -95,15 +90,18 @@ class SwinVideoTransformer(UnimodalRepresentation):
             values = swin_output
             pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 1))
 
-            embeddings[video_id].extend(torch.flatten(pooled, 
1).detach().cpu().numpy())
+            embeddings[video_id].extend(
+                torch.flatten(pooled, 1)
+                .detach()
+                .cpu()
+                .numpy()
+                .astype(modality.data_type)
+            )
 
             embeddings[video_id] = np.array(embeddings[video_id])
 
         transformed_modality = TransformedModality(
-            self.output_modality_type,
-            "swinVideoTransformer",
-            modality.modality_id,
-            modality.metadata,
+            modality, self, self.output_modality_type
         )
 
         transformed_modality.data = list(embeddings.values())
diff --git a/src/main/python/systemds/scuro/representations/unimodal.py 
b/src/main/python/systemds/scuro/representations/unimodal.py
index 559eec1401..362888aa27 100644
--- a/src/main/python/systemds/scuro/representations/unimodal.py
+++ b/src/main/python/systemds/scuro/representations/unimodal.py
@@ -24,7 +24,9 @@ from systemds.scuro.representations.representation import 
Representation
 
 
 class UnimodalRepresentation(Representation):
-    def __init__(self, name: str, output_modality_type, parameters=None):
+    def __init__(
+        self, name: str, output_modality_type, parameters=None, 
self_contained=True
+    ):
         """
         Parent class for all unimodal representation types
         :param name: name of the representation
@@ -35,6 +37,7 @@ class UnimodalRepresentation(Representation):
         self.output_modality_type = output_modality_type
         if parameters is None:
             parameters = {}
+        self.self_contained = self_contained
 
     @abc.abstractmethod
     def transform(self, data):
diff --git a/src/main/python/systemds/scuro/representations/word2vec.py 
b/src/main/python/systemds/scuro/representations/word2vec.py
index 88d60ac828..06e082fb69 100644
--- a/src/main/python/systemds/scuro/representations/word2vec.py
+++ b/src/main/python/systemds/scuro/representations/word2vec.py
@@ -41,7 +41,7 @@ def get_embedding(sentence, model):
 
 @register_representation(ModalityType.TEXT)
 class W2V(UnimodalRepresentation):
-    def __init__(self, vector_size=3, min_count=2, window=2, output_file=None):
+    def __init__(self, vector_size=150, min_count=2, window=5, 
output_file=None):
         parameters = {
             "vector_size": [vector_size],
             "min_count": [min_count],
diff --git a/src/main/python/systemds/scuro/utils/static_variables.py 
b/src/main/python/systemds/scuro/utils/static_variables.py
index 8237cdf1b3..807287cd95 100644
--- a/src/main/python/systemds/scuro/utils/static_variables.py
+++ b/src/main/python/systemds/scuro/utils/static_variables.py
@@ -32,5 +32,6 @@ def get_device():
     return torch.device(
         "cuda:0"
         if torch.cuda.is_available()
-        else "mps" if torch.mps.is_available() else "cpu"
+        # else "mps" if torch.mps.is_available()
+        else "cpu"
     )
diff --git a/src/main/python/tests/scuro/data_generator.py 
b/src/main/python/tests/scuro/data_generator.py
index e57716fa99..4dcfa5a89c 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -95,10 +95,16 @@ class ModalityRandomDataGenerator:
 
     def create_audio_data(self, num_instances, max_audio_length):
         data = [
-            [random.random() for _ in range(random.randint(1, 
max_audio_length))]
+            [
+                random.random()
+                for _ in range(random.randint(max_audio_length * 0.9, 
max_audio_length))
+            ]
             for _ in range(num_instances)
         ]
 
+        for i in range(num_instances):
+            data[i] = np.array(data[i]).astype(self.data_type)
+
         metadata = {
             i: ModalityType.AUDIO.create_audio_metadata(16000, 
np.array(data[i]))
             for i in range(num_instances)
diff --git a/src/main/python/tests/scuro/test_operator_registry.py 
b/src/main/python/tests/scuro/test_operator_registry.py
index 7f2a752722..a6941fe618 100644
--- a/src/main/python/tests/scuro/test_operator_registry.py
+++ b/src/main/python/tests/scuro/test_operator_registry.py
@@ -21,7 +21,14 @@
 
 import unittest
 
+from systemds.scuro.representations.covarep_audio_features import (
+    ZeroCrossing,
+    Spectral,
+    Pitch,
+    RMSE,
+)
 from systemds.scuro.representations.mfcc import MFCC
+from systemds.scuro.representations.swin_video_transformer import 
SwinVideoTransformer
 from systemds.scuro.representations.wav2vec import Wav2Vec
 from systemds.scuro.representations.window_aggregation import WindowAggregation
 from systemds.scuro.representations.bow import BoW
@@ -39,19 +46,29 @@ from systemds.scuro.representations.spectrogram import 
Spectrogram
 from systemds.scuro.representations.hadamard import Hadamard
 from systemds.scuro.representations.resnet import ResNet
 from systemds.scuro.representations.sum import Sum
+from systemds.scuro.representations.multimodal_attention_fusion import 
AttentionFusion
 
 
 class TestOperatorRegistry(unittest.TestCase):
     def test_audio_representations_in_registry(self):
         registry = Registry()
-        for representation in [Spectrogram, MelSpectrogram, Wav2Vec, MFCC]:
-            assert representation in registry.get_representations(
-                ModalityType.AUDIO
-            ), f"{representation} not in registry"
+        assert registry.get_representations(ModalityType.AUDIO) == [
+            MelSpectrogram,
+            MFCC,
+            Spectrogram,
+            Wav2Vec,
+            Spectral,
+            ZeroCrossing,
+            RMSE,
+            Pitch,
+        ]
 
     def test_video_representations_in_registry(self):
         registry = Registry()
-        assert registry.get_representations(ModalityType.VIDEO) == [ResNet]
+        assert registry.get_representations(ModalityType.VIDEO) == [
+            ResNet,
+            # SwinVideoTransformer,
+        ]
 
     def test_timeseries_representations_in_registry(self):
         registry = Registry()
@@ -70,17 +87,15 @@ class TestOperatorRegistry(unittest.TestCase):
 
     # def test_fusion_operator_in_registry(self):
     #     registry = Registry()
-    #     for fusion_operator in [
-    #         # RowMax,
-    #         Sum,
+    #     assert registry.get_fusion_operators() == [
     #         Average,
     #         Concatenation,
     #         LSTM,
-    #         Multiplication,
-    #     ]:
-    #         assert (
-    #             fusion_operator in registry.get_fusion_operators()
-    #         ), f"{fusion_operator} not in registry"
+    #         RowMax,
+    #         Hadamard,
+    #         Sum,
+    #         AttentionFusion,
+    #     ]
 
 
 if __name__ == "__main__":
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py 
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
index a73d7b5fcc..b5d2b266f6 100644
--- a/src/main/python/tests/scuro/test_unimodal_optimizer.py
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -30,11 +30,14 @@ from sklearn.model_selection import train_test_split
 from systemds.scuro.drsearch.operator_registry import Registry
 from systemds.scuro.models.model import Model
 from systemds.scuro.drsearch.task import Task
-from systemds.scuro.drsearch.unimodal_optimizer import (
-    UnimodalOptimizer,
-)
+from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
 
 from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.covarep_audio_features import (
+    ZeroCrossing,
+    Spectral,
+    Pitch,
+)
 from systemds.scuro.representations.word2vec import W2V
 from systemds.scuro.modality.unimodal_modality import UnimodalModality
 from systemds.scuro.representations.resnet import ResNet
@@ -176,7 +179,7 @@ class 
TestUnimodalRepresentationOptimizer(unittest.TestCase):
             "_representations",
             {
                 ModalityType.TEXT: [W2V],
-                ModalityType.AUDIO: [Spectrogram],
+                ModalityType.AUDIO: [Spectrogram, ZeroCrossing, Spectral, 
Pitch],
                 ModalityType.TIMESERIES: [ResNet],
                 ModalityType.VIDEO: [ResNet],
                 ModalityType.EMBEDDING: [],
diff --git a/src/main/python/tests/scuro/test_unimodal_representations.py 
b/src/main/python/tests/scuro/test_unimodal_representations.py
index 2f2e64efd7..52bca501ac 100644
--- a/src/main/python/tests/scuro/test_unimodal_representations.py
+++ b/src/main/python/tests/scuro/test_unimodal_representations.py
@@ -22,8 +22,18 @@
 import os
 import shutil
 import unittest
+import copy
+import numpy as np
 
 from systemds.scuro.representations.bow import BoW
+from systemds.scuro.representations.covarep_audio_features import (
+    Spectral,
+    RMSE,
+    Pitch,
+    ZeroCrossing,
+)
+from systemds.scuro.representations.wav2vec import Wav2Vec
+from systemds.scuro.representations.spectrogram import Spectrogram
 from systemds.scuro.representations.word2vec import W2V
 from systemds.scuro.representations.tfidf import TfIdf
 from systemds.scuro.modality.unimodal_modality import UnimodalModality
@@ -31,8 +41,13 @@ from systemds.scuro.representations.bert import Bert
 from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
 from systemds.scuro.representations.mfcc import MFCC
 from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.swin_video_transformer import 
SwinVideoTransformer
 from tests.scuro.data_generator import setup_data
-
+from tests.scuro.data_generator import (
+    setup_data,
+    TestDataLoader,
+    ModalityRandomDataGenerator,
+)
 from systemds.scuro.dataloader.audio_loader import AudioLoader
 from systemds.scuro.dataloader.video_loader import VideoLoader
 from systemds.scuro.dataloader.text_loader import TextLoader
@@ -50,52 +65,70 @@ class TestUnimodalRepresentations(unittest.TestCase):
 
     @classmethod
     def setUpClass(cls):
-        cls.test_file_path = "unimodal_test_data"
-
         cls.num_instances = 4
-        cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
-
-        cls.data_generator = setup_data(cls.mods, cls.num_instances, 
cls.test_file_path)
-        os.makedirs(f"{cls.test_file_path}/embeddings")
-
-    @classmethod
-    def tearDownClass(cls):
-        print("Cleaning up test data")
-        shutil.rmtree(cls.test_file_path)
+        cls.indices = np.array(range(cls.num_instances))
 
     def test_audio_representations(self):
-        audio_representations = [MFCC()]  # TODO: add FFT, TFN, 1DCNN
-        audio_data_loader = AudioLoader(
-            self.data_generator.get_modality_path(ModalityType.AUDIO),
-            self.data_generator.indices,
+        audio_representations = [
+            MFCC(),
+            MelSpectrogram(),
+            Spectrogram(),
+            Wav2Vec(),
+            Spectral(),
+            ZeroCrossing(),
+            RMSE(),
+            Pitch(),
+        ]  # TODO: add FFT, TFN, 1DCNN
+        audio_data, audio_md = ModalityRandomDataGenerator().create_audio_data(
+            self.num_instances, 1000
+        )
+
+        audio = UnimodalModality(
+            TestDataLoader(
+                self.indices, None, ModalityType.AUDIO, audio_data, 
np.float32, audio_md
+            )
         )
-        audio = UnimodalModality(audio_data_loader)
+
+        audio.extract_raw_data()
+        original_data = copy.deepcopy(audio.data)
 
         for representation in audio_representations:
             r = audio.apply_representation(representation)
             assert r.data is not None
             assert len(r.data) == self.num_instances
+            for i in range(self.num_instances):
+                assert (audio.data[i] == original_data[i]).all()
+            assert r.data[0].ndim == 2
 
     def test_video_representations(self):
-        video_representations = [ResNet()]  # Todo: add other video 
representations
-        video_data_loader = VideoLoader(
-            self.data_generator.get_modality_path(ModalityType.VIDEO),
-            self.data_generator.indices,
+        video_representations = [
+            ResNet(),
+            SwinVideoTransformer(),
+        ]  # Todo: add other video representations
+        video_data, video_md = 
ModalityRandomDataGenerator().create_visual_modality(
+            self.num_instances, 60
+        )
+        video = UnimodalModality(
+            TestDataLoader(
+                self.indices, None, ModalityType.VIDEO, video_data, 
np.float32, video_md
+            )
         )
-        video = UnimodalModality(video_data_loader)
         for representation in video_representations:
             r = video.apply_representation(representation)
             assert r.data is not None
             assert len(r.data) == self.num_instances
+            assert r.data[0].ndim == 2
 
     def test_text_representations(self):
         test_representations = [BoW(2, 2), W2V(5, 2, 2), TfIdf(2), Bert()]
-        text_data_loader = TextLoader(
-            self.data_generator.get_modality_path(ModalityType.TEXT),
-            self.data_generator.indices,
+        text_data, text_md = ModalityRandomDataGenerator().create_text_data(
+            self.num_instances
+        )
+        text = UnimodalModality(
+            TestDataLoader(
+                self.indices, None, ModalityType.TEXT, text_data, str, text_md
+            )
         )
-        text = UnimodalModality(text_data_loader)
-
         for representation in test_representations:
             r = text.apply_representation(representation)
             assert r.data is not None
@@ -103,12 +136,14 @@ class TestUnimodalRepresentations(unittest.TestCase):
 
     def test_chunked_video_representations(self):
         video_representations = [ResNet()]
-        video_data_loader = VideoLoader(
-            self.data_generator.get_modality_path(ModalityType.VIDEO),
-            self.data_generator.indices,
-            chunk_size=2,
+        video_data, video_md = 
ModalityRandomDataGenerator().create_visual_modality(
+            self.num_instances, 60
+        )
+        video = UnimodalModality(
+            TestDataLoader(
+                self.indices, None, ModalityType.VIDEO, video_data, 
np.float32, video_md
+            )
         )
-        video = UnimodalModality(video_data_loader)
         for representation in video_representations:
             r = video.apply_representation(representation)
             assert r.data is not None

Reply via email to