This is an automated email from the ASF dual-hosted git repository.
cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 23b3d08eb1 [SYSTEMDS-3887] Create representation optimizer
23b3d08eb1 is described below
commit 23b3d08eb18305612ce038edd0fbd50a011aba9b
Author: Christina Dionysio <[email protected]>
AuthorDate: Wed May 28 10:52:52 2025 +0200
[SYSTEMDS-3887] Create representation optimizer
This patch adds an initial version of the representation optimizer for the
Scuro library. It is a two stage optimization where in the first step the best
unimodal representation for given raw modalities is found and in the next step
the k-best unimodal rerpesentations are combined into multimodal
representations and evaluated against the target downstream task. Additionally,
this patch adds tests for each stage of the optimizer.
Closes #2267
---
src/main/python/systemds/scuro/__init__.py | 83 ++++--
.../python/systemds/scuro/aligner/alignment.py | 48 ----
.../systemds/scuro/dataloader/audio_loader.py | 11 +-
.../scuro/{aligner => drsearch}/__init__.py | 0
.../scuro/{aligner => drsearch}/dr_search.py | 4 +-
.../systemds/scuro/drsearch/fusion_optimizer.py | 295 +++++++++++++++++++++
.../scuro/drsearch/hyperparameter_tuner.py | 106 ++++++++
.../systemds/scuro/drsearch/operator_registry.py | 107 ++++++++
.../systemds/scuro/drsearch/optimization_data.py | 164 ++++++++++++
.../scuro/drsearch/representation_cache.py | 127 +++++++++
.../{aligner => drsearch}/similarity_measures.py | 0
.../systemds/scuro/{aligner => drsearch}/task.py | 22 +-
.../drsearch/unimodal_representation_optimizer.py | 271 +++++++++++++++++++
src/main/python/systemds/scuro/main.py | 4 +-
src/main/python/systemds/scuro/modality/joined.py | 6 +-
.../python/systemds/scuro/modality/modality.py | 2 +-
.../systemds/scuro/modality/modality_identifier.py | 7 -
.../python/systemds/scuro/modality/transformed.py | 5 +-
.../systemds/scuro/modality/unimodal_modality.py | 1 -
.../systemds/scuro/representations/aggregate.py | 30 ++-
.../aggregated_representation.py} | 29 +-
.../systemds/scuro/representations/average.py | 5 +
.../python/systemds/scuro/representations/bert.py | 11 +-
.../python/systemds/scuro/representations/bow.py | 2 +
.../scuro/representations/concatenation.py | 3 +
.../systemds/scuro/representations/context.py | 1 -
.../python/systemds/scuro/representations/glove.py | 4 +-
.../python/systemds/scuro/representations/lstm.py | 3 +
.../python/systemds/scuro/representations/max.py | 3 +
.../scuro/representations/mel_spectrogram.py | 13 +-
.../{mel_spectrogram.py => mfcc.py} | 38 ++-
.../scuro/representations/multiplication.py | 3 +
.../systemds/scuro/representations/optical_flow.py | 79 ++++++
.../systemds/scuro/representations/resnet.py | 84 ++----
.../systemds/scuro/representations/rowmax.py | 78 ------
.../{mel_spectrogram.py => spectrogram.py} | 23 +-
.../python/systemds/scuro/representations/sum.py | 3 +
.../representations/swin_video_transformer.py | 111 ++++++++
.../python/systemds/scuro/representations/tfidf.py | 2 +
.../{mel_spectrogram.py => wav2vec.py} | 52 ++--
.../systemds/scuro/representations/window.py | 6 +-
.../systemds/scuro/representations/word2vec.py | 6 +-
.../scuro/representations/{resnet.py => x3d.py} | 123 +++------
.../python/systemds/scuro/utils/schema_helpers.py | 1 -
.../python/systemds/scuro/utils/torch_dataset.py | 63 +++++
src/main/python/tests/scuro/data_generator.py | 12 +-
src/main/python/tests/scuro/test_dr_search.py | 4 +-
.../python/tests/scuro/test_multimodal_fusion.py | 202 ++++++++++++++
.../python/tests/scuro/test_multimodal_join.py | 2 -
.../python/tests/scuro/test_operator_registry.py | 87 ++++++
.../python/tests/scuro/test_unimodal_optimizer.py | 203 ++++++++++++++
51 files changed, 2133 insertions(+), 416 deletions(-)
diff --git a/src/main/python/systemds/scuro/__init__.py
b/src/main/python/systemds/scuro/__init__.py
index 53b68d430f..4b2185316a 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -24,27 +24,55 @@ from systemds.scuro.dataloader.video_loader import
VideoLoader
from systemds.scuro.dataloader.text_loader import TextLoader
from systemds.scuro.dataloader.json_loader import JSONLoader
from systemds.scuro.representations.representation import Representation
+from systemds.scuro.representations.aggregate import Aggregation
+from systemds.scuro.representations.aggregated_representation import (
+ AggregatedRepresentation,
+)
from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.bow import BoW
from systemds.scuro.representations.concatenation import Concatenation
-from systemds.scuro.representations.sum import Sum
+from systemds.scuro.representations.context import Context
+from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.representations.glove import GloVe
+from systemds.scuro.representations.lstm import LSTM
from systemds.scuro.representations.max import RowMax
-from systemds.scuro.representations.multiplication import Multiplication
from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.mfcc import MFCC
+from systemds.scuro.representations.multiplication import Multiplication
+from systemds.scuro.representations.optical_flow import OpticalFlow
+from systemds.scuro.representations.representation import Representation
+from systemds.scuro.representations.representation_dataloader import NPY
+from systemds.scuro.representations.representation_dataloader import JSON
+from systemds.scuro.representations.representation_dataloader import Pickle
from systemds.scuro.representations.resnet import ResNet
-from systemds.scuro.representations.bert import Bert
-from systemds.scuro.representations.lstm import LSTM
-from systemds.scuro.representations.bow import BoW
-from systemds.scuro.representations.glove import GloVe
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.sum import Sum
+from systemds.scuro.representations.swin_video_transformer import
SwinVideoTransformer
from systemds.scuro.representations.tfidf import TfIdf
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.representations.wav2vec import Wav2Vec
+from systemds.scuro.representations.window import WindowAggregation
from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.representations.x3d import X3D
from systemds.scuro.models.model import Model
from systemds.scuro.models.discrete_model import DiscreteModel
+from systemds.scuro.modality.joined import JoinedModality
+from systemds.scuro.modality.joined_transformed import
JoinedTransformedModality
from systemds.scuro.modality.modality import Modality
-from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.modality.modality_identifier import ModalityIdentifier
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.modality.type import ModalityType
-from systemds.scuro.aligner.dr_search import DRSearch
-from systemds.scuro.aligner.task import Task
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.drsearch.dr_search import DRSearch
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.fusion_optimizer import FusionOptimizer
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.drsearch.optimization_data import OptimizationData
+from systemds.scuro.drsearch.representation_cache import RepresentationCache
+from systemds.scuro.drsearch.unimodal_representation_optimizer import (
+ UnimodalRepresentationOptimizer,
+)
__all__ = [
@@ -53,25 +81,50 @@ __all__ = [
"VideoLoader",
"TextLoader",
"Representation",
+ "Aggregation",
+ "AggregatedRepresentation",
"Average",
+ "Bert",
+ "BoW",
"Concatenation",
- "Sum",
+ "Context",
+ "Fusion",
+ "GloVe",
+ "LSTM",
"RowMax",
- "Multiplication",
"MelSpectrogram",
+ "MFCC",
+ "Multiplication",
+ "OpticalFlow",
+ "Representation",
+ "NPY",
+ "JSON",
+ "Pickle",
"ResNet",
- "Bert",
- "LSTM",
+ "Spectrogram",
+ "Sum",
"BoW",
- "GloVe",
+ "SwinVideoTransformer",
"TfIdf",
+ "UnimodalRepresentation",
+ "Wav2Vec",
+ "WindowAggregation",
"W2V",
+ "X3D",
"Model",
"DiscreteModel",
+ "JoinedModality",
+ "JoinedTransformedModality",
"Modality",
- "UnimodalModality",
+ "ModalityIdentifier",
"TransformedModality",
"ModalityType",
+ "UnimodalModality",
"DRSearch",
"Task",
+ "FusionOptimizer",
+ "Registry",
+ "OptimizationData",
+ "RepresentationCache",
+ "UnimodalRepresentationOptimizer",
]
diff --git a/src/main/python/systemds/scuro/aligner/alignment.py
b/src/main/python/systemds/scuro/aligner/alignment.py
deleted file mode 100644
index 62f88a272b..0000000000
--- a/src/main/python/systemds/scuro/aligner/alignment.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# -------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-# -------------------------------------------------------------
-from aligner.alignment_strategy import AlignmentStrategy
-from modality.modality import Modality
-from modality.representation import Representation
-from aligner.similarity_measures import Measure
-
-
-class Alignment:
- def __init__(
- self,
- modality_a: Modality,
- modality_b: Modality,
- strategy: AlignmentStrategy,
- similarity_measure: Measure,
- ):
- """
- Defines the core of the library where the alignment of two modalities
is performed
- :param modality_a: first modality
- :param modality_b: second modality
- :param strategy: the alignment strategy used in the alignment process
- :param similarity_measure: the similarity measure used to check the
score of the alignment
- """
- self.modality_a = modality_a
- self.modality_b = modality_b
- self.strategy = strategy
- self.similarity_measure = similarity_measure
-
- def align_modalities(self) -> Modality:
- return Modality(Representation())
diff --git a/src/main/python/systemds/scuro/dataloader/audio_loader.py
b/src/main/python/systemds/scuro/dataloader/audio_loader.py
index a6a164b4fb..a008962680 100644
--- a/src/main/python/systemds/scuro/dataloader/audio_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/audio_loader.py
@@ -27,13 +27,22 @@ from systemds.scuro.modality.type import ModalityType
class AudioLoader(BaseLoader):
def __init__(
- self, source_path: str, indices: List[str], chunk_size: Optional[int]
= None
+ self,
+ source_path: str,
+ indices: List[str],
+ chunk_size: Optional[int] = None,
+ normalize: bool = True,
):
super().__init__(source_path, indices, chunk_size, ModalityType.AUDIO)
+ self.normalize = normalize
def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
self.file_sanity_check(file)
audio, sr = librosa.load(file)
+
+ if self.normalize:
+ audio = librosa.util.normalize(audio)
+
self.metadata[file] = self.modality_type.create_audio_metadata(sr,
audio)
self.data.append(audio)
diff --git a/src/main/python/systemds/scuro/aligner/__init__.py
b/src/main/python/systemds/scuro/drsearch/__init__.py
similarity index 100%
rename from src/main/python/systemds/scuro/aligner/__init__.py
rename to src/main/python/systemds/scuro/drsearch/__init__.py
diff --git a/src/main/python/systemds/scuro/aligner/dr_search.py
b/src/main/python/systemds/scuro/drsearch/dr_search.py
similarity index 98%
rename from src/main/python/systemds/scuro/aligner/dr_search.py
rename to src/main/python/systemds/scuro/drsearch/dr_search.py
index b46139dff3..2000608a1d 100644
--- a/src/main/python/systemds/scuro/aligner/dr_search.py
+++ b/src/main/python/systemds/scuro/drsearch/dr_search.py
@@ -22,7 +22,7 @@ import itertools
import random
from typing import List
-from systemds.scuro.aligner.task import Task
+from systemds.scuro.drsearch.task import Task
from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.representation import Representation
@@ -111,7 +111,7 @@ class DRSearch:
representation = random.choice(self.representations)
modality = modality_combination[0].combine(
- modality_combination[1:], representation
+ list(modality_combination[1:]), representation
)
scores = self.task.run(modality.data)
diff --git a/src/main/python/systemds/scuro/drsearch/fusion_optimizer.py
b/src/main/python/systemds/scuro/drsearch/fusion_optimizer.py
new file mode 100644
index 0000000000..7247720f55
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/fusion_optimizer.py
@@ -0,0 +1,295 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import time
+import copy
+import pickle
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.drsearch.optimization_data import (
+ OptimizationResult,
+ OptimizationStatistics,
+)
+from systemds.scuro.drsearch.representation_cache import RepresentationCache
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.representations.aggregate import Aggregation
+from systemds.scuro.representations.context import Context
+
+
+def extract_names(operator_chain):
+ result = []
+ for op in operator_chain:
+ result.append(op.name)
+
+ return result
+
+
+class FusionOptimizer:
+ def __init__(
+ self,
+ modalities,
+ task: Task,
+ unimodal_representations_candidates,
+ representation_cache: RepresentationCache,
+ num_best_candidates=4,
+ max_chain_depth=5,
+ debug=False,
+ ):
+ self.modalities = modalities
+ self.task = task
+ self.unimodal_representations_candidates =
unimodal_representations_candidates
+ self.num_best_candidates = num_best_candidates
+ self.k_best_candidates, self.candidates_per_modality =
self.get_k_best_results(
+ num_best_candidates
+ )
+ self.operator_registry = Registry()
+ self.max_chain_depth = max_chain_depth
+ self.debug = debug
+ self.evaluated_candidates = set()
+ self.cache = representation_cache
+ self.optimization_statistics =
OptimizationStatistics(self.k_best_candidates)
+ self.optimization_results = []
+
+ def optimize(self):
+ """
+ This method finds different ways in how to combine modalities and
evaluates the fused representations against
+ the given task. It can fuse different representations from the same
modality as well as fuse representations
+ form different modalities.
+ """
+
+ # TODO: add an aligned representation for all modalities with a
temporal dimension
+ # TODO: keep a map of operator chains so that we don't evaluate them
multiple times in different orders (if it does not make a difference)
+
+ r = []
+
+ for candidate in self.k_best_candidates:
+ modality = self.candidates_per_modality[str(candidate)]
+ cached_representation, representation_ops, used_op_names = (
+ self.cache.load_from_cache(modality, candidate.operator_chain)
+ )
+ if cached_representation is not None:
+ modality = cached_representation
+ store = False
+ for representation in representation_ops:
+ if isinstance(representation, Context):
+ modality = modality.context(representation)
+ elif representation.name == "RowWiseConcatenation":
+ modality = modality.flatten(True)
+ else:
+ modality = modality.apply_representation(representation)
+ store = True
+ if store:
+ self.cache.save_to_cache(modality, used_op_names,
representation_ops)
+
+ remaining_candidates = [c for c in self.k_best_candidates if c !=
candidate]
+ r.append(
+ self._optimize_candidate(modality, candidate,
remaining_candidates, 1)
+ )
+
+ if self.debug:
+ with open(
+
f"fusion_statistics_{self.task.model.name}_{self.num_best_candidates}_{self.max_chain_depth}.pkl",
+ "wb",
+ ) as fp:
+ pickle.dump(
+ self.optimization_statistics,
+ fp,
+ protocol=pickle.HIGHEST_PROTOCOL,
+ )
+
+ opt_results = copy.deepcopy(self.optimization_results)
+ for i, opt_res in enumerate(self.optimization_results):
+ op_name = []
+ for op in opt_res.operator_chain:
+ if isinstance(op, list):
+ for o in op:
+ if isinstance(o, list):
+ for j in o:
+ op_name.append(j.name)
+ elif isinstance(o, str):
+ op_name.append(o)
+ else:
+ op_name.append(o.name)
+ elif isinstance(op, str):
+ op_name.append(op)
+ else:
+ op_name.append(op.name)
+ opt_results[i].operator_chain = op_name
+ with open(
+
f"fusion_results_{self.task.model.name}_{self.num_best_candidates}_{self.max_chain_depth}.pkl",
+ "wb",
+ ) as fp:
+ pickle.dump(opt_results, fp, protocol=pickle.HIGHEST_PROTOCOL)
+
+ self.optimization_statistics.print_statistics()
+
+ def get_k_best_results(self, k: int):
+ """
+ Get the k best results per modality
+ :param k: number of best results
+ """
+ best_results = []
+ candidate_for_modality = {}
+ for modality in self.modalities:
+ k_results = sorted(
+ self.unimodal_representations_candidates[modality.modality_id][
+ self.task.model.name
+ ],
+ key=lambda x: x.test_accuracy,
+ reverse=True,
+ )[:k]
+ for k_result in k_results:
+ candidate_for_modality[str(k_result)] = modality
+ best_results.extend(k_results)
+
+ return best_results, candidate_for_modality
+
+ def _optimize_candidate(
+ self, modality, candidate, remaining_candidates, chain_depth
+ ):
+ """
+ Optimize a single candidate by fusing it with others recursively.
+
+ :param candidate: The current candidate representation.
+ :param chain_depth: The current depth of fusion chains.
+ """
+ if chain_depth > self.max_chain_depth:
+ return
+
+ for other_candidate in remaining_candidates:
+ other_modality = self.candidates_per_modality[str(other_candidate)]
+ cached_representation, representation_ops, used_op_names = (
+ self.cache.load_from_cache(
+ other_modality, other_candidate.operator_chain
+ )
+ )
+ if cached_representation is not None:
+ other_modality = cached_representation
+ store = False
+ for representation in representation_ops:
+ if representation.name == "Aggregation":
+ params = other_candidate.parameters[representation.name]
+ representation = Aggregation(
+ aggregation_function=params["aggregation"]
+ )
+ if isinstance(representation, Context):
+ other_modality = other_modality.context(representation)
+ elif isinstance(representation, Aggregation):
+ other_modality = representation.execute(other_modality)
+ elif representation.name == "RowWiseConcatenation":
+ other_modality = other_modality.flatten(True)
+ else:
+ other_modality =
other_modality.apply_representation(representation)
+ store = True
+ if store:
+ self.cache.save_to_cache(
+ other_modality, used_op_names, representation_ops
+ )
+
+ fusion_results = self.operator_registry.get_fusion_operators()
+ fusion_representation = None
+ for fusion_operator in fusion_results:
+ fusion_operator = fusion_operator()
+ chain_key = self.create_identifier(
+ candidate, fusion_operator, other_candidate
+ )
+ # print(fusion_operator.name)
+ representation_start = time.time()
+ if (
+ isinstance(fusion_operator, Context)
+ and fusion_representation is not None
+ ):
+ fusion_representation.context(fusion_operator)
+ elif isinstance(fusion_operator, Context):
+ continue
+ else:
+ fused_representation = modality.combine(
+ other_modality, fusion_operator
+ )
+
+ representation_end = time.time()
+ if chain_key not in self.evaluated_candidates:
+ # Evaluate the fused representation
+
+ score = self.task.run(fused_representation.data)
+ fusion_params = {fusion_operator.name:
fusion_operator.parameters}
+ result = OptimizationResult(
+ operator_chain=[
+ candidate.operator_chain,
+ fusion_operator.name,
+ other_candidate.operator_chain,
+ ],
+ parameters=[
+ candidate.parameters,
+ fusion_params,
+ other_candidate.parameters,
+ ],
+ train_accuracy=score[0],
+ test_accuracy=score[1],
+ # train_min_it_acc=score[2],
+ # test_min_it_acc=score[3],
+ training_runtime=self.task.training_time,
+ inference_runtime=self.task.inference_time,
+ representation_time=representation_end -
representation_start,
+ output_shape=(1, 1), # TODO
+ )
+
+ # Store the result
+ self.optimization_results.append(result)
+ self.optimization_statistics.add_entry(
+ [
+ candidate.operator_chain,
+ [fusion_operator.name],
+ other_candidate.operator_chain,
+ ],
+ score[1],
+ )
+
+ # Mark this chain as evaluated
+ self.evaluated_candidates.add(chain_key)
+
+ if self.debug:
+ print(
+ f"Evaluated chain: {candidate.operator_chain} +
{fusion_operator.name} + {other_candidate.operator_chain} -> {score[1]}"
+ )
+
+ # Recursively optimize further with this fused
representation
+ self._optimize_candidate(
+ fused_representation,
+ result,
+ [c for c in remaining_candidates if c !=
other_candidate],
+ chain_depth + 1,
+ )
+
+ def create_identifier(self, candidate, fusion, other_candidate):
+ identifier = "".join(flatten_and_join(candidate.operator_chain))
+ identifier += fusion.name
+ identifier += "".join(flatten_and_join(other_candidate.operator_chain))
+
+ return identifier
+
+
+def flatten_and_join(data):
+ flat_list = []
+ for item in data:
+ if isinstance(item, list):
+ flat_list.extend(flatten_and_join(item))
+ else:
+ flat_list.append(item.name if not isinstance(item, str) else item)
+ return flat_list
diff --git a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
new file mode 100644
index 0000000000..04a3fa4701
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
@@ -0,0 +1,106 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import itertools
+import time
+
+import numpy as np
+
+from systemds.scuro.drsearch.optimization_data import OptimizationResult
+from systemds.scuro.representations.context import Context
+
+
+class HyperparameterTuner:
+ def __init__(self, task, n_trials=10, early_stopping_patience=5):
+ self.task = task
+ self.n_trials = n_trials
+ self.early_stopping_patience = early_stopping_patience
+
+ def tune_operator_chain(self, modality, operator_chain):
+ best_result = None
+ best_score = -np.inf
+
+ param_grids = {}
+
+ for operator in operator_chain:
+ param_grids[operator.name] = operator.parameters
+
+ param_combinations = self._generate_search_space(param_grids)
+
+ for params in param_combinations:
+ modified_modality = modality
+ current_chain = []
+
+ representation_start = time.time()
+ try:
+ for operator in operator_chain:
+
+ if operator.name in params:
+ operator.set_parameters(params[operator.name])
+
+ if isinstance(operator, Context):
+ modified_modality = modified_modality.context(operator)
+ else:
+ modified_modality =
modified_modality.apply_representation(
+ operator
+ )
+
+ current_chain.append(operator)
+
+ representation_end = time.time()
+
+ score = self.task.run(modified_modality.data)
+
+ if score[1] > best_score:
+ best_score = score[1]
+ best_params = params
+ best_result = OptimizationResult(
+ operator_chain=current_chain,
+ parameters=params,
+ train_accuracy=score[0],
+ test_accuracy=score[1],
+ training_runtime=self.task.training_time,
+ inference_runtime=self.task.inference_time,
+ representation_time=representation_end -
representation_start,
+ output_shape=(1, 1),
+ )
+
+ except Exception as e:
+ print(f"Failed parameter combination {params}: {str(e)}")
+ continue
+
+ return best_result
+
+ def _generate_search_space(self, param_grids):
+ combinations = {}
+ for operator_name, params in param_grids.items():
+ operator_combinations = [
+ dict(zip(params.keys(), v)) for v in
itertools.product(*params.values())
+ ]
+ combinations[operator_name] = operator_combinations
+
+ keys = list(combinations.keys())
+ values = [combinations[key] for key in keys]
+
+ parameter_grid = [
+ dict(zip(keys, combo)) for combo in itertools.product(*values)
+ ]
+
+ return parameter_grid
diff --git a/src/main/python/systemds/scuro/drsearch/operator_registry.py
b/src/main/python/systemds/scuro/drsearch/operator_registry.py
new file mode 100644
index 0000000000..942e5bb80e
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/operator_registry.py
@@ -0,0 +1,107 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from typing import Union, List
+
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.representation import Representation
+
+
+class Registry:
+ """
+ A registry for all representations per modality.
+ The representations are stored in a dictionary where a specific modality
type is the key.
+ Implemented as a singleton.
+ """
+
+ _instance = None
+ _representations = {}
+ _context_operators = []
+ _fusion_operators = []
+
+ def __new__(cls):
+ if not cls._instance:
+ cls._instance = super().__new__(cls)
+ for m_type in ModalityType:
+ cls._representations[m_type] = []
+ return cls._instance
+
+ def add_representation(
+ self, representation: Representation, modality: ModalityType
+ ):
+ self._representations[modality].append(representation)
+
+ def add_context_operator(self, context_operator):
+ self._context_operators.append(context_operator)
+
+ def add_fusion_operator(self, fusion_operator):
+ self._fusion_operators.append(fusion_operator)
+
+ def get_representations(self, modality: ModalityType):
+ return self._representations[modality]
+
+ def get_context_operators(self):
+ return self._context_operators
+
+ def get_fusion_operators(self):
+ return self._fusion_operators
+
+
+def register_representation(modalities: Union[ModalityType,
List[ModalityType]]):
+ """
+ Decorator to register representation for a specific modality.
+ :param modalities: The modalities for which the representation is to be
registered
+ """
+ if isinstance(modalities, ModalityType):
+ modalities = [modalities]
+
+ def decorator(cls):
+ for modality in modalities:
+ if modality not in ModalityType:
+ raise f"Modality {modality} not in ModalityTypes please add it
to constants.py ModalityTypes first!"
+
+ Registry().add_representation(cls, modality)
+ return cls
+
+ return decorator
+
+
+def register_context_operator():
+ """
+ Decorator to register a context operator.
+ """
+
+ def decorator(cls):
+ Registry().add_context_operator(cls)
+ return cls
+
+ return decorator
+
+
+def register_fusion_operator():
+ """
+ Decorator to register a fusion operator.
+ """
+
+ def decorator(cls):
+ Registry().add_fusion_operator(cls)
+ return cls
+
+ return decorator
diff --git a/src/main/python/systemds/scuro/drsearch/optimization_data.py
b/src/main/python/systemds/scuro/drsearch/optimization_data.py
new file mode 100644
index 0000000000..4ca54c10d3
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/optimization_data.py
@@ -0,0 +1,164 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from dataclasses import dataclass
+from typing import List, Dict, Any, Union
+
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.representations.representation import Representation
+
+
+@dataclass
+class OptimizationResult:
+ """
+ The OptimizationResult class stores the results of an individual
optimization
+
+ Attributes:
+ operator_chain (List[str]): stores the name of the operators used in
the optimization run
+ parameters (Dict[str, Any]): stores the parameters used for the
operators in the optimization run
+ accuracy (float): stores the test accuracy of the optimization run
+ training_runtime (float): stores the training runtime of the
optimization run
+ inference_runtime (float): stores the inference runtime of the
optimization run
+ output_shape (tupe): stores the output shape of the data produced by
the optimization run
+ """
+
+ operator_chain: List[Representation]
+ parameters: Union[Dict[str, Any], List[Any]]
+ train_accuracy: float
+ test_accuracy: float
+ # train_min_it_acc: float
+ # test_min_it_acc: float
+ training_runtime: float
+ inference_runtime: float
+ representation_time: float
+ output_shape: tuple
+
+ # def __str__(self):
+ # result_string = ""
+ # for operator in self.operator_chain:
+ # if isinstance(operator, List):
+ # result_string += extract_operator_names(operator)
+ # else:
+ # result_string += operator.name
+ # return result_string
+
+
+@dataclass
+class OptimizationData:
+ representation_name: str
+ mean_accuracy = 0.0
+ min_accuracy = 1.0
+ max_accuracy = 0.0
+ num_times_used = 0
+
+ def add_entry(self, score):
+ self.num_times_used += 1
+ self.min_accuracy = min(score, self.min_accuracy)
+ self.max_accuracy = max(score, self.max_accuracy)
+ if self.num_times_used > 1:
+ self.mean_accuracy += (score - self.mean_accuracy) /
self.num_times_used
+ else:
+ self.mean_accuracy = score
+
+ def __str__(self):
+ return f"Name: {self.representation_name} mean: {self.mean_accuracy}
max: {self.max_accuracy} min: {self.min_accuracy} num_times:
{self.num_times_used}"
+
+
+def extract_names(operator_chain):
+ result = []
+ for op in operator_chain:
+ result.append(op.name if not isinstance(op, str) else op)
+
+ return result
+
+
+class OptimizationStatistics:
+ optimization_data: Dict[str, OptimizationData] = {}
+ fusion_names = []
+
+ def __init__(self, candidates):
+ for candidate in candidates:
+ representation_name =
"".join(extract_names(candidate.operator_chain))
+ self.optimization_data[representation_name] = OptimizationData(
+ representation_name
+ )
+
+ for fusion_method in Registry().get_fusion_operators():
+ self.optimization_data[fusion_method.__name__] = OptimizationData(
+ fusion_method.__name__
+ )
+ self.fusion_names.append(fusion_method.__name__)
+
+ def parse_representation_name(self, name):
+ parts = []
+ current_part = ""
+
+ i = 0
+ while i < len(name):
+ found_fusion = False
+ for fusion in self.fusion_names:
+ if name[i:].startswith(fusion):
+ if current_part:
+ parts.append(current_part)
+ parts.append(fusion)
+ i += len(fusion)
+ found_fusion = True
+ break
+
+ if not found_fusion:
+ current_part += name[i]
+ i += 1
+ else:
+ current_part = ""
+
+ if current_part:
+ parts.append(current_part)
+
+ return parts
+
+ def add_entry(self, representations, score):
+ # names = self.parse_representation_name(representation_name)
+
+ for rep in representations:
+ if isinstance(rep[0], list):
+ for r in rep:
+ name = "".join(extract_names(r))
+ if self.optimization_data.get(name) is None:
+ self.optimization_data[name] = OptimizationData(name)
+ self.optimization_data[name].add_entry(score)
+ else:
+ name = "".join(extract_names(rep))
+ if self.optimization_data.get(name) is None:
+ self.optimization_data[name] = OptimizationData(name)
+ self.optimization_data[name].add_entry(score)
+
+ def print_statistics(self):
+ for statistic in self.optimization_data.values():
+ print(statistic)
+
+
+def extract_operator_names(operators):
+ names = ""
+ for operator in operators:
+ if isinstance(operator, List):
+ names += extract_operator_names(operator)
+ else:
+ names += operator.name
+ return names
diff --git a/src/main/python/systemds/scuro/drsearch/representation_cache.py
b/src/main/python/systemds/scuro/drsearch/representation_cache.py
new file mode 100644
index 0000000000..fc78167f2e
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/representation_cache.py
@@ -0,0 +1,127 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import copy
+import os
+import pickle
+import tempfile
+
+from systemds.scuro.modality.transformed import TransformedModality
+
+
+class RepresentationCache:
+ """ """
+
+ _instance = None
+ _cache_dir = None
+ debug = False
+
+ def __new__(cls, debug=False):
+ if not cls._instance:
+ cls.debug = debug
+ cls._instance = super().__new__(cls)
+ cls._cache_dir = tempfile.TemporaryDirectory()
+ # cls._cache_dir = "representation_cache"
+ return cls._instance
+
+ def _generate_cache_filename(self, modality_id, operators):
+ """
+ Generate a unique filename for an operator based on its name.
+
+ :param operator_name: The name of the operator.
+ :return: A full path to the cache file.
+ """
+ op_names = []
+ filename = modality_id
+ for operator in operators:
+ if isinstance(operator, str):
+ op_names.append(operator)
+ filename += operator
+ else:
+ op_names.append(operator.name)
+ filename += operator.name
+
+ return os.path.join(self._cache_dir.name, filename), op_names #
_cache_dir.name
+
+ def save_to_cache(self, modality, used_op_names, operators):
+ """
+ Save data to a cache file.
+
+ :param operator_name: The name of the operator.
+ :param data: The data to save.
+ """
+ filename, op_names = self._generate_cache_filename(
+ str(modality.modality_id) + used_op_names, operators
+ )
+ if not os.path.exists(filename):
+ with open(f"{filename}.pkl", "wb") as f:
+ pickle.dump(modality.data, f)
+
+ with open(f"{filename}.meta", "wb") as f:
+ pickle.dump(modality.metadata, f)
+
+ if self.debug:
+ str_names = ", ".join(op_names)
+ print(
+ f"Saved data for operator
{str(modality.modality_id)}{used_op_names}{str_names} to cache: {filename}"
+ )
+
+ def load_from_cache(self, modality, operators):
+ """
+ Load data from a cache file if it exists.
+
+ :param operator_name: The name of the operator.
+ :return: The cached data or None if not found.
+ """
+ ops = copy.deepcopy(operators)
+ filename, op_names = self._generate_cache_filename(
+ str(modality.modality_id), ops
+ )
+ dropped_ops = []
+ while not os.path.exists(f"{filename}.pkl"):
+ op_names.pop()
+ dropped_ops.append(ops.pop())
+ if len(ops) < 1:
+ break
+ filename, op_names = self._generate_cache_filename(
+ str(modality.modality_id), ops
+ )
+
+ dropped_ops.reverse()
+ op_names = "".join(op_names)
+
+ if os.path.exists(f"{filename}.pkl"):
+ with open(f"{filename}.meta", "rb") as f:
+ metadata = pickle.load(f)
+
+ transformed_modality = TransformedModality(
+ modality.modality_type, op_names, modality.modality_id,
metadata
+ )
+ data = None
+ with open(f"{filename}.pkl", "rb") as f:
+ if self.debug:
+ print(
+ f"Loaded cached data for operator
'{str(modality.modality_id) + op_names}' from {filename}"
+ )
+ data = pickle.load(f)
+ transformed_modality.data = data
+ return transformed_modality, dropped_ops, op_names
+
+ return None, dropped_ops, op_names
diff --git a/src/main/python/systemds/scuro/aligner/similarity_measures.py
b/src/main/python/systemds/scuro/drsearch/similarity_measures.py
similarity index 100%
rename from src/main/python/systemds/scuro/aligner/similarity_measures.py
rename to src/main/python/systemds/scuro/drsearch/similarity_measures.py
diff --git a/src/main/python/systemds/scuro/aligner/task.py
b/src/main/python/systemds/scuro/drsearch/task.py
similarity index 80%
rename from src/main/python/systemds/scuro/aligner/task.py
rename to src/main/python/systemds/scuro/drsearch/task.py
index f33546ae65..7e05a489e4 100644
--- a/src/main/python/systemds/scuro/aligner/task.py
+++ b/src/main/python/systemds/scuro/drsearch/task.py
@@ -18,6 +18,7 @@
# under the License.
#
# -------------------------------------------------------------
+import time
from typing import List
from systemds.scuro.models.model import Model
@@ -34,6 +35,7 @@ class Task:
train_indices: List,
val_indices: List,
kfold=5,
+ measure_performance=True,
):
"""
Parent class for the prediction task that is performed on top of the
aligned representation
@@ -51,6 +53,10 @@ class Task:
self.train_indices = train_indices
self.val_indices = val_indices
self.kfold = kfold
+ self.measure_performance = measure_performance
+ self.inference_time = []
+ self.training_time = []
+ self.expected_dim = 1
def get_train_test_split(self, data):
X_train = [data[i] for i in self.train_indices]
@@ -67,6 +73,8 @@ class Task:
:param data: The aligned data used in the prediction process
:return: the validation accuracy
"""
+ self.inference_time = []
+ self.training_time = []
skf = KFold(n_splits=self.kfold, shuffle=True, random_state=11)
train_scores = []
test_scores = []
@@ -76,13 +84,21 @@ class Task:
for train, test in skf.split(X, y):
train_X = np.array(X)[train]
train_y = np.array(y)[train]
-
+ train_start = time.time()
train_score = self.model.fit(train_X, train_y, X_test, y_test)
+ train_end = time.time()
+ self.training_time.append(train_end - train_start)
train_scores.append(train_score)
-
- test_score = self.model.test(X_test, y_test)
+ test_start = time.time()
+ test_score = self.model.test(np.array(X_test), y_test)
+ test_end = time.time()
+ self.inference_time.append(test_end - test_start)
test_scores.append(test_score)
fold += 1
+ if self.measure_performance:
+ self.inference_time = np.mean(self.inference_time)
+ self.training_time = np.mean(self.training_time)
+
return [np.mean(train_scores), np.mean(test_scores)]
diff --git
a/src/main/python/systemds/scuro/drsearch/unimodal_representation_optimizer.py
b/src/main/python/systemds/scuro/drsearch/unimodal_representation_optimizer.py
new file mode 100644
index 0000000000..e59ddbe9be
--- /dev/null
+++
b/src/main/python/systemds/scuro/drsearch/unimodal_representation_optimizer.py
@@ -0,0 +1,271 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import copy
+import os
+import pickle
+import time
+from typing import List
+
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.drsearch.optimization_data import OptimizationResult
+from systemds.scuro.drsearch.representation_cache import RepresentationCache
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.aggregate import Aggregation
+from systemds.scuro.representations.context import Context
+
+
+class UnimodalRepresentationOptimizer:
+ def __init__(
+ self,
+ modalities: List[Modality],
+ tasks: List[Task],
+ max_chain_depth=5,
+ debug=False,
+ folder_name=None,
+ ):
+ self.optimization_results = {}
+ self.modalities = modalities
+ self.tasks = tasks
+ self.operator_registry = Registry()
+ self.initialize_optimization_results()
+ self.max_chain_depth = max_chain_depth
+ self.debug = debug
+ self.cache = RepresentationCache(self.debug)
+ if self.debug:
+ self.folder_name = folder_name
+ os.makedirs(self.folder_name, exist_ok=True)
+
+ def initialize_optimization_results(self):
+ for modality in self.modalities:
+ self.optimization_results[modality.modality_id] = {}
+ for task in self.tasks:
+
self.optimization_results[modality.modality_id][task.model.name] = []
+
+ def optimize(self):
+ """
+ This method finds different unimodal representations for all given
modalities
+ """
+
+ for modality in self.modalities:
+ self._optimize_modality(modality)
+
+ copy_results = copy.deepcopy(
+ self.optimization_results[modality.modality_id]
+ )
+ for model in copy_results:
+ for i, model_task in enumerate(copy_results[model]):
+ ops = []
+ for op in model_task.operator_chain:
+ if not isinstance(op, str):
+ ops.append(op.name)
+ if len(ops) > 0:
+ copy_results[model][i].operator_chain = ops
+ if self.debug:
+ with open(
+
f"{self.folder_name}/results_{model}_{modality.modality_type.name}.p",
+ "wb",
+ ) as fp:
+ pickle.dump(
+ copy_results[model], fp,
protocol=pickle.HIGHEST_PROTOCOL
+ )
+
+ def get_k_best_results(self, modality: Modality, k: int, task: Task):
+ """
+ Get the k best results for the given modality
+ :param modality: modality to get the best results for
+ :param k: number of best results
+ """
+ results = sorted(
+ self.optimization_results[modality.modality_id][task.model.name],
+ key=lambda x: x.test_accuracy,
+ reverse=True,
+ )[:k]
+
+ return results
+
+ def _optimize_modality(self, modality: Modality):
+ """
+ Optimize a single modality by leveraging modality specific heuristics
and incorporating context and
+ stores the resulting operation chains as optimization results.
+ :param modality: modality to optimize
+ """
+
+ representations =
self._get_compatible_operators(modality.modality_type, [])
+
+ for rep in representations:
+ self._build_operator_chain(modality, [rep()], 1)
+
+ def _get_compatible_operators(self, modality_type, used_operators):
+ next_operators = []
+ for operator in
self.operator_registry.get_representations(modality_type):
+ if operator.__name__ not in used_operators:
+ next_operators.append(operator)
+
+ for context_operator in self.operator_registry.get_context_operators():
+ if (
+ len(used_operators) == 0
+ or context_operator.__name__ not in used_operators[-1]
+ ):
+ next_operators.append(context_operator)
+
+ return next_operators
+
+ def _build_operator_chain(self, modality, current_operator_chain, depth):
+
+ if depth > self.max_chain_depth:
+ return
+
+ self._apply_operator_chain(modality, current_operator_chain)
+
+ current_modality_type = modality.modality_type
+
+ for operator in current_operator_chain:
+ if hasattr(operator, "output_modality_type"):
+ current_modality_type = operator.output_modality_type
+
+ next_representations = self._get_compatible_operators(
+ current_modality_type, [type(op).__name__ for op in
current_operator_chain]
+ )
+
+ for next_rep in next_representations:
+ rep_instance = next_rep()
+ new_chain = current_operator_chain + [rep_instance]
+ self._build_operator_chain(modality, new_chain, depth + 1)
+
+ def _evaluate_with_flattened_data(
+ self, modality, operator_chain, op_params, representation_time, task
+ ):
+ from systemds.scuro.representations.aggregated_representation import (
+ AggregatedRepresentation,
+ )
+
+ results = []
+ for aggregation in Aggregation().get_aggregation_functions():
+ start = time.time()
+ agg_operator = AggregatedRepresentation(Aggregation(aggregation,
True))
+ agg_modality = agg_operator.transform(modality)
+ end = time.time()
+
+ agg_opperator_chain = operator_chain + [agg_operator]
+ agg_params = dict(op_params)
+ agg_params.update({agg_operator.name: agg_operator.parameters})
+
+ score = task.run(agg_modality.data)
+ result = OptimizationResult(
+ operator_chain=agg_opperator_chain,
+ parameters=agg_params,
+ train_accuracy=score[0],
+ test_accuracy=score[1],
+ # train_min_it_acc=score[2],
+ # test_min_it_acc=score[3],
+ training_runtime=task.training_time,
+ inference_runtime=task.inference_time,
+ representation_time=representation_time + end - start,
+ output_shape=(1, 1), # TODO
+ )
+ results.append(result)
+
+ if self.debug:
+ op_name = ""
+ for operator in agg_opperator_chain:
+ op_name += str(operator.__class__.__name__)
+ print(f"{task.name} {task.model.name} {op_name}: {score[1]}")
+
+ return results
+
+ def _evaluate_operator_chain(
+ self, modality, operator_chain, op_params, representation_time
+ ):
+ for task in self.tasks:
+ if isinstance(modality.data[0], str):
+ continue
+
+ if (
+ task.expected_dim == 1
+ and not isinstance(modality.data[0], list)
+ and modality.data[0].ndim > 1
+ ):
+ r = self._evaluate_with_flattened_data(
+ modality, operator_chain, op_params, representation_time,
task
+ )
+
self.optimization_results[modality.modality_id][task.model.name].extend(
+ r
+ )
+ else:
+ score = task.run(modality.data)
+ result = OptimizationResult(
+ operator_chain=operator_chain,
+ parameters=op_params,
+ train_accuracy=score[0],
+ test_accuracy=score[1],
+ # train_min_it_acc=score[2],
+ # test_min_it_acc=score[3],
+ training_runtime=task.training_time,
+ inference_runtime=task.inference_time,
+ representation_time=representation_time,
+ output_shape=(1, 1),
+ ) # TODO
+
self.optimization_results[modality.modality_id][task.model.name].append(
+ result
+ )
+ if self.debug:
+ op_name = ""
+ for operator in operator_chain:
+ op_name += str(operator.__class__.__name__)
+ print(f"{task.name} {task.model.name} - {op_name}:
{score[1]}")
+
+ def _apply_operator_chain(self, current_modality, operator_chain):
+ op_params = {}
+ modified_modality = current_modality
+
+ representation_start = time.time()
+ try:
+ cached_representation, representation_ops, used_op_names = (
+ self.cache.load_from_cache(
+ modified_modality, copy.deepcopy(operator_chain)
+ )
+ )
+ if cached_representation is not None:
+ modified_modality = cached_representation
+ store = False
+ for operator in representation_ops:
+ if isinstance(operator, Context):
+ modified_modality = modified_modality.context(operator)
+ else:
+ modified_modality =
modified_modality.apply_representation(operator)
+ store = True
+ op_params[operator.name] = operator.get_current_parameters()
+ if store:
+ self.cache.save_to_cache(
+ modified_modality, used_op_names, representation_ops
+ )
+ representation_end = time.time()
+
+ self._evaluate_operator_chain(
+ modified_modality,
+ operator_chain,
+ op_params,
+ representation_end - representation_start,
+ )
+ except Exception as e:
+ print(f"Failed to evaluate chain {operator_chain}: {str(e)}")
+ return
diff --git a/src/main/python/systemds/scuro/main.py
b/src/main/python/systemds/scuro/main.py
index 8a51e098cc..f88e211157 100644
--- a/src/main/python/systemds/scuro/main.py
+++ b/src/main/python/systemds/scuro/main.py
@@ -25,8 +25,8 @@ from systemds.scuro.representations.average import Average
from systemds.scuro.representations.concatenation import Concatenation
from systemds.scuro.modality.unimodal_modality import UnimodalModality
from systemds.scuro.models.discrete_model import DiscreteModel
-from systemds.scuro.aligner.task import Task
-from systemds.scuro.aligner.dr_search import DRSearch
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.dr_search import DRSearch
from systemds.scuro.dataloader.audio_loader import AudioLoader
from systemds.scuro.dataloader.text_loader import TextLoader
diff --git a/src/main/python/systemds/scuro/modality/joined.py
b/src/main/python/systemds/scuro/modality/joined.py
index c1aa26abf6..1a58df9256 100644
--- a/src/main/python/systemds/scuro/modality/joined.py
+++ b/src/main/python/systemds/scuro/modality/joined.py
@@ -18,13 +18,13 @@
# under the License.
#
# -------------------------------------------------------------
+import importlib
import sys
import numpy as np
from systemds.scuro.modality.joined_transformed import
JoinedTransformedModality
from systemds.scuro.modality.modality import Modality
-from systemds.scuro.representations.aggregate import Aggregation
from systemds.scuro.representations.utils import pad_sequences
@@ -167,7 +167,9 @@ class JoinedModality(Modality):
def aggregate(
self, aggregation_function, field_name
): # TODO: use the filed name to extract data entries from modalities
- self.aggregation = Aggregation(aggregation_function, field_name)
+ module =
importlib.import_module("systemds.scuro.representations.aggregate")
+
+ self.aggregation = module.Aggregation(aggregation_function, field_name)
if not self.chunked_execution and self.joined_right:
return self.aggregation.aggregate(self.joined_right)
diff --git a/src/main/python/systemds/scuro/modality/modality.py
b/src/main/python/systemds/scuro/modality/modality.py
index c110a24eba..c16db00172 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -23,7 +23,7 @@ from typing import List
import numpy as np
-from systemds.scuro.modality.type import ModalityType, DataLayout
+from systemds.scuro.modality.type import ModalityType
from systemds.scuro.representations import utils
diff --git a/src/main/python/systemds/scuro/modality/modality_identifier.py
b/src/main/python/systemds/scuro/modality/modality_identifier.py
index 95668c6e58..5eeee7dc13 100644
--- a/src/main/python/systemds/scuro/modality/modality_identifier.py
+++ b/src/main/python/systemds/scuro/modality/modality_identifier.py
@@ -18,13 +18,6 @@
# under the License.
#
# -------------------------------------------------------------
-import os
-import pickle
-from typing import List, Dict, Any, Union
-import tempfile
-from systemds.scuro.representations.representation import Representation
-
-
class ModalityIdentifier:
""" """
diff --git a/src/main/python/systemds/scuro/modality/transformed.py
b/src/main/python/systemds/scuro/modality/transformed.py
index 2b4b049ef4..aba59c1efb 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -100,7 +100,10 @@ class TransformedModality(Modality):
self.metadata,
)
modalities = [self]
- modalities.extend(other)
+ if isinstance(other, list):
+ modalities.extend(other)
+ else:
+ modalities.append(other)
fused_modality.data = fusion_method.transform(modalities)
return fused_modality
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 6173237e0a..714fe42c33 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -26,7 +26,6 @@ from systemds.scuro.dataloader.base_loader import BaseLoader
from systemds.scuro.modality.modality import Modality
from systemds.scuro.modality.joined import JoinedModality
from systemds.scuro.modality.transformed import TransformedModality
-from systemds.scuro.modality.type import ModalityType
from systemds.scuro.modality.modality_identifier import ModalityIdentifier
diff --git a/src/main/python/systemds/scuro/representations/aggregate.py
b/src/main/python/systemds/scuro/representations/aggregate.py
index 4b4545ef47..756e6271ea 100644
--- a/src/main/python/systemds/scuro/representations/aggregate.py
+++ b/src/main/python/systemds/scuro/representations/aggregate.py
@@ -20,7 +20,6 @@
# -------------------------------------------------------------
import numpy as np
-from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations import utils
@@ -48,21 +47,28 @@ class Aggregation:
"sum": _sum_agg.__func__,
}
- def __init__(self, aggregation_function="mean", pad_modality=False):
+ def __init__(self, aggregation_function="mean", pad_modality=False,
params=None):
+ if params is not None:
+ aggregation_function = params["aggregation_function"]
+ pad_modality = params["pad_modality"]
+
if aggregation_function not in self._aggregation_function.keys():
raise ValueError("Invalid aggregation function")
+
self._aggregation_func =
self._aggregation_function[aggregation_function]
self.name = "Aggregation"
self.pad_modality = pad_modality
+ self.parameters = {
+ "aggregation_function": aggregation_function,
+ "pad_modality": pad_modality,
+ }
+
def execute(self, modality):
- aggregated_modality = Modality(
- modality.modality_type, modality.modality_id, modality.metadata
- )
- aggregated_modality.data = []
+ data = []
max_len = 0
for i, instance in enumerate(modality.data):
- aggregated_modality.data.append([])
+ data.append([])
if isinstance(instance, np.ndarray):
aggregated_data = self._aggregation_func(instance)
else:
@@ -70,22 +76,22 @@ class Aggregation:
for entry in instance:
aggregated_data.append(self._aggregation_func(entry))
max_len = max(max_len, len(aggregated_data))
- aggregated_modality.data[i] = aggregated_data
+ data[i] = aggregated_data
if self.pad_modality:
- for i, instance in enumerate(aggregated_modality.data):
+ for i, instance in enumerate(data):
if isinstance(instance, np.ndarray):
if len(instance) < max_len:
padded_data = np.zeros(max_len, dtype=instance.dtype)
padded_data[: len(instance)] = instance
- aggregated_modality.data[i] = padded_data
+ data[i] = padded_data
else:
padded_data = []
for entry in instance:
padded_data.append(utils.pad_sequences(entry, max_len))
- aggregated_modality.data[i] = padded_data
+ data[i] = padded_data
- return aggregated_modality
+ return data
def transform(self, modality):
return self.execute(modality)
diff --git a/src/main/python/systemds/scuro/aligner/alignment_strategy.py
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
similarity index 59%
rename from src/main/python/systemds/scuro/aligner/alignment_strategy.py
rename to
src/main/python/systemds/scuro/representations/aggregated_representation.py
index 698a6d0d98..46e6b8bed2 100644
--- a/src/main/python/systemds/scuro/aligner/alignment_strategy.py
+++
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
@@ -18,23 +18,18 @@
# under the License.
#
# -------------------------------------------------------------
-from aligner.similarity_measures import Measure
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.representation import Representation
-class AlignmentStrategy:
- def __init__(self):
- pass
+class AggregatedRepresentation(Representation):
+ def __init__(self, aggregation):
+ super().__init__("AggregatedRepresentation", aggregation.parameters)
+ self.aggregation = aggregation
- def align_chunk(self, chunk_a, chunk_b, similarity_measure: Measure):
- raise "Not implemented error"
-
-
-class ChunkedCrossCorrelation(AlignmentStrategy):
- def __init__(self):
- super().__init__()
-
- def align_chunk(self, chunk_a, chunk_b, similarity_measure: Measure):
- raise "Not implemented error"
-
-
-# TODO: Add additional alignment methods
+ def transform(self, modality):
+ aggregated_modality = TransformedModality(
+ modality.modality_type, self.name, modality.modality_id,
modality.metadata
+ )
+ aggregated_modality.data = self.aggregation.execute(modality)
+ return aggregated_modality
diff --git a/src/main/python/systemds/scuro/representations/average.py
b/src/main/python/systemds/scuro/representations/average.py
index db44050e9e..4c6b0e1787 100644
--- a/src/main/python/systemds/scuro/representations/average.py
+++ b/src/main/python/systemds/scuro/representations/average.py
@@ -27,8 +27,10 @@ from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.utils import pad_sequences
from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
+@register_fusion_operator()
class Average(Fusion):
def __init__(self):
"""
@@ -37,6 +39,9 @@ class Average(Fusion):
super().__init__("Average")
def transform(self, modalities: List[Modality]):
+ for modality in modalities:
+ modality.flatten()
+
max_emb_size = self.get_max_embedding_size(modalities)
padded_modalities = []
diff --git a/src/main/python/systemds/scuro/representations/bert.py
b/src/main/python/systemds/scuro/representations/bert.py
index 6395d0b9e6..802d7e3d0b 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -19,16 +19,16 @@
#
# -------------------------------------------------------------
-import numpy as np
-
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
import torch
from transformers import BertTokenizer, BertModel
from systemds.scuro.representations.utils import save_embeddings
from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+@register_representation(ModalityType.TEXT)
class Bert(UnimodalRepresentation):
def __init__(self, model_name="bert", output_file=None):
parameters = {"model_name": "bert"}
@@ -49,7 +49,7 @@ class Bert(UnimodalRepresentation):
model = BertModel.from_pretrained(model_name)
embeddings = self.create_embeddings(modality.data, model, tokenizer)
- embeddings = [embeddings[i : i + 1] for i in
range(embeddings.shape[0])]
+
if self.output_file is not None:
save_embeddings(embeddings, self.output_file)
@@ -65,7 +65,6 @@ class Bert(UnimodalRepresentation):
outputs = model(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0,
:].squeeze().numpy()
- embeddings.append(cls_embedding)
+ embeddings.append(cls_embedding.reshape(1, -1))
- embeddings = np.array(embeddings)
- return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
+ return embeddings
diff --git a/src/main/python/systemds/scuro/representations/bow.py
b/src/main/python/systemds/scuro/representations/bow.py
index 52fddc7d3f..e2bc94041f 100644
--- a/src/main/python/systemds/scuro/representations/bow.py
+++ b/src/main/python/systemds/scuro/representations/bow.py
@@ -26,8 +26,10 @@ from systemds.scuro.representations.unimodal import
UnimodalRepresentation
from systemds.scuro.representations.utils import save_embeddings
from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+@register_representation(ModalityType.TEXT)
class BoW(UnimodalRepresentation):
def __init__(self, ngram_range=2, min_df=2, output_file=None):
parameters = {"ngram_range": [ngram_range], "min_df": [min_df]}
diff --git a/src/main/python/systemds/scuro/representations/concatenation.py
b/src/main/python/systemds/scuro/representations/concatenation.py
index fd9293d399..1265563b6c 100644
--- a/src/main/python/systemds/scuro/representations/concatenation.py
+++ b/src/main/python/systemds/scuro/representations/concatenation.py
@@ -28,7 +28,10 @@ from systemds.scuro.representations.utils import
pad_sequences
from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
+
+@register_fusion_operator()
class Concatenation(Fusion):
def __init__(self, padding=True):
"""
diff --git a/src/main/python/systemds/scuro/representations/context.py
b/src/main/python/systemds/scuro/representations/context.py
index 4cbcf54f8e..54f22633cc 100644
--- a/src/main/python/systemds/scuro/representations/context.py
+++ b/src/main/python/systemds/scuro/representations/context.py
@@ -19,7 +19,6 @@
#
# -------------------------------------------------------------
import abc
-from typing import List
from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.representation import Representation
diff --git a/src/main/python/systemds/scuro/representations/glove.py
b/src/main/python/systemds/scuro/representations/glove.py
index 7bb586dc99..66a6847a94 100644
--- a/src/main/python/systemds/scuro/representations/glove.py
+++ b/src/main/python/systemds/scuro/representations/glove.py
@@ -23,8 +23,9 @@ from gensim.utils import tokenize
from systemds.scuro.representations.unimodal import UnimodalRepresentation
-from systemds.scuro.representations.utils import read_data_from_file,
save_embeddings
+from systemds.scuro.representations.utils import save_embeddings
from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
def load_glove_embeddings(file_path):
@@ -38,6 +39,7 @@ def load_glove_embeddings(file_path):
return embeddings
+# @register_representation(ModalityType.TEXT)
class GloVe(UnimodalRepresentation):
def __init__(self, glove_path, output_file=None):
super().__init__("GloVe", ModalityType.TEXT)
diff --git a/src/main/python/systemds/scuro/representations/lstm.py
b/src/main/python/systemds/scuro/representations/lstm.py
index 6f06e762a5..a82a1e2500 100644
--- a/src/main/python/systemds/scuro/representations/lstm.py
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -28,7 +28,10 @@ import numpy as np
from systemds.scuro.modality.modality import Modality
from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
+
+@register_fusion_operator()
class LSTM(Fusion):
def __init__(self, width=128, depth=1, dropout_rate=0.1):
"""
diff --git a/src/main/python/systemds/scuro/representations/max.py
b/src/main/python/systemds/scuro/representations/max.py
index 194b20801e..5a787dcf0c 100644
--- a/src/main/python/systemds/scuro/representations/max.py
+++ b/src/main/python/systemds/scuro/representations/max.py
@@ -28,7 +28,10 @@ from systemds.scuro.representations.utils import
pad_sequences
from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
+
+@register_fusion_operator()
class RowMax(Fusion):
def __init__(self, split=4):
"""
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
index dfff4f3b7e..4095ceead0 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
@@ -25,8 +25,10 @@ from systemds.scuro.modality.type import ModalityType
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
+@register_representation(ModalityType.AUDIO)
class MelSpectrogram(UnimodalRepresentation):
def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
parameters = {
@@ -45,8 +47,15 @@ class MelSpectrogram(UnimodalRepresentation):
)
result = []
max_length = 0
- for sample in modality.data:
- S = librosa.feature.melspectrogram(y=sample, sr=22050)
+ for i, sample in enumerate(modality.data):
+ sr = list(modality.metadata.values())[i]["frequency"]
+ S = librosa.feature.melspectrogram(
+ y=sample,
+ sr=sr,
+ n_mels=self.n_mels,
+ hop_length=self.hop_length,
+ n_fft=self.n_fft,
+ )
S_dB = librosa.power_to_db(S, ref=np.max)
if S_dB.shape[-1] > max_length:
max_length = S_dB.shape[-1]
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
b/src/main/python/systemds/scuro/representations/mfcc.py
similarity index 60%
copy from src/main/python/systemds/scuro/representations/mel_spectrogram.py
copy to src/main/python/systemds/scuro/representations/mfcc.py
index dfff4f3b7e..75cc00d62d 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mfcc.py
@@ -25,19 +25,23 @@ from systemds.scuro.modality.type import ModalityType
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
-class MelSpectrogram(UnimodalRepresentation):
- def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
+@register_representation(ModalityType.AUDIO)
+class MFCC(UnimodalRepresentation):
+ def __init__(self, n_mfcc=12, dct_type=2, n_mels=128, hop_length=512):
parameters = {
- "n_mels": [20, 32, 64, 128],
+ "n_mfcc": [x for x in range(10, 26)],
+ "dct_type": [1, 2, 3],
"hop_length": [256, 512, 1024, 2048],
- "n_fft": [1024, 2048, 4096],
- }
- super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
+ "n_mels": [20, 32, 64, 128],
+ } # TODO
+ super().__init__("MFCC", ModalityType.TIMESERIES, parameters)
+ self.n_mfcc = n_mfcc
+ self.dct_type = dct_type
self.n_mels = n_mels
self.hop_length = hop_length
- self.n_fft = n_fft
def transform(self, modality):
transformed_modality = TransformedModality(
@@ -45,12 +49,20 @@ class MelSpectrogram(UnimodalRepresentation):
)
result = []
max_length = 0
- for sample in modality.data:
- S = librosa.feature.melspectrogram(y=sample, sr=22050)
- S_dB = librosa.power_to_db(S, ref=np.max)
- if S_dB.shape[-1] > max_length:
- max_length = S_dB.shape[-1]
- result.append(S_dB.T)
+ for i, sample in enumerate(modality.data):
+ sr = list(modality.metadata.values())[i]["frequency"]
+ mfcc = librosa.feature.mfcc(
+ y=sample,
+ sr=sr,
+ n_mfcc=self.n_mfcc,
+ dct_type=self.dct_type,
+ hop_length=self.hop_length,
+ n_mels=self.n_mels,
+ )
+ mfcc = (mfcc - np.mean(mfcc)) / np.std(mfcc)
+ if mfcc.shape[-1] > max_length: # TODO: check if this needs to be
done
+ max_length = mfcc.shape[-1]
+ result.append(mfcc.T)
transformed_modality.data = result
return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/multiplication.py
b/src/main/python/systemds/scuro/representations/multiplication.py
index 2934fe5b3c..8d1e7f8c90 100644
--- a/src/main/python/systemds/scuro/representations/multiplication.py
+++ b/src/main/python/systemds/scuro/representations/multiplication.py
@@ -28,7 +28,10 @@ from systemds.scuro.representations.utils import
pad_sequences
from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
+
+@register_fusion_operator()
class Multiplication(Fusion):
def __init__(self):
"""
diff --git a/src/main/python/systemds/scuro/representations/optical_flow.py
b/src/main/python/systemds/scuro/representations/optical_flow.py
new file mode 100644
index 0000000000..1fb922d7a3
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/optical_flow.py
@@ -0,0 +1,79 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import cv2
+
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from typing import Callable, Dict, Tuple, Any
+import torch.utils.data
+import torch
+import torchvision.models as models
+import numpy as np
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+from systemds.scuro.utils.torch_dataset import CustomDataset
+
+if torch.backends.mps.is_available():
+ DEVICE = torch.device("mps")
+# elif torch.cuda.is_available():
+# DEVICE = torch.device("cuda")
+else:
+ DEVICE = torch.device("cpu")
+
+
+# @register_representation([ModalityType.VIDEO])
+class OpticalFlow(UnimodalRepresentation):
+ def __init__(self):
+ parameters = {}
+ super().__init__("OpticalFlow", ModalityType.TIMESERIES, parameters)
+
+ def transform(self, modality):
+ transformed_modality = TransformedModality(
+ self.output_modality_type,
+ "opticalFlow",
+ modality.modality_id,
+ modality.metadata,
+ )
+
+ for video_id, instance in enumerate(modality.data):
+ transformed_modality.data.append([])
+
+ previous_gray = cv2.cvtColor(instance[0], cv2.COLOR_BGR2GRAY)
+ for frame_id in range(1, len(instance)):
+ gray = cv2.cvtColor(instance[frame_id], cv2.COLOR_BGR2GRAY)
+
+ flow = cv2.calcOpticalFlowFarneback(
+ previous_gray,
+ gray,
+ None,
+ pyr_scale=0.5,
+ levels=3,
+ winsize=15,
+ iterations=3,
+ poly_n=5,
+ poly_sigma=1.1,
+ flags=0,
+ )
+
+ transformed_modality.data[video_id].append(flow)
+ transformed_modality.update_metadata()
+ return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/resnet.py
b/src/main/python/systemds/scuro/representations/resnet.py
index 60eed9ea12..68771eccdd 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/resnet.py
@@ -18,14 +18,14 @@
# under the License.
#
# -------------------------------------------------------------
-
+from systemds.scuro.utils.torch_dataset import CustomDataset
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from typing import Callable, Dict, Tuple, Any
+from systemds.scuro.drsearch.operator_registry import register_representation
import torch.utils.data
import torch
import torchvision.models as models
-import torchvision.transforms as transforms
import numpy as np
from systemds.scuro.modality.type import ModalityType
@@ -37,17 +37,19 @@ else:
DEVICE = torch.device("cpu")
+@register_representation(
+ [ModalityType.IMAGE, ModalityType.VIDEO, ModalityType.TIMESERIES]
+)
class ResNet(UnimodalRepresentation):
def __init__(self, layer="avgpool", model_name="ResNet18",
output_file=None):
self.model_name = model_name
parameters = self._get_parameters()
super().__init__(
"ResNet", ModalityType.TIMESERIES, parameters
- ) # TODO: TIMESERIES only for videos - images would be handled as
EMBEDDIGN
+ ) # TODO: TIMESERIES only for videos - images would be handled as
EMBEDDING
self.output_file = output_file
self.layer_name = layer
- self.model = model_name
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
@@ -59,29 +61,30 @@ class ResNet(UnimodalRepresentation):
self.model.fc = Identity()
@property
- def model(self):
- return self._model
-
- @model.setter
- def model(self, model):
- if model == "ResNet18":
- self._model =
models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(
+ def model_name(self):
+ return self._model_name
+
+ @model_name.setter
+ def model_name(self, model_name):
+ self._model_name = model_name
+ if model_name == "ResNet18":
+ self.model =
models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(
DEVICE
)
- elif model == "ResNet34":
- self._model =
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
+ elif model_name == "ResNet34":
+ self.model =
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
DEVICE
)
- elif model == "ResNet50":
- self._model =
models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(
+ elif model_name == "ResNet50":
+ self.model =
models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(
DEVICE
)
- elif model == "ResNet101":
- self._model =
models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(
+ elif model_name == "ResNet101":
+ self.model =
models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(
DEVICE
)
- elif model == "ResNet152":
- self._model =
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(
+ elif model_name == "ResNet152":
+ self.model =
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(
DEVICE
)
else:
@@ -107,20 +110,7 @@ class ResNet(UnimodalRepresentation):
return parameters
def transform(self, modality):
-
- t = transforms.Compose(
- [
- transforms.ToPILImage(),
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- ),
- ]
- )
-
- dataset = ResNetDataset(modality.data, t)
+ dataset = CustomDataset(modality.data)
embeddings = {}
res5c_output = None
@@ -168,31 +158,3 @@ class ResNet(UnimodalRepresentation):
transformed_modality.data = list(embeddings.values())
return transformed_modality
-
-
-class ResNetDataset(torch.utils.data.Dataset):
- def __init__(self, data: str, tf: Callable = None):
- self.data = data
- self.tf = tf
-
- def __getitem__(self, index) -> Dict[str, object]:
- data = self.data[index]
- if type(data) is np.ndarray:
- output = torch.empty((1, 3, 224, 224))
- d = torch.tensor(data)
- d = d.repeat(3, 1, 1)
- output[0] = self.tf(d)
- else:
- output = torch.empty((len(data), 3, 224, 224))
-
- for i, d in enumerate(data):
- if data[0].ndim < 3:
- d = torch.tensor(d)
- d = d.repeat(3, 1, 1)
-
- output[i] = self.tf(d)
-
- return {"id": index, "data": output}
-
- def __len__(self) -> int:
- return len(self.data)
diff --git a/src/main/python/systemds/scuro/representations/rowmax.py
b/src/main/python/systemds/scuro/representations/rowmax.py
deleted file mode 100644
index 3152782026..0000000000
--- a/src/main/python/systemds/scuro/representations/rowmax.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# -------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-# -------------------------------------------------------------
-import itertools
-from typing import List
-
-import numpy as np
-
-from systemds.scuro.modality.modality import Modality
-from systemds.scuro.representations.utils import pad_sequences
-
-from systemds.scuro.representations.fusion import Fusion
-
-
-class RowMax(Fusion):
- def __init__(self, split=1):
- """
- Combines modalities by computing the outer product of a modality
combination and
- taking the row max
- """
- super().__init__("RowMax")
- self.split = split
-
- def transform(self, modalities: List[Modality]):
- if len(modalities) < 2:
- return np.array(modalities)
-
- max_emb_size = self.get_max_embedding_size(modalities)
-
- padded_modalities = []
- for modality in modalities:
- d = pad_sequences(modality.data, maxlen=max_emb_size,
dtype="float32")
- padded_modalities.append(d)
-
- split_rows = int(len(modalities[0].data) / self.split)
-
- data = []
-
- for combination in itertools.combinations(padded_modalities, 2):
- combined = None
- for i in range(0, self.split):
- start = split_rows * i
- end = (
- split_rows * (i + 1)
- if i < (self.split - 1)
- else len(modalities[0].data)
- )
- m = np.einsum(
- "bi,bo->bio", combination[0][start:end],
combination[1][start:end]
- )
- m = m.max(axis=2)
- if combined is None:
- combined = m
- else:
- combined = np.concatenate((combined, m), axis=0)
- data.append(combined)
-
- data = np.stack(data)
- data = data.max(axis=0)
-
- return np.array(data)
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
b/src/main/python/systemds/scuro/representations/spectrogram.py
similarity index 72%
copy from src/main/python/systemds/scuro/representations/mel_spectrogram.py
copy to src/main/python/systemds/scuro/representations/spectrogram.py
index dfff4f3b7e..b5558b1b26 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/spectrogram.py
@@ -25,17 +25,14 @@ from systemds.scuro.modality.type import ModalityType
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
-class MelSpectrogram(UnimodalRepresentation):
- def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
- parameters = {
- "n_mels": [20, 32, 64, 128],
- "hop_length": [256, 512, 1024, 2048],
- "n_fft": [1024, 2048, 4096],
- }
- super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
- self.n_mels = n_mels
+@register_representation(ModalityType.AUDIO)
+class Spectrogram(UnimodalRepresentation):
+ def __init__(self, hop_length=512, n_fft=2048):
+ parameters = {"hop_length": [256, 512, 1024, 2048], "n_fft": [1024,
2048, 4096]}
+ super().__init__("Spectrogram", ModalityType.TIMESERIES, parameters)
self.hop_length = hop_length
self.n_fft = n_fft
@@ -45,9 +42,11 @@ class MelSpectrogram(UnimodalRepresentation):
)
result = []
max_length = 0
- for sample in modality.data:
- S = librosa.feature.melspectrogram(y=sample, sr=22050)
- S_dB = librosa.power_to_db(S, ref=np.max)
+ for i, sample in enumerate(modality.data):
+ spectrogram = librosa.stft(
+ y=sample, hop_length=self.hop_length, n_fft=self.n_fft
+ )
+ S_dB = librosa.amplitude_to_db(np.abs(spectrogram))
if S_dB.shape[-1] > max_length:
max_length = S_dB.shape[-1]
result.append(S_dB.T)
diff --git a/src/main/python/systemds/scuro/representations/sum.py
b/src/main/python/systemds/scuro/representations/sum.py
index 0608338a0f..46d93f2eda 100644
--- a/src/main/python/systemds/scuro/representations/sum.py
+++ b/src/main/python/systemds/scuro/representations/sum.py
@@ -27,7 +27,10 @@ from systemds.scuro.representations.utils import
pad_sequences
from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.drsearch.operator_registry import register_fusion_operator
+
+@register_fusion_operator()
class Sum(Fusion):
def __init__(self):
"""
diff --git
a/src/main/python/systemds/scuro/representations/swin_video_transformer.py
b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
new file mode 100644
index 0000000000..19b2fd05c4
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
@@ -0,0 +1,111 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+# from torchvision.models.video.swin_transformer import swin3d_t
+
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from typing import Callable, Dict, Tuple, Any
+import torch.utils.data
+import torch
+import torchvision.models as models
+import numpy as np
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+from systemds.scuro.utils.torch_dataset import CustomDataset
+
+if torch.backends.mps.is_available():
+ DEVICE = torch.device("mps")
+# elif torch.cuda.is_available():
+# DEVICE = torch.device("cuda")
+else:
+ DEVICE = torch.device("cpu")
+
+
+# @register_representation([ModalityType.VIDEO])
+class SwinVideoTransformer(UnimodalRepresentation):
+ def __init__(self, layer_name="avgpool"):
+ parameters = {
+ "layer_name": [
+ "features",
+ "features.1",
+ "features.2",
+ "features.3",
+ "features.4",
+ "features.5",
+ "features.6",
+ "avgpool",
+ ],
+ }
+ super().__init__("SwinVideoTransformer", ModalityType.TIMESERIES,
parameters)
+ self.layer_name = layer_name
+ # self.model =
swin3d_t(weights=models.video.Swin3D_T_Weights).to(DEVICE)
+ self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ def transform(self, modality):
+ # model = swin3d_t(weights=models.video.Swin3D_T_Weights)
+
+ embeddings = {}
+ swin_output = None
+
+ def get_features(name_):
+ def hook(
+ _module: torch.nn.Module, input_: Tuple[torch.Tensor], output:
Any
+ ):
+ nonlocal swin_output
+ swin_output = output
+
+ return hook
+
+ if self.layer_name:
+ for name, layer in self.model.named_modules():
+ if name == self.layer_name:
+ layer.register_forward_hook(get_features(name))
+ break
+ dataset = CustomDataset(modality.data)
+
+ for instance in dataset:
+ video_id = instance["id"]
+ frames = instance["data"].to(DEVICE)
+ embeddings[video_id] = []
+
+ frames = frames.unsqueeze(0).permute(0, 2, 1, 3, 4)
+
+ _ = self.model(frames)
+ values = swin_output
+ pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 1))
+
+ embeddings[video_id].extend(torch.flatten(pooled,
1).detach().cpu().numpy())
+
+ embeddings[video_id] = np.array(embeddings[video_id])
+
+ transformed_modality = TransformedModality(
+ self.output_modality_type,
+ "swinVideoTransformer",
+ modality.modality_id,
+ modality.metadata,
+ )
+
+ transformed_modality.data = list(embeddings.values())
+
+ return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/tfidf.py
b/src/main/python/systemds/scuro/representations/tfidf.py
index 30a6655150..c17527b476 100644
--- a/src/main/python/systemds/scuro/representations/tfidf.py
+++ b/src/main/python/systemds/scuro/representations/tfidf.py
@@ -26,8 +26,10 @@ from systemds.scuro.representations.unimodal import
UnimodalRepresentation
from systemds.scuro.representations.utils import save_embeddings
from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+@register_representation(ModalityType.TEXT)
class TfIdf(UnimodalRepresentation):
def __init__(self, min_df=2, output_file=None):
parameters = {"min_df": [min_df]}
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
b/src/main/python/systemds/scuro/representations/wav2vec.py
similarity index 50%
copy from src/main/python/systemds/scuro/representations/mel_spectrogram.py
copy to src/main/python/systemds/scuro/representations/wav2vec.py
index dfff4f3b7e..bf251b101c 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/wav2vec.py
@@ -18,39 +18,51 @@
# under the License.
#
# -------------------------------------------------------------
-import librosa
import numpy as np
-
+from transformers import Wav2Vec2Processor, Wav2Vec2Model
+import librosa
+import torch
from systemds.scuro.modality.type import ModalityType
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+import warnings
+warnings.filterwarnings("ignore", message="Some weights of")
-class MelSpectrogram(UnimodalRepresentation):
- def __init__(self, n_mels=128, hop_length=512, n_fft=2048):
- parameters = {
- "n_mels": [20, 32, 64, 128],
- "hop_length": [256, 512, 1024, 2048],
- "n_fft": [1024, 2048, 4096],
- }
- super().__init__("MelSpectrogram", ModalityType.TIMESERIES, parameters)
- self.n_mels = n_mels
- self.hop_length = hop_length
- self.n_fft = n_fft
+
+@register_representation(ModalityType.AUDIO)
+class Wav2Vec(UnimodalRepresentation):
+ def __init__(self):
+ super().__init__("Wav2Vec", ModalityType.TIMESERIES, {})
+ self.processor = Wav2Vec2Processor.from_pretrained(
+ "facebook/wav2vec2-base-960h"
+ )
+ self.model = Wav2Vec2Model.from_pretrained(
+ "facebook/wav2vec2-base-960h"
+ ).float()
def transform(self, modality):
transformed_modality = TransformedModality(
self.output_modality_type, self, modality.modality_id,
modality.metadata
)
+
result = []
- max_length = 0
- for sample in modality.data:
- S = librosa.feature.melspectrogram(y=sample, sr=22050)
- S_dB = librosa.power_to_db(S, ref=np.max)
- if S_dB.shape[-1] > max_length:
- max_length = S_dB.shape[-1]
- result.append(S_dB.T)
+ for i, sample in enumerate(modality.data):
+ sr = list(modality.metadata.values())[i]["frequency"]
+ audio_resampled = librosa.resample(sample, orig_sr=sr,
target_sr=16000)
+ input = self.processor(
+ audio_resampled, sampling_rate=16000, return_tensors="pt",
padding=True
+ )
+ input.input_values = input.input_values.float()
+ input.data["input_values"] = input.data["input_values"].float()
+ with torch.no_grad():
+ outputs = self.model(**input)
+ features = outputs.extract_features
+ # TODO: check how to get intermediate representations
+ result.append(torch.flatten(features.mean(dim=1),
1).detach().cpu().numpy())
transformed_modality.data = result
return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/window.py
b/src/main/python/systemds/scuro/representations/window.py
index 264d40ca42..bff63729c7 100644
--- a/src/main/python/systemds/scuro/representations/window.py
+++ b/src/main/python/systemds/scuro/representations/window.py
@@ -23,12 +23,12 @@ import math
from systemds.scuro.modality.type import DataLayout
-# from systemds.scuro.drsearch.operator_registry import
register_context_operator
+from systemds.scuro.drsearch.operator_registry import register_context_operator
from systemds.scuro.representations.aggregate import Aggregation
from systemds.scuro.representations.context import Context
-# @register_context_operator()
+@register_context_operator()
class WindowAggregation(Context):
def __init__(self, window_size=10, aggregation_function="mean"):
parameters = {
@@ -65,6 +65,8 @@ class WindowAggregation(Context):
return windowed_data
def window_aggregate_single_level(self, instance, new_length):
+ if isinstance(instance, str):
+ return instance
num_cols = instance.shape[1] if instance.ndim > 1 else 1
result = np.empty((new_length, num_cols))
for i in range(0, new_length):
diff --git a/src/main/python/systemds/scuro/representations/word2vec.py
b/src/main/python/systemds/scuro/representations/word2vec.py
index 929dbd4415..e1d1669d9b 100644
--- a/src/main/python/systemds/scuro/representations/word2vec.py
+++ b/src/main/python/systemds/scuro/representations/word2vec.py
@@ -26,10 +26,9 @@ from gensim.models import Word2Vec
from gensim.utils import tokenize
from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
import nltk
-nltk.download("punkt_tab")
-
def get_embedding(sentence, model):
vectors = []
@@ -40,6 +39,7 @@ def get_embedding(sentence, model):
return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size)
+@register_representation(ModalityType.TEXT)
class W2V(UnimodalRepresentation):
def __init__(self, vector_size=3, min_count=2, window=2, output_file=None):
parameters = {
@@ -71,5 +71,5 @@ class W2V(UnimodalRepresentation):
if self.output_file is not None:
save_embeddings(np.array(embeddings), self.output_file)
- transformed_modality.data = np.array(embeddings)
+ transformed_modality.data = embeddings
return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/resnet.py
b/src/main/python/systemds/scuro/representations/x3d.py
similarity index 50%
copy from src/main/python/systemds/scuro/representations/resnet.py
copy to src/main/python/systemds/scuro/representations/x3d.py
index 60eed9ea12..bb5d1ec5ed 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/x3d.py
@@ -18,36 +18,36 @@
# under the License.
#
# -------------------------------------------------------------
-
+from systemds.scuro.utils.torch_dataset import CustomDataset
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from typing import Callable, Dict, Tuple, Any
import torch.utils.data
import torch
+from torchvision.models.video import r3d_18, s3d
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
if torch.backends.mps.is_available():
DEVICE = torch.device("mps")
-elif torch.cuda.is_available():
- DEVICE = torch.device("cuda")
+# elif torch.cuda.is_available():
+# DEVICE = torch.device("cuda")
else:
DEVICE = torch.device("cpu")
-class ResNet(UnimodalRepresentation):
- def __init__(self, layer="avgpool", model_name="ResNet18",
output_file=None):
+# @register_representation([ModalityType.VIDEO])
+class X3D(UnimodalRepresentation):
+ def __init__(self, layer="avgpool", model_name="r3d", output_file=None):
self.model_name = model_name
parameters = self._get_parameters()
- super().__init__(
- "ResNet", ModalityType.TIMESERIES, parameters
- ) # TODO: TIMESERIES only for videos - images would be handled as
EMBEDDIGN
+ super().__init__("X3D", ModalityType.TIMESERIES, parameters)
self.output_file = output_file
self.layer_name = layer
- self.model = model_name
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
@@ -59,37 +59,22 @@ class ResNet(UnimodalRepresentation):
self.model.fc = Identity()
@property
- def model(self):
- return self._model
-
- @model.setter
- def model(self, model):
- if model == "ResNet18":
- self._model =
models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(
- DEVICE
- )
- elif model == "ResNet34":
- self._model =
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
- DEVICE
- )
- elif model == "ResNet50":
- self._model =
models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(
- DEVICE
- )
- elif model == "ResNet101":
- self._model =
models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(
- DEVICE
- )
- elif model == "ResNet152":
- self._model =
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(
- DEVICE
- )
+ def model_name(self):
+ return self._model_name
+
+ @model_name.setter
+ def model_name(self, model_name):
+ self._model_name = model_name
+ if model_name == "r3d":
+ self.model = r3d_18(pretrained=True).to(DEVICE)
+ elif model_name == "s3d":
+ self.model =
s3d(weights=models.video.S3D_Weights.DEFAULT).to(DEVICE)
else:
raise NotImplementedError
def _get_parameters(self, high_level=True):
parameters = {"model_name": [], "layer_name": []}
- for m in ["ResNet18", "ResNet34", "ResNet50", "ResNet101",
"ResNet152"]:
+ for m in ["r3d", "s3d"]:
parameters["model_name"].append(m)
if high_level:
@@ -107,20 +92,7 @@ class ResNet(UnimodalRepresentation):
return parameters
def transform(self, modality):
-
- t = transforms.Compose(
- [
- transforms.ToPILImage(),
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- ),
- ]
- )
-
- dataset = ResNetDataset(modality.data, t)
+ dataset = CustomDataset(modality.data)
embeddings = {}
res5c_output = None
@@ -140,59 +112,24 @@ class ResNet(UnimodalRepresentation):
layer.register_forward_hook(get_features(name))
break
- for instance in torch.utils.data.DataLoader(dataset):
- video_id = instance["id"][0]
- frames = instance["data"][0].to(DEVICE)
+ for instance in dataset:
+ video_id = instance["id"]
+ frames = instance["data"].to(DEVICE)
embeddings[video_id] = []
- batch_size = 64
-
- for start_index in range(0, len(frames), batch_size):
- end_index = min(start_index + batch_size, len(frames))
- frame_ids_range = range(start_index, end_index)
- frame_batch = frames[frame_ids_range]
- _ = self.model(frame_batch)
- values = res5c_output
- pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1,
1))
+ frames = frames.unsqueeze(0).permute(0, 2, 1, 3, 4)
+ _ = self.model(frames)
+ values = res5c_output
+ pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 1))
- embeddings[video_id].extend(
- torch.flatten(pooled, 1).detach().cpu().numpy()
- )
+ embeddings[video_id].extend(torch.flatten(pooled,
1).detach().cpu().numpy())
embeddings[video_id] = np.array(embeddings[video_id])
transformed_modality = TransformedModality(
- self.output_modality_type, "resnet", modality.modality_id,
modality.metadata
+ self.output_modality_type, "x3d", modality.modality_id,
modality.metadata
)
transformed_modality.data = list(embeddings.values())
return transformed_modality
-
-
-class ResNetDataset(torch.utils.data.Dataset):
- def __init__(self, data: str, tf: Callable = None):
- self.data = data
- self.tf = tf
-
- def __getitem__(self, index) -> Dict[str, object]:
- data = self.data[index]
- if type(data) is np.ndarray:
- output = torch.empty((1, 3, 224, 224))
- d = torch.tensor(data)
- d = d.repeat(3, 1, 1)
- output[0] = self.tf(d)
- else:
- output = torch.empty((len(data), 3, 224, 224))
-
- for i, d in enumerate(data):
- if data[0].ndim < 3:
- d = torch.tensor(d)
- d = d.repeat(3, 1, 1)
-
- output[i] = self.tf(d)
-
- return {"id": index, "data": output}
-
- def __len__(self) -> int:
- return len(self.data)
diff --git a/src/main/python/systemds/scuro/utils/schema_helpers.py
b/src/main/python/systemds/scuro/utils/schema_helpers.py
index a88e81f716..28af476cca 100644
--- a/src/main/python/systemds/scuro/utils/schema_helpers.py
+++ b/src/main/python/systemds/scuro/utils/schema_helpers.py
@@ -18,7 +18,6 @@
# under the License.
#
# -------------------------------------------------------------
-import math
import numpy as np
diff --git a/src/main/python/systemds/scuro/utils/torch_dataset.py
b/src/main/python/systemds/scuro/utils/torch_dataset.py
new file mode 100644
index 0000000000..a0f3d88b6a
--- /dev/null
+++ b/src/main/python/systemds/scuro/utils/torch_dataset.py
@@ -0,0 +1,63 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from typing import Dict
+
+import numpy as np
+import torch
+import torchvision.transforms as transforms
+
+
+class CustomDataset(torch.utils.data.Dataset):
+ def __init__(self, data):
+ self.data = data
+ self.tf = transforms.Compose(
+ [
+ transforms.ToPILImage(),
+ transforms.Resize(256),
+ transforms.CenterCrop(224),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ ),
+ ]
+ )
+
+ def __getitem__(self, index) -> Dict[str, object]:
+ data = self.data[index]
+ if type(data) is np.ndarray:
+ output = torch.empty((1, 3, 224, 224))
+ d = torch.tensor(data)
+ d = d.repeat(3, 1, 1)
+ output[0] = self.tf(d)
+ else:
+ output = torch.empty((len(data), 3, 224, 224))
+
+ for i, d in enumerate(data):
+ if data[0].ndim < 3:
+ d = torch.tensor(d)
+ d = d.repeat(3, 1, 1)
+
+ output[i] = self.tf(d)
+
+ return {"id": index, "data": output}
+
+ def __len__(self) -> int:
+ return len(self.data)
diff --git a/src/main/python/tests/scuro/data_generator.py
b/src/main/python/tests/scuro/data_generator.py
index 48ff208e43..e31887ff83 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -26,13 +26,11 @@ from scipy.io.wavfile import write
import random
import os
-from systemds.scuro import (
- VideoLoader,
- AudioLoader,
- TextLoader,
- UnimodalModality,
- TransformedModality,
-)
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.text_loader import TextLoader
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.modality.type import ModalityType
diff --git a/src/main/python/tests/scuro/test_dr_search.py
b/src/main/python/tests/scuro/test_dr_search.py
index 0959c246e0..521ff3f468 100644
--- a/src/main/python/tests/scuro/test_dr_search.py
+++ b/src/main/python/tests/scuro/test_dr_search.py
@@ -29,8 +29,8 @@ from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from systemds.scuro.modality.type import ModalityType
-from systemds.scuro.aligner.dr_search import DRSearch
-from systemds.scuro.aligner.task import Task
+from systemds.scuro.drsearch.dr_search import DRSearch
+from systemds.scuro.drsearch.task import Task
from systemds.scuro.models.model import Model
from systemds.scuro.representations.average import Average
from systemds.scuro.representations.bert import Bert
diff --git a/src/main/python/tests/scuro/test_multimodal_fusion.py
b/src/main/python/tests/scuro/test_multimodal_fusion.py
new file mode 100644
index 0000000000..8456279c3d
--- /dev/null
+++ b/src/main/python/tests/scuro/test_multimodal_fusion.py
@@ -0,0 +1,202 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+
+import shutil
+import unittest
+
+import numpy as np
+from sklearn import svm
+from sklearn.metrics import classification_report
+from sklearn.model_selection import train_test_split
+
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.average import Average
+from systemds.scuro.drsearch.fusion_optimizer import FusionOptimizer
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.models.model import Model
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.unimodal_representation_optimizer import (
+ UnimodalRepresentationOptimizer,
+)
+
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.representations.resnet import ResNet
+from tests.scuro.data_generator import setup_data
+
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.dataloader.text_loader import TextLoader
+from systemds.scuro.modality.type import ModalityType
+
+from unittest.mock import patch
+
+
+class TestSVM(Model):
+ def __init__(self):
+ super().__init__("TestSVM")
+
+ def fit(self, X, y, X_test, y_test):
+ if X.ndim > 2:
+ X = X.reshape(X.shape[0], -1)
+ self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+ self.clf = self.clf.fit(X, np.array(y))
+ y_pred = self.clf.predict(X)
+
+ return classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+
+ def test(self, test_X: np.ndarray, test_y: np.ndarray):
+ if test_X.ndim > 2:
+ test_X = test_X.reshape(test_X.shape[0], -1)
+ y_pred = self.clf.predict(np.array(test_X)) # noqa
+
+ return classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+
+
+class TestCNN(Model):
+ def __init__(self):
+ super().__init__("TestCNN")
+
+ def fit(self, X, y, X_test, y_test):
+ if X.ndim > 2:
+ X = X.reshape(X.shape[0], -1)
+ self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+ self.clf = self.clf.fit(X, np.array(y))
+ y_pred = self.clf.predict(X)
+
+ return classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+
+ def test(self, test_X: np.ndarray, test_y: np.ndarray):
+ if test_X.ndim > 2:
+ test_X = test_X.reshape(test_X.shape[0], -1)
+ y_pred = self.clf.predict(np.array(test_X)) # noqa
+
+ return classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+
+
+class TestMultimodalRepresentationOptimizer(unittest.TestCase):
+ test_file_path = None
+ data_generator = None
+ num_instances = 0
+
+ @classmethod
+ def setUpClass(cls):
+ cls.test_file_path = "fusion_optimizer_test_data"
+
+ cls.num_instances = 10
+ cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
+
+ cls.data_generator = setup_data(cls.mods, cls.num_instances,
cls.test_file_path)
+ split = train_test_split(
+ cls.data_generator.indices,
+ cls.data_generator.labels,
+ test_size=0.2,
+ random_state=42,
+ )
+ cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
+ int(i) for i in split[1]
+ ]
+
+ cls.tasks = [
+ Task(
+ "UnimodalRepresentationTask1",
+ TestSVM(),
+ cls.data_generator.labels,
+ cls.train_indizes,
+ cls.val_indizes,
+ ),
+ Task(
+ "UnimodalRepresentationTask2",
+ TestCNN(),
+ cls.data_generator.labels,
+ cls.train_indizes,
+ cls.val_indizes,
+ ),
+ ]
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.test_file_path)
+
+ def test_multimodal_fusion(self):
+ task = Task(
+ "UnimodalRepresentationTask1",
+ TestSVM(),
+ self.data_generator.labels,
+ self.train_indizes,
+ self.val_indizes,
+ )
+ audio_data_loader = AudioLoader(
+ self.data_generator.get_modality_path(ModalityType.AUDIO),
+ self.data_generator.indices,
+ )
+ audio = UnimodalModality(audio_data_loader)
+
+ text_data_loader = TextLoader(
+ self.data_generator.get_modality_path(ModalityType.TEXT),
+ self.data_generator.indices,
+ )
+ text = UnimodalModality(text_data_loader)
+
+ video_data_loader = VideoLoader(
+ self.data_generator.get_modality_path(ModalityType.VIDEO),
+ self.data_generator.indices,
+ )
+ video = UnimodalModality(video_data_loader)
+
+ with patch.object(
+ Registry,
+ "_representations",
+ {
+ ModalityType.TEXT: [W2V],
+ ModalityType.AUDIO: [Spectrogram],
+ ModalityType.TIMESERIES: [ResNet],
+ ModalityType.VIDEO: [ResNet],
+ ModalityType.EMBEDDING: [],
+ },
+ ):
+ registry = Registry()
+ registry._fusion_operators = [Average, Concatenation]
+ unimodal_optimizer = UnimodalRepresentationOptimizer(
+ [text, audio, video], [task], max_chain_depth=2
+ )
+ unimodal_optimizer.optimize()
+
+ multimodal_optimizer = FusionOptimizer(
+ [audio, text, video],
+ task,
+ unimodal_optimizer.optimization_results,
+ unimodal_optimizer.cache,
+ 2,
+ 2,
+ debug=False,
+ )
+ multimodal_optimizer.optimize()
diff --git a/src/main/python/tests/scuro/test_multimodal_join.py
b/src/main/python/tests/scuro/test_multimodal_join.py
index 8388829f30..a5e3a7caf9 100644
--- a/src/main/python/tests/scuro/test_multimodal_join.py
+++ b/src/main/python/tests/scuro/test_multimodal_join.py
@@ -24,8 +24,6 @@ import shutil
import unittest
from systemds.scuro.modality.joined import JoinCondition
-from systemds.scuro.representations.aggregate import Aggregation
-from systemds.scuro.representations.window import WindowAggregation
from systemds.scuro.modality.unimodal_modality import UnimodalModality
from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
from systemds.scuro.representations.resnet import ResNet
diff --git a/src/main/python/tests/scuro/test_operator_registry.py
b/src/main/python/tests/scuro/test_operator_registry.py
new file mode 100644
index 0000000000..aaecde2991
--- /dev/null
+++ b/src/main/python/tests/scuro/test_operator_registry.py
@@ -0,0 +1,87 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+
+from systemds.scuro.representations.mfcc import MFCC
+from systemds.scuro.representations.wav2vec import Wav2Vec
+from systemds.scuro.representations.window import WindowAggregation
+from systemds.scuro.representations.bow import BoW
+from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.representations.tfidf import TfIdf
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.lstm import LSTM
+from systemds.scuro.representations.max import RowMax
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.multiplication import Multiplication
+from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.sum import Sum
+
+
+class TestOperatorRegistry(unittest.TestCase):
+ def test_audio_representations_in_registry(self):
+ registry = Registry()
+ for representation in [Spectrogram, MelSpectrogram, Wav2Vec, MFCC]:
+ assert representation in registry.get_representations(
+ ModalityType.AUDIO
+ ), f"{representation} not in registry"
+
+ def test_video_representations_in_registry(self):
+ registry = Registry()
+ assert registry.get_representations(ModalityType.VIDEO) == [ResNet]
+
+ def test_timeseries_representations_in_registry(self):
+ registry = Registry()
+ assert registry.get_representations(ModalityType.TIMESERIES) ==
[ResNet]
+
+ def test_text_representations_in_registry(self):
+ registry = Registry()
+ for representation in [BoW, TfIdf, W2V, Bert]:
+ assert representation in registry.get_representations(
+ ModalityType.TEXT
+ ), f"{representation} not in registry"
+
+ def test_context_operator_in_registry(self):
+ registry = Registry()
+ assert registry.get_context_operators() == [WindowAggregation]
+
+ # def test_fusion_operator_in_registry(self):
+ # registry = Registry()
+ # for fusion_operator in [
+ # # RowMax,
+ # Sum,
+ # Average,
+ # Concatenation,
+ # LSTM,
+ # Multiplication,
+ # ]:
+ # assert (
+ # fusion_operator in registry.get_fusion_operators()
+ # ), f"{fusion_operator} not in registry"
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
new file mode 100644
index 0000000000..bfc52f0103
--- /dev/null
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -0,0 +1,203 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+
+import shutil
+import unittest
+
+import numpy as np
+from sklearn import svm
+from sklearn.metrics import classification_report
+from sklearn.model_selection import train_test_split
+
+from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.models.model import Model
+from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.unimodal_representation_optimizer import (
+ UnimodalRepresentationOptimizer,
+)
+
+from systemds.scuro.representations.spectrogram import Spectrogram
+from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.representations.resnet import ResNet
+from tests.scuro.data_generator import setup_data
+
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.dataloader.text_loader import TextLoader
+from systemds.scuro.modality.type import ModalityType
+
+
+class TestSVM(Model):
+ def __init__(self):
+ super().__init__("TestSVM")
+
+ def fit(self, X, y, X_test, y_test):
+ if X.ndim > 2:
+ X = X.reshape(X.shape[0], -1)
+ self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+ self.clf = self.clf.fit(X, np.array(y))
+ y_pred = self.clf.predict(X)
+
+ return classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+
+ def test(self, test_X: np.ndarray, test_y: np.ndarray):
+ if test_X.ndim > 2:
+ test_X = test_X.reshape(test_X.shape[0], -1)
+ y_pred = self.clf.predict(np.array(test_X)) # noqa
+
+ return classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+
+
+class TestCNN(Model):
+ def __init__(self):
+ super().__init__("TestCNN")
+
+ def fit(self, X, y, X_test, y_test):
+ if X.ndim > 2:
+ X = X.reshape(X.shape[0], -1)
+ self.clf = svm.SVC(C=1, gamma="scale", kernel="rbf", verbose=False)
+ self.clf = self.clf.fit(X, np.array(y))
+ y_pred = self.clf.predict(X)
+
+ return classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+
+ def test(self, test_X: np.ndarray, test_y: np.ndarray):
+ if test_X.ndim > 2:
+ test_X = test_X.reshape(test_X.shape[0], -1)
+ y_pred = self.clf.predict(np.array(test_X)) # noqa
+
+ return classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+
+
+from unittest.mock import patch
+
+
+class TestUnimodalRepresentationOptimizer(unittest.TestCase):
+ test_file_path = None
+ data_generator = None
+ num_instances = 0
+
+ @classmethod
+ def setUpClass(cls):
+ cls.test_file_path = "unimodal_optimizer_test_data"
+
+ cls.num_instances = 10
+ cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
+
+ cls.data_generator = setup_data(cls.mods, cls.num_instances,
cls.test_file_path)
+ split = train_test_split(
+ cls.data_generator.indices,
+ cls.data_generator.labels,
+ test_size=0.2,
+ random_state=42,
+ )
+ cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
+ int(i) for i in split[1]
+ ]
+
+ cls.tasks = [
+ Task(
+ "UnimodalRepresentationTask1",
+ TestSVM(),
+ cls.data_generator.labels,
+ cls.train_indizes,
+ cls.val_indizes,
+ ),
+ Task(
+ "UnimodalRepresentationTask2",
+ TestCNN(),
+ cls.data_generator.labels,
+ cls.train_indizes,
+ cls.val_indizes,
+ ),
+ ]
+
+ @classmethod
+ def tearDownClass(cls):
+ shutil.rmtree(cls.test_file_path)
+
+ def test_unimodal_optimizer_for_audio_modality(self):
+ audio_data_loader = AudioLoader(
+ self.data_generator.get_modality_path(ModalityType.AUDIO),
+ self.data_generator.indices,
+ )
+ audio = UnimodalModality(audio_data_loader)
+
+ self.optimize_unimodal_representation_for_modality(audio)
+
+ def test_unimodal_optimizer_for_text_modality(self):
+ text_data_loader = TextLoader(
+ self.data_generator.get_modality_path(ModalityType.TEXT),
+ self.data_generator.indices,
+ )
+ text = UnimodalModality(text_data_loader)
+ self.optimize_unimodal_representation_for_modality(text)
+
+ def test_unimodal_optimizer_for_video_modality(self):
+ video_data_loader = VideoLoader(
+ self.data_generator.get_modality_path(ModalityType.VIDEO),
+ self.data_generator.indices,
+ )
+ video = UnimodalModality(video_data_loader)
+ self.optimize_unimodal_representation_for_modality(video)
+
+ def optimize_unimodal_representation_for_modality(self, modality):
+ with patch.object(
+ Registry,
+ "_representations",
+ {
+ ModalityType.TEXT: [W2V],
+ ModalityType.AUDIO: [Spectrogram],
+ ModalityType.TIMESERIES: [ResNet],
+ ModalityType.VIDEO: [ResNet],
+ ModalityType.EMBEDDING: [],
+ },
+ ):
+ registry = Registry()
+
+ unimodal_optimizer = UnimodalRepresentationOptimizer(
+ [modality], self.tasks, max_chain_depth=2
+ )
+ unimodal_optimizer.optimize()
+
+ assert (
+ list(unimodal_optimizer.optimization_results.keys())[0]
+ == modality.modality_id
+ )
+ assert
len(list(unimodal_optimizer.optimization_results.values())[0]) == 2
+ assert (
+ len(
+ unimodal_optimizer.get_k_best_results(modality, 1,
self.tasks[0])[
+ 0
+ ].operator_chain
+ )
+ >= 1
+ )