This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4e00aa1a81 [SYSTEMDS-3701] Add test suite for Scuro
4e00aa1a81 is described below
commit 4e00aa1a8124cdf5c9510986de577d4d73034caf
Author: Christina Dionysio <[email protected]>
AuthorDate: Thu Nov 21 10:37:13 2024 +0100
[SYSTEMDS-3701] Add test suite for Scuro
Closes #2143.
---
src/main/python/systemds/scuro/__init__.py | 7 +-
src/main/python/systemds/scuro/aligner/task.py | 41 ++++-
.../python/systemds/scuro/representations/bert.py | 6 +-
.../systemds/scuro/representations/fusion.py | 2 -
.../{utils.py => representation_dataloader.py} | 16 +-
.../python/systemds/scuro/representations/utils.py | 73 --------
src/main/python/tests/scuro/__init__.py | 20 +++
src/main/python/tests/scuro/data_generator.py | 127 ++++++++++++++
src/main/python/tests/scuro/test_data_loaders.py | 117 +++++++++++++
src/main/python/tests/scuro/test_dr_search.py | 189 +++++++++++++++++++++
10 files changed, 499 insertions(+), 99 deletions(-)
diff --git a/src/main/python/systemds/scuro/__init__.py
b/src/main/python/systemds/scuro/__init__.py
index 04139b5283..84494a158e 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -30,7 +30,12 @@ from systemds.scuro.representations.resnet import ResNet
from systemds.scuro.representations.bert import Bert
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from systemds.scuro.representations.lstm import LSTM
-from systemds.scuro.representations.utils import NPY, Pickle, HDF5, JSON
+from systemds.scuro.representations.representation_dataloader import (
+ NPY,
+ Pickle,
+ HDF5,
+ JSON,
+)
from systemds.scuro.models.model import Model
from systemds.scuro.models.discrete_model import DiscreteModel
from systemds.scuro.modality.aligned_modality import AlignedModality
diff --git a/src/main/python/systemds/scuro/aligner/task.py
b/src/main/python/systemds/scuro/aligner/task.py
index fcf0952840..f33546ae65 100644
--- a/src/main/python/systemds/scuro/aligner/task.py
+++ b/src/main/python/systemds/scuro/aligner/task.py
@@ -21,11 +21,19 @@
from typing import List
from systemds.scuro.models.model import Model
+import numpy as np
+from sklearn.model_selection import KFold
class Task:
def __init__(
- self, name: str, model: Model, labels, train_indices: List,
val_indices: List
+ self,
+ name: str,
+ model: Model,
+ labels,
+ train_indices: List,
+ val_indices: List,
+ kfold=5,
):
"""
Parent class for the prediction task that is performed on top of the
aligned representation
@@ -34,12 +42,15 @@ class Task:
:param labels: Labels used for prediction
:param train_indices: Indices to extract training data
:param val_indices: Indices to extract validation data
+ :param kfold: Number of crossvalidation runs
+
"""
self.name = name
self.model = model
self.labels = labels
self.train_indices = train_indices
self.val_indices = val_indices
+ self.kfold = kfold
def get_train_test_split(self, data):
X_train = [data[i] for i in self.train_indices]
@@ -51,9 +62,27 @@ class Task:
def run(self, data):
"""
- The run method need to be implemented by every task class
- It handles the training and validation procedures for the specific task
- :param data: The aligned data used in the prediction process
- :return: the validation accuracy
+ The run method needs to be implemented by every task class
+ It handles the training and validation procedures for the specific
task
+ :param data: The aligned data used in the prediction process
+ :return: the validation accuracy
"""
- pass
+ skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
+ train_scores = []
+ test_scores = []
+ fold = 0
+ X, y, X_test, y_test = self.get_train_test_split(data)
+
+ for train, test in skf.split(X, y):
+ train_X = np.array(X)[train]
+ train_y = np.array(y)[train]
+
+ train_score = self.model.fit(train_X, train_y, X_test, y_test)
+ train_scores.append(train_score)
+
+ test_score = self.model.test(X_test, y_test)
+ test_scores.append(test_score)
+
+ fold += 1
+
+ return [np.mean(train_scores), np.mean(test_scores)]
diff --git a/src/main/python/systemds/scuro/representations/bert.py
b/src/main/python/systemds/scuro/representations/bert.py
index 30bdc24a53..d68729a97e 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -56,7 +56,9 @@ class Bert(UnimodalRepresentation):
data = file.readlines()
model_name = "bert-base-uncased"
- tokenizer = BertTokenizer.from_pretrained(model_name)
+ tokenizer = BertTokenizer.from_pretrained(
+ model_name, clean_up_tokenization_spaces=True
+ )
if self.avg_layers is not None:
model = BertModel.from_pretrained(model_name,
output_hidden_states=True)
@@ -89,7 +91,7 @@ class Bert(UnimodalRepresentation):
cls_embedding = torch.mean(torch.stack(cls_embedding), dim=0)
else:
cls_embedding = outputs.last_hidden_state[:, 0,
:].squeeze().numpy()
- embeddings.append(cls_embedding)
+ embeddings.append(cls_embedding.numpy())
embeddings = np.array(embeddings)
return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
diff --git a/src/main/python/systemds/scuro/representations/fusion.py
b/src/main/python/systemds/scuro/representations/fusion.py
index 0d5cd34726..e84e59f666 100644
--- a/src/main/python/systemds/scuro/representations/fusion.py
+++ b/src/main/python/systemds/scuro/representations/fusion.py
@@ -20,8 +20,6 @@
# -------------------------------------------------------------
from typing import List
-from sklearn.preprocessing import StandardScaler
-
from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.representation import Representation
diff --git a/src/main/python/systemds/scuro/representations/utils.py
b/src/main/python/systemds/scuro/representations/representation_dataloader.py
similarity index 89%
copy from src/main/python/systemds/scuro/representations/utils.py
copy to
src/main/python/systemds/scuro/representations/representation_dataloader.py
index bccc3ac4b2..9d44844e44 100644
--- a/src/main/python/systemds/scuro/representations/utils.py
+++
b/src/main/python/systemds/scuro/representations/representation_dataloader.py
@@ -22,9 +22,8 @@
import json
import pickle
-
-import h5py
import numpy as np
+import h5py
from systemds.scuro.representations.unimodal import UnimodalRepresentation
@@ -93,16 +92,3 @@ class JSON(UnimodalRepresentation):
def parse_all(self, filepath, indices):
with open(filepath) as file:
return json.load(file)
-
-
-def pad_sequences(sequences, maxlen=None, dtype="float32", value=0):
- if maxlen is None:
- maxlen = max([len(seq) for seq in sequences])
-
- result = np.full((len(sequences), maxlen), value, dtype=dtype)
-
- for i, seq in enumerate(sequences):
- data = seq[:maxlen]
- result[i, : len(data)] = data
-
- return result
diff --git a/src/main/python/systemds/scuro/representations/utils.py
b/src/main/python/systemds/scuro/representations/utils.py
index bccc3ac4b2..e23dd89dd7 100644
--- a/src/main/python/systemds/scuro/representations/utils.py
+++ b/src/main/python/systemds/scuro/representations/utils.py
@@ -19,81 +19,8 @@
#
# -------------------------------------------------------------
-
-import json
-import pickle
-
-import h5py
import numpy as np
-from systemds.scuro.representations.unimodal import UnimodalRepresentation
-
-
-class NPY(UnimodalRepresentation):
- def __init__(self):
- super().__init__("NPY")
-
- def parse_all(self, filepath, indices, get_sequences=False):
- data = np.load(filepath, allow_pickle=True)
-
- if indices is not None:
- return np.array([data[index] for index in indices])
- else:
- return np.array([data[index] for index in data])
-
-
-class Pickle(UnimodalRepresentation):
- def __init__(self):
- super().__init__("Pickle")
-
- def parse_all(self, file_path, indices, get_sequences=False):
- with open(file_path, "rb") as f:
- data = pickle.load(f)
-
- embeddings = []
- for n, idx in enumerate(indices):
- embeddings.append(data[idx])
-
- return np.array(embeddings)
-
-
-class HDF5(UnimodalRepresentation):
- def __init__(self):
- super().__init__("HDF5")
-
- def parse_all(self, filepath, indices=None, get_sequences=False):
- data = h5py.File(filepath)
-
- if get_sequences:
- max_emb = 0
- for index in indices:
- if max_emb < len(data[index][()]):
- max_emb = len(data[index][()])
-
- emb = []
- if indices is not None:
- for index in indices:
- emb_i = data[index].tolist()
- for i in range(len(emb_i), max_emb):
- emb_i.append([0 for x in range(0, len(emb_i[0]))])
- emb.append(emb_i)
-
- return np.array(emb)
- else:
- if indices is not None:
- return np.array([np.mean(data[index], axis=0) for index in
indices])
- else:
- return np.array([np.mean(data[index][()], axis=0) for index in
data])
-
-
-class JSON(UnimodalRepresentation):
- def __init__(self):
- super().__init__("JSON")
-
- def parse_all(self, filepath, indices):
- with open(filepath) as file:
- return json.load(file)
-
def pad_sequences(sequences, maxlen=None, dtype="float32", value=0):
if maxlen is None:
diff --git a/src/main/python/tests/scuro/__init__.py
b/src/main/python/tests/scuro/__init__.py
new file mode 100644
index 0000000000..e66abb4646
--- /dev/null
+++ b/src/main/python/tests/scuro/__init__.py
@@ -0,0 +1,20 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
diff --git a/src/main/python/tests/scuro/data_generator.py
b/src/main/python/tests/scuro/data_generator.py
new file mode 100644
index 0000000000..9ded5316d9
--- /dev/null
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -0,0 +1,127 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import cv2
+import numpy as np
+from scipy.io.wavfile import write
+import random
+import os
+
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.modality.text_modality import TextModality
+
+
+class TestDataGenerator:
+ def __init__(self, modalities, path, balanced=True):
+ self.modalities = modalities
+ self.path = path
+ self.balanced = balanced
+
+ for modality in modalities:
+ mod_path = f"{self.path}/{modality.name.lower()}"
+ os.mkdir(mod_path)
+ modality.file_path = mod_path
+ self.labels = []
+ self.label_path = f"{path}/labels.npy"
+
+ def create_multimodal_data(self, num_instances, duration=2, seed=42):
+ speed_fast = 0
+ speed_slow = 0
+ for idx in range(num_instances):
+ np.random.seed(seed)
+ if self.balanced:
+ inst_half = int(num_instances / 2)
+ if speed_slow < inst_half and speed_fast < inst_half:
+ speed_factor = random.uniform(0.5, 1.5)
+ elif speed_fast >= inst_half:
+ speed_factor = random.uniform(0.5, 0.99)
+ else:
+ speed_factor = random.uniform(1, 1.5)
+
+ else:
+ if speed_fast >= int(num_instances * 0.9):
+ speed_factor = random.uniform(0.5, 0.99)
+ elif speed_slow >= int(num_instances * 0.9):
+ speed_factor = random.uniform(0.5, 1.5)
+ else:
+ speed_factor = random.uniform(1, 1.5)
+
+ self.labels.append(1 if speed_factor >= 1 else 0)
+
+ if speed_factor >= 1:
+ speed_fast += 1
+ else:
+ speed_slow += 1
+
+ for modality in self.modalities:
+ if isinstance(modality, VideoModality):
+ self.__create_video_data(idx, duration, 30, speed_factor)
+ if isinstance(modality, AudioModality):
+ self.__create_audio_data(idx, duration, speed_factor)
+ if isinstance(modality, TextModality):
+ self.__create_text_data(idx, speed_factor)
+
+ np.save(f"{self.path}/labels.npy", np.array(self.labels))
+
+ def __create_video_data(self, idx, duration, fps, speed_factor):
+ path = f"{self.path}/video/{idx}.mp4"
+
+ width, height = 160, 120
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
+ out = cv2.VideoWriter(path, fourcc, fps, (width, height))
+
+ num_frames = duration * fps
+ ball_radius = 20
+ center_x = width // 2
+
+ amplitude = random.uniform(0.5, 1.5) * (height // 3)
+
+ for i in range(num_frames):
+ frame = np.ones((height, width, 3), dtype=np.uint8) * 255
+ center_y = int(
+ height // 2
+ + amplitude * np.sin(speed_factor * 2 * np.pi * i / num_frames)
+ )
+ frame = cv2.circle(
+ frame, (center_x, center_y), ball_radius, (0, 255, 0), -1
+ )
+ out.write(frame)
+
+ out.release()
+
+ def __create_text_data(self, idx, speed_factor):
+ path = f"{self.path}/text/{idx}.txt"
+
+ with open(path, "w") as f:
+ f.write(f"The ball moves at speed factor {speed_factor:.2f}.")
+
+ def __create_audio_data(self, idx, duration, speed_factor):
+ path = f"{self.path}/audio/{idx}.wav"
+ sample_rate = 44100
+
+ t = np.linspace(0, duration, int(sample_rate * duration),
endpoint=False)
+ frequency_variation = random.uniform(200.0, 500.0)
+ frequency = 440.0 + frequency_variation * np.sin(
+ speed_factor * 2 * np.pi * np.linspace(0, 1, len(t))
+ )
+ audio_data = 0.5 * np.sin(2 * np.pi * frequency * t)
+
+ write(path, sample_rate, audio_data)
diff --git a/src/main/python/tests/scuro/test_data_loaders.py
b/src/main/python/tests/scuro/test_data_loaders.py
new file mode 100644
index 0000000000..cbbeafab8a
--- /dev/null
+++ b/src/main/python/tests/scuro/test_data_loaders.py
@@ -0,0 +1,117 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import os
+import shutil
+import unittest
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.modality.text_modality import TextModality
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.representation_dataloader import HDF5,
NPY, Pickle
+from tests.scuro.data_generator import TestDataGenerator
+
+
+class TestDataLoaders(unittest.TestCase):
+ test_file_path = None
+ mods = None
+ text = None
+ audio = None
+ video = None
+ data_generator = None
+ num_instances = 0
+ indizes = []
+
+ @classmethod
+ def setUpClass(cls):
+ cls.test_file_path = "test_data"
+
+ if os.path.isdir(cls.test_file_path):
+ shutil.rmtree(cls.test_file_path)
+
+ os.makedirs(f"{cls.test_file_path}/embeddings")
+
+ cls.num_instances = 2
+ cls.indizes = [str(i) for i in range(0, cls.num_instances)]
+ cls.video = VideoModality(
+ "",
ResNet(f"{cls.test_file_path}/embeddings/resnet_embeddings.hdf5")
+ )
+ cls.audio = AudioModality(
+ "",
+ MelSpectrogram(
+
output_file=f"{cls.test_file_path}/embeddings/mel_sp_embeddings.npy"
+ ),
+ )
+ cls.text = TextModality(
+ "",
+ Bert(
+ avg_layers=4,
+
output_file=f"{cls.test_file_path}/embeddings/bert_embeddings.pkl",
+ ),
+ )
+ cls.mods = [cls.video, cls.audio, cls.text]
+ cls.data_generator = TestDataGenerator(cls.mods, cls.test_file_path)
+ cls.data_generator.create_multimodal_data(cls.num_instances)
+ cls.text.read_all(cls.indizes)
+ cls.audio.read_all(cls.indizes)
+ cls.video.read_all([i for i in range(0, cls.num_instances)])
+
+ @classmethod
+ def tearDownClass(cls):
+ print("Cleaning up test data")
+ shutil.rmtree(cls.test_file_path)
+
+ def test_load_audio_data_from_file(self):
+ load_audio = AudioModality(
+ f"{self.test_file_path}/embeddings/mel_sp_embeddings.npy", NPY()
+ )
+ load_audio.read_all(self.indizes)
+
+ for i in range(0, self.num_instances):
+ assert round(sum(self.audio.data[i]), 4) == round(
+ sum(load_audio.data[i]), 4
+ )
+
+ def test_load_video_data_from_file(self):
+ load_video = VideoModality(
+ f"{self.test_file_path}/embeddings/resnet_embeddings.hdf5", HDF5()
+ )
+ load_video.read_all(self.indizes)
+
+ for i in range(0, self.num_instances):
+ assert round(sum(self.video.data[i]), 4) == round(
+ sum(load_video.data[i]), 4
+ )
+
+ def test_load_text_data_from_file(self):
+ load_text = TextModality(
+ f"{self.test_file_path}/embeddings/bert_embeddings.pkl", Pickle()
+ )
+ load_text.read_all(self.indizes)
+
+ for i in range(0, self.num_instances):
+ assert round(sum(self.text.data[i]), 4) ==
round(sum(load_text.data[i]), 4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/main/python/tests/scuro/test_dr_search.py
b/src/main/python/tests/scuro/test_dr_search.py
new file mode 100644
index 0000000000..eac4a77641
--- /dev/null
+++ b/src/main/python/tests/scuro/test_dr_search.py
@@ -0,0 +1,189 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import os
+import shutil
+import unittest
+
+import numpy as np
+from sklearn import svm
+from sklearn.metrics import classification_report
+from sklearn.model_selection import train_test_split, KFold
+from sklearn.preprocessing import MinMaxScaler
+
+from systemds.scuro.aligner.dr_search import DRSearch
+from systemds.scuro.aligner.task import Task
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.modality.text_modality import TextModality
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.models.model import Model
+from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.lstm import LSTM
+from systemds.scuro.representations.max import RowMax
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.multiplication import Multiplication
+from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.sum import Sum
+from tests.scuro.data_generator import TestDataGenerator
+
+import warnings
+
+warnings.filterwarnings("always")
+
+
+class TestSVM(Model):
+ def __init__(self):
+ super().__init__("Test")
+
+ def fit(self, X, y, X_test, y_test):
+ self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+ self.clf = self.clf.fit(X, np.array(y))
+ y_pred = self.clf.predict(X)
+
+ return classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+
+ def test(self, test_X: np.ndarray, test_y: np.ndarray):
+ y_pred = self.clf.predict(np.array(test_X)) # noqa
+
+ return classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+
+
+def scale_data(data, train_indizes):
+ scaler = MinMaxScaler(feature_range=(0, 1))
+ scaler.fit(data[train_indizes])
+ return scaler.transform(data)
+
+
+class TestDataLoaders(unittest.TestCase):
+ train_indizes = None
+ val_indizes = None
+ test_file_path = None
+ mods = None
+ text = None
+ audio = None
+ video = None
+ data_generator = None
+ num_instances = 0
+ indizes = []
+ representations = None
+
+ @classmethod
+ def setUpClass(cls):
+ cls.test_file_path = "test_data_dr_search"
+
+ if os.path.isdir(cls.test_file_path):
+ shutil.rmtree(cls.test_file_path)
+
+ os.makedirs(f"{cls.test_file_path}/embeddings")
+
+ cls.num_instances = 8
+ cls.indizes = [str(i) for i in range(0, cls.num_instances)]
+ cls.video = VideoModality(
+ "",
ResNet(f"{cls.test_file_path}/embeddings/resnet_embeddings.hdf5")
+ )
+ cls.audio = AudioModality(
+ "",
+ MelSpectrogram(
+
output_file=f"{cls.test_file_path}/embeddings/mel_sp_embeddings.npy"
+ ),
+ )
+ cls.text = TextModality(
+ "",
+ Bert(
+ avg_layers=4,
+
output_file=f"{cls.test_file_path}/embeddings/bert_embeddings.pkl",
+ ),
+ )
+ cls.mods = [cls.video, cls.audio, cls.text]
+ cls.data_generator = TestDataGenerator(cls.mods, cls.test_file_path)
+ cls.data_generator.create_multimodal_data(cls.num_instances)
+ cls.text.read_all(cls.indizes)
+ cls.audio.read_all(cls.indizes)
+ cls.video.read_all([i for i in range(0, cls.num_instances)])
+
+ split = train_test_split(
+ cls.indizes, cls.data_generator.labels, test_size=0.2,
random_state=42
+ )
+ cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
+ int(i) for i in split[1]
+ ]
+
+ for m in cls.mods:
+ m.data = scale_data(m.data, [int(i) for i in cls.train_indizes])
+
+ cls.representations = [
+ Concatenation(),
+ Average(),
+ RowMax(),
+ Multiplication(),
+ Sum(),
+ LSTM(width=256, depth=3),
+ ]
+
+ @classmethod
+ def tearDownClass(cls):
+ print("Cleaning up test data")
+ shutil.rmtree(cls.test_file_path)
+
+ def test_enumerate_all(self):
+ task = Task(
+ "TestTask",
+ TestSVM(),
+ self.data_generator.labels,
+ self.train_indizes,
+ self.val_indizes,
+ )
+ dr_search = DRSearch(self.mods, task, self.representations)
+ best_representation, best_score, best_modalities =
dr_search.fit_enumerate_all()
+
+ for r in dr_search.scores.values():
+ for scores in r.values():
+ assert scores[1] <= best_score
+
+ def test_enumerate_all_vs_random(self):
+ task = Task(
+ "TestTask",
+ TestSVM(),
+ self.data_generator.labels,
+ self.train_indizes,
+ self.val_indizes,
+ )
+ dr_search = DRSearch(self.mods, task, self.representations)
+ best_representation_enum, best_score_enum, best_modalities_enum = (
+ dr_search.fit_enumerate_all()
+ )
+
+ dr_search.reset_best_params()
+
+ best_representation_rand, best_score_rand, best_modalities_rand = (
+ dr_search.fit_random(seed=42)
+ )
+
+ assert best_score_rand <= best_score_enum
+
+
+if __name__ == "__main__":
+ unittest.main()