This is an automated email from the ASF dual-hosted git repository.
christinadionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c60f75cc59 [SYSTEMDS-3944] Modality Alignment, Contrastive Learning,
new Data Loaders
c60f75cc59 is described below
commit c60f75cc59e7b5024518cff95bb15816a88d0490
Author: b-enedict <[email protected]>
AuthorDate: Tue Apr 28 10:10:57 2026 +0200
[SYSTEMDS-3944] Modality Alignment, Contrastive Learning, new Data Loaders
Summary
This PR introduces new functionality for multimodal learning in Scuro,
including a contrastive learning operator, a modality alignment operator, and
additional data loaders.
Changes
Contrastive Learning Operator
- Constructs modality pairs via a Cartesian product
- Uses a user-defined function to label pairs as positive or negative
- Enables dynamic generation of contrastive samples
Modality Alignment Operator
- Aligns previously unaligned modalities using feature-based similarity
(e.g., ORB, perceptual hashing)
- Outputs a matching between a primary and secondary modality
- Matching is applied after representation learning and before fusion
Data Loaders
- PDF loader: converts document pages into NumPy arrays for OpenCV
processing
- Audio loader: converts audio to text using faster-whisper
Closes #2461
---
.../systemds/scuro/dataloader/audio_loader.py | 9 +-
.../systemds/scuro/dataloader/base_loader.py | 7 +-
.../systemds/scuro/dataloader/image_loader.py | 4 +-
.../systemds/scuro/dataloader/json_loader.py | 4 +-
.../python/systemds/scuro/dataloader/pdf_loader.py | 70 +++++++++++
.../systemds/scuro/dataloader/text_loader.py | 4 +-
.../systemds/scuro/dataloader/timeseries_loader.py | 15 ++-
.../systemds/scuro/dataloader/transcript_loader.py | 59 +++++++++
.../systemds/scuro/dataloader/video_loader.py | 6 +-
src/main/python/systemds/scuro/modality/joined.py | 17 ++-
.../python/systemds/scuro/modality/modality.py | 41 +++----
.../python/systemds/scuro/modality/transformed.py | 9 +-
src/main/python/systemds/scuro/modality/type.py | 8 +-
.../systemds/scuro/modality/unimodal_modality.py | 9 +-
.../ orb_alignment.py} | 41 ++++---
.../systemds/scuro/representations/alignment.py | 136 +++++++++++++++++++++
.../python/systemds/scuro/representations/clip.py | 3 +
.../scuro/representations/concatenation.py | 8 +-
.../scuro/representations/contrastive_learning.py | 81 ++++++++++++
.../representations/covarep_audio_features.py | 4 +-
.../python/systemds/scuro/representations/lstm.py | 2 +-
.../scuro/representations/mel_spectrogram.py | 3 +-
.../python/systemds/scuro/representations/mfcc.py | 2 +-
.../pHash_alignment.py} | 39 +++---
.../python/systemds/scuro/representations/sum.py | 8 +-
.../representations/text_context_with_indices.py | 4 +-
.../representations/timeseries_representations.py | 2 +-
.../systemds/scuro/representations/wav2vec.py | 2 +-
.../scuro/representations/window_aggregation.py | 2 +-
.../python/systemds/scuro/utils/schema_helpers.py | 2 +-
src/main/python/tests/scuro/data_generator.py | 62 +++++-----
.../tests/scuro/test_text_context_operators.py | 4 +-
32 files changed, 504 insertions(+), 163 deletions(-)
diff --git a/src/main/python/systemds/scuro/dataloader/audio_loader.py
b/src/main/python/systemds/scuro/dataloader/audio_loader.py
index e0b9e61a07..1c7ae4f3a8 100644
--- a/src/main/python/systemds/scuro/dataloader/audio_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/audio_loader.py
@@ -63,18 +63,17 @@ class AudioLoader(BaseLoader):
if not self.load_data_from_file:
import numpy as np
- self.metadata[file] = self.modality_type.create_metadata(
- 1000, np.array([0])
- )
+ audio = np.array([0])
+ sr = 1000
else:
audio, sr = librosa.load(file, dtype=self._data_type)
if self.normalize:
audio = librosa.util.normalize(audio)
- self.metadata[file] = self.modality_type.create_metadata(sr, audio)
+ self.metadata.append(self.modality_type.create_metadata(sr, audio))
- self.data.append(audio)
+ self.data.append(audio)
def get_stats(self, source_path: str):
sampling_rate = 0
diff --git a/src/main/python/systemds/scuro/dataloader/base_loader.py
b/src/main/python/systemds/scuro/dataloader/base_loader.py
index 88decd641f..9b89c77394 100644
--- a/src/main/python/systemds/scuro/dataloader/base_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/base_loader.py
@@ -44,9 +44,7 @@ class BaseLoader(ABC):
(otherwise please provide your own Dataloader that knows about the
file name convention)
"""
self.data = []
- self.metadata = (
- {}
- ) # TODO: check what the index should be for storing the metadata
(file_name, counter, ...)
+ self.metadata = []
self.source_path = source_path
self.indices = indices
self.modality_type = modality_type
@@ -87,7 +85,7 @@ class BaseLoader(ABC):
def reset(self):
self._next_chunk = 0
self.data = []
- self.metadata = {}
+ self.metadata = []
def load(self):
"""
@@ -134,6 +132,7 @@ class BaseLoader(ABC):
Loads the next chunk of data
"""
self.data = []
+ # TODO: Handle metadata correctly
next_chunk_indices = self.indices[
self._next_chunk
* self._chunk_size : (self._next_chunk + 1)
diff --git a/src/main/python/systemds/scuro/dataloader/image_loader.py
b/src/main/python/systemds/scuro/dataloader/image_loader.py
index 498ae77a89..25e8690cf5 100644
--- a/src/main/python/systemds/scuro/dataloader/image_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/image_loader.py
@@ -71,8 +71,8 @@ class ImageLoader(BaseLoader):
image = image.astype(np.uint8, copy=False)
- self.metadata[file] = self.modality_type.create_metadata(
- width, height, channels
+ self.metadata.append(
+ self.modality_type.create_metadata(width, height, channels)
)
self.data.append(image)
diff --git a/src/main/python/systemds/scuro/dataloader/json_loader.py
b/src/main/python/systemds/scuro/dataloader/json_loader.py
index f5ffd89ea3..adb3f6aaf6 100644
--- a/src/main/python/systemds/scuro/dataloader/json_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/json_loader.py
@@ -69,7 +69,9 @@ class JSONLoader(BaseLoader):
text = " ".join(text) if isinstance(text, list) else text
self.data.append(text)
- self.metadata[idx] =
self.modality_type.create_metadata(len(text), text)
+ self.metadata.append(
+ self.modality_type.create_metadata(len(text), text) |
json_file[idx]
+ )
def get_stats(self, source_path: str):
self.file_sanity_check(source_path)
diff --git a/src/main/python/systemds/scuro/dataloader/pdf_loader.py
b/src/main/python/systemds/scuro/dataloader/pdf_loader.py
new file mode 100644
index 0000000000..add02e5045
--- /dev/null
+++ b/src/main/python/systemds/scuro/dataloader/pdf_loader.py
@@ -0,0 +1,70 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from typing import List, Optional, Union
+import pymupdf
+
+import numpy as np
+
+from systemds.scuro.dataloader.base_loader import BaseLoader
+import cv2
+from systemds.scuro.modality.type import ModalityType
+
+
+class PdfLoader(BaseLoader):
+ def __init__(
+ self,
+ source_path: str,
+ indices: List[str],
+ data_type: Union[np.dtype, str] = np.float16,
+ chunk_size: Optional[int] = None,
+ load=True,
+ ext=".pdf",
+ ):
+ super().__init__(
+ source_path, indices, data_type, chunk_size, ModalityType.IMAGE,
ext
+ )
+ self.load_data_from_file = load
+
+ def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
+ self.file_sanity_check(file)
+
+ doc = pymupdf.open(file)
+
+ for i, page in enumerate(doc.pages()):
+ image_bytes = page.get_pixmap().tobytes("jpg")
+ np_buffer = np.frombuffer(image_bytes, dtype=np.uint8)
+
+ image = cv2.imdecode(np_buffer, cv2.IMREAD_COLOR)
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+
+ if image.ndim == 2:
+ height, width = image.shape
+ channels = 1
+ else:
+ height, width, channels = image.shape
+
+ image = image.astype(np.uint8, copy=False)
+
+ self.metadata.append(
+ self.modality_type.create_metadata(width, height, channels)
+ )
+
+ self.data.append(image)
diff --git a/src/main/python/systemds/scuro/dataloader/text_loader.py
b/src/main/python/systemds/scuro/dataloader/text_loader.py
index 8b987f6845..f734a080b1 100644
--- a/src/main/python/systemds/scuro/dataloader/text_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/text_loader.py
@@ -56,8 +56,8 @@ class TextLoader(BaseLoader):
if self.prefix:
line = re.sub(self.prefix, "", line)
line = line.replace("\n", "")
- self.metadata[file] = self.modality_type.create_metadata(
- len(line.split()), line
+ self.metadata.append(
+ self.modality_type.create_metadata(len(line.split()), line)
)
self.data.append(line)
diff --git a/src/main/python/systemds/scuro/dataloader/timeseries_loader.py
b/src/main/python/systemds/scuro/dataloader/timeseries_loader.py
index 6e40e8eb08..7131b55db1 100644
--- a/src/main/python/systemds/scuro/dataloader/timeseries_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/timeseries_loader.py
@@ -81,15 +81,20 @@ class TimeseriesLoader(BaseLoader):
data = self._normalize_signals(data)
if file:
- self.metadata[index] = self.modality_type.create_metadata(
- self.signal_names, data, self.sampling_rate
+ self.metadata.append(
+ self.modality_type.create_metadata(
+ self.signal_names, data, self.sampling_rate
+ )
)
+ self.data.append(data)
else:
for i, index in enumerate(self.indices):
- self.metadata[str(index)] = self.modality_type.create_metadata(
- self.signal_names, data[i], self.sampling_rate
+ self.metadata.append(
+ self.modality_type.create_metadata(
+ self.signal_names, data[i], self.sampling_rate
+ )
)
- self.data.append(data)
+ self.data.append(data[i])
def _normalize_signals(self, data: np.ndarray) -> np.ndarray:
if data.ndim == 1:
diff --git a/src/main/python/systemds/scuro/dataloader/transcript_loader.py
b/src/main/python/systemds/scuro/dataloader/transcript_loader.py
new file mode 100644
index 0000000000..67166f2c32
--- /dev/null
+++ b/src/main/python/systemds/scuro/dataloader/transcript_loader.py
@@ -0,0 +1,59 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from typing import List, Optional, Union
+from faster_whisper import WhisperModel
+import numpy as np
+
+from systemds.scuro.dataloader.base_loader import BaseLoader
+from systemds.scuro.modality.type import ModalityType
+
+
+class TranscriptLoader(BaseLoader):
+ def __init__(
+ self,
+ source_path: str,
+ indices: List[str],
+ data_type: Union[np.dtype, str] = np.float32,
+ chunk_size: Optional[int] = None,
+ normalize: bool = True,
+ transcribe_model_size: str = "medium",
+ load=True,
+ ):
+ super().__init__(source_path, indices, data_type, chunk_size,
ModalityType.TEXT)
+ self.model = WhisperModel(
+ transcribe_model_size, device="cpu", compute_type="int8"
+ )
+ self.normalize = normalize
+ self.load_data_from_file = load
+
+ def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
+ self.file_sanity_check(file)
+ segments, _ = self.model.transcribe(file, vad_filter=True)
+
+ for i, seg in enumerate(segments):
+ md = self.modality_type.create_metadata(len(seg.text.split()),
seg.text)
+ md["timestamp_start"] = seg.start
+ md["timestamp_end"] = seg.end
+ md["text"] = seg.text
+
+ self.metadata.append(md)
+
+ self.data.append(seg.text)
diff --git a/src/main/python/systemds/scuro/dataloader/video_loader.py
b/src/main/python/systemds/scuro/dataloader/video_loader.py
index e57f685e03..a60b7acc60 100644
--- a/src/main/python/systemds/scuro/dataloader/video_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/video_loader.py
@@ -87,8 +87,10 @@ class VideoLoader(BaseLoader):
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
num_channels = 3
- self.metadata[file] = self.modality_type.create_metadata(
- self.fps, length, width, height, num_channels
+ self.metadata.append(
+ self.modality_type.create_metadata(
+ self.fps, length, width, height, num_channels
+ )
)
frames = []
diff --git a/src/main/python/systemds/scuro/modality/joined.py
b/src/main/python/systemds/scuro/modality/joined.py
index 335d1959fd..124c7952fd 100644
--- a/src/main/python/systemds/scuro/modality/joined.py
+++ b/src/main/python/systemds/scuro/modality/joined.py
@@ -77,9 +77,8 @@ class JoinedModality(Modality):
)
for i in range(start, end):
- idx_1 = list(self.left_modality.metadata.values())[i +
starting_idx][
- self.condition.leftField
- ]
+ left_meta_idx = i if self.chunk_left else i + starting_idx
+ idx_1 =
self.left_modality.metadata[left_meta_idx][self.condition.leftField]
if (
self.condition.alignment is None and self.condition.join_type
== "<"
): # TODO compute correct alignment timestamps/spatial params
@@ -90,9 +89,7 @@ class JoinedModality(Modality):
if self.chunk_left:
i = i + starting_idx
- idx_2 = list(self.right_modality.metadata.values())[i][
- self.condition.rightField
- ]
+ idx_2 = self.right_modality.metadata[i][self.condition.rightField]
self.joined_right.data.append([])
c = 0
@@ -228,8 +225,8 @@ class JoinedModality(Modality):
def _apply_representation_chunked(
self, left_modality, right_modality, chunk_right, representation
):
- new_left = Modality(left_modality.modality_type, {})
- new_right = Modality(right_modality.modality_type, {})
+ new_left = Modality(left_modality.modality_type)
+ new_right = Modality(right_modality.modality_type)
for _ in left_modality.iter_raw_data_chunks(reset=True):
if chunk_right:
@@ -246,11 +243,11 @@ class JoinedModality(Modality):
self.joined_right, representation
)
new_right.data.extend(right_transformed.data)
- new_right.metadata.update(right_transformed.metadata)
+ new_right.metadata.extend(right_transformed.metadata)
left_transformed = self._apply_representation(left_modality,
representation)
new_left.data.extend(left_transformed.data)
- new_left.metadata.update(left_transformed.metadata)
+ new_left.metadata.extend(left_transformed.metadata)
new_left.update_metadata()
new_right.update_metadata()
diff --git a/src/main/python/systemds/scuro/modality/modality.py
b/src/main/python/systemds/scuro/modality/modality.py
index 1bc8180e19..477e4e45f3 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -33,7 +33,7 @@ class Modality:
self,
modalityType: ModalityType,
modality_id=-1,
- metadata={},
+ metadata=[],
data_type=None,
transform_time=0,
):
@@ -91,10 +91,10 @@ class Modality:
):
return
- for i, (md_k, md_v) in enumerate(self.metadata.items()):
+ for i, md_v in enumerate(self.metadata):
md_v = selective_copy_metadata(md_v)
updated_md = self.modality_type.update_metadata(md_v, self.data[i])
- self.metadata[md_k] = updated_md
+ self.metadata[i] = updated_md
if i == 0:
self.data_type = updated_md["data_layout"]["type"]
@@ -160,13 +160,10 @@ class Modality:
if self.has_metadata():
attention_mask = np.zeros(maxlen, dtype=np.int8)
attention_mask[: len(data)] = 1
- md_key = list(self.metadata.keys())[i]
- if "attention_mask" in self.metadata[md_key]:
- self.metadata[md_key]["attention_mask"] =
attention_mask
+ if "attention_mask" in self.metadata[i]:
+ self.metadata[i]["attention_mask"] = attention_mask
else:
- self.metadata[md_key].update(
- {"attention_mask": attention_mask}
- )
+ self.metadata[i].update({"attention_mask":
attention_mask})
elif (
isinstance(first, list)
and len(first) > 0
@@ -190,13 +187,10 @@ class Modality:
if self.has_metadata():
attention_mask = np.zeros(maxlen, dtype=np.int8)
attention_mask[: len(data)] = 1
- md_key = list(self.metadata.keys())[i]
- if "attention_mask" in self.metadata[md_key]:
- self.metadata[md_key]["attention_mask"] =
attention_mask
+ if "attention_mask" in self.metadata[i]:
+ self.metadata[i]["attention_mask"] = attention_mask
else:
- self.metadata[md_key].update(
- {"attention_mask": attention_mask}
- )
+ self.metadata[i].update({"attention_mask":
attention_mask})
else:
maxlen = (
max([len(seq) for seq in self.data]) if max_len is None
else max_len
@@ -214,19 +208,16 @@ class Modality:
if self.has_metadata():
attention_mask = np.zeros(result.shape[1],
dtype=np.int8)
attention_mask[: len(data)] = 1
- md_key = list(self.metadata.keys())[i]
- if "attention_mask" in self.metadata[md_key]:
- self.metadata[md_key]["attention_mask"] =
attention_mask
+ if "attention_mask" in self.metadata[i]:
+ self.metadata[i]["attention_mask"] = attention_mask
else:
- self.metadata[md_key].update(
- {"attention_mask": attention_mask}
- )
+ self.metadata[i].update({"attention_mask":
attention_mask})
# TODO: this might need to be a new modality (otherwise we loose the
original data)
self.data = result
def get_data_layout(self):
if self.has_metadata():
- return
list(self.metadata.values())[0]["data_layout"]["representation"]
+ return self.metadata[0]["data_layout"]["representation"]
return None
@@ -234,14 +225,14 @@ class Modality:
return self.data is not None and len(self.data) != 0
def has_metadata(self):
- return self.metadata is not None and self.metadata != {}
+ return self.metadata is not None and len(self.metadata) != 0
def is_aligned(self, other_modality):
aligned = True
for i in range(len(self.data)):
if (
- list(self.metadata.values())[i]["data_layout"]["shape"]
- !=
list(other_modality.metadata.values())[i]["data_layout"]["shape"]
+ self.metadata[i]["data_layout"]["shape"]
+ != other_modality.metadata[i]["data_layout"]["shape"]
):
aligned = False
break
diff --git a/src/main/python/systemds/scuro/modality/transformed.py
b/src/main/python/systemds/scuro/modality/transformed.py
index e185deb7c9..eaaa7a2032 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -66,9 +66,9 @@ class TransformedModality(Modality):
self.aggregate_dim = aggregate_dim
if modality.__class__.__name__ == "UnimodalModality":
- for k, v in self.metadata.items():
- if "attention_masks" in v:
- del self.metadata[k]["attention_masks"]
+ for m in self.metadata:
+ if "attention_masks" in m:
+ del m["attention_masks"]
def copy_from_instance(self):
"""
@@ -82,8 +82,7 @@ class TransformedModality(Modality):
data_bytes += self._estimate_data_bytes(instance)
md_bytes = 0
- for key, value in self.metadata.items():
- md_bytes += self._estimate_data_bytes(key)
+ for value in self.metadata:
md_bytes += self._estimate_data_bytes(value)
total_bytes = (
diff --git a/src/main/python/systemds/scuro/modality/type.py
b/src/main/python/systemds/scuro/modality/type.py
index 9c883efec3..0493edf5bd 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/modality/type.py
@@ -212,7 +212,7 @@ class ModalityType(Flag):
return ModalitySchemas.get(self.name)
def has_field(self, md, field):
- for value in md.values():
+ for value in md:
if field in value:
return True
else:
@@ -221,7 +221,7 @@ class ModalityType(Flag):
def get_field_for_instances(self, md, field):
data = []
- for items in md.values():
+ for items in md:
data.append(self.get_field(items, field))
return data
@@ -242,8 +242,8 @@ class ModalityType(Flag):
return md
def add_field_for_instances(self, md, field, data):
- for key, value in zip(md.keys(), data):
- md[key].update({field: value})
+ for i, value in enumerate(data):
+ md[i].update({field: value})
return md
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 319599d680..84204ac570 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -41,7 +41,7 @@ class UnimodalModality(Modality):
super().__init__(
data_loader.modality_type,
Identifier().new_id(),
- {},
+ [],
data_loader.data_type,
)
self.data_loader = data_loader
@@ -56,10 +56,11 @@ class UnimodalModality(Modality):
def get_metadata_at_position(self, position: int):
if self.data_loader.chunk_size:
return self.metadata[
- self.data_loader.chunk_size * self.data_loader.next_chunk +
position
+ (self.data_loader.next_chunk - 1) * self.data_loader.chunk_size
+ + position
]
- return self.metadata[self.dataIndex][position]
+ return self.metadata[position]
def get_stats(self):
return self.stats
@@ -165,7 +166,7 @@ class UnimodalModality(Modality):
].data.extend(transformed_chunk.data)
transformed_modalities_per_representation[
representation.name
- ].metadata.update(transformed_chunk.metadata)
+ ].metadata.extend(transformed_chunk.metadata)
for d in transformed_chunk.data:
original_lengths_per_representation[representation.name].append(
d.shape[0]
diff --git a/src/main/python/systemds/scuro/utils/schema_helpers.py
b/src/main/python/systemds/scuro/representations/ orb_alignment.py
similarity index 53%
copy from src/main/python/systemds/scuro/utils/schema_helpers.py
copy to src/main/python/systemds/scuro/representations/ orb_alignment.py
index 3d1fbf4d71..5594f3bc54 100644
--- a/src/main/python/systemds/scuro/utils/schema_helpers.py
+++ b/src/main/python/systemds/scuro/representations/ orb_alignment.py
@@ -18,29 +18,34 @@
# under the License.
#
# -------------------------------------------------------------
+from systemds.scuro.representations.alignment import Alignment
+from dataclasses import dataclass
import numpy as np
+import cv2 as cv
-def create_timestamps(frequency, sample_length, start_datetime=None):
- start_time = (
- start_datetime
- if start_datetime is not None
- else np.datetime64("1970-01-01T00:00:00.000000")
- )
- time_increment = 1 / frequency
- time_increments_array = np.arange(sample_length) * np.timedelta64(
- int(time_increment * 1e6)
- )
- timestamps = start_time + time_increments_array
+@dataclass
+class OrbDescriptor:
+ kp: object
+ desc: object
- return timestamps.astype(np.int64)
+class OrbAlignment(Alignment):
+ def __init__(self):
+ self.orb = cv.ORB_create()
+ self.bfm = cv.BFMatcher(cv.NORM_HAMMING, crossCheck=True)
+ super().__init("OrbAlignment")
-def calculate_new_frequency(new_length, old_length, old_frequency):
- duration = old_length / old_frequency
- new_frequency = new_length / duration
- return new_frequency
+ def compute_descriptor(self, segment):
+ return [OrbDescriptor(self.orb.detectAndCompute(segment, None))]
+ def compare(self, a, b):
+ if a.desc is None or b.desc is None:
+ return float("inf")
+ matches = bfm.match(a.desc, b.desc)
+ good_matches = [m for m in matches if m.distance < 40]
-def get_shape(metadata):
- return len(list(metadata.values())[0]["data_layout"]["shape"])
+ if len(good_matches) == 0:
+ return float(inf)
+
+ return np.median([m.distance for m in good_matches])
diff --git a/src/main/python/systemds/scuro/representations/alignment.py
b/src/main/python/systemds/scuro/representations/alignment.py
new file mode 100644
index 0000000000..42e6df0b20
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/alignment.py
@@ -0,0 +1,136 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+from collections import defaultdict
+import copy
+
+
+@dataclass
+class Match:
+ primary: int
+ secondary: int
+ distance: float
+
+
+class Alignment(ABC):
+ def __init__(self, name):
+ self.name = name
+
+ def execute(self, primary_modality, secondary_modality):
+ primary_descriptor_collections = self._batch_compute_descriptors(
+ primary_modality
+ )
+ secondary_descriptor_collections = self._batch_compute_descriptors(
+ secondary_modality
+ )
+
+ matches = []
+
+ for p, p_collection in enumerate(primary_descriptor_collections):
+ stats = defaultdict(
+ lambda: {
+ "count": 0,
+ "total_distance": 0.0,
+ "best_distance": float("inf"),
+ }
+ )
+
+ for p_desc in p_collection:
+ best_secondary = None
+ best_dist = float("inf")
+
+ for s, s_collection in
enumerate(secondary_descriptor_collections):
+ for s_desc in s_collection:
+ dist = self.compare(p_desc, s_desc)
+
+ if dist < best_dist:
+ best_dist = dist
+ best_secondary = s
+
+ if best_secondary is not None:
+ stats[best_secondary]["count"] += 1
+ stats[best_secondary]["total_distance"] += best_dist
+ stats[best_secondary]["best_distance"] = min(
+ stats[best_secondary]["best_distance"], best_dist
+ )
+
+ if not stats:
+ matches.append(Match(p, None, float("inf")))
+ continue
+
+ best_match = min(
+ stats.items(),
+ key=lambda item: (
+ -item[1]["count"], # mehr Votes ist besser
+ item[1]["total_distance"], # kleinere Gesamtdistanz ist
besser
+ item[1]["best_distance"], # optional weiterer Tie-Breaker
+ ),
+ )[0]
+
+ result_distance = (
+ stats[best_match]["total_distance"] /
stats[best_match]["count"]
+ )
+
+ matches.append(Match(p, best_match, result_distance))
+
+ return matches
+
+ @staticmethod
+ def apply_matching(alignment, secondary_modality):
+ aligned_modality = copy.deepcopy(secondary_modality)
+ aligned_modality.data = [None] * len(alignment)
+ aligned_modality.metadata = [None] * len(alignment)
+
+ for match in alignment:
+ aligned_modality.data[match.primary] = secondary_modality.data[
+ match.secondary
+ ]
+ aligned_modality.metadata[match.primary] =
secondary_modality.metadata[
+ match.secondary
+ ]
+
+ return aligned_modality
+
+ def _batch_compute_descriptors(self, modality):
+ descriptors = []
+
+ if modality.data_loader.chunk_size:
+ modality.data_loader.reset()
+ while modality.data_loader.next_chunk <
modality.data_loader.num_chunks:
+ modality.extract_raw_data()
+ for d in modality.data:
+ descriptors.append(self.compute_descriptor(d))
+ else:
+ if not modality.has_data():
+ modality.extract_raw_data()
+ for d in modality.data:
+ descriptors.append(self.compute_descriptor(d))
+
+ return descriptors
+
+ @abstractmethod
+ def compute_descriptor(self, segment):
+ pass
+
+ @abstractmethod
+ def compare(self, a, b):
+ pass
diff --git a/src/main/python/systemds/scuro/representations/clip.py
b/src/main/python/systemds/scuro/representations/clip.py
index b13d2dfeb8..518cc1eb5d 100644
--- a/src/main/python/systemds/scuro/representations/clip.py
+++ b/src/main/python/systemds/scuro/representations/clip.py
@@ -243,6 +243,9 @@ class CLIPVisual(UnimodalRepresentation):
with torch.no_grad():
output = self.model.get_image_features(**inputs)
+ if hasattr(output, "pooler_output"):
+ output = output.pooler_output
+
if len(output.shape) > 2:
output = torch.nn.functional.adaptive_avg_pool2d(output,
(1, 1))
diff --git a/src/main/python/systemds/scuro/representations/concatenation.py
b/src/main/python/systemds/scuro/representations/concatenation.py
index a7ca905f47..5d53690317 100644
--- a/src/main/python/systemds/scuro/representations/concatenation.py
+++ b/src/main/python/systemds/scuro/representations/concatenation.py
@@ -44,9 +44,7 @@ class Concatenation(Fusion):
if len(modalities) == 1:
return np.asarray(
modalities[0].data,
-
dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][
- "data_layout"
- ]["type"],
+ dtype=modalities[0].metadata[0]["data_layout"]["type"],
)
max_emb_size = self.get_max_embedding_size(modalities)
@@ -64,9 +62,7 @@ class Concatenation(Fusion):
data,
np.asarray(
other_modality,
-
dtype=modality.metadata[list(modality.metadata.keys())[0]][
- "data_layout"
- ]["type"],
+ dtype=modality.metadata[0]["data_layout"]["type"],
),
],
axis=-1,
diff --git
a/src/main/python/systemds/scuro/representations/contrastive_learning.py
b/src/main/python/systemds/scuro/representations/contrastive_learning.py
new file mode 100644
index 0000000000..3599079e6f
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/contrastive_learning.py
@@ -0,0 +1,81 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import copy
+
+
+class ContrasitveLearning:
+ @staticmethod
+ def execute(
+ input_first_modality,
+ input_second_modality,
+ input_first_extensions,
+ input_second_extensions,
+ metadata_matching_function,
+ ):
+ # Add check for same dimensionality of input modlities and extensions
+
+ def empty_modality_copy(input_modality):
+ modality = copy.deepcopy(input_modality)
+ if isinstance(modality, list):
+ for m in modality:
+ m.data = []
+ m.metadata = []
+ else:
+ modality.data = []
+ modality.metadata = []
+
+ return modality
+
+ first_modality = empty_modality_copy(input_first_modality)
+ second_modality = empty_modality_copy(input_second_modality)
+ first_extensions = empty_modality_copy(input_first_extensions)
+ second_extensions = empty_modality_copy(input_second_extensions)
+
+ labels = []
+
+ for i in range(len(input_first_modality.data)):
+ for j in range(len(input_second_modality.data)):
+ first_modality.data.append(input_first_modality.data[i])
+
first_modality.metadata.append(input_first_modality.metadata[i])
+
+ for m, input_m in zip(first_extensions,
input_first_extensions):
+ m.data.append(input_m.data[i])
+ m.metadata.append(input_m.metadata[i])
+
+ second_modality.data.append(input_second_modality.data[j])
+
second_modality.metadata.append(input_second_modality.metadata[j])
+
+ for m, input_m in zip(second_extensions,
input_second_extensions):
+ m.data.append(input_m.data[j])
+ m.metadata.append(input_m.metadata[j])
+
+ if metadata_matching_function(
+ input_first_modality.metadata[i],
input_second_modality.metadata[j]
+ ):
+ labels.append(True)
+ else:
+ labels.append(False)
+
+ return (
+ [first_modality] + first_extensions,
+ [second_modality] + second_extensions,
+ labels,
+ )
diff --git
a/src/main/python/systemds/scuro/representations/covarep_audio_features.py
b/src/main/python/systemds/scuro/representations/covarep_audio_features.py
index 28dd6fc297..01098ef4ee 100644
--- a/src/main/python/systemds/scuro/representations/covarep_audio_features.py
+++ b/src/main/python/systemds/scuro/representations/covarep_audio_features.py
@@ -47,7 +47,7 @@ class Spectral(UnimodalRepresentation):
)
result = []
for i, y in enumerate(modality.data):
- sr = list(modality.metadata.values())[i]["frequency"]
+ sr = modality.metadata[i]["frequency"]
spectral_centroid = librosa.feature.spectral_centroid(
y=y, sr=sr, hop_length=self.hop_length
@@ -222,7 +222,7 @@ class Pitch(UnimodalRepresentation):
)
result = []
for i, y in enumerate(modality.data):
- sr = list(modality.metadata.values())[i]["frequency"]
+ sr = modality.metadata[i]["frequency"]
pitches, magnitudes = librosa.piptrack(
y=y, sr=sr, hop_length=self.hop_length
diff --git a/src/main/python/systemds/scuro/representations/lstm.py
b/src/main/python/systemds/scuro/representations/lstm.py
index e067e2ccf9..104f2727e6 100644
--- a/src/main/python/systemds/scuro/representations/lstm.py
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -92,7 +92,7 @@ class LSTM(Fusion):
data = np.array(modality.data)
except:
max_len = -1
- for md in modality.metadata.values():
+ for md in modality.metadata:
if max_len < md["data_layout"]["shape"][0]:
max_len = md["data_layout"]["shape"][0]
data = np.zeros((len(modality.data), max_len))
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
index 67ff955a98..46e5045b2e 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
@@ -57,10 +57,9 @@ class MelSpectrogram(UnimodalRepresentation):
modality, self, self.output_modality_type
)
result = []
- metadata_values = list(modality.metadata.values())
for i, sample in enumerate(modality.data):
- sr = metadata_values[i]["frequency"]
+ sr = modality.metadata[i]["frequency"]
computed_feature = self.compute_feature(sample, sr)
result.append(computed_feature)
diff --git a/src/main/python/systemds/scuro/representations/mfcc.py
b/src/main/python/systemds/scuro/representations/mfcc.py
index 5e5b54b079..737a3dffe9 100644
--- a/src/main/python/systemds/scuro/representations/mfcc.py
+++ b/src/main/python/systemds/scuro/representations/mfcc.py
@@ -63,7 +63,7 @@ class MFCC(UnimodalRepresentation):
result = []
for i, sample in enumerate(modality.data):
- sr = list(modality.metadata.values())[i]["frequency"]
+ sr = modality.metadata[i]["frequency"]
computed_feature = self.compute_feature(sample, sr)
result.append(computed_feature)
diff --git a/src/main/python/systemds/scuro/utils/schema_helpers.py
b/src/main/python/systemds/scuro/representations/pHash_alignment.py
similarity index 55%
copy from src/main/python/systemds/scuro/utils/schema_helpers.py
copy to src/main/python/systemds/scuro/representations/pHash_alignment.py
index 3d1fbf4d71..f70a12a770 100644
--- a/src/main/python/systemds/scuro/utils/schema_helpers.py
+++ b/src/main/python/systemds/scuro/representations/pHash_alignment.py
@@ -18,29 +18,26 @@
# under the License.
#
# -------------------------------------------------------------
+from systemds.scuro.representations.alignment import Alignment
+import cv2 as cv
import numpy as np
-def create_timestamps(frequency, sample_length, start_datetime=None):
- start_time = (
- start_datetime
- if start_datetime is not None
- else np.datetime64("1970-01-01T00:00:00.000000")
- )
- time_increment = 1 / frequency
- time_increments_array = np.arange(sample_length) * np.timedelta64(
- int(time_increment * 1e6)
- )
- timestamps = start_time + time_increments_array
+class PHashAlignment(Alignment):
+ def __init__(self):
+ super().__init__("pHashAlignment")
+ self.hasher = cv.img_hash.PHash_create()
- return timestamps.astype(np.int64)
+ def compute_descriptor(self, segment):
+ if segment.ndim == 3:
+ return [self.hasher.compute(segment)]
+ if segment.ndim == 4: # For videos
+ descriptors = []
+ for s in segment:
+ frame = (s * 255).astype(np.uint8, copy=True)
+ descriptors.append(self.hasher.compute(frame))
+ return descriptors
+ raise ("PHashAlignment is only implemented for ndim=3 or ndim=4")
-
-def calculate_new_frequency(new_length, old_length, old_frequency):
- duration = old_length / old_frequency
- new_frequency = new_length / duration
- return new_frequency
-
-
-def get_shape(metadata):
- return len(list(metadata.values())[0]["data_layout"]["shape"])
+ def compare(self, a, b):
+ return self.hasher.compare(a, b)
diff --git a/src/main/python/systemds/scuro/representations/sum.py
b/src/main/python/systemds/scuro/representations/sum.py
index e6187b4050..d6c4fe659b 100644
--- a/src/main/python/systemds/scuro/representations/sum.py
+++ b/src/main/python/systemds/scuro/representations/sum.py
@@ -43,17 +43,13 @@ class Sum(Fusion):
def execute(self, modalities: List[Modality]):
data = np.asarray(
modalities[0].data,
-
dtype=modalities[0].metadata[list(modalities[0].metadata.keys())[0]][
- "data_layout"
- ]["type"],
+ dtype=modalities[0].metadata[0]["data_layout"]["type"],
)
for m in range(1, len(modalities)):
data += np.asarray(
modalities[m].data,
-
dtype=modalities[m].metadata[list(modalities[m].metadata.keys())[0]][
- "data_layout"
- ]["type"],
+ dtype=modalities[m].metadata[0]["data_layout"]["type"],
)
return data
diff --git
a/src/main/python/systemds/scuro/representations/text_context_with_indices.py
b/src/main/python/systemds/scuro/representations/text_context_with_indices.py
index 095e9f8106..1a341af1e3 100644
---
a/src/main/python/systemds/scuro/representations/text_context_with_indices.py
+++
b/src/main/python/systemds/scuro/representations/text_context_with_indices.py
@@ -208,7 +208,7 @@ class SentenceBoundarySplitIndices(Context):
List of lists, where each inner list contains text chunks (strings)
"""
- for instance, metadata in zip(modality.data,
modality.metadata.values()):
+ for instance, metadata in zip(modality.data, modality.metadata):
text = _extract_text(instance)
if not text:
ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)])
@@ -344,7 +344,7 @@ class OverlappingSplitIndices(Context):
List of tuples, where each tuple contains start and end index to
the text chunks
"""
- for instance, metadata in zip(modality.data,
modality.metadata.values()):
+ for instance, metadata in zip(modality.data, modality.metadata):
text = _extract_text(instance)
if not text:
ModalityType.TEXT.add_field(metadata, "text_spans", [(0, 0)])
diff --git
a/src/main/python/systemds/scuro/representations/timeseries_representations.py
b/src/main/python/systemds/scuro/representations/timeseries_representations.py
index bb4ea4f49c..80f6880a0b 100644
---
a/src/main/python/systemds/scuro/representations/timeseries_representations.py
+++
b/src/main/python/systemds/scuro/representations/timeseries_representations.py
@@ -52,7 +52,7 @@ class TimeSeriesRepresentation(UnimodalRepresentation):
result.append(feature)
transformed_modality.data = np.vstack(np.array(result)).astype(
-
modality.metadata[list(modality.metadata.keys())[0]]["data_layout"]["type"]
+ modality.metadata[0]["data_layout"]["type"]
)
return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/wav2vec.py
b/src/main/python/systemds/scuro/representations/wav2vec.py
index c599550be1..38dcb84843 100644
--- a/src/main/python/systemds/scuro/representations/wav2vec.py
+++ b/src/main/python/systemds/scuro/representations/wav2vec.py
@@ -52,7 +52,7 @@ class Wav2Vec(UnimodalRepresentation):
result = []
for i, sample in enumerate(modality.data):
- sr = list(modality.metadata.values())[i]["frequency"]
+ sr = modality.metadata[i]["frequency"]
audio_resampled = librosa.resample(
np.array(sample), orig_sr=sr, target_sr=16000
)
diff --git
a/src/main/python/systemds/scuro/representations/window_aggregation.py
b/src/main/python/systemds/scuro/representations/window_aggregation.py
index c9622f1a2a..a34b6ebe4c 100644
--- a/src/main/python/systemds/scuro/representations/window_aggregation.py
+++ b/src/main/python/systemds/scuro/representations/window_aggregation.py
@@ -224,7 +224,7 @@ class WindowAggregation(Window):
)
windowed_data = np.array(padded_features)
- data_type =
list(modality.metadata.values())[0]["data_layout"]["type"]
+ data_type = modality.metadata[0]["data_layout"]["type"]
if data_type != "str":
windowed_data = windowed_data.astype(data_type)
diff --git a/src/main/python/systemds/scuro/utils/schema_helpers.py
b/src/main/python/systemds/scuro/utils/schema_helpers.py
index 3d1fbf4d71..929e9c7f4f 100644
--- a/src/main/python/systemds/scuro/utils/schema_helpers.py
+++ b/src/main/python/systemds/scuro/utils/schema_helpers.py
@@ -43,4 +43,4 @@ def calculate_new_frequency(new_length, old_length,
old_frequency):
def get_shape(metadata):
- return len(list(metadata.values())[0]["data_layout"]["shape"])
+ return len(metadata[0]["data_layout"]["shape"])
diff --git a/src/main/python/tests/scuro/data_generator.py
b/src/main/python/tests/scuro/data_generator.py
index b78ea31483..a0b43fc859 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -51,6 +51,7 @@ class TestDataLoader(BaseLoader):
def __init__(self, indices, chunk_size, modality_type, data, data_type,
metadata):
super().__init__("", indices, data_type, chunk_size, modality_type)
+ self._full_metadata = metadata
self.metadata = metadata
self.test_data = data
if modality_type == ModalityType.TEXT:
@@ -110,8 +111,10 @@ class TestDataLoader(BaseLoader):
def extract(self, file, indices):
if isinstance(self.test_data, list):
self.data = [self.test_data[i] for i in indices]
+ self.metadata = [self._full_metadata[i] for i in indices]
else:
self.data = self.test_data[indices]
+ self.metadata = [self._full_metadata[i] for i in indices]
class ModalityRandomDataGenerator:
@@ -120,7 +123,7 @@ class ModalityRandomDataGenerator:
np.random.seed(4)
self.modality_id = 0
self.modality_type = None
- self.metadata = {}
+ self.metadata = []
self.data_type = np.float32
self.transform_time = 0
self.stats = None
@@ -186,21 +189,22 @@ class ModalityRandomDataGenerator:
# TODO: write a dummy method to create the same metadata for all
instances to avoid the for loop
self.modality_type = modality_type
+ self.metadata = []
for i in range(num_instances):
if modality_type == ModalityType.AUDIO:
- self.metadata[i] = modality_type.create_metadata(
+ self.metadata.append(modality_type.create_metadata(
num_features / 10, data[i]
- )
+ ))
elif modality_type == ModalityType.TEXT:
- self.metadata[i] = modality_type.create_metadata(
+ self.metadata.append(modality_type.create_metadata(
num_features / 10, data[i]
- )
+ ))
elif modality_type == ModalityType.VIDEO:
- self.metadata[i] = modality_type.create_metadata(
+ self.metadata.append(modality_type.create_metadata(
num_features / 30, 10, 0, 0, 1
- )
+ ))
elif modality_type == ModalityType.TIMESERIES:
- self.metadata[i] = modality_type.create_metadata(["test"],
data[i])
+ self.metadata.append(modality_type.create_metadata(["test"],
data[i]))
else:
raise NotImplementedError
@@ -222,10 +226,10 @@ class ModalityRandomDataGenerator:
for i in range(num_instances):
data[i] = np.array(data[i]).astype(self.data_type)
- self.metadata = {
- i: self.modality_type.create_metadata(16000, np.array(data[i]))
+ self.metadata = [
+ self.modality_type.create_metadata(16000, np.array(data[i]))
for i in range(num_instances)
- }
+ ]
return data, self.metadata
@@ -237,12 +241,12 @@ class ModalityRandomDataGenerator:
]
if num_features == 1:
data = [d.squeeze(-1) for d in data]
- self.metadata = {
- i: self.modality_type.create_metadata(
+ self.metadata = [
+ self.modality_type.create_metadata(
[f"feature_{j}" for j in range(num_features)], data[i]
)
for i in range(num_instances)
- }
+ ]
return data, self.metadata
def create_text_data(self, num_instances, num_sentences_per_instance=1):
@@ -308,10 +312,10 @@ class ModalityRandomDataGenerator:
sentence += f" {verb} {obj}{punct}"
sentences.append(sentence)
- self.metadata = {
- i: self.modality_type.create_metadata(len(sentences[i]),
sentences[i])
+ self.metadata = [
+ self.modality_type.create_metadata(len(sentences[i]), sentences[i])
for i in range(num_instances)
- }
+ ]
return sentences, self.metadata
@@ -321,9 +325,9 @@ class ModalityRandomDataGenerator:
np.random.rand(dims[0], dims[1], dims[2]).astype(self.data_type)
for _ in range(num_instances)
]
- self.metadata = {
- i: self.modality_type.create_metadata(data[i]) for i in
range(num_instances)
- }
+ self.metadata = [
+ self.modality_type.create_metadata(data[i]) for i in
range(num_instances)
+ ]
return data, self.metadata
def create_2d_modality(self, num_instances, dims=(100, 28)):
@@ -332,9 +336,9 @@ class ModalityRandomDataGenerator:
np.random.rand(dims[0], dims[1]).astype(self.data_type)
for _ in range(num_instances)
]
- self.metadata = {
- i: self.modality_type.create_metadata(data[i]) for i in
range(num_instances)
- }
+ self.metadata = [
+ self.modality_type.create_metadata(data[i]) for i in
range(num_instances)
+ ]
return data, self.metadata
def create_visual_modality(
@@ -356,12 +360,12 @@ class ModalityRandomDataGenerator:
for _ in range(num_instances)
]
- self.metadata = {
- i: self.modality_type.create_metadata(
+ self.metadata = [
+ self.modality_type.create_metadata(
30, data[i].shape[0], width, height, color_channels
)
for i in range(num_instances)
- }
+ ]
else:
self.modality_type = ModalityType.IMAGE
data = [
@@ -373,10 +377,10 @@ class ModalityRandomDataGenerator:
)
for _ in range(num_instances)
]
- self.metadata = {
- i: self.modality_type.create_metadata(width, height,
color_channels)
+ self.metadata = [
+ self.modality_type.create_metadata(width, height,
color_channels)
for i in range(num_instances)
- }
+ ]
return data, self.metadata
diff --git a/src/main/python/tests/scuro/test_text_context_operators.py
b/src/main/python/tests/scuro/test_text_context_operators.py
index ffa702b7c8..0b8ad36c16 100644
--- a/src/main/python/tests/scuro/test_text_context_operators.py
+++ b/src/main/python/tests/scuro/test_text_context_operators.py
@@ -83,7 +83,7 @@ class TestTextContextOperator(unittest.TestCase):
sentence_boundary_split = SentenceBoundarySplitIndices(10, min_words=4)
sentence_boundary_split.execute(self.text_modality)
for instance, md in zip(
- self.text_modality.data, self.text_modality.metadata.values()
+ self.text_modality.data, self.text_modality.metadata
):
for chunk in md["text_spans"]:
text = instance[chunk[0] : chunk[1]].split(" ")
@@ -95,7 +95,7 @@ class TestTextContextOperator(unittest.TestCase):
overlapping_split = OverlappingSplitIndices(40, 0.1)
overlapping_split.execute(self.text_modality)
for instance, md in zip(
- self.text_modality.data, self.text_modality.metadata.values()
+ self.text_modality.data, self.text_modality.metadata
):
prev_chunk = (0, 0)
for j, chunk in enumerate(md["text_spans"]):