This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e00aa1a81 [SYSTEMDS-3701] Add test suite for Scuro
4e00aa1a81 is described below

commit 4e00aa1a8124cdf5c9510986de577d4d73034caf
Author: Christina Dionysio <[email protected]>
AuthorDate: Thu Nov 21 10:37:13 2024 +0100

    [SYSTEMDS-3701] Add test suite for Scuro
    
    Closes #2143.
---
 src/main/python/systemds/scuro/__init__.py         |   7 +-
 src/main/python/systemds/scuro/aligner/task.py     |  41 ++++-
 .../python/systemds/scuro/representations/bert.py  |   6 +-
 .../systemds/scuro/representations/fusion.py       |   2 -
 .../{utils.py => representation_dataloader.py}     |  16 +-
 .../python/systemds/scuro/representations/utils.py |  73 --------
 src/main/python/tests/scuro/__init__.py            |  20 +++
 src/main/python/tests/scuro/data_generator.py      | 127 ++++++++++++++
 src/main/python/tests/scuro/test_data_loaders.py   | 117 +++++++++++++
 src/main/python/tests/scuro/test_dr_search.py      | 189 +++++++++++++++++++++
 10 files changed, 499 insertions(+), 99 deletions(-)

diff --git a/src/main/python/systemds/scuro/__init__.py 
b/src/main/python/systemds/scuro/__init__.py
index 04139b5283..84494a158e 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -30,7 +30,12 @@ from systemds.scuro.representations.resnet import ResNet
 from systemds.scuro.representations.bert import Bert
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
 from systemds.scuro.representations.lstm import LSTM
-from systemds.scuro.representations.utils import NPY, Pickle, HDF5, JSON
+from systemds.scuro.representations.representation_dataloader import (
+    NPY,
+    Pickle,
+    HDF5,
+    JSON,
+)
 from systemds.scuro.models.model import Model
 from systemds.scuro.models.discrete_model import DiscreteModel
 from systemds.scuro.modality.aligned_modality import AlignedModality
diff --git a/src/main/python/systemds/scuro/aligner/task.py 
b/src/main/python/systemds/scuro/aligner/task.py
index fcf0952840..f33546ae65 100644
--- a/src/main/python/systemds/scuro/aligner/task.py
+++ b/src/main/python/systemds/scuro/aligner/task.py
@@ -21,11 +21,19 @@
 from typing import List
 
 from systemds.scuro.models.model import Model
+import numpy as np
+from sklearn.model_selection import KFold
 
 
 class Task:
     def __init__(
-        self, name: str, model: Model, labels, train_indices: List, 
val_indices: List
+        self,
+        name: str,
+        model: Model,
+        labels,
+        train_indices: List,
+        val_indices: List,
+        kfold=5,
     ):
         """
         Parent class for the prediction task that is performed on top of the 
aligned representation
@@ -34,12 +42,15 @@ class Task:
         :param labels: Labels used for prediction
         :param train_indices: Indices to extract training data
         :param val_indices: Indices to extract validation data
+        :param kfold: Number of crossvalidation runs
+
         """
         self.name = name
         self.model = model
         self.labels = labels
         self.train_indices = train_indices
         self.val_indices = val_indices
+        self.kfold = kfold
 
     def get_train_test_split(self, data):
         X_train = [data[i] for i in self.train_indices]
@@ -51,9 +62,27 @@ class Task:
 
     def run(self, data):
         """
-        The run method need to be implemented by every task class
-        It handles the training and validation procedures for the specific task
-        :param data: The aligned data used in the prediction process
-        :return: the validation accuracy
+        The run method needs to be implemented by every task class
+         It handles the training and validation procedures for the specific 
task
+         :param data: The aligned data used in the prediction process
+         :return: the validation accuracy
         """
-        pass
+        skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
+        train_scores = []
+        test_scores = []
+        fold = 0
+        X, y, X_test, y_test = self.get_train_test_split(data)
+
+        for train, test in skf.split(X, y):
+            train_X = np.array(X)[train]
+            train_y = np.array(y)[train]
+
+            train_score = self.model.fit(train_X, train_y, X_test, y_test)
+            train_scores.append(train_score)
+
+            test_score = self.model.test(X_test, y_test)
+            test_scores.append(test_score)
+
+            fold += 1
+
+        return [np.mean(train_scores), np.mean(test_scores)]
diff --git a/src/main/python/systemds/scuro/representations/bert.py 
b/src/main/python/systemds/scuro/representations/bert.py
index 30bdc24a53..d68729a97e 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -56,7 +56,9 @@ class Bert(UnimodalRepresentation):
                 data = file.readlines()
 
         model_name = "bert-base-uncased"
-        tokenizer = BertTokenizer.from_pretrained(model_name)
+        tokenizer = BertTokenizer.from_pretrained(
+            model_name, clean_up_tokenization_spaces=True
+        )
 
         if self.avg_layers is not None:
             model = BertModel.from_pretrained(model_name, 
output_hidden_states=True)
@@ -89,7 +91,7 @@ class Bert(UnimodalRepresentation):
                 cls_embedding = torch.mean(torch.stack(cls_embedding), dim=0)
             else:
                 cls_embedding = outputs.last_hidden_state[:, 0, 
:].squeeze().numpy()
-            embeddings.append(cls_embedding)
+            embeddings.append(cls_embedding.numpy())
 
         embeddings = np.array(embeddings)
         return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
diff --git a/src/main/python/systemds/scuro/representations/fusion.py 
b/src/main/python/systemds/scuro/representations/fusion.py
index 0d5cd34726..e84e59f666 100644
--- a/src/main/python/systemds/scuro/representations/fusion.py
+++ b/src/main/python/systemds/scuro/representations/fusion.py
@@ -20,8 +20,6 @@
 # -------------------------------------------------------------
 from typing import List
 
-from sklearn.preprocessing import StandardScaler
-
 from systemds.scuro.modality.modality import Modality
 from systemds.scuro.representations.representation import Representation
 
diff --git a/src/main/python/systemds/scuro/representations/utils.py 
b/src/main/python/systemds/scuro/representations/representation_dataloader.py
similarity index 89%
copy from src/main/python/systemds/scuro/representations/utils.py
copy to 
src/main/python/systemds/scuro/representations/representation_dataloader.py
index bccc3ac4b2..9d44844e44 100644
--- a/src/main/python/systemds/scuro/representations/utils.py
+++ 
b/src/main/python/systemds/scuro/representations/representation_dataloader.py
@@ -22,9 +22,8 @@
 
 import json
 import pickle
-
-import h5py
 import numpy as np
+import h5py
 
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
 
@@ -93,16 +92,3 @@ class JSON(UnimodalRepresentation):
     def parse_all(self, filepath, indices):
         with open(filepath) as file:
             return json.load(file)
-
-
-def pad_sequences(sequences, maxlen=None, dtype="float32", value=0):
-    if maxlen is None:
-        maxlen = max([len(seq) for seq in sequences])
-
-    result = np.full((len(sequences), maxlen), value, dtype=dtype)
-
-    for i, seq in enumerate(sequences):
-        data = seq[:maxlen]
-        result[i, : len(data)] = data
-
-    return result
diff --git a/src/main/python/systemds/scuro/representations/utils.py 
b/src/main/python/systemds/scuro/representations/utils.py
index bccc3ac4b2..e23dd89dd7 100644
--- a/src/main/python/systemds/scuro/representations/utils.py
+++ b/src/main/python/systemds/scuro/representations/utils.py
@@ -19,81 +19,8 @@
 #
 # -------------------------------------------------------------
 
-
-import json
-import pickle
-
-import h5py
 import numpy as np
 
-from systemds.scuro.representations.unimodal import UnimodalRepresentation
-
-
-class NPY(UnimodalRepresentation):
-    def __init__(self):
-        super().__init__("NPY")
-
-    def parse_all(self, filepath, indices, get_sequences=False):
-        data = np.load(filepath, allow_pickle=True)
-
-        if indices is not None:
-            return np.array([data[index] for index in indices])
-        else:
-            return np.array([data[index] for index in data])
-
-
-class Pickle(UnimodalRepresentation):
-    def __init__(self):
-        super().__init__("Pickle")
-
-    def parse_all(self, file_path, indices, get_sequences=False):
-        with open(file_path, "rb") as f:
-            data = pickle.load(f)
-
-        embeddings = []
-        for n, idx in enumerate(indices):
-            embeddings.append(data[idx])
-
-        return np.array(embeddings)
-
-
-class HDF5(UnimodalRepresentation):
-    def __init__(self):
-        super().__init__("HDF5")
-
-    def parse_all(self, filepath, indices=None, get_sequences=False):
-        data = h5py.File(filepath)
-
-        if get_sequences:
-            max_emb = 0
-            for index in indices:
-                if max_emb < len(data[index][()]):
-                    max_emb = len(data[index][()])
-
-            emb = []
-            if indices is not None:
-                for index in indices:
-                    emb_i = data[index].tolist()
-                    for i in range(len(emb_i), max_emb):
-                        emb_i.append([0 for x in range(0, len(emb_i[0]))])
-                    emb.append(emb_i)
-
-                return np.array(emb)
-        else:
-            if indices is not None:
-                return np.array([np.mean(data[index], axis=0) for index in 
indices])
-            else:
-                return np.array([np.mean(data[index][()], axis=0) for index in 
data])
-
-
-class JSON(UnimodalRepresentation):
-    def __init__(self):
-        super().__init__("JSON")
-
-    def parse_all(self, filepath, indices):
-        with open(filepath) as file:
-            return json.load(file)
-
 
 def pad_sequences(sequences, maxlen=None, dtype="float32", value=0):
     if maxlen is None:
diff --git a/src/main/python/tests/scuro/__init__.py 
b/src/main/python/tests/scuro/__init__.py
new file mode 100644
index 0000000000..e66abb4646
--- /dev/null
+++ b/src/main/python/tests/scuro/__init__.py
@@ -0,0 +1,20 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
diff --git a/src/main/python/tests/scuro/data_generator.py 
b/src/main/python/tests/scuro/data_generator.py
new file mode 100644
index 0000000000..9ded5316d9
--- /dev/null
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -0,0 +1,127 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import cv2
+import numpy as np
+from scipy.io.wavfile import write
+import random
+import os
+
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.modality.text_modality import TextModality
+
+
+class TestDataGenerator:
+    def __init__(self, modalities, path, balanced=True):
+        self.modalities = modalities
+        self.path = path
+        self.balanced = balanced
+
+        for modality in modalities:
+            mod_path = f"{self.path}/{modality.name.lower()}"
+            os.mkdir(mod_path)
+            modality.file_path = mod_path
+        self.labels = []
+        self.label_path = f"{path}/labels.npy"
+
+    def create_multimodal_data(self, num_instances, duration=2, seed=42):
+        speed_fast = 0
+        speed_slow = 0
+        for idx in range(num_instances):
+            np.random.seed(seed)
+            if self.balanced:
+                inst_half = int(num_instances / 2)
+                if speed_slow < inst_half and speed_fast < inst_half:
+                    speed_factor = random.uniform(0.5, 1.5)
+                elif speed_fast >= inst_half:
+                    speed_factor = random.uniform(0.5, 0.99)
+                else:
+                    speed_factor = random.uniform(1, 1.5)
+
+            else:
+                if speed_fast >= int(num_instances * 0.9):
+                    speed_factor = random.uniform(0.5, 0.99)
+                elif speed_slow >= int(num_instances * 0.9):
+                    speed_factor = random.uniform(0.5, 1.5)
+                else:
+                    speed_factor = random.uniform(1, 1.5)
+
+            self.labels.append(1 if speed_factor >= 1 else 0)
+
+            if speed_factor >= 1:
+                speed_fast += 1
+            else:
+                speed_slow += 1
+
+            for modality in self.modalities:
+                if isinstance(modality, VideoModality):
+                    self.__create_video_data(idx, duration, 30, speed_factor)
+                if isinstance(modality, AudioModality):
+                    self.__create_audio_data(idx, duration, speed_factor)
+                if isinstance(modality, TextModality):
+                    self.__create_text_data(idx, speed_factor)
+
+        np.save(f"{self.path}/labels.npy", np.array(self.labels))
+
+    def __create_video_data(self, idx, duration, fps, speed_factor):
+        path = f"{self.path}/video/{idx}.mp4"
+
+        width, height = 160, 120
+        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+        out = cv2.VideoWriter(path, fourcc, fps, (width, height))
+
+        num_frames = duration * fps
+        ball_radius = 20
+        center_x = width // 2
+
+        amplitude = random.uniform(0.5, 1.5) * (height // 3)
+
+        for i in range(num_frames):
+            frame = np.ones((height, width, 3), dtype=np.uint8) * 255
+            center_y = int(
+                height // 2
+                + amplitude * np.sin(speed_factor * 2 * np.pi * i / num_frames)
+            )
+            frame = cv2.circle(
+                frame, (center_x, center_y), ball_radius, (0, 255, 0), -1
+            )
+            out.write(frame)
+
+        out.release()
+
+    def __create_text_data(self, idx, speed_factor):
+        path = f"{self.path}/text/{idx}.txt"
+
+        with open(path, "w") as f:
+            f.write(f"The ball moves at speed factor {speed_factor:.2f}.")
+
+    def __create_audio_data(self, idx, duration, speed_factor):
+        path = f"{self.path}/audio/{idx}.wav"
+        sample_rate = 44100
+
+        t = np.linspace(0, duration, int(sample_rate * duration), 
endpoint=False)
+        frequency_variation = random.uniform(200.0, 500.0)
+        frequency = 440.0 + frequency_variation * np.sin(
+            speed_factor * 2 * np.pi * np.linspace(0, 1, len(t))
+        )
+        audio_data = 0.5 * np.sin(2 * np.pi * frequency * t)
+
+        write(path, sample_rate, audio_data)
diff --git a/src/main/python/tests/scuro/test_data_loaders.py 
b/src/main/python/tests/scuro/test_data_loaders.py
new file mode 100644
index 0000000000..cbbeafab8a
--- /dev/null
+++ b/src/main/python/tests/scuro/test_data_loaders.py
@@ -0,0 +1,117 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import os
+import shutil
+import unittest
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.modality.text_modality import TextModality
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.representation_dataloader import HDF5, 
NPY, Pickle
+from tests.scuro.data_generator import TestDataGenerator
+
+
+class TestDataLoaders(unittest.TestCase):
+    test_file_path = None
+    mods = None
+    text = None
+    audio = None
+    video = None
+    data_generator = None
+    num_instances = 0
+    indizes = []
+
+    @classmethod
+    def setUpClass(cls):
+        cls.test_file_path = "test_data"
+
+        if os.path.isdir(cls.test_file_path):
+            shutil.rmtree(cls.test_file_path)
+
+        os.makedirs(f"{cls.test_file_path}/embeddings")
+
+        cls.num_instances = 2
+        cls.indizes = [str(i) for i in range(0, cls.num_instances)]
+        cls.video = VideoModality(
+            "", 
ResNet(f"{cls.test_file_path}/embeddings/resnet_embeddings.hdf5")
+        )
+        cls.audio = AudioModality(
+            "",
+            MelSpectrogram(
+                
output_file=f"{cls.test_file_path}/embeddings/mel_sp_embeddings.npy"
+            ),
+        )
+        cls.text = TextModality(
+            "",
+            Bert(
+                avg_layers=4,
+                
output_file=f"{cls.test_file_path}/embeddings/bert_embeddings.pkl",
+            ),
+        )
+        cls.mods = [cls.video, cls.audio, cls.text]
+        cls.data_generator = TestDataGenerator(cls.mods, cls.test_file_path)
+        cls.data_generator.create_multimodal_data(cls.num_instances)
+        cls.text.read_all(cls.indizes)
+        cls.audio.read_all(cls.indizes)
+        cls.video.read_all([i for i in range(0, cls.num_instances)])
+
+    @classmethod
+    def tearDownClass(cls):
+        print("Cleaning up test data")
+        shutil.rmtree(cls.test_file_path)
+
+    def test_load_audio_data_from_file(self):
+        load_audio = AudioModality(
+            f"{self.test_file_path}/embeddings/mel_sp_embeddings.npy", NPY()
+        )
+        load_audio.read_all(self.indizes)
+
+        for i in range(0, self.num_instances):
+            assert round(sum(self.audio.data[i]), 4) == round(
+                sum(load_audio.data[i]), 4
+            )
+
+    def test_load_video_data_from_file(self):
+        load_video = VideoModality(
+            f"{self.test_file_path}/embeddings/resnet_embeddings.hdf5", HDF5()
+        )
+        load_video.read_all(self.indizes)
+
+        for i in range(0, self.num_instances):
+            assert round(sum(self.video.data[i]), 4) == round(
+                sum(load_video.data[i]), 4
+            )
+
+    def test_load_text_data_from_file(self):
+        load_text = TextModality(
+            f"{self.test_file_path}/embeddings/bert_embeddings.pkl", Pickle()
+        )
+        load_text.read_all(self.indizes)
+
+        for i in range(0, self.num_instances):
+            assert round(sum(self.text.data[i]), 4) == 
round(sum(load_text.data[i]), 4)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/src/main/python/tests/scuro/test_dr_search.py 
b/src/main/python/tests/scuro/test_dr_search.py
new file mode 100644
index 0000000000..eac4a77641
--- /dev/null
+++ b/src/main/python/tests/scuro/test_dr_search.py
@@ -0,0 +1,189 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import os
+import shutil
+import unittest
+
+import numpy as np
+from sklearn import svm
+from sklearn.metrics import classification_report
+from sklearn.model_selection import train_test_split, KFold
+from sklearn.preprocessing import MinMaxScaler
+
+from systemds.scuro.aligner.dr_search import DRSearch
+from systemds.scuro.aligner.task import Task
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.modality.text_modality import TextModality
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.models.model import Model
+from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.lstm import LSTM
+from systemds.scuro.representations.max import RowMax
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.multiplication import Multiplication
+from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.sum import Sum
+from tests.scuro.data_generator import TestDataGenerator
+
+import warnings
+
+warnings.filterwarnings("always")
+
+
+class TestSVM(Model):
+    def __init__(self):
+        super().__init__("Test")
+
+    def fit(self, X, y, X_test, y_test):
+        self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+        self.clf = self.clf.fit(X, np.array(y))
+        y_pred = self.clf.predict(X)
+
+        return classification_report(
+            y, y_pred, output_dict=True, digits=3, zero_division=1
+        )["accuracy"]
+
+    def test(self, test_X: np.ndarray, test_y: np.ndarray):
+        y_pred = self.clf.predict(np.array(test_X))  # noqa
+
+        return classification_report(
+            np.array(test_y), y_pred, output_dict=True, digits=3, 
zero_division=1
+        )["accuracy"]
+
+
+def scale_data(data, train_indizes):
+    scaler = MinMaxScaler(feature_range=(0, 1))
+    scaler.fit(data[train_indizes])
+    return scaler.transform(data)
+
+
+class TestDataLoaders(unittest.TestCase):
+    train_indizes = None
+    val_indizes = None
+    test_file_path = None
+    mods = None
+    text = None
+    audio = None
+    video = None
+    data_generator = None
+    num_instances = 0
+    indizes = []
+    representations = None
+
+    @classmethod
+    def setUpClass(cls):
+        cls.test_file_path = "test_data_dr_search"
+
+        if os.path.isdir(cls.test_file_path):
+            shutil.rmtree(cls.test_file_path)
+
+        os.makedirs(f"{cls.test_file_path}/embeddings")
+
+        cls.num_instances = 8
+        cls.indizes = [str(i) for i in range(0, cls.num_instances)]
+        cls.video = VideoModality(
+            "", 
ResNet(f"{cls.test_file_path}/embeddings/resnet_embeddings.hdf5")
+        )
+        cls.audio = AudioModality(
+            "",
+            MelSpectrogram(
+                
output_file=f"{cls.test_file_path}/embeddings/mel_sp_embeddings.npy"
+            ),
+        )
+        cls.text = TextModality(
+            "",
+            Bert(
+                avg_layers=4,
+                
output_file=f"{cls.test_file_path}/embeddings/bert_embeddings.pkl",
+            ),
+        )
+        cls.mods = [cls.video, cls.audio, cls.text]
+        cls.data_generator = TestDataGenerator(cls.mods, cls.test_file_path)
+        cls.data_generator.create_multimodal_data(cls.num_instances)
+        cls.text.read_all(cls.indizes)
+        cls.audio.read_all(cls.indizes)
+        cls.video.read_all([i for i in range(0, cls.num_instances)])
+
+        split = train_test_split(
+            cls.indizes, cls.data_generator.labels, test_size=0.2, 
random_state=42
+        )
+        cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
+            int(i) for i in split[1]
+        ]
+
+        for m in cls.mods:
+            m.data = scale_data(m.data, [int(i) for i in cls.train_indizes])
+
+        cls.representations = [
+            Concatenation(),
+            Average(),
+            RowMax(),
+            Multiplication(),
+            Sum(),
+            LSTM(width=256, depth=3),
+        ]
+
+    @classmethod
+    def tearDownClass(cls):
+        print("Cleaning up test data")
+        shutil.rmtree(cls.test_file_path)
+
+    def test_enumerate_all(self):
+        task = Task(
+            "TestTask",
+            TestSVM(),
+            self.data_generator.labels,
+            self.train_indizes,
+            self.val_indizes,
+        )
+        dr_search = DRSearch(self.mods, task, self.representations)
+        best_representation, best_score, best_modalities = 
dr_search.fit_enumerate_all()
+
+        for r in dr_search.scores.values():
+            for scores in r.values():
+                assert scores[1] <= best_score
+
+    def test_enumerate_all_vs_random(self):
+        task = Task(
+            "TestTask",
+            TestSVM(),
+            self.data_generator.labels,
+            self.train_indizes,
+            self.val_indizes,
+        )
+        dr_search = DRSearch(self.mods, task, self.representations)
+        best_representation_enum, best_score_enum, best_modalities_enum = (
+            dr_search.fit_enumerate_all()
+        )
+
+        dr_search.reset_best_params()
+
+        best_representation_rand, best_score_rand, best_modalities_rand = (
+            dr_search.fit_random(seed=42)
+        )
+
+        assert best_score_rand <= best_score_enum
+
+
+if __name__ == "__main__":
+    unittest.main()

Reply via email to