This is an automated email from the ASF dual-hosted git repository.
cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c236a25078 [SYSTEMDS-3835] Improve memory efficiency of text context
operations
c236a25078 is described below
commit c236a25078947137a4882a5c8868455e51b9894d
Author: Christina Dionysio <[email protected]>
AuthorDate: Wed Jan 28 12:35:43 2026 +0100
[SYSTEMDS-3835] Improve memory efficiency of text context operations
This patch uses a index based method to split text into multiple chunks and
stores a list of those start and end indices for each chunk in the data in the
instance metadata.
---
.../systemds/scuro/drsearch/operator_registry.py | 2 +
.../systemds/scuro/drsearch/unimodal_optimizer.py | 107 ++++++++++++---------
.../python/systemds/scuro/modality/transformed.py | 9 +-
src/main/python/systemds/scuro/modality/type.py | 19 ++++
.../systemds/scuro/modality/unimodal_modality.py | 24 ++++-
.../representations/aggregated_representation.py | 2 +-
.../python/systemds/scuro/representations/bert.py | 35 ++-----
.../python/systemds/scuro/representations/elmo.py | 26 +----
.../systemds/scuro/representations/text_context.py | 4 +-
.../representations/text_context_with_indices.py | 26 +++--
.../python/systemds/scuro/utils/torch_dataset.py | 42 ++++++++
.../python/tests/scuro/test_operator_registry.py | 10 +-
.../tests/scuro/test_text_context_operators.py | 32 +++---
.../python/tests/scuro/test_unimodal_optimizer.py | 5 +-
14 files changed, 199 insertions(+), 144 deletions(-)
diff --git a/src/main/python/systemds/scuro/drsearch/operator_registry.py
b/src/main/python/systemds/scuro/drsearch/operator_registry.py
index bf9547ddbf..e9c302ba90 100644
--- a/src/main/python/systemds/scuro/drsearch/operator_registry.py
+++ b/src/main/python/systemds/scuro/drsearch/operator_registry.py
@@ -97,6 +97,8 @@ class Registry:
return reps
def get_context_operators(self, modality_type):
+ if modality_type not in self._context_operators.keys():
+ return []
return self._context_operators[modality_type]
def get_dimensionality_reduction_operators(self, modality_type):
diff --git a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
index c555c2b677..5b03147ec1 100644
--- a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
@@ -356,7 +356,8 @@ class UnimodalOptimizer:
operator.__class__, [leaf_id], operator.get_current_parameters()
)
current_node_id = rep_node_id
- dags.append(builder.build(current_node_id))
+ rep_dag = builder.build(current_node_id)
+ dags.append(rep_dag)
dimensionality_reduction_dags =
self.add_dimensionality_reduction_operators(
builder, current_node_id
@@ -387,11 +388,6 @@ class UnimodalOptimizer:
[context_node_id],
operator.get_current_parameters(),
)
- dimensionality_reduction_dags =
self.add_dimensionality_reduction_operators(
- builder, context_rep_node_id
- ) # TODO: check if this is correctly using the 3d approach of
the dimensionality reduction operator
- if dimensionality_reduction_dags is not None:
- dags.extend(dimensionality_reduction_dags)
agg_operator = AggregatedRepresentation()
context_agg_node_id = builder.create_operation_node(
@@ -409,64 +405,88 @@ class UnimodalOptimizer:
not_self_contained_reps = [
rep for rep in not_self_contained_reps if rep !=
operator.__class__
]
+ rep_id = current_node_id
- for combination in self._combination_operators:
- current_node_id = rep_node_id
- for other_rep in not_self_contained_reps:
- other_rep_id = builder.create_operation_node(
- other_rep, [leaf_id], other_rep().parameters
- )
-
+ for rep in not_self_contained_reps:
+ other_rep_id = builder.create_operation_node(
+ rep, [leaf_id], rep().parameters
+ )
+ for combination in self._combination_operators:
combine_id = builder.create_operation_node(
combination.__class__,
- [current_node_id, other_rep_id],
+ [rep_id, other_rep_id],
combination.get_current_parameters(),
)
- dags.append(builder.build(combine_id))
- current_node_id = combine_id
- if modality.modality_type in [
- ModalityType.EMBEDDING,
- ModalityType.IMAGE,
- ModalityType.AUDIO,
- ]:
- dags.extend(
- self.default_context_operators(
- modality, builder, leaf_id, current_node_id
+ rep_dag = builder.build(combine_id)
+ dags.append(rep_dag)
+ if modality.modality_type in [
+ ModalityType.EMBEDDING,
+ ModalityType.IMAGE,
+ ModalityType.AUDIO,
+ ]:
+ dags.extend(
+ self.default_context_operators(
+ modality, builder, leaf_id, rep_dag, False
+ )
)
- )
- elif modality.modality_type == ModalityType.TIMESERIES:
- dags.extend(
- self.temporal_context_operators(
- modality, builder, leaf_id, current_node_id
+ elif modality.modality_type == ModalityType.TIMESERIES:
+ dags.extend(
+ self.temporal_context_operators(
+ modality,
+ builder,
+ leaf_id,
+ )
)
- )
+ rep_id = combine_id
+
+ if rep_dag.nodes[-1].operation().output_modality_type in [
+ ModalityType.EMBEDDING
+ ]:
+ dags.extend(
+ self.default_context_operators(
+ modality, builder, leaf_id, rep_dag, True
+ )
+ )
+
+ if (
+ modality.modality_type == ModalityType.TIMESERIES
+ or modality.modality_type == ModalityType.AUDIO
+ ):
+ dags.extend(self.temporal_context_operators(modality, builder,
leaf_id))
return dags
- def default_context_operators(self, modality, builder, leaf_id,
current_node_id):
+ def default_context_operators(
+ self, modality, builder, leaf_id, rep_dag, apply_context_to_leaf=False
+ ):
dags = []
- context_operators = self._get_context_operators(modality.modality_type)
- for context_op in context_operators:
+ if apply_context_to_leaf:
if (
modality.modality_type != ModalityType.TEXT
and modality.modality_type != ModalityType.VIDEO
):
- context_node_id = builder.create_operation_node(
- context_op,
- [leaf_id],
- context_op().get_current_parameters(),
- )
- dags.append(builder.build(context_node_id))
+ context_operators =
self._get_context_operators(modality.modality_type)
+ for context_op in context_operators:
+ context_node_id = builder.create_operation_node(
+ context_op,
+ [leaf_id],
+ context_op().get_current_parameters(),
+ )
+ dags.append(builder.build(context_node_id))
+ context_operators = self._get_context_operators(
+ rep_dag.nodes[-1].operation().output_modality_type
+ )
+ for context_op in context_operators:
context_node_id = builder.create_operation_node(
context_op,
- [current_node_id],
+ [rep_dag.nodes[-1].node_id],
context_op().get_current_parameters(),
)
dags.append(builder.build(context_node_id))
return dags
- def temporal_context_operators(self, modality, builder, leaf_id,
current_node_id):
+ def temporal_context_operators(self, modality, builder, leaf_id):
aggregators =
self.operator_registry.get_representations(modality.modality_type)
context_operators = self._get_context_operators(modality.modality_type)
@@ -561,12 +581,11 @@ class UnimodalResults:
results = results[: self.k]
sorted_indices = sorted_indices[: self.k]
-
task_cache = self.cache.get(modality.modality_id,
{}).get(task.model.name, None)
if not task_cache:
cache = [
- list(task_results[i].dag.execute([modality]).values())[-1]
- for i in sorted_indices
+ list(results[i].dag.execute([modality]).values())[-1]
+ for i in range(len(results))
]
elif isinstance(task_cache, list):
cache = task_cache
diff --git a/src/main/python/systemds/scuro/modality/transformed.py
b/src/main/python/systemds/scuro/modality/transformed.py
index 078b65f0bc..a443f5a313 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -31,7 +31,12 @@ import copy
class TransformedModality(Modality):
def __init__(
- self, modality, transformation, new_modality_type=None,
self_contained=True
+ self,
+ modality,
+ transformation,
+ new_modality_type=None,
+ self_contained=True,
+ set_data=False,
):
"""
Parent class of the different Modalities (unimodal & multimodal)
@@ -49,6 +54,8 @@ class TransformedModality(Modality):
modality.data_type,
modality.transform_time,
)
+ if set_data:
+ self.data = modality.data
self.transformation = None
self.self_contained = (
self_contained and transformation.self_contained
diff --git a/src/main/python/systemds/scuro/modality/type.py
b/src/main/python/systemds/scuro/modality/type.py
index 23d97e869b..85f4d04e9b 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/modality/type.py
@@ -210,6 +210,25 @@ class ModalityType(Flag):
def get_schema(self):
return ModalitySchemas.get(self.name)
+ def has_field(self, md, field):
+ for value in md.values():
+ if field in value:
+ return True
+ else:
+ return False
+ return False
+
+ def get_field_for_instances(self, md, field):
+ data = []
+ for items in md.values():
+ data.append(self.get_field(items, field))
+ return data
+
+ def get_field(self, md, field):
+ if field in md:
+ return md[field]
+ return None
+
def update_metadata(self, md, data):
return ModalitySchemas.update_metadata(self.name, md, data)
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 4efaa7d733..89d95810e0 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -91,9 +91,14 @@ class UnimodalModality(Modality):
if not self.has_data():
self.extract_raw_data()
- transformed_modality = TransformedModality(self, context_operator)
-
- transformed_modality.data = context_operator.execute(self)
+ transformed_modality = TransformedModality(
+ self, context_operator, set_data=True
+ )
+ d = context_operator.execute(transformed_modality)
+ if d is not None:
+ transformed_modality.data = d
+ else:
+ transformed_modality.data = self.data
transformed_modality.transform_time += time.time() - start
return transformed_modality
@@ -212,14 +217,23 @@ class UnimodalModality(Modality):
mode="constant",
constant_values=0,
)
- else:
+ elif len(embeddings.shape) == 2:
padded = np.pad(
embeddings,
((0, padding_needed), (0, 0)),
mode="constant",
constant_values=0,
)
- padded_embeddings.append(padded)
+ elif len(embeddings.shape) == 3:
+ padded = np.pad(
+ embeddings,
+ ((0, padding_needed), (0, 0), (0, 0)),
+ mode="constant",
+ constant_values=0,
+ )
+ padded_embeddings.append(padded)
+ else:
+ raise ValueError(f"Unsupported shape:
{embeddings.shape}")
else:
padded_embeddings.append(embeddings)
diff --git
a/src/main/python/systemds/scuro/representations/aggregated_representation.py
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
index bcc36f4621..cad1a4a448 100644
---
a/src/main/python/systemds/scuro/representations/aggregated_representation.py
+++
b/src/main/python/systemds/scuro/representations/aggregated_representation.py
@@ -38,7 +38,7 @@ class AggregatedRepresentation(Representation):
aggregated_modality = TransformedModality(
modality, self, self_contained=modality.self_contained
)
+ aggregated_modality.data = self.aggregation.execute(modality)
end = time.perf_counter()
aggregated_modality.transform_time += end - start
- aggregated_modality.data = self.aggregation.execute(modality)
return aggregated_modality
diff --git a/src/main/python/systemds/scuro/representations/bert.py
b/src/main/python/systemds/scuro/representations/bert.py
index be579c0dd6..6f4d3705a1 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -28,35 +28,12 @@ from systemds.scuro.modality.type import ModalityType
from systemds.scuro.drsearch.operator_registry import register_representation
from systemds.scuro.utils.static_variables import get_device
import os
-from torch.utils.data import Dataset, DataLoader
+from torch.utils.data import DataLoader
+from systemds.scuro.utils.torch_dataset import TextDataset, TextSpanDataset
os.environ["TOKENIZERS_PARALLELISM"] = "false"
-class TextDataset(Dataset):
- def __init__(self, texts):
-
- self.texts = []
- if isinstance(texts, list):
- self.texts = texts
- else:
- for text in texts:
- if text is None:
- self.texts.append("")
- elif isinstance(text, np.ndarray):
- self.texts.append(str(text.item()) if text.size == 1 else
str(text))
- elif not isinstance(text, str):
- self.texts.append(str(text))
- else:
- self.texts.append(text)
-
- def __len__(self):
- return len(self.texts)
-
- def __getitem__(self, idx):
- return self.texts[idx]
-
-
class BertFamily(UnimodalRepresentation):
def __init__(
self,
@@ -96,10 +73,12 @@ class BertFamily(UnimodalRepresentation):
layer.register_forward_hook(get_activation(name))
break
- if isinstance(modality.data[0], list):
+ if ModalityType.TEXT.has_field(modality.metadata, "text_spans"):
+ dataset = TextSpanDataset(modality.data, modality.metadata)
embeddings = []
- for d in modality.data:
- embeddings.append(self.create_embeddings(d, self.model,
tokenizer))
+ for text in dataset:
+ embedding = self.create_embeddings(text, self.model, tokenizer)
+ embeddings.append(embedding)
else:
embeddings = self.create_embeddings(modality.data, self.model,
tokenizer)
diff --git a/src/main/python/systemds/scuro/representations/elmo.py
b/src/main/python/systemds/scuro/representations/elmo.py
index ba2a99f8e1..33e4f74141 100644
--- a/src/main/python/systemds/scuro/representations/elmo.py
+++ b/src/main/python/systemds/scuro/representations/elmo.py
@@ -29,34 +29,10 @@ from systemds.scuro.modality.type import ModalityType
from systemds.scuro.utils.static_variables import get_device
from flair.embeddings import ELMoEmbeddings
from flair.data import Sentence
-from torch.utils.data import Dataset
+from systemds.scuro.utils.torch_dataset import TextDataset
from torch.utils.data import DataLoader
-class TextDataset(Dataset):
- def __init__(self, texts):
-
- self.texts = []
- if isinstance(texts, list):
- self.texts = texts
- else:
- for text in texts:
- if text is None:
- self.texts.append("")
- elif isinstance(text, np.ndarray):
- self.texts.append(str(text.item()) if text.size == 1 else
str(text))
- elif not isinstance(text, str):
- self.texts.append(str(text))
- else:
- self.texts.append(text)
-
- def __len__(self):
- return len(self.texts)
-
- def __getitem__(self, idx):
- return self.texts[idx]
-
-
# @register_representation([ModalityType.TEXT])
class ELMoRepresentation(UnimodalRepresentation):
def __init__(
diff --git a/src/main/python/systemds/scuro/representations/text_context.py
b/src/main/python/systemds/scuro/representations/text_context.py
index b98b90e187..b4f82bda19 100644
--- a/src/main/python/systemds/scuro/representations/text_context.py
+++ b/src/main/python/systemds/scuro/representations/text_context.py
@@ -72,7 +72,7 @@ def _extract_text(instance: Any) -> str:
return text
-@register_context_operator(ModalityType.TEXT)
+# @register_context_operator(ModalityType.TEXT)
class SentenceBoundarySplit(Context):
"""
Splits text at sentence boundaries while respecting maximum word count.
@@ -154,7 +154,7 @@ class SentenceBoundarySplit(Context):
return chunked_data
-@register_context_operator(ModalityType.TEXT)
+# @register_context_operator(ModalityType.TEXT)
class OverlappingSplit(Context):
"""
Splits text with overlapping chunks using a sliding window approach.
diff --git
a/src/main/python/systemds/scuro/representations/text_context_with_indices.py
b/src/main/python/systemds/scuro/representations/text_context_with_indices.py
index 7daf93855f..5a3c3b34e0 100644
---
a/src/main/python/systemds/scuro/representations/text_context_with_indices.py
+++
b/src/main/python/systemds/scuro/representations/text_context_with_indices.py
@@ -134,7 +134,7 @@ class WordCountSplitIndices(Context):
return chunked_data
-# @register_context_operator(ModalityType.TEXT)
+@register_context_operator(ModalityType.TEXT)
class SentenceBoundarySplitIndices(Context):
"""
Splits text at sentence boundaries while respecting maximum word count.
@@ -162,18 +162,17 @@ class SentenceBoundarySplitIndices(Context):
Returns:
List of lists, where each inner list contains text chunks (strings)
"""
- chunked_data = []
- for instance in modality.data:
+ for instance, metadata in zip(modality.data,
modality.metadata.values()):
text = _extract_text(instance)
if not text:
- chunked_data.append((0, 0))
+ ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)])
continue
sentences = _split_into_sentences(text)
if not sentences:
- chunked_data.append((0, len(text)))
+ ModalityType.TEXT.add_field(metadata, "text_spans", [(0,
len(text))])
continue
chunks = []
@@ -225,12 +224,12 @@ class SentenceBoundarySplitIndices(Context):
if not chunks:
chunks = [(0, len(text))]
- chunked_data.append(chunks)
+ ModalityType.TEXT.add_field(metadata, "text_spans", chunks)
- return chunked_data
+ return None
-# @register_context_operator(ModalityType.TEXT)
+@register_context_operator(ModalityType.TEXT)
class OverlappingSplitIndices(Context):
"""
Splits text with overlapping chunks using a sliding window approach.
@@ -263,18 +262,17 @@ class OverlappingSplitIndices(Context):
Returns:
List of tuples, where each tuple contains start and end index to
the text chunks
"""
- chunked_data = []
- for instance in modality.data:
+ for instance, metadata in zip(modality.data,
modality.metadata.values()):
text = _extract_text(instance)
if not text:
- chunked_data.append((0, 0))
+ ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)])
continue
words = _split_into_words(text)
if len(words) <= self.max_words:
- chunked_data.append((0, len(text)))
+ ModalityType.TEXT.add_field(metadata, "text_spans", [(0,
len(text))])
continue
chunks = []
@@ -295,6 +293,6 @@ class OverlappingSplitIndices(Context):
if not chunks:
chunks = [(0, len(text))]
- chunked_data.append(chunks)
+ ModalityType.TEXT.add_field(metadata, "text_spans", chunks)
- return chunked_data
+ return None
diff --git a/src/main/python/systemds/scuro/utils/torch_dataset.py
b/src/main/python/systemds/scuro/utils/torch_dataset.py
index 9c462e3675..ba3e24a317 100644
--- a/src/main/python/systemds/scuro/utils/torch_dataset.py
+++ b/src/main/python/systemds/scuro/utils/torch_dataset.py
@@ -24,6 +24,8 @@ import numpy as np
import torch
import torchvision.transforms as transforms
+from systemds.scuro.modality.type import ModalityType
+
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data, data_type, device, size=None, tf=None):
@@ -78,3 +80,43 @@ class CustomDataset(torch.utils.data.Dataset):
def __len__(self) -> int:
return len(self.data)
+
+
+class TextDataset(torch.utils.data.Dataset):
+ def __init__(self, texts):
+
+ self.texts = []
+ if isinstance(texts, list):
+ self.texts = texts
+ else:
+ for text in texts:
+ if text is None:
+ self.texts.append("")
+ elif isinstance(text, np.ndarray):
+ self.texts.append(str(text.item()) if text.size == 1 else
str(text))
+ elif not isinstance(text, str):
+ self.texts.append(str(text))
+ else:
+ self.texts.append(text)
+
+ def __len__(self):
+ return len(self.texts)
+
+ def __getitem__(self, idx):
+ return self.texts[idx]
+
+
+class TextSpanDataset(torch.utils.data.Dataset):
+ def __init__(self, full_texts, metadata):
+ self.full_texts = full_texts
+ self.spans_per_text = ModalityType.TEXT.get_field_for_instances(
+ metadata, "text_spans"
+ )
+
+ def __len__(self):
+ return len(self.full_texts)
+
+ def __getitem__(self, idx):
+ text = self.full_texts[idx]
+ spans = self.spans_per_text[idx]
+ return [text[s:e] for (s, e) in spans]
diff --git a/src/main/python/tests/scuro/test_operator_registry.py
b/src/main/python/tests/scuro/test_operator_registry.py
index 189e3e44d7..443cc039d6 100644
--- a/src/main/python/tests/scuro/test_operator_registry.py
+++ b/src/main/python/tests/scuro/test_operator_registry.py
@@ -21,9 +21,9 @@
import unittest
-from systemds.scuro.representations.text_context import (
- SentenceBoundarySplit,
- OverlappingSplit,
+from systemds.scuro.representations.text_context_with_indices import (
+ SentenceBoundarySplitIndices,
+ OverlappingSplitIndices,
)
from systemds.scuro.representations.covarep_audio_features import (
@@ -134,8 +134,8 @@ class TestOperatorRegistry(unittest.TestCase):
DynamicWindow,
]
assert registry.get_context_operators(ModalityType.TEXT) == [
- SentenceBoundarySplit,
- OverlappingSplit,
+ SentenceBoundarySplitIndices,
+ OverlappingSplitIndices,
]
# def test_fusion_operator_in_registry(self):
diff --git a/src/main/python/tests/scuro/test_text_context_operators.py
b/src/main/python/tests/scuro/test_text_context_operators.py
index 1f04165407..ffa702b7c8 100644
--- a/src/main/python/tests/scuro/test_text_context_operators.py
+++ b/src/main/python/tests/scuro/test_text_context_operators.py
@@ -36,6 +36,7 @@ from tests.scuro.data_generator import (
)
from systemds.scuro.modality.unimodal_modality import UnimodalModality
from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.bert import Bert
class TestTextContextOperator(unittest.TestCase):
@@ -80,33 +81,30 @@ class TestTextContextOperator(unittest.TestCase):
def test_sentence_boundary_split_indices(self):
sentence_boundary_split = SentenceBoundarySplitIndices(10, min_words=4)
- chunks = sentence_boundary_split.execute(self.text_modality)
- for i in range(0, len(chunks)):
- for chunk in chunks[i]:
- text = self.text_modality.data[i][chunk[0] : chunk[1]].split("
")
+ sentence_boundary_split.execute(self.text_modality)
+ for instance, md in zip(
+ self.text_modality.data, self.text_modality.metadata.values()
+ ):
+ for chunk in md["text_spans"]:
+ text = instance[chunk[0] : chunk[1]].split(" ")
assert len(text) <= 10 and (
text[-1][-1] == "." or text[-1][-1] == "!" or text[-1][-1]
== "?"
)
def test_overlapping_split_indices(self):
overlapping_split = OverlappingSplitIndices(40, 0.1)
- chunks = overlapping_split.execute(self.text_modality)
- for i in range(len(chunks)):
+ overlapping_split.execute(self.text_modality)
+ for instance, md in zip(
+ self.text_modality.data, self.text_modality.metadata.values()
+ ):
prev_chunk = (0, 0)
- for j, chunk in enumerate(chunks[i]):
+ for j, chunk in enumerate(md["text_spans"]):
if j > 0:
- prev_words = self.text_modality.data[i][
- prev_chunk[0] : prev_chunk[1]
- ].split(" ")
- curr_words = self.text_modality.data[i][chunk[0] :
chunk[1]].split(
- " "
- )
+ prev_words = instance[prev_chunk[0] :
prev_chunk[1]].split(" ")
+ curr_words = instance[chunk[0] : chunk[1]].split(" ")
assert prev_words[-4:] == curr_words[:4]
prev_chunk = chunk
- assert (
- len(self.text_modality.data[i][chunk[0] :
chunk[1]].split(" "))
- <= 40
- )
+ assert len(instance[chunk[0] : chunk[1]].split(" ")) <= 40
if __name__ == "__main__":
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
index 0d8ae90177..7fa606d835 100644
--- a/src/main/python/tests/scuro/test_unimodal_optimizer.py
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -36,6 +36,7 @@ from systemds.scuro.representations.covarep_audio_features
import (
)
from systemds.scuro.representations.word2vec import W2V
from systemds.scuro.representations.bow import BoW
+from systemds.scuro.representations.bert import Bert
from systemds.scuro.modality.unimodal_modality import UnimodalModality
from systemds.scuro.representations.resnet import ResNet
from tests.scuro.data_generator import (
@@ -124,7 +125,7 @@ class
TestUnimodalRepresentationOptimizer(unittest.TestCase):
):
registry = Registry()
- unimodal_optimizer = UnimodalOptimizer([modality], self.tasks,
False)
+ unimodal_optimizer = UnimodalOptimizer([modality], self.tasks,
False, k=1)
unimodal_optimizer.optimize()
assert (
@@ -133,7 +134,7 @@ class
TestUnimodalRepresentationOptimizer(unittest.TestCase):
)
assert len(unimodal_optimizer.operator_performance.task_names) == 2
result, cached =
unimodal_optimizer.operator_performance.get_k_best_results(
- modality, 1, self.tasks[0], "accuracy"
+ modality, self.tasks[0], "accuracy"
)
assert len(result) == 1
assert len(cached) == 1