This is an automated email from the ASF dual-hosted git repository.
cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c40f95ecf0 [SYSTEMDS-3937] Add score ranking function to Scuro
c40f95ecf0 is described below
commit c40f95ecf02afdf12623b6f6f09d08cd3ff01386
Author: Christina Dionysio <[email protected]>
AuthorDate: Fri Nov 28 12:32:28 2025 +0100
[SYSTEMDS-3937] Add score ranking function to Scuro
This patch adds a new functionality to rank representations via different
metrics (runtime, performance metric) in Scuro.
---
.../systemds/scuro/dataloader/json_loader.py | 1 +
.../scuro/drsearch/hyperparameter_tuner.py | 8 +-
.../scuro/drsearch/multimodal_optimizer.py | 20 ++--
src/main/python/systemds/scuro/drsearch/ranking.py | 90 +++++++++++++++++
src/main/python/systemds/scuro/drsearch/task.py | 58 +++++++++--
.../systemds/scuro/drsearch/unimodal_optimizer.py | 106 +++++++++++++--------
.../representations/timeseries_representations.py | 4 +-
.../python/systemds/scuro/utils/torch_dataset.py | 6 +-
src/main/python/tests/scuro/test_hp_tuner.py | 34 ++++---
.../python/tests/scuro/test_multimodal_fusion.py | 52 ++++++----
.../python/tests/scuro/test_unimodal_optimizer.py | 42 ++++----
11 files changed, 311 insertions(+), 110 deletions(-)
diff --git a/src/main/python/systemds/scuro/dataloader/json_loader.py
b/src/main/python/systemds/scuro/dataloader/json_loader.py
index 89ba6b43d5..ed15448597 100644
--- a/src/main/python/systemds/scuro/dataloader/json_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/json_loader.py
@@ -55,5 +55,6 @@ class JSONLoader(BaseLoader):
except:
text = json_file[self.field]
+ text = " ".join(text)
self.data.append(text)
self.metadata[idx] =
self.modality_type.create_metadata(len(text), text)
diff --git a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
index 15136ac28f..8c5e4c24e1 100644
--- a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
+++ b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
@@ -174,9 +174,13 @@ class HyperparameterTuner:
all_results.append(result)
if self.maximize_metric:
- best_params, best_score = max(all_results, key=lambda x: x[1])
+ best_params, best_score = max(
+ all_results, key=lambda x: x[1].scores[self.scoring_metric]
+ )
else:
- best_params, best_score = min(all_results, key=lambda x: x[1])
+ best_params, best_score = min(
+ all_results, key=lambda x: x[1].scores[self.scoring_metric]
+ )
tuning_time = time.time() - start_time
diff --git a/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
index bb44703d5c..9d0088a976 100644
--- a/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
@@ -24,7 +24,7 @@ import itertools
import threading
from dataclasses import dataclass
from typing import List, Dict, Any, Generator
-from systemds.scuro.drsearch.task import Task
+from systemds.scuro.drsearch.task import Task, PerformanceMeasure
from systemds.scuro.drsearch.representation_dag import (
RepresentationDag,
RepresentationDAGBuilder,
@@ -87,7 +87,8 @@ def _evaluate_dag_worker(dag_pickle, task_pickle,
modalities_pickle, debug=False
val_score=scores[1],
runtime=total_time,
task_name=task_copy.model.name,
- evaluation_time=eval_time,
+ task_time=eval_time,
+ representation_time=total_time - eval_time,
)
except Exception:
if debug:
@@ -390,8 +391,9 @@ class MultimodalOptimizer:
train_score=scores[0],
val_score=scores[1],
runtime=total_time,
+ representation_time=total_time - eval_time,
task_name=task_copy.model.name,
- evaluation_time=eval_time,
+ task_time=eval_time,
)
except Exception as e:
@@ -475,8 +477,10 @@ class MultimodalOptimizer:
@dataclass
class OptimizationResult:
dag: RepresentationDag
- train_score: float
- val_score: float
- runtime: float
- task_name: str
- evaluation_time: float = 0.0
+ train_score: PerformanceMeasure = None
+ val_score: PerformanceMeasure = None
+ runtime: float = 0.0
+ task_time: float = 0.0
+ representation_time: float = 0.0
+ task_name: str = ""
+ tradeoff_score: float = 0.0
diff --git a/src/main/python/systemds/scuro/drsearch/ranking.py
b/src/main/python/systemds/scuro/drsearch/ranking.py
new file mode 100644
index 0000000000..831a059eb8
--- /dev/null
+++ b/src/main/python/systemds/scuro/drsearch/ranking.py
@@ -0,0 +1,90 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+from dataclasses import replace
+from typing import Callable, Iterable, List, Optional
+
+
+def rank_by_tradeoff(
+ entries: Iterable,
+ *,
+ weights=(0.7, 0.3),
+ performance_metric_name: str = "accuracy",
+ runtime_accessor: Optional[Callable[[object], float]] = None,
+ cache_scores: bool = True,
+ score_attr: str = "tradeoff_score",
+) -> List:
+ entries = list(entries)
+ if not entries:
+ return []
+
+ performance_score_accessor = lambda entry: getattr(entry, "val_score")[
+ performance_metric_name
+ ]
+ if runtime_accessor is None:
+
+ def runtime_accessor(entry):
+ if hasattr(entry, "runtime"):
+ return getattr(entry, "runtime")
+ rep = getattr(entry, "representation_time", 0.0)
+ task = getattr(entry, "task_time", 0.0)
+ return rep + task
+
+ performance = [float(performance_score_accessor(e)) for e in entries]
+ runtimes = [float(runtime_accessor(e)) for e in entries]
+
+ perf_min, perf_max = min(performance), max(performance)
+ run_min, run_max = min(runtimes), max(runtimes)
+
+ def safe_normalize(values, vmin, vmax):
+ if vmax - vmin == 0.0:
+ return [1.0] * len(values)
+ return [(v - vmin) / (vmax - vmin) for v in values]
+
+ norm_perf = safe_normalize(performance, perf_min, perf_max)
+ norm_run = safe_normalize(runtimes, run_min, run_max)
+ norm_run = [1.0 - r for r in norm_run]
+
+ acc_w, run_w = weights
+ total_w = (acc_w or 0.0) + (run_w or 0.0)
+ if total_w == 0.0:
+ acc_w = 1.0
+ run_w = 0.0
+ else:
+ acc_w /= total_w
+ run_w /= total_w
+
+ scores = [acc_w * a + run_w * r for a, r in zip(norm_perf, norm_run)]
+
+ if cache_scores:
+ for entry, score in zip(entries, scores):
+ if hasattr(entry, score_attr):
+ try:
+ new_entry = replace(entry, **{score_attr: score})
+ entries[entries.index(entry)] = new_entry
+ except TypeError:
+ setattr(entry, score_attr, score)
+ else:
+ setattr(entry, score_attr, score)
+
+ return sorted(
+ entries, key=lambda entry: getattr(entry, score_attr, 0.0),
reverse=True
+ )
diff --git a/src/main/python/systemds/scuro/drsearch/task.py
b/src/main/python/systemds/scuro/drsearch/task.py
index 0dedc7ede3..bfd1f16ab3 100644
--- a/src/main/python/systemds/scuro/drsearch/task.py
+++ b/src/main/python/systemds/scuro/drsearch/task.py
@@ -28,6 +28,37 @@ import numpy as np
from sklearn.model_selection import KFold
+class PerformanceMeasure:
+ def __init__(self, name, metrics, higher_is_better=True):
+ self.average_scores = None
+ self.name = name
+ self.metrics = metrics
+ self.higher_is_better = higher_is_better
+ self.scores = {}
+
+ if isinstance(metrics, list):
+ for metric in metrics:
+ self.scores[metric] = []
+ else:
+ self.scores[metrics] = []
+
+ def add_scores(self, scores):
+ if isinstance(self.metrics, list):
+ for metric in self.metrics:
+ self.scores[metric].append(scores[metric])
+ else:
+ self.scores[self.metrics].append(scores[self.metrics])
+
+ def compute_averages(self):
+ self.average_scores = {}
+ if isinstance(self.metrics, list):
+ for metric in self.metrics:
+ self.average_scores[metric] = np.mean(self.scores[metric])
+ else:
+ self.average_scores[self.metrics] =
np.mean(self.scores[self.metrics])
+ return self
+
+
class Task:
def __init__(
self,
@@ -38,6 +69,7 @@ class Task:
val_indices: List,
kfold=5,
measure_performance=True,
+ performance_measures="accuracy",
):
"""
Parent class for the prediction task that is performed on top of the
aligned representation
@@ -59,8 +91,9 @@ class Task:
self.inference_time = []
self.training_time = []
self.expected_dim = 1
- self.train_scores = []
- self.val_scores = []
+ self.performance_measures = performance_measures
+ self.train_scores = PerformanceMeasure("train", performance_measures)
+ self.val_scores = PerformanceMeasure("val", performance_measures)
def create_model(self):
"""
@@ -74,8 +107,12 @@ class Task:
def get_train_test_split(self, data):
X_train = [data[i] for i in self.train_indices]
y_train = [self.labels[i] for i in self.train_indices]
- X_test = [data[i] for i in self.val_indices]
- y_test = [self.labels[i] for i in self.val_indices]
+ if self.val_indices is None:
+ X_test = None
+ y_test = None
+ else:
+ X_test = [data[i] for i in self.val_indices]
+ y_test = [self.labels[i] for i in self.val_indices]
return X_train, y_train, X_test, y_test
@@ -101,25 +138,28 @@ class Task:
self._run_fold(model, train_X, train_y, test_X, test_y)
fold += 1
- return [np.mean(self.train_scores), np.mean(self.val_scores)]
+ return [
+ self.train_scores.compute_averages(),
+ self.val_scores.compute_averages(),
+ ]
def _reset_params(self):
self.inference_time = []
self.training_time = []
- self.train_scores = []
- self.val_scores = []
+ self.train_scores = PerformanceMeasure("train",
self.performance_measures)
+ self.val_scores = PerformanceMeasure("val", self.performance_measures)
def _run_fold(self, model, train_X, train_y, test_X, test_y):
train_start = time.time()
train_score = model.fit(train_X, train_y, test_X, test_y)
train_end = time.time()
self.training_time.append(train_end - train_start)
- self.train_scores.append(train_score)
+ self.train_scores.add_scores(train_score[0])
test_start = time.time()
test_score = model.test(np.array(test_X), test_y)
test_end = time.time()
self.inference_time.append(test_end - test_start)
- self.val_scores.append(test_score)
+ self.val_scores.add_scores(test_score[0])
def create_representation_and_run(
self,
diff --git a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
index 91d72dd35a..7735986c2e 100644
--- a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
@@ -27,6 +27,8 @@ from typing import List, Any
from functools import lru_cache
from systemds.scuro import ModalityType
+from systemds.scuro.drsearch.ranking import rank_by_tradeoff
+from systemds.scuro.drsearch.task import PerformanceMeasure
from systemds.scuro.representations.fusion import Fusion
from systemds.scuro.representations.concatenation import Concatenation
from systemds.scuro.representations.hadamard import Hadamard
@@ -47,10 +49,14 @@ from systemds.scuro.drsearch.representation_dag_visualizer
import visualize_dag
class UnimodalOptimizer:
- def __init__(self, modalities, tasks, debug=True):
+ def __init__(
+ self, modalities, tasks, debug=True, save_all_results=False,
result_path=None
+ ):
self.modalities = modalities
self.tasks = tasks
self.run = None
+ self.save_all_results = save_all_results
+ self.result_path = result_path
self.builders = {
modality.modality_id: RepresentationDAGBuilder() for modality in
modalities
@@ -92,6 +98,7 @@ class UnimodalOptimizer:
timestr = time.strftime("%Y%m%d-%H%M%S")
file_name = "unimodal_optimizer" + timestr + ".pkl"
+ file_name = f"{self.result_path}/{file_name}"
with open(file_name, "wb") as f:
pickle.dump(self.operator_performance.results, f)
@@ -117,8 +124,8 @@ class UnimodalOptimizer:
def optimize_parallel(self, n_workers=None):
if n_workers is None:
n_workers = min(len(self.modalities), mp.cpu_count())
-
- with ProcessPoolExecutor(max_workers=n_workers) as executor:
+ ctx = mp.get_context("spawn")
+ with ProcessPoolExecutor(max_workers=n_workers, mp_context=ctx) as
executor:
future_to_modality = {
executor.submit(self._process_modality, modality, True):
modality
for modality in self.modalities
@@ -126,17 +133,34 @@ class UnimodalOptimizer:
for future in as_completed(future_to_modality):
modality = future_to_modality[future]
- results = future.result()
- self._merge_results(results)
+ try:
+ results = future.result()
+ self._merge_results(results)
+ except Exception as e:
+ print(f"Error processing modality {modality.modality_id}:
{e}")
+ import traceback
+
+ traceback.print_exc()
+ continue
def optimize(self):
"""Optimize representations for each modality"""
for modality in self.modalities:
+ # try:
local_result = self._process_modality(modality, False)
+ if self.save_all_results:
+
self.store_results(f"{modality.modality_id}_unimodal_results.pkl")
+ # except Exception as e:
+ # print(f"Error processing modality {modality.modality_id}:
{e}")
+ # if self.save_all_results:
+ #
self.store_results(f"{modality.modality_id}_unimodal_results.pkl")
+ # continue
def _process_modality(self, modality, parallel):
if parallel:
- local_results = UnimodalResults([modality], self.tasks,
debug=False)
+ local_results = UnimodalResults(
+ [modality], self.tasks, debug=False, store_cache=False
+ )
else:
local_results = self.operator_performance
@@ -162,6 +186,13 @@ class UnimodalOptimizer:
if self.debug:
visualize_dag(dag)
+ if self.save_all_results:
+ timestr = time.strftime("%Y%m%d-%H%M%S")
+ file_name =
f"{modality.modality_id}_unimodal_results_{timestr}.pkl"
+
+ with open(file_name, "wb") as f:
+ pickle.dump(local_results.results, f)
+
return local_results
def _get_representation_chain(
@@ -185,10 +216,10 @@ class UnimodalOptimizer:
local_results.results[modality_id][task_name]
)
- for modality in self.modalities:
- for task_name in local_results.cache[modality]:
- for key, value in
local_results.cache[modality][task_name].items():
- self.operator_performance.cache[modality][task_name][key]
= value
+ # for modality in self.modalities:
+ # for task_name in local_results.cache[modality]:
+ # for key, value in
local_results.cache[modality][task_name].items():
+ #
self.operator_performance.cache[modality][task_name][key] = value
def _evaluate_local(self, modality, local_results, dag, combination=None):
if self._tasks_require_same_dims:
@@ -343,12 +374,13 @@ class UnimodalOptimizer:
class UnimodalResults:
- def __init__(self, modalities, tasks, debug=False, run=None):
+ def __init__(self, modalities, tasks, debug=False, store_cache=True):
self.modality_ids = [modality.modality_id for modality in modalities]
self.task_names = [task.model.name for task in tasks]
self.results = {}
self.debug = debug
self.cache = {}
+ self.store_cache = store_cache
for modality in self.modality_ids:
self.results[modality] = {task_name: [] for task_name in
self.task_names}
@@ -356,8 +388,8 @@ class UnimodalResults:
def add_result(self, scores, modality, task_name, task_time, combination,
dag):
entry = ResultEntry(
- train_score=scores[0],
- val_score=scores[1],
+ train_score=scores[0].average_scores,
+ val_score=scores[1].average_scores,
representation_time=modality.transform_time,
task_time=task_time,
combination=combination.name if combination else "",
@@ -365,13 +397,13 @@ class UnimodalResults:
)
self.results[modality.modality_id][task_name].append(entry)
-
- cache_key = (
- id(dag),
- scores[1],
- modality.transform_time,
- )
- self.cache[modality.modality_id][task_name][cache_key] = modality
+ if self.store_cache:
+ cache_key = (
+ id(dag),
+ scores[1],
+ modality.transform_time,
+ )
+ self.cache[modality.modality_id][task_name][cache_key] = modality
if self.debug:
print(f"{modality.modality_id}_{task_name}: {entry}")
@@ -388,30 +420,27 @@ class UnimodalResults:
:param modality: modality to get the best results for
:param k: number of best results
"""
+
task_results = self.results[modality.modality_id][task.model.name]
- results = sorted(task_results, key=lambda x: x.val_score,
reverse=True)[:k]
+ results = rank_by_tradeoff(task_results)[:k]
sorted_indices = sorted(
range(len(task_results)),
- key=lambda x: task_results[x].val_score,
+ key=lambda x: task_results[x].tradeoff_score,
reverse=True,
)[:k]
- if not self.cache:
+
+ task_cache = self.cache.get(modality.modality_id,
{}).get(task.model.name, None)
+ if not task_cache:
cache = [
list(task_results[i].dag.execute([modality]).values())[-1]
for i in sorted_indices
]
- elif isinstance(self.cache[modality.modality_id][task.model.name],
list):
- cache = self.cache[modality.modality_id][
- task.model.name
- ] # used for precomputed cache
+ elif isinstance(task_cache, list):
+ cache = task_cache
else:
- cache_items = (
- list(self.cache[modality.modality_id][task.model.name].items())
- if self.cache[modality.modality_id][task.model.name]
- else []
- )
+ cache_items = list(task_cache.items()) if task_cache else []
cache = [cache_items[i][1] for i in sorted_indices if i <
len(cache_items)]
return results, cache
@@ -419,9 +448,10 @@ class UnimodalResults:
@dataclass(frozen=True)
class ResultEntry:
- val_score: float
- train_score: float
- representation_time: float
- task_time: float
- combination: str
- dag: RepresentationDag
+ val_score: PerformanceMeasure = None
+ train_score: PerformanceMeasure = None
+ representation_time: float = 0.0
+ task_time: float = 0.0
+ combination: str = ""
+ dag: RepresentationDag = None
+ tradeoff_score: float = 0.0
diff --git
a/src/main/python/systemds/scuro/representations/timeseries_representations.py
b/src/main/python/systemds/scuro/representations/timeseries_representations.py
index 03464df7d4..d1dee67f86 100644
---
a/src/main/python/systemds/scuro/representations/timeseries_representations.py
+++
b/src/main/python/systemds/scuro/representations/timeseries_representations.py
@@ -46,7 +46,9 @@ class TimeSeriesRepresentation(UnimodalRepresentation):
feature = self.compute_feature(signal)
result.append(feature)
- transformed_modality.data = np.vstack(result)
+ transformed_modality.data = np.vstack(result).astype(
+
modality.metadata[list(modality.metadata.keys())[0]]["data_layout"]["type"]
+ )
return transformed_modality
diff --git a/src/main/python/systemds/scuro/utils/torch_dataset.py
b/src/main/python/systemds/scuro/utils/torch_dataset.py
index 2a7ec1f963..19875f8802 100644
--- a/src/main/python/systemds/scuro/utils/torch_dataset.py
+++ b/src/main/python/systemds/scuro/utils/torch_dataset.py
@@ -32,12 +32,12 @@ class CustomDataset(torch.utils.data.Dataset):
self.device = device
self.size = size
if size is None:
- self.size = (256, 224)
+ self.size = (224, 224)
tf_default = transforms.Compose(
[
transforms.ToPILImage(),
- transforms.Resize(self.size[0]),
+ transforms.Resize(256),
transforms.CenterCrop(self.size[1]),
transforms.ToTensor(),
transforms.ConvertImageDtype(dtype=self.data_type),
@@ -55,7 +55,7 @@ class CustomDataset(torch.utils.data.Dataset):
def __getitem__(self, index) -> Dict[str, object]:
data = self.data[index]
output = torch.empty(
- (len(data), 3, self.size[1], self.size[1]),
+ (len(data), 3, self.size[1], self.size[0]),
dtype=self.data_type,
device=self.device,
)
diff --git a/src/main/python/tests/scuro/test_hp_tuner.py
b/src/main/python/tests/scuro/test_hp_tuner.py
index 73aab4493d..802f737b0a 100644
--- a/src/main/python/tests/scuro/test_hp_tuner.py
+++ b/src/main/python/tests/scuro/test_hp_tuner.py
@@ -62,18 +62,22 @@ class TestSVM(Model):
self.clf = self.clf.fit(X, np.array(y))
y_pred = self.clf.predict(X)
- return classification_report(
- y, y_pred, output_dict=True, digits=3, zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+ }, 0
def test(self, test_X: np.ndarray, test_y: np.ndarray):
if test_X.ndim > 2:
test_X = test_X.reshape(test_X.shape[0], -1)
- y_pred = self.clf.predict(np.array(test_X)) # noqa
+ y_pred = self.clf.predict(np.array(test_X)) # noqa]
- return classification_report(
- np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+ }, 0
class TestSVM2(Model):
@@ -87,18 +91,22 @@ class TestSVM2(Model):
self.clf = self.clf.fit(X, np.array(y))
y_pred = self.clf.predict(X)
- return classification_report(
- y, y_pred, output_dict=True, digits=3, zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+ }, 0
def test(self, test_X: np.ndarray, test_y: np.ndarray):
if test_X.ndim > 2:
test_X = test_X.reshape(test_X.shape[0], -1)
y_pred = self.clf.predict(np.array(test_X)) # noqa
- return classification_report(
- np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+ }, 0
from unittest.mock import patch
diff --git a/src/main/python/tests/scuro/test_multimodal_fusion.py
b/src/main/python/tests/scuro/test_multimodal_fusion.py
index 0f9c08d216..395a9cd862 100644
--- a/src/main/python/tests/scuro/test_multimodal_fusion.py
+++ b/src/main/python/tests/scuro/test_multimodal_fusion.py
@@ -60,18 +60,22 @@ class TestSVM(Model):
self.clf = self.clf.fit(X, np.array(y))
y_pred = self.clf.predict(X)
- return classification_report(
- y, y_pred, output_dict=True, digits=3, zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+ }, 0
def test(self, test_X: np.ndarray, test_y: np.ndarray):
if test_X.ndim > 2:
test_X = test_X.reshape(test_X.shape[0], -1)
y_pred = self.clf.predict(np.array(test_X)) # noqa
- return classification_report(
- np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+ }, 0
class TestCNN(Model):
@@ -85,18 +89,22 @@ class TestCNN(Model):
self.clf = self.clf.fit(X, np.array(y))
y_pred = self.clf.predict(X)
- return classification_report(
- y, y_pred, output_dict=True, digits=3, zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+ }, 0
def test(self, test_X: np.ndarray, test_y: np.ndarray):
if test_X.ndim > 2:
test_X = test_X.reshape(test_X.shape[0], -1)
y_pred = self.clf.predict(np.array(test_X)) # noqa
- return classification_report(
- np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+ }, 0
class TestMultimodalRepresentationOptimizer(unittest.TestCase):
@@ -178,10 +186,15 @@ class
TestMultimodalRepresentationOptimizer(unittest.TestCase):
fusion_results = m_o.optimize()
best_results = sorted(
- fusion_results[task.model.name], key=lambda x: x.val_score,
reverse=True
+ fusion_results[task.model.name],
+ key=lambda x: getattr(x,
"val_score").average_scores["accuracy"],
+ reverse=True,
)[:2]
- assert best_results[0].val_score >= best_results[1].val_score
+ assert (
+ best_results[0].val_score.average_scores["accuracy"]
+ >= best_results[1].val_score.average_scores["accuracy"]
+ )
def test_parallel_multimodal_fusion(self):
task = Task(
@@ -238,18 +251,23 @@ class
TestMultimodalRepresentationOptimizer(unittest.TestCase):
parallel_fusion_results = m_o.optimize_parallel(max_workers=4,
batch_size=8)
best_results = sorted(
- fusion_results[task.model.name], key=lambda x: x.val_score,
reverse=True
+ fusion_results[task.model.name],
+ key=lambda x: getattr(x,
"val_score").average_scores["accuracy"],
+ reverse=True,
)
best_results_parallel = sorted(
parallel_fusion_results[task.model.name],
- key=lambda x: x.val_score,
+ key=lambda x: getattr(x,
"val_score").average_scores["accuracy"],
reverse=True,
)
assert len(best_results) == len(best_results_parallel)
for i in range(len(best_results)):
- assert best_results[i].val_score ==
best_results_parallel[i].val_score
+ assert (
+ best_results[i].val_score.average_scores["accuracy"]
+ ==
best_results_parallel[i].val_score.average_scores["accuracy"]
+ )
if __name__ == "__main__":
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
index 30ae725737..252dfe997a 100644
--- a/src/main/python/tests/scuro/test_unimodal_optimizer.py
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -29,8 +29,6 @@ from sklearn.model_selection import train_test_split
from systemds.scuro.representations.timeseries_representations import (
Mean,
- Max,
- Min,
ACF,
)
from systemds.scuro.drsearch.operator_registry import Registry
@@ -41,8 +39,6 @@ from systemds.scuro.drsearch.unimodal_optimizer import
UnimodalOptimizer
from systemds.scuro.representations.spectrogram import Spectrogram
from systemds.scuro.representations.covarep_audio_features import (
ZeroCrossing,
- Spectral,
- Pitch,
)
from systemds.scuro.representations.word2vec import W2V
from systemds.scuro.representations.bow import BoW
@@ -64,18 +60,22 @@ class TestSVM(Model):
self.clf = self.clf.fit(X, np.array(y))
y_pred = self.clf.predict(X)
- return classification_report(
- y, y_pred, output_dict=True, digits=3, zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+ }, 0
def test(self, test_X: np.ndarray, test_y: np.ndarray):
if test_X.ndim > 2:
test_X = test_X.reshape(test_X.shape[0], -1)
y_pred = self.clf.predict(np.array(test_X)) # noqa
- return classification_report(
- np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+ }, 0
class TestCNN(Model):
@@ -89,18 +89,22 @@ class TestCNN(Model):
self.clf = self.clf.fit(X, np.array(y))
y_pred = self.clf.predict(X)
- return classification_report(
- y, y_pred, output_dict=True, digits=3, zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ y, y_pred, output_dict=True, digits=3, zero_division=1
+ )["accuracy"]
+ }, 0
def test(self, test_X: np.ndarray, test_y: np.ndarray):
if test_X.ndim > 2:
test_X = test_X.reshape(test_X.shape[0], -1)
y_pred = self.clf.predict(np.array(test_X)) # noqa
- return classification_report(
- np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
- )["accuracy"]
+ return {
+ "accuracy": classification_report(
+ np.array(test_y), y_pred, output_dict=True, digits=3,
zero_division=1
+ )["accuracy"]
+ }, 0
from unittest.mock import patch
@@ -197,8 +201,8 @@ class
TestUnimodalRepresentationOptimizer(unittest.TestCase):
"_representations",
{
ModalityType.TEXT: [W2V, BoW],
- ModalityType.AUDIO: [Spectrogram, ZeroCrossing, Spectral,
Pitch],
- ModalityType.TIMESERIES: [Mean, Max, Min, ACF],
+ ModalityType.AUDIO: [Spectrogram, ZeroCrossing],
+ ModalityType.TIMESERIES: [Mean, ACF],
ModalityType.VIDEO: [ResNet],
ModalityType.EMBEDDING: [],
},
@@ -206,7 +210,7 @@ class
TestUnimodalRepresentationOptimizer(unittest.TestCase):
registry = Registry()
unimodal_optimizer = UnimodalOptimizer([modality], self.tasks,
False)
- unimodal_optimizer.optimize()
+ unimodal_optimizer.optimize_parallel()
assert (
unimodal_optimizer.operator_performance.modality_ids[0]