This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4dd5ab3e25 [SYSTEMDS-3830] Add join operator to Scuro
4dd5ab3e25 is described below
commit 4dd5ab3e2585baef339a870f181a4bba1ff8e9dd
Author: Christina Dionysio <[email protected]>
AuthorDate: Fri Feb 7 18:03:03 2025 +0100
[SYSTEMDS-3830] Add join operator to Scuro
This patch adds a new join operator to Scuro. The join operation takes
two modalities as well as a join condition as input and joins the two
modalities on their common dimension (temporal for now). This includes
two new modalities and the ability to apply new representations on top
of a joined modality. In the future the join operator will also serve
as a simple alignment operator by joining two modalities by a given
offset.
Closes #2220
---
.github/workflows/python.yml | 6 +-
src/main/python/systemds/__init__.py | 5 +-
.../systemds/scuro/dataloader/audio_loader.py | 9 +-
.../systemds/scuro/dataloader/base_loader.py | 55 ++++-
.../systemds/scuro/dataloader/json_loader.py | 6 +-
.../systemds/scuro/dataloader/text_loader.py | 5 +-
.../systemds/scuro/dataloader/video_loader.py | 21 +-
src/main/python/systemds/scuro/modality/joined.py | 274 +++++++++++++++++++++
.../{transformed.py => joined_transformed.py} | 44 +++-
.../python/systemds/scuro/modality/modality.py | 79 +++++-
.../python/systemds/scuro/modality/transformed.py | 54 +++-
src/main/python/systemds/scuro/modality/type.py | 93 ++++++-
.../systemds/scuro/modality/unimodal_modality.py | 56 ++++-
.../systemds/scuro/representations/aggregate.py | 51 ++++
.../python/systemds/scuro/representations/bert.py | 29 +--
.../python/systemds/scuro/representations/bow.py | 11 +-
.../python/systemds/scuro/representations/glove.py | 5 +-
.../python/systemds/scuro/representations/lstm.py | 6 +-
.../scuro/representations/mel_spectrogram.py | 50 ++--
.../systemds/scuro/representations/resnet.py | 126 ++++++----
.../python/systemds/scuro/representations/tfidf.py | 12 +-
.../systemds/scuro/representations/window.py | 49 ++++
.../systemds/scuro/representations/word2vec.py | 19 +-
.../scuro/{modality/type.py => utils/__init__.py} | 11 -
.../json_loader.py => utils/schema_helpers.py} | 38 +--
src/main/python/tests/scuro/data_generator.py | 56 ++++-
src/main/python/tests/scuro/test_data_loaders.py | 63 +++--
src/main/python/tests/scuro/test_dr_search.py | 58 ++---
.../python/tests/scuro/test_multimodal_join.py | 135 ++++++++++
.../tests/scuro/test_unimodal_representations.py | 120 +++++++++
30 files changed, 1279 insertions(+), 267 deletions(-)
diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index 9f39f07ecb..54da49f8fb 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -114,10 +114,10 @@ jobs:
torch \
librosa \
h5py \
- nltk \
gensim \
- black
-
+ black \
+ opt-einsum
+
- name: Build Python Package
run: |
cd src/main/python
diff --git a/src/main/python/systemds/__init__.py
b/src/main/python/systemds/__init__.py
index 443b5d23d9..a618ff6e9d 100644
--- a/src/main/python/systemds/__init__.py
+++ b/src/main/python/systemds/__init__.py
@@ -26,13 +26,14 @@ from systemds import examples
__all__ = ["context", "operator", "examples"]
required_packages = [
- ("torch", "2.4.1"),
- ("torchvision", "0.19.1"),
+ ("torch", "2.5.1"),
+ ("torchvision", "0.20.1"),
("librosa", "0.10.2"),
("opencv-python", "4.10.0.84"),
("opt-einsum", "3.3.0"),
("h5py", "3.11.0"),
("transformers", "4.46.3"),
+ ("nltk", "3.9.1"),
("gensim", "4.3.3"),
]
diff --git a/src/main/python/systemds/scuro/dataloader/audio_loader.py
b/src/main/python/systemds/scuro/dataloader/audio_loader.py
index f85b1b80fa..f7319fe191 100644
--- a/src/main/python/systemds/scuro/dataloader/audio_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/audio_loader.py
@@ -18,10 +18,11 @@
# under the License.
#
# -------------------------------------------------------------
-from typing import List, Optional
+from typing import List, Optional, Union
import librosa
from systemds.scuro.dataloader.base_loader import BaseLoader
+from systemds.scuro.utils.schema_helpers import create_timestamps
class AudioLoader(BaseLoader):
@@ -33,7 +34,11 @@ class AudioLoader(BaseLoader):
):
super().__init__(source_path, indices, chunk_size)
- def extract(self, file: str):
+ def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
self.file_sanity_check(file)
audio, sr = librosa.load(file)
+ self.metadata[file] = {"sample_rate": sr, "length": audio.shape[0]}
+ self.metadata[file]["timestamp"] = create_timestamps(
+ self.metadata[file]["sample_rate"], self.metadata[file]["length"]
+ )
self.data.append(audio)
diff --git a/src/main/python/systemds/scuro/dataloader/base_loader.py
b/src/main/python/systemds/scuro/dataloader/base_loader.py
index 2ef60677c6..5cdf63f584 100644
--- a/src/main/python/systemds/scuro/dataloader/base_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/base_loader.py
@@ -35,32 +35,68 @@ class BaseLoader(ABC):
(otherwise please provide your own Dataloader that knows about the
file name convention)
"""
self.data = []
+ self.metadata = (
+ {}
+ ) # TODO: check what the index should be for storing the metadata
(file_name, counter, ...)
self.source_path = source_path
self.indices = indices
- self.chunk_size = chunk_size
- self.next_chunk = 0
+ self._next_chunk = 0
+ self._num_chunks = 1
+ self._chunk_size = None
- if self.chunk_size:
- self.num_chunks = int(len(self.indices) / self.chunk_size)
+ if chunk_size:
+ self.chunk_size = chunk_size
+
+ @property
+ def chunk_size(self):
+ return self._chunk_size
+
+ @chunk_size.setter
+ def chunk_size(self, value):
+ self._chunk_size = value
+ self._num_chunks = int(len(self.indices) / self._chunk_size)
+
+ @property
+ def num_chunks(self):
+ return self._num_chunks
+
+ @property
+ def next_chunk(self):
+ return self._next_chunk
def load(self):
"""
Takes care of loading the raw data either chunk wise (if chunk size is
defined) or all at once
"""
- if self.chunk_size:
+ if self._chunk_size:
return self._load_next_chunk()
return self._load(self.indices)
+ def update_chunk_sizes(self, other):
+ if not self._chunk_size and not other.chunk_size:
+ return
+
+ if (
+ self._chunk_size
+ and not other.chunk_size
+ or self._chunk_size < other.chunk_size
+ ):
+ other.chunk_size = self.chunk_size
+ else:
+ self.chunk_size = other.chunk_size
+
def _load_next_chunk(self):
"""
Loads the next chunk of data
"""
self.data = []
next_chunk_indices = self.indices[
- self.next_chunk * self.chunk_size : (self.next_chunk + 1) *
self.chunk_size
+ self._next_chunk
+ * self._chunk_size : (self._next_chunk + 1)
+ * self._chunk_size
]
- self.next_chunk += 1
+ self._next_chunk += 1
return self._load(next_chunk_indices)
def _load(self, indices: List[str]):
@@ -73,13 +109,14 @@ class BaseLoader(ABC):
else:
self.extract(self.source_path, indices)
- return self.data
+ return self.data, self.metadata
@abstractmethod
def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
pass
- def file_sanity_check(self, file):
+ @staticmethod
+ def file_sanity_check(file):
"""
Checks if the file can be found is not empty
"""
diff --git a/src/main/python/systemds/scuro/dataloader/json_loader.py
b/src/main/python/systemds/scuro/dataloader/json_loader.py
index c4e3b95611..ac37545188 100644
--- a/src/main/python/systemds/scuro/dataloader/json_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/json_loader.py
@@ -21,7 +21,7 @@
import json
from systemds.scuro.dataloader.base_loader import BaseLoader
-from typing import Optional, List
+from typing import Optional, List, Union
class JSONLoader(BaseLoader):
@@ -35,9 +35,9 @@ class JSONLoader(BaseLoader):
super().__init__(source_path, indices, chunk_size)
self.field = field
- def extract(self, file: str, indices: List[str]):
+ def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
self.file_sanity_check(file)
with open(file) as f:
json_file = json.load(f)
- for idx in indices:
+ for idx in index:
self.data.append(json_file[idx][self.field])
diff --git a/src/main/python/systemds/scuro/dataloader/text_loader.py
b/src/main/python/systemds/scuro/dataloader/text_loader.py
index f614472bce..bf34cf85c7 100644
--- a/src/main/python/systemds/scuro/dataloader/text_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/text_loader.py
@@ -19,7 +19,7 @@
#
# -------------------------------------------------------------
from systemds.scuro.dataloader.base_loader import BaseLoader
-from typing import Optional, Pattern, List
+from typing import Optional, Pattern, List, Union
import re
@@ -34,11 +34,12 @@ class TextLoader(BaseLoader):
super().__init__(source_path, indices, chunk_size)
self.prefix = prefix
- def extract(self, file: str):
+ def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
self.file_sanity_check(file)
with open(file) as text_file:
for i, line in enumerate(text_file):
if self.prefix:
line = re.sub(self.prefix, "", line)
line = line.replace("\n", "")
+ self.metadata[file] = {"length": len(line.split())}
self.data.append(line)
diff --git a/src/main/python/systemds/scuro/dataloader/video_loader.py
b/src/main/python/systemds/scuro/dataloader/video_loader.py
index 6da20b3475..807a43b21c 100644
--- a/src/main/python/systemds/scuro/dataloader/video_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/video_loader.py
@@ -18,11 +18,12 @@
# under the License.
#
# -------------------------------------------------------------
-from typing import List, Optional
+from typing import List, Optional, Union
import numpy as np
from systemds.scuro.dataloader.base_loader import BaseLoader
+from systemds.scuro.utils.schema_helpers import create_timestamps
import cv2
@@ -35,9 +36,25 @@ class VideoLoader(BaseLoader):
):
super().__init__(source_path, indices, chunk_size)
- def extract(self, file: str):
+ def extract(self, file: str, index: Optional[Union[str, List[str]]] =
None):
self.file_sanity_check(file)
cap = cv2.VideoCapture(file)
+
+ if not cap.isOpened():
+ raise f"Could not read video at path: {file}"
+
+ self.metadata[file] = {
+ "fps": cap.get(cv2.CAP_PROP_FPS),
+ "length": int(cap.get(cv2.CAP_PROP_FRAME_COUNT)),
+ "width": int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
+ "height": int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
+ "num_channels": 3,
+ }
+
+ self.metadata[file]["timestamp"] = create_timestamps(
+ self.metadata[file]["fps"], self.metadata[file]["length"]
+ )
+
frames = []
while cap.isOpened():
ret, frame = cap.read()
diff --git a/src/main/python/systemds/scuro/modality/joined.py
b/src/main/python/systemds/scuro/modality/joined.py
new file mode 100644
index 0000000000..acdf4fb94f
--- /dev/null
+++ b/src/main/python/systemds/scuro/modality/joined.py
@@ -0,0 +1,274 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import sys
+
+import numpy as np
+
+from systemds.scuro.modality.joined_transformed import
JoinedTransformedModality
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.aggregate import Aggregation
+from systemds.scuro.representations.utils import pad_sequences
+
+
+class JoinCondition:
+ def __init__(self, leftField, rightField, joinType, alignment=None):
+ self.leftField = leftField
+ self.rightField = rightField
+ self.join_type = joinType
+ self.alignment = alignment
+
+
+class JoinedModality(Modality):
+
+ def __init__(
+ self,
+ modality_type,
+ left_modality,
+ right_modality,
+ join_condition: JoinCondition,
+ chunked_execution=False,
+ ):
+ """
+ TODO
+ :param modality_type: Type of the original modality(ies)
+ """
+ super().__init__(modality_type)
+ self.aggregation = None
+ self.joined_right = None
+ self.left_modality = left_modality
+ self.right_modality = right_modality
+ self.condition = join_condition
+ self.chunked_execution = (
+ chunked_execution # TODO: maybe move this into parent class
+ )
+ self.left_type = type(left_modality)
+ self.right_type = type(right_modality)
+ self.chunk_left = False
+ if self.chunked_execution and
self.left_type.__name__.__contains__("Unimodal"):
+ self.chunk_left = left_modality.data_loader.chunk_size is not None
+
+ def execute(self, starting_idx=0):
+ self.joined_right = self.right_modality.copy_from_instance()
+
+ start, end = 0, len(self.left_modality.data)
+ if self.chunked_execution and not self.chunk_left:
+ start = starting_idx
+ end = (
+ self.right_modality.data_loader.chunk_size
+ * self.right_modality.data_loader.next_chunk
+ )
+
+ for i in range(start, end):
+ idx_1 = list(self.left_modality.metadata.values())[i +
starting_idx][
+ self.condition.leftField
+ ]
+ if (
+ self.condition.alignment is None and self.condition.join_type
== "<"
+ ): # TODO compute correct alignment timestamps/spatial params
+ nextIdx = np.zeros(len(idx_1), dtype=int)
+ nextIdx[:-1] = idx_1[1:]
+ nextIdx[-1] = sys.maxsize
+
+ if self.chunk_left:
+ i = i + starting_idx
+
+ idx_2 = list(self.right_modality.metadata.values())[i][
+ self.condition.rightField
+ ]
+ self.joined_right.data.append([])
+
+ c = 0
+ # Assumes ordered lists (temporal)
+ # TODO: need to extract the shape of the data from the metadata
+ # video: list of lists of numpy array
+ # audio: list of numpy array
+ for j in range(0, len(idx_1)):
+ self.joined_right.data[i - starting_idx].append([])
+ right = np.array([])
+ if self.condition.join_type == "<":
+ while c < len(idx_2) and idx_2[c] < nextIdx[j]:
+ if right.size == 0:
+ right = self.right_modality.data[i][c]
+ if right.ndim == 1:
+ right = right[np.newaxis, :]
+ else:
+ if self.right_modality.data[i][c].ndim == 1:
+ right = np.concatenate(
+ [
+ right,
+
self.right_modality.data[i][c][np.newaxis, :],
+ ],
+ axis=0,
+ )
+ else:
+ right = np.concatenate(
+ [right, self.right_modality.data[i][c]],
+ axis=0,
+ )
+ c = c + 1
+ else:
+ while c < len(idx_2) and idx_2[c] <= idx_1[j]:
+ if idx_2[c] == idx_1[j]:
+ right.append(self.right_modality.data[i][c])
+ c = c + 1
+
+ if (
+ len(right) == 0
+ ): # Audio and video length sometimes do not match so we add
the average all audio samples for this specific frame
+ right = np.mean(self.right_modality.data[i][c - 1 : c],
axis=0)
+ if right.ndim == 1:
+ right = right[
+ np.newaxis, :
+ ] # TODO: check correct loading for all data layouts,
this is similar to missing data, add a different operation for this
+
+ self.joined_right.data[i - starting_idx][j] = right
+
+ def apply_representation(self, representation, aggregation):
+ self.aggregation = aggregation
+ if self.chunked_execution:
+ return self._handle_chunked_execution(representation)
+ elif self.left_type.__name__.__contains__("Unimodal"):
+ self.left_modality.extract_raw_data()
+ if self.left_type == self.right_type:
+ self.right_modality.extract_raw_data()
+ elif self.right_type.__name__.__contains__("Unimodal"):
+ self.right_modality.extract_raw_data()
+
+ self.execute()
+ left_transformed = self._apply_representation(
+ self.left_modality, representation
+ )
+ right_transformed = self._apply_representation(
+ self.joined_right, representation
+ )
+ left_transformed.update_metadata()
+ right_transformed.update_metadata()
+ return JoinedTransformedModality(
+ left_transformed, right_transformed,
f"joined_{representation.name}"
+ )
+
+ def aggregate(
+ self, aggregation_function, field_name
+ ): # TODO: use the filed name to extract data entries from modalities
+ self.aggregation = Aggregation(aggregation_function, field_name)
+
+ if not self.chunked_execution and self.joined_right:
+ return self.aggregation.aggregate(self.joined_right)
+
+ return self
+
+ def combine(self, fusion_method):
+ """
+ Combines two or more modalities with each other using a dedicated
fusion method
+ :param other: The modality to be combined
+ :param fusion_method: The fusion method to be used to combine
modalities
+ """
+ modalities = [self.left_modality, self.right_modality]
+ self.data = []
+ reshape = False
+ if self.left_modality.get_data_shape() !=
self.joined_right.get_data_shape():
+ reshape = True
+ for i in range(0, len(self.left_modality.data)):
+ self.data.append([])
+ for j in range(0, len(self.left_modality.data[i])):
+ self.data[i].append([])
+ if reshape:
+ self.joined_right.data[i][j] =
self.joined_right.data[i][j].reshape(
+ self.left_modality.get_data_shape()
+ )
+ fused = np.concatenate(
+ [self.left_modality.data[i][j],
self.joined_right.data[i][j]],
+ axis=0,
+ )
+ self.data[i][j] = fused
+ # self.data = fusion_method.transform(modalities)
+
+ for i, instance in enumerate(
+ self.data
+ ): # TODO: only if the layout is list_of_lists_of_numpy_array
+ r = []
+ [r.extend(l) for l in instance]
+ self.data[i] = np.array(r)
+ self.data = pad_sequences(self.data)
+ return self
+
+ def _handle_chunked_execution(self, representation):
+ if self.left_type == self.right_type:
+ return self._apply_representation_chunked(
+ self.left_modality, self.right_modality, True, representation
+ )
+ elif self.chunk_left:
+ return self._apply_representation_chunked(
+ self.left_modality, self.right_modality, False, representation
+ )
+ else: # TODO: refactor this approach (it is changing the way the
modalities are joined)
+ return self._apply_representation_chunked(
+ self.right_modality, self.left_modality, False, representation
+ )
+
+ def _apply_representation_chunked(
+ self, left_modality, right_modality, chunk_right, representation
+ ):
+ new_left = Modality(left_modality.modality_type, {})
+ new_right = Modality(right_modality.modality_type, {})
+
+ while (
+ left_modality.data_loader.next_chunk <
left_modality.data_loader.num_chunks
+ ):
+ if chunk_right:
+ right_modality.extract_raw_data()
+ starting_idx = 0
+ else:
+ starting_idx = (
+ left_modality.data_loader.next_chunk
+ * left_modality.data_loader.chunk_size
+ )
+ left_modality.extract_raw_data()
+
+ self.execute(starting_idx)
+
+ right_transformed = self._apply_representation(
+ self.joined_right, representation
+ )
+ new_right.data.extend(right_transformed.data)
+ new_right.metadata.update(right_transformed.metadata)
+
+ left_transformed = self._apply_representation(left_modality,
representation)
+ new_left.data.extend(left_transformed.data)
+ new_left.metadata.update(left_transformed.metadata)
+
+ new_left.update_metadata()
+ new_right.update_metadata()
+ return JoinedTransformedModality(
+ new_left, new_right, f"joined_{representation.name}"
+ )
+
+ def _apply_representation(self, modality, representation):
+ transformed = representation.transform(modality)
+ if self.aggregation:
+ aggregated_data_left = self.aggregation.window(transformed)
+ transformed = Modality(
+ transformed.modality_type,
+ transformed.metadata,
+ )
+ transformed.data = aggregated_data_left
+
+ return transformed
diff --git a/src/main/python/systemds/scuro/modality/transformed.py
b/src/main/python/systemds/scuro/modality/joined_transformed.py
similarity index 54%
copy from src/main/python/systemds/scuro/modality/transformed.py
copy to src/main/python/systemds/scuro/modality/joined_transformed.py
index 61c327e469..e2b53671aa 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/joined_transformed.py
@@ -21,32 +21,50 @@
from functools import reduce
from operator import or_
+import numpy as np
+
from systemds.scuro.modality.modality import Modality
-from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.utils import pad_sequences
-class TransformedModality(Modality):
+class JoinedTransformedModality(Modality):
- def __init__(self, modality_type: ModalityType, transformation):
+ def __init__(self, left_modality, right_modality, transformation):
"""
Parent class of the different Modalities (unimodal & multimodal)
- :param modality_type: Type of the original modality(ies)
:param transformation: Representation to be applied on the modality
"""
- super().__init__(modality_type)
+ super().__init__(
+ reduce(or_, [left_modality.modality_type],
right_modality.modality_type)
+ )
self.transformation = transformation
+ self.left_modality = left_modality
+ self.right_modality = right_modality
- def combine(self, other, fusion_method):
+ def combine(self, fusion_method):
"""
Combines two or more modalities with each other using a dedicated
fusion method
:param other: The modality to be combined
:param fusion_method: The fusion method to be used to combine
modalities
"""
- fused_modality = TransformedModality(
- reduce(or_, (o.type for o in other), self.type), fusion_method
- )
- modalities = [self]
- modalities.extend(other)
- fused_modality.data = fusion_method.transform(modalities)
+ modalities = [self.left_modality, self.right_modality]
+ self.data = []
+ for i in range(0, len(self.left_modality.data)):
+ self.data.append([])
+ for j in range(0, len(self.left_modality.data[i])):
+ self.data[i].append([])
+ fused = np.concatenate(
+ [self.left_modality.data[i][j],
self.right_modality.data[i][j]],
+ axis=0,
+ )
+ self.data[i][j] = fused
+ # self.data = fusion_method.transform(modalities)
- return fused_modality
+ for i, instance in enumerate(
+ self.data
+ ): # TODO: only if the layout is list_of_lists_of_numpy_array
+ r = []
+ [r.extend(l) for l in instance]
+ self.data[i] = np.array(r)
+ self.data = pad_sequences(self.data)
+ return self
diff --git a/src/main/python/systemds/scuro/modality/modality.py
b/src/main/python/systemds/scuro/modality/modality.py
index 9a3d1b148d..cce26eee01 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -20,25 +20,94 @@
# -------------------------------------------------------------
from typing import List
+import numpy as np
+
from systemds.scuro.modality.type import ModalityType
class Modality:
- def __init__(self, modality_type: ModalityType):
+ def __init__(self, modalityType: ModalityType, metadata=None):
"""
Parent class of the different Modalities (unimodal & multimodal)
:param modality_type: Type of the modality
"""
- self.type = modality_type
- self.data = None
+ self.modality_type = modalityType
+ self.schema = modalityType.get_schema()
+ self.data = []
self.data_type = None
self.cost = None
self.shape = None
- self.schema = {}
+ self.dataIndex = None
+ self.metadata = metadata
def get_modality_names(self) -> List[str]:
"""
Extracts the individual unimodal modalities for a given transformed
modality.
"""
- return [modality.name for modality in ModalityType if modality in
self.type]
+ return [
+ modality.name for modality in ModalityType if modality in
self.modality_type
+ ]
+
+ def copy_from_instance(self):
+ return type(self)(self.modality_type, self.metadata)
+
+ def update_metadata(self):
+ md_copy = self.metadata
+ self.metadata = {}
+ for i, (md_k, md_v) in enumerate(md_copy.items()):
+ updated_md = self.modality_type.update_metadata(md_v, self.data[i])
+ self.metadata[md_k] = updated_md
+
+ def get_metadata_at_position(self, position: int):
+ return self.metadata[self.dataIndex][position]
+
+ def flatten(self):
+ for num_instance, instance in enumerate(self.data):
+ if type(instance) is np.ndarray:
+ self.data[num_instance] = instance.flatten()
+ elif type(instance) is list:
+ self.data[num_instance] = np.array(
+ [item for sublist in instance for item in sublist]
+ )
+
+ self.data = np.array(self.data)
+ return self
+
+ def get_data_layout(self):
+ if not self.data:
+ return self.data
+
+ if isinstance(self.data[0], list):
+ return "list_of_lists_of_numpy_array"
+ elif isinstance(self.data[0], np.ndarray):
+ return "list_of_numpy_array"
+
+ def get_data_shape(self):
+ layout = self.get_data_layout()
+ if not layout:
+ return None
+
+ if layout == "list_of_lists_of_numpy_array":
+ return self.data[0][0].shape
+ elif layout == "list_of_numpy_array":
+ return self.data[0].shape
+
+ def get_data_dtype(self):
+ layout = self.get_data_layout()
+ if not layout:
+ return None
+
+ if layout == "list_of_lists_of_numpy_array":
+ return self.data[0][0].dtype
+ elif layout == "list_of_numpy_array":
+ return self.data[0].dtype
+
+ def update_data_layout(self):
+ if not self.data:
+ return
+
+ self.schema["data_layout"]["representation"] = self.get_data_layout()
+
+ self.shape = self.get_data_shape()
+ self.schema["data_layout"]["type"] = self.get_data_dtype()
diff --git a/src/main/python/systemds/scuro/modality/transformed.py
b/src/main/python/systemds/scuro/modality/transformed.py
index 61c327e469..64bfba0819 100644
--- a/src/main/python/systemds/scuro/modality/transformed.py
+++ b/src/main/python/systemds/scuro/modality/transformed.py
@@ -21,20 +21,64 @@
from functools import reduce
from operator import or_
+from systemds.scuro.modality.joined import JoinedModality
from systemds.scuro.modality.modality import Modality
-from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.window import WindowAggregation
class TransformedModality(Modality):
- def __init__(self, modality_type: ModalityType, transformation):
+ def __init__(self, modality_type, transformation, metadata):
"""
Parent class of the different Modalities (unimodal & multimodal)
:param modality_type: Type of the original modality(ies)
:param transformation: Representation to be applied on the modality
"""
- super().__init__(modality_type)
+ super().__init__(modality_type, metadata)
self.transformation = transformation
+ self.data = []
+
+ def copy_from_instance(self):
+ return type(self)(self.modality_type, self.transformation,
self.metadata)
+
+ def join(self, right, join_condition):
+ chunked_execution = False
+ if type(right).__name__.__contains__("Unimodal"):
+ if right.data_loader.chunk_size:
+ chunked_execution = True
+ elif right.data is None or len(right.data) == 0:
+ right.extract_raw_data()
+
+ joined_modality = JoinedModality(
+ reduce(or_, [right.modality_type], self.modality_type),
+ self,
+ right,
+ join_condition,
+ chunked_execution,
+ )
+
+ if not chunked_execution:
+ joined_modality.execute(0)
+
+ return joined_modality
+
+ def window(self, windowSize, aggregationFunction, fieldName=None):
+ transformed_modality = TransformedModality(
+ self.modality_type, "window", self.metadata
+ )
+ w = WindowAggregation(windowSize, aggregationFunction)
+ transformed_modality.data = w.window(self)
+
+ return transformed_modality
+
+ def apply_representation(self, representation, aggregation):
+ new_modality = representation.transform(self)
+
+ if aggregation:
+ new_modality.data = aggregation.window(new_modality)
+
+ new_modality.update_metadata()
+ return new_modality
def combine(self, other, fusion_method):
"""
@@ -43,7 +87,9 @@ class TransformedModality(Modality):
:param fusion_method: The fusion method to be used to combine
modalities
"""
fused_modality = TransformedModality(
- reduce(or_, (o.type for o in other), self.type), fusion_method
+ reduce(or_, (o.modality_type for o in other), self.modality_type),
+ fusion_method,
+ self.metadata,
)
modalities = [self]
modalities.extend(other)
diff --git a/src/main/python/systemds/scuro/modality/type.py
b/src/main/python/systemds/scuro/modality/type.py
index c451eea6f1..197ad23c54 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/modality/type.py
@@ -18,7 +18,90 @@
# under the License.
#
# -------------------------------------------------------------
-from enum import Enum, Flag, auto
+from enum import Flag, auto
+from systemds.scuro.utils.schema_helpers import (
+ calculate_new_frequency,
+ create_timestamps,
+)
+
+
+# TODO: needs a way to define if data comes from a dataset with multiple
instances or is like a streaming scenario where we only have one instance
+# right now it is a list of instances (if only one instance the list would
contain only a single item)
+class ModalitySchemas:
+ TEXT_SCHEMA = {"type": "string", "length": "int"}
+
+ AUDIO_SCHEMA = {
+ "timestamp": "array",
+ "data_layout": {"type": "?", "representation": "?"},
+ "sample_rate": "integer",
+ "length": "integer",
+ }
+
+ VIDEO_SCHEMA = {
+ "timestamp": "array",
+ "data_layout": {"type": "?", "representation": "?"},
+ "fps": "integer",
+ "length": "integer",
+ "width": "integer",
+ "height": "integer",
+ "num_channels": "integer",
+ }
+
+ _metadata_handlers = {}
+
+ @classmethod
+ def get(cls, name):
+ return getattr(cls, f"{name}_SCHEMA", None)
+
+ @classmethod
+ def add_schema(cls, name, schema):
+ setattr(cls, f"{name}_SCHEMA", schema)
+
+ @classmethod
+ def register_metadata_handler(cls, name):
+ def decorator(metadata_handler):
+ cls._metadata_handlers[name] = metadata_handler
+ return metadata_handler
+
+ return decorator
+
+ @classmethod
+ def update_metadata(cls, name, md, data):
+ mdHandler = cls._metadata_handlers.get(name)
+ if mdHandler:
+ return mdHandler(md, data)
+
+ def extract_data(self, data, index):
+ if self.get("data_layout").get("representation") == "list_array":
+ return data[index]
+ else:
+ return data[index]
+
+
[email protected]_metadata_handler("AUDIO")
+def handle_audio_metadata(md, data):
+ new_frequency = calculate_new_frequency(len(data), md["length"],
md["sample_rate"])
+ md.update(
+ {
+ "length": len(data),
+ "sample_rate": new_frequency,
+ "timestamp": create_timestamps(new_frequency, len(data)),
+ }
+ )
+ return md
+
+
[email protected]_metadata_handler("VIDEO")
+def handle_video_metadata(md, data):
+ new_frequency = calculate_new_frequency(len(data), md["length"], md["fps"])
+ md.update(
+ {
+ "length": len(data),
+ "fps": new_frequency,
+ "timestamp": create_timestamps(new_frequency, len(data)),
+ }
+ )
+ return md
class ModalityType(Flag):
@@ -26,6 +109,8 @@ class ModalityType(Flag):
AUDIO = auto()
VIDEO = auto()
- # def __init__(self, value, name):
- # self._value_ = value
- # self.name = name
+ def get_schema(self):
+ return ModalitySchemas.get(self.name)
+
+ def update_metadata(self, md, data):
+ return ModalitySchemas.update_metadata(self.name, md, data)
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 976d4194d4..ae33b6605b 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -18,8 +18,13 @@
# under the License.
#
# -------------------------------------------------------------
+from functools import reduce
+from operator import or_
+
+
from systemds.scuro.dataloader.base_loader import BaseLoader
from systemds.scuro.modality.modality import Modality
+from systemds.scuro.modality.joined import JoinedModality
from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.modality.type import ModalityType
@@ -32,28 +37,65 @@ class UnimodalModality(Modality):
:param data_loader: Defines how the raw data should be loaded
:param modality_type: Type of the modality
"""
- super().__init__(modality_type)
+ super().__init__(modality_type, None)
self.data_loader = data_loader
+ def copy_from_instance(self):
+ new_instance = type(self)(self.data_loader, self.modality_type)
+ if self.metadata:
+ new_instance.metadata = self.metadata.copy()
+ return new_instance
+
+ def get_metadata_at_position(self, position: int):
+ if self.data_loader.chunk_size:
+ return self.metadata[
+ self.data_loader.chunk_size * self.data_loader.next_chunk +
position
+ ]
+
+ return self.metadata[self.dataIndex][position]
+
def extract_raw_data(self):
"""
Uses the data loader to read the raw data from a specified location
and stores the data in the data location.
- TODO: schema
"""
- self.data = self.data_loader.load()
+ self.data, self.metadata = self.data_loader.load()
+
+ def join(self, other, join_condition):
+ if isinstance(other, UnimodalModality):
+ self.data_loader.update_chunk_sizes(other.data_loader)
+
+ joined_modality = JoinedModality(
+ reduce(or_, [other.modality_type], self.modality_type),
+ self,
+ other,
+ join_condition,
+ self.data_loader.chunk_size is not None,
+ )
- def apply_representation(self, representation):
- new_modality = TransformedModality(self.type, representation)
+ return joined_modality
+
+ def apply_representation(self, representation, aggregation=None):
+ new_modality = TransformedModality(
+ self.modality_type, representation.name,
self.data_loader.metadata.copy()
+ )
new_modality.data = []
if self.data_loader.chunk_size:
while self.data_loader.next_chunk < self.data_loader.num_chunks:
self.extract_raw_data()
- new_modality.data.extend(representation.transform(self.data))
+ transformed_chunk = representation.transform(self)
+ if aggregation:
+ transformed_chunk.data =
aggregation.window(transformed_chunk)
+ new_modality.data.extend(transformed_chunk.data)
+ new_modality.metadata.update(transformed_chunk.metadata)
else:
if not self.data:
self.extract_raw_data()
- new_modality.data = representation.transform(self.data)
+ new_modality = representation.transform(self)
+
+ if aggregation:
+ new_modality.data = aggregation.window(new_modality)
+ new_modality.update_metadata()
return new_modality
diff --git a/src/main/python/systemds/scuro/representations/aggregate.py
b/src/main/python/systemds/scuro/representations/aggregate.py
new file mode 100644
index 0000000000..7c8d1c68d1
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/aggregate.py
@@ -0,0 +1,51 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import numpy as np
+
+from systemds.scuro.modality.modality import Modality
+
+
+# TODO: make this a Representation and add a fusion method that fuses two
modalities with each other
+
+
+class Aggregation:
+ def __init__(self, aggregation_function, field_name):
+ self.aggregation_function = aggregation_function
+ self.field_name = field_name
+
+ def aggregate(self, modality):
+ aggregated_modality = Modality(modality.modality_type,
modality.metadata)
+ aggregated_modality.data = []
+ for i, instance in enumerate(modality.data):
+ aggregated_modality.data.append([])
+ for j, entry in enumerate(instance):
+ if self.aggregation_function == "sum":
+ aggregated_modality.data[i].append(np.sum(entry, axis=0))
+ elif self.aggregation_function == "mean":
+ aggregated_modality.data[i].append(np.mean(entry, axis=0))
+ elif self.aggregation_function == "min":
+ aggregated_modality.data[i].append(np.min(entry, axis=0))
+ elif self.aggregation_function == "max":
+ aggregated_modality.data[i].append(np.max(entry, axis=0))
+ else:
+ raise ValueError("Invalid aggregation function")
+
+ return aggregated_modality
diff --git a/src/main/python/systemds/scuro/representations/bert.py
b/src/main/python/systemds/scuro/representations/bert.py
index 0fcf1e8d28..bfaaa22642 100644
--- a/src/main/python/systemds/scuro/representations/bert.py
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -21,6 +21,7 @@
import numpy as np
+from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
import torch
from transformers import BertTokenizer, BertModel
@@ -28,30 +29,29 @@ from systemds.scuro.representations.utils import
save_embeddings
class Bert(UnimodalRepresentation):
- def __init__(self, avg_layers=None, output_file=None):
+ def __init__(self, output_file=None):
super().__init__("Bert")
- self.avg_layers = avg_layers
self.output_file = output_file
- def transform(self, data):
-
+ def transform(self, modality):
+ transformed_modality = TransformedModality(
+ modality.modality_type, self, modality.metadata
+ )
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(
model_name, clean_up_tokenization_spaces=True
)
- if self.avg_layers is not None:
- model = BertModel.from_pretrained(model_name,
output_hidden_states=True)
- else:
- model = BertModel.from_pretrained(model_name)
+ model = BertModel.from_pretrained(model_name)
- embeddings = self.create_embeddings(data, model, tokenizer)
+ embeddings = self.create_embeddings(modality.data, model, tokenizer)
if self.output_file is not None:
save_embeddings(embeddings, self.output_file)
- return embeddings
+ transformed_modality.data = embeddings
+ return transformed_modality
def create_embeddings(self, data, model, tokenizer):
embeddings = []
@@ -61,15 +61,8 @@ class Bert(UnimodalRepresentation):
with torch.no_grad():
outputs = model(**inputs)
- if self.avg_layers is not None:
- cls_embedding = [
- outputs.hidden_states[i][:, 0, :]
- for i in range(-self.avg_layers, 0)
- ]
- cls_embedding = torch.mean(torch.stack(cls_embedding),
dim=0).numpy()
- else:
cls_embedding = outputs.last_hidden_state[:, 0,
:].squeeze().numpy()
- embeddings.append(cls_embedding)
+ embeddings.append(cls_embedding)
embeddings = np.array(embeddings)
return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
diff --git a/src/main/python/systemds/scuro/representations/bow.py
b/src/main/python/systemds/scuro/representations/bow.py
index bd54654a5c..f16f6ec04d 100644
--- a/src/main/python/systemds/scuro/representations/bow.py
+++ b/src/main/python/systemds/scuro/representations/bow.py
@@ -21,6 +21,7 @@
from sklearn.feature_extraction.text import CountVectorizer
+from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from systemds.scuro.representations.utils import save_embeddings
@@ -32,14 +33,18 @@ class BoW(UnimodalRepresentation):
self.min_df = min_df
self.output_file = output_file
- def transform(self, data):
+ def transform(self, modality):
+ transformed_modality = TransformedModality(
+ modality.modality_type, self, modality.metadata
+ )
vectorizer = CountVectorizer(
ngram_range=(1, self.ngram_range), min_df=self.min_df
)
- X = vectorizer.fit_transform(data).toarray()
+ X = vectorizer.fit_transform(modality.data).toarray()
if self.output_file is not None:
save_embeddings(X, self.output_file)
- return X
+ transformed_modality.data = X
+ return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/glove.py
b/src/main/python/systemds/scuro/representations/glove.py
index cf13c717d2..767fc8d375 100644
--- a/src/main/python/systemds/scuro/representations/glove.py
+++ b/src/main/python/systemds/scuro/representations/glove.py
@@ -19,7 +19,8 @@
#
# -------------------------------------------------------------
import numpy as np
-from nltk import word_tokenize
+from gensim.utils import tokenize
+
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from systemds.scuro.representations.utils import read_data_from_file,
save_embeddings
@@ -47,7 +48,7 @@ class GloVe(UnimodalRepresentation):
embeddings = []
for sentences in data:
- tokens = word_tokenize(sentences.lower())
+ tokens = list(tokenize(sentences.lower()))
embeddings.append(
np.mean(
[
diff --git a/src/main/python/systemds/scuro/representations/lstm.py
b/src/main/python/systemds/scuro/representations/lstm.py
index 649b81117b..6f06e762a5 100644
--- a/src/main/python/systemds/scuro/representations/lstm.py
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -46,11 +46,11 @@ class LSTM(Fusion):
result = np.zeros((size, 0))
for modality in modalities:
- if modality.type in self.unimodal_embeddings.keys():
- out = self.unimodal_embeddings.get(modality.type)
+ if modality.modality_type in self.unimodal_embeddings.keys():
+ out = self.unimodal_embeddings.get(modality.modality_type)
else:
out = self.run_lstm(modality.data)
- self.unimodal_embeddings[modality.type] = out
+ self.unimodal_embeddings[modality.modality_type] = out
result = np.concatenate([result, out], axis=-1)
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
index 57a7fab83e..483ea181b8 100644
--- a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
+++ b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
@@ -18,44 +18,40 @@
# under the License.
#
# -------------------------------------------------------------
-
-import pickle
-
import librosa
import numpy as np
-from systemds.scuro.representations.utils import pad_sequences
+from systemds.scuro.modality.transformed import TransformedModality
+
+# import matplotlib.pyplot as plt
from systemds.scuro.representations.unimodal import UnimodalRepresentation
class MelSpectrogram(UnimodalRepresentation):
- def __init__(self, avg=True, output_file=None):
+ def __init__(self):
super().__init__("MelSpectrogram")
- self.avg = avg
- self.output_file = output_file
- def transform(self, data):
+ def transform(self, modality):
+ transformed_modality = TransformedModality(
+ modality.modality_type, self, modality.metadata
+ )
result = []
max_length = 0
- for sample in data:
- S = librosa.feature.melspectrogram(y=sample)
+ for sample in modality.data:
+ S = librosa.feature.melspectrogram(y=sample, sr=22050)
S_dB = librosa.power_to_db(S, ref=np.max)
if S_dB.shape[-1] > max_length:
max_length = S_dB.shape[-1]
- result.append(S_dB)
-
- r = []
- for elem in result:
- d = pad_sequences(elem, maxlen=max_length, dtype="float32")
- r.append(d)
-
- np_array_r = np.array(r) if not self.avg else np.mean(np.array(r),
axis=1)
-
- if self.output_file is not None:
- data = []
- for i in range(0, np_array_r.shape[0]):
- data.append(np_array_r[i])
- with open(self.output_file, "wb") as file:
- pickle.dump(data, file)
-
- return np_array_r
+ result.append(S_dB.T)
+
+ transformed_modality.data = result
+ return transformed_modality
+
+ # def plot_spectrogram(self, spectrogram):
+ # plt.figure(figsize=(10, 4))
+ # librosa.display.specshow(
+ # spectrogram, x_axis="time", y_axis="mel", sr=22050,
cmap="viridis"
+ # )
+ # plt.colorbar(format="%+2.0f dB")
+ # plt.title("Mel Spectrogram")
+ # plt.savefig("spectrogram.jpg")
diff --git a/src/main/python/systemds/scuro/representations/resnet.py
b/src/main/python/systemds/scuro/representations/resnet.py
index 1c1bfa1d5e..ff63e6766b 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/resnet.py
@@ -19,9 +19,7 @@
#
# -------------------------------------------------------------
-
-import h5py
-
+from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from typing import Callable, Dict, Tuple, Any
import torch.utils.data
@@ -30,23 +28,61 @@ import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
-DEVICE = "cpu"
+if torch.backends.mps.is_available():
+ DEVICE = torch.device("mps")
+elif torch.cuda.is_available():
+ DEVICE = torch.device("cuda")
+else:
+ DEVICE = torch.device("cpu")
class ResNet(UnimodalRepresentation):
- def __init__(self, layer="avgpool", output_file=None):
+ def __init__(self, layer="avgpool", model_name="ResNet18",
output_file=None):
super().__init__("ResNet")
self.output_file = output_file
self.layer_name = layer
+ self.model = model_name
+ self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
- def transform(self, data):
-
- resnet =
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(DEVICE)
- resnet.eval()
+ class Identity(torch.nn.Module):
+ def forward(self, input_: torch.Tensor) -> torch.Tensor:
+ return input_
- for param in resnet.parameters():
- param.requires_grad = False
+ self.model.fc = Identity()
+
+ @property
+ def model(self):
+ return self._model
+
+ @model.setter
+ def model(self, model):
+ if model == "ResNet18":
+ self._model =
models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(
+ DEVICE
+ )
+ elif model == "ResNet34":
+ self._model =
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
+ DEVICE
+ )
+ elif model == "ResNet50":
+ self._model =
models.resnet50(weights=models.ResNet50_Weights.DEFAULT).to(
+ DEVICE
+ )
+ elif model == "ResNet101":
+ self._model =
models.resnet101(weights=models.ResNet101_Weights.DEFAULT).to(
+ DEVICE
+ )
+ elif model == "ResNet152":
+ self._model =
models.resnet152(weights=models.ResNet152_Weights.DEFAULT).to(
+ DEVICE
+ )
+ else:
+ raise NotImplementedError
+
+ def transform(self, modality):
t = transforms.Compose(
[
@@ -60,15 +96,9 @@ class ResNet(UnimodalRepresentation):
]
)
- dataset = ResNetDataset(data, t)
+ dataset = ResNetDataset(modality.data, t)
embeddings = {}
- class Identity(torch.nn.Module):
- def forward(self, input_: torch.Tensor) -> torch.Tensor:
- return input_
-
- resnet.fc = Identity()
-
res5c_output = None
def get_features(name_):
@@ -81,14 +111,14 @@ class ResNet(UnimodalRepresentation):
return hook
if self.layer_name:
- for name, layer in resnet.named_modules():
+ for name, layer in self.model.named_modules():
if name == self.layer_name:
layer.register_forward_hook(get_features(name))
break
for instance in torch.utils.data.DataLoader(dataset):
video_id = instance["id"][0]
- frames = instance["frames"][0].to(DEVICE)
+ frames = instance["data"][0].to(DEVICE)
embeddings[video_id] = []
batch_size = 64
@@ -97,32 +127,21 @@ class ResNet(UnimodalRepresentation):
frame_ids_range = range(start_index, end_index)
frame_batch = frames[frame_ids_range]
- _ = resnet(frame_batch)
+ _ = self.model(frame_batch)
values = res5c_output
+ pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1,
1))
- if self.layer_name == "avgpool" or self.layer_name ==
"maxpool":
- embeddings[video_id].extend(
- torch.flatten(values, 1).detach().cpu().numpy()
- )
-
- else:
- pooled = torch.nn.functional.adaptive_avg_pool2d(values,
(1, 1))
-
- embeddings[video_id].extend(
- torch.flatten(pooled, 1).detach().cpu().numpy()
- )
+ embeddings[video_id].extend(
+ torch.flatten(pooled, 1).detach().cpu().numpy()
+ )
- if self.output_file is not None:
- with h5py.File(self.output_file, "w") as hdf:
- for key, value in embeddings.items():
- hdf.create_dataset(key, data=value)
-
- emb = []
-
- for video in embeddings.values():
- emb.append(np.array(video).mean(axis=0).tolist())
+ transformed_modality = TransformedModality(
+ modality.modality_type, "resnet", modality.metadata
+ )
+ transformed_modality.data = list(embeddings.values())
+ transformed_modality.update_data_layout()
- return np.array(emb)
+ return transformed_modality
class ResNetDataset(torch.utils.data.Dataset):
@@ -131,12 +150,23 @@ class ResNetDataset(torch.utils.data.Dataset):
self.tf = tf
def __getitem__(self, index) -> Dict[str, object]:
- video = self.data[index]
- frames = torch.empty((len(video), 3, 224, 224))
-
- for i, frame in enumerate(video):
- frames[i] = self.tf(frame)
- return {"id": index, "frames": frames}
+ data = self.data[index]
+ if type(data) is np.ndarray:
+ output = torch.empty((1, 3, 224, 224))
+ d = torch.tensor(data)
+ d = d.repeat(3, 1, 1)
+ output[0] = self.tf(d)
+ else:
+ output = torch.empty((len(data), 3, 224, 224))
+
+ for i, d in enumerate(data):
+ if data[0].ndim < 3:
+ d = torch.tensor(d)
+ d = d.repeat(3, 1, 1)
+
+ output[i] = self.tf(d)
+
+ return {"id": index, "data": output}
def __len__(self) -> int:
return len(self.data)
diff --git a/src/main/python/systemds/scuro/representations/tfidf.py
b/src/main/python/systemds/scuro/representations/tfidf.py
index 4849aba136..02cfb927c7 100644
--- a/src/main/python/systemds/scuro/representations/tfidf.py
+++ b/src/main/python/systemds/scuro/representations/tfidf.py
@@ -20,7 +20,7 @@
# -------------------------------------------------------------
from sklearn.feature_extraction.text import TfidfVectorizer
-
+from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from systemds.scuro.representations.utils import read_data_from_file,
save_embeddings
@@ -31,13 +31,17 @@ class TfIdf(UnimodalRepresentation):
self.min_df = min_df
self.output_file = output_file
- def transform(self, data):
+ def transform(self, modality):
+ transformed_modality = TransformedModality(
+ modality.modality_type, self, modality.metadata
+ )
vectorizer = TfidfVectorizer(min_df=self.min_df)
- X = vectorizer.fit_transform(data)
+ X = vectorizer.fit_transform(modality.data)
X = X.toarray()
if self.output_file is not None:
save_embeddings(X, self.output_file)
- return X
+ transformed_modality.data = X
+ return transformed_modality
diff --git a/src/main/python/systemds/scuro/representations/window.py
b/src/main/python/systemds/scuro/representations/window.py
new file mode 100644
index 0000000000..af0301d0e3
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/window.py
@@ -0,0 +1,49 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import numpy as np
+import math
+
+
+# TODO: move this into the aggregation class and add an aggregate() and a
window(window_size) function there so they can use the same functionality.
+class WindowAggregation:
+ def __init__(self, window_size, aggregation_function):
+ self.window_size = window_size
+ self.aggregation_function = aggregation_function
+
+ def window(self, modality):
+ # data is a 2d array
+ windowed_data = []
+ for instance in modality.data:
+ window_length = math.ceil(len(instance) / self.window_size)
+ result = [[] for _ in range(0, window_length)]
+ # if modality.schema["data_layout"]["representation"] ==
"list_of_lists_of_numpy_array":
+ data = np.stack(instance)
+ for i in range(0, window_length):
+ result[i] = np.mean(
+ data[
+ i * self.window_size : i * self.window_size +
self.window_size
+ ],
+ axis=0,
+ ) # TODO: add actual aggregation function here
+
+ windowed_data.append(result)
+
+ return windowed_data
diff --git a/src/main/python/systemds/scuro/representations/word2vec.py
b/src/main/python/systemds/scuro/representations/word2vec.py
index 209091648d..b68a9fd3eb 100644
--- a/src/main/python/systemds/scuro/representations/word2vec.py
+++ b/src/main/python/systemds/scuro/representations/word2vec.py
@@ -19,11 +19,11 @@
#
# -------------------------------------------------------------
import numpy as np
-
+from systemds.scuro.modality.transformed import TransformedModality
from systemds.scuro.representations.unimodal import UnimodalRepresentation
from systemds.scuro.representations.utils import save_embeddings
from gensim.models import Word2Vec
-from nltk.tokenize import word_tokenize
+from gensim.utils import tokenize
def get_embedding(sentence, model):
@@ -43,8 +43,11 @@ class W2V(UnimodalRepresentation):
self.window = window
self.output_file = output_file
- def transform(self, data):
- t = [word_tokenize(s.lower()) for s in data]
+ def transform(self, modality):
+ transformed_modality = TransformedModality(
+ modality.modality_type, self, modality.metadata
+ )
+ t = [list(tokenize(s.lower())) for s in modality.data]
model = Word2Vec(
sentences=t,
vector_size=self.vector_size,
@@ -52,11 +55,11 @@ class W2V(UnimodalRepresentation):
min_count=self.min_count,
)
embeddings = []
- for sentences in data:
- tokens = word_tokenize(sentences.lower())
+ for sentences in modality.data:
+ tokens = list(tokenize(sentences.lower()))
embeddings.append(get_embedding(tokens, model))
if self.output_file is not None:
save_embeddings(np.array(embeddings), self.output_file)
-
- return np.array(embeddings)
+ transformed_modality.data = np.array(embeddings)
+ return transformed_modality
diff --git a/src/main/python/systemds/scuro/modality/type.py
b/src/main/python/systemds/scuro/utils/__init__.py
similarity index 80%
copy from src/main/python/systemds/scuro/modality/type.py
copy to src/main/python/systemds/scuro/utils/__init__.py
index c451eea6f1..e66abb4646 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/utils/__init__.py
@@ -18,14 +18,3 @@
# under the License.
#
# -------------------------------------------------------------
-from enum import Enum, Flag, auto
-
-
-class ModalityType(Flag):
- TEXT = auto()
- AUDIO = auto()
- VIDEO = auto()
-
- # def __init__(self, value, name):
- # self._value_ = value
- # self.name = name
diff --git a/src/main/python/systemds/scuro/dataloader/json_loader.py
b/src/main/python/systemds/scuro/utils/schema_helpers.py
similarity index 58%
copy from src/main/python/systemds/scuro/dataloader/json_loader.py
copy to src/main/python/systemds/scuro/utils/schema_helpers.py
index c4e3b95611..a88e81f716 100644
--- a/src/main/python/systemds/scuro/dataloader/json_loader.py
+++ b/src/main/python/systemds/scuro/utils/schema_helpers.py
@@ -18,26 +18,26 @@
# under the License.
#
# -------------------------------------------------------------
-import json
+import math
+import numpy as np
-from systemds.scuro.dataloader.base_loader import BaseLoader
-from typing import Optional, List
+def create_timestamps(frequency, sample_length, start_datetime=None):
+ start_time = (
+ start_datetime
+ if start_datetime is not None
+ else np.datetime64("1970-01-01T00:00:00.000000")
+ )
+ time_increment = 1 / frequency
+ time_increments_array = np.arange(sample_length) * np.timedelta64(
+ int(time_increment * 1e6)
+ )
+ timestamps = start_time + time_increments_array
-class JSONLoader(BaseLoader):
- def __init__(
- self,
- source_path: str,
- indices: List[str],
- field: str,
- chunk_size: Optional[int] = None,
- ):
- super().__init__(source_path, indices, chunk_size)
- self.field = field
+ return timestamps.astype(np.int64)
- def extract(self, file: str, indices: List[str]):
- self.file_sanity_check(file)
- with open(file) as f:
- json_file = json.load(f)
- for idx in indices:
- self.data.append(json_file[idx][self.field])
+
+def calculate_new_frequency(new_length, old_length, old_frequency):
+ duration = old_length / old_frequency
+ new_frequency = new_length / duration
+ return new_frequency
diff --git a/src/main/python/tests/scuro/data_generator.py
b/src/main/python/tests/scuro/data_generator.py
index 6856ee7044..ec0783df9c 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -18,30 +18,78 @@
# under the License.
#
# -------------------------------------------------------------
+import shutil
+
import cv2
import numpy as np
from scipy.io.wavfile import write
import random
import os
+
+from systemds.scuro import VideoLoader, AudioLoader, TextLoader,
UnimodalModality
from systemds.scuro.modality.type import ModalityType
+def setup_data(modalities, num_instances, path):
+ if os.path.isdir(path):
+ shutil.rmtree(path)
+
+ os.makedirs(path)
+
+ indizes = [str(i) for i in range(0, num_instances)]
+
+ modalities_to_create = []
+ for modality in modalities:
+ mod_path = path + "/" + modality.name + "/"
+
+ if modality == ModalityType.VIDEO:
+ data_loader = VideoLoader(mod_path, indizes)
+ elif modality == ModalityType.AUDIO:
+ data_loader = AudioLoader(mod_path, indizes)
+ elif modality == ModalityType.TEXT:
+ data_loader = TextLoader(mod_path, indizes)
+ else:
+ raise "Modality not supported in DataGenerator"
+
+ modalities_to_create.append(UnimodalModality(data_loader, modality))
+
+ data_generator = TestDataGenerator(modalities_to_create, path)
+ data_generator.create_multimodal_data(num_instances)
+ return data_generator
+
+
class TestDataGenerator:
def __init__(self, modalities, path, balanced=True):
+
self.modalities = modalities
+ self.modalities_by_type = {}
+ for modality in modalities:
+ self.modalities_by_type[modality.modality_type] = modality
+
+ self._indices = None
self.path = path
self.balanced = balanced
for modality in modalities:
- mod_path = f"{self.path}/{modality.type.name}/"
+ mod_path = f"{self.path}/{modality.modality_type.name}/"
os.mkdir(mod_path)
modality.file_path = mod_path
self.labels = []
self.label_path = f"{path}/labels.npy"
+ def get_modality_path(self, modality_type):
+ return self.modalities_by_type[modality_type].data_loader.source_path
+
+ @property
+ def indices(self):
+ if self._indices is None:
+ raise "No indices available, please call setup_data first"
+ return self._indices
+
def create_multimodal_data(self, num_instances, duration=2, seed=42):
speed_fast = 0
speed_slow = 0
+ self._indices = [str(i) for i in range(0, num_instances)]
for idx in range(num_instances):
np.random.seed(seed)
if self.balanced:
@@ -69,11 +117,11 @@ class TestDataGenerator:
speed_slow += 1
for modality in self.modalities:
- if modality.type == ModalityType.VIDEO:
+ if modality.modality_type == ModalityType.VIDEO:
self.__create_video_data(idx, duration, 30, speed_factor)
- if modality.type == ModalityType.AUDIO:
+ if modality.modality_type == ModalityType.AUDIO:
self.__create_audio_data(idx, duration, speed_factor)
- if modality.type == ModalityType.TEXT:
+ if modality.modality_type == ModalityType.TEXT:
self.__create_text_data(idx, speed_factor)
np.save(f"{self.path}/labels.npy", np.array(self.labels))
diff --git a/src/main/python/tests/scuro/test_data_loaders.py
b/src/main/python/tests/scuro/test_data_loaders.py
index 55704b8d8a..4ca77b205d 100644
--- a/src/main/python/tests/scuro/test_data_loaders.py
+++ b/src/main/python/tests/scuro/test_data_loaders.py
@@ -26,7 +26,7 @@ from systemds.scuro.modality.unimodal_modality import
UnimodalModality
from systemds.scuro.representations.bert import Bert
from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
from systemds.scuro.representations.resnet import ResNet
-from tests.scuro.data_generator import TestDataGenerator
+from tests.scuro.data_generator import setup_data
from systemds.scuro.dataloader.audio_loader import AudioLoader
from systemds.scuro.dataloader.video_loader import VideoLoader
@@ -42,39 +42,25 @@ class TestDataLoaders(unittest.TestCase):
video = None
data_generator = None
num_instances = 0
- indizes = []
@classmethod
def setUpClass(cls):
cls.test_file_path = "test_data"
-
- if os.path.isdir(cls.test_file_path):
- shutil.rmtree(cls.test_file_path)
-
- os.makedirs(f"{cls.test_file_path}/embeddings")
-
cls.num_instances = 2
- cls.indizes = [str(i) for i in range(0, cls.num_instances)]
+ cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
+ cls.data_generator = setup_data(cls.mods, cls.num_instances,
cls.test_file_path)
- cls.video_path = cls.test_file_path + "/" + ModalityType.VIDEO.name +
"/"
- cls.audio_path = cls.test_file_path + "/" + ModalityType.AUDIO.name +
"/"
- cls.text_path = cls.test_file_path + "/" + ModalityType.TEXT.name + "/"
-
- video_data_loader = VideoLoader(cls.video_path, cls.indizes)
- audio_data_loader = AudioLoader(cls.audio_path, cls.indizes)
- text_data_loader = TextLoader(cls.text_path, cls.indizes)
-
- # Load modalities (audio, video, text)
- video = UnimodalModality(video_data_loader, ModalityType.VIDEO)
- audio = UnimodalModality(audio_data_loader, ModalityType.AUDIO)
- text = UnimodalModality(text_data_loader, ModalityType.TEXT)
+ os.makedirs(f"{cls.test_file_path}/embeddings")
- cls.mods = [video, audio, text]
- cls.data_generator = TestDataGenerator(cls.mods, cls.test_file_path)
- cls.data_generator.create_multimodal_data(cls.num_instances)
- cls.text_ref = text.apply_representation(Bert())
- cls.audio_ref = audio.apply_representation(MelSpectrogram())
- cls.video_ref = video.apply_representation(ResNet())
+ cls.text_ref = cls.data_generator.modalities_by_type[
+ ModalityType.TEXT
+ ].apply_representation(Bert())
+ cls.audio_ref = cls.data_generator.modalities_by_type[
+ ModalityType.AUDIO
+ ].apply_representation(MelSpectrogram())
+ cls.video_ref = cls.data_generator.modalities_by_type[
+ ModalityType.VIDEO
+ ].apply_representation(ResNet())
@classmethod
def tearDownClass(cls):
@@ -82,25 +68,38 @@ class TestDataLoaders(unittest.TestCase):
shutil.rmtree(cls.test_file_path)
def test_load_audio_data_from_file(self):
- audio_data_loader = AudioLoader(self.audio_path, self.indizes)
+ audio_data_loader = AudioLoader(
+ self.data_generator.get_modality_path(ModalityType.AUDIO),
+ self.data_generator.indices,
+ )
audio = UnimodalModality(
audio_data_loader, ModalityType.AUDIO
).apply_representation(MelSpectrogram())
for i in range(0, self.num_instances):
- assert round(sum(self.audio_ref.data[i]), 4) ==
round(sum(audio.data[i]), 4)
+ assert round(sum(sum(self.audio_ref.data[i])), 4) == round(
+ sum(sum(audio.data[i])), 4
+ )
def test_load_video_data_from_file(self):
- video_data_loader = VideoLoader(self.video_path, self.indizes)
+ video_data_loader = VideoLoader(
+ self.data_generator.get_modality_path(ModalityType.VIDEO),
+ self.data_generator.indices,
+ )
video = UnimodalModality(
video_data_loader, ModalityType.VIDEO
).apply_representation(ResNet())
for i in range(0, self.num_instances):
- assert round(sum(self.video_ref.data[i]), 4) ==
round(sum(video.data[i]), 4)
+ assert round(sum(sum(self.video_ref.data[i])), 4) == round(
+ sum(sum(video.data[i])), 4
+ )
def test_load_text_data_from_file(self):
- text_data_loader = TextLoader(self.text_path, self.indizes)
+ text_data_loader = TextLoader(
+ self.data_generator.get_modality_path(ModalityType.TEXT),
+ self.data_generator.indices,
+ )
text = UnimodalModality(
text_data_loader, ModalityType.TEXT
).apply_representation(Bert())
diff --git a/src/main/python/tests/scuro/test_dr_search.py
b/src/main/python/tests/scuro/test_dr_search.py
index d0d7ef5077..f2ba9d2d79 100644
--- a/src/main/python/tests/scuro/test_dr_search.py
+++ b/src/main/python/tests/scuro/test_dr_search.py
@@ -25,14 +25,10 @@ import unittest
import numpy as np
from sklearn import svm
from sklearn.metrics import classification_report
-from sklearn.model_selection import train_test_split, KFold
+from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
-from systemds.scuro.modality.unimodal_modality import UnimodalModality
from systemds.scuro.modality.type import ModalityType
-from systemds.scuro.dataloader.text_loader import TextLoader
-from systemds.scuro.dataloader.audio_loader import AudioLoader
-from systemds.scuro.dataloader.video_loader import VideoLoader
from systemds.scuro.aligner.dr_search import DRSearch
from systemds.scuro.aligner.task import Task
from systemds.scuro.models.model import Model
@@ -45,7 +41,7 @@ from systemds.scuro.representations.mel_spectrogram import
MelSpectrogram
from systemds.scuro.representations.multiplication import Multiplication
from systemds.scuro.representations.resnet import ResNet
from systemds.scuro.representations.sum import Sum
-from tests.scuro.data_generator import TestDataGenerator
+from tests.scuro.data_generator import setup_data
import warnings
@@ -89,56 +85,54 @@ class TestDataLoaders(unittest.TestCase):
video = None
data_generator = None
num_instances = 0
- indizes = []
representations = None
@classmethod
def setUpClass(cls):
cls.test_file_path = "test_data_dr_search"
+ cls.num_instances = 20
+ modalities = [ModalityType.VIDEO, ModalityType.AUDIO,
ModalityType.TEXT]
- if os.path.isdir(cls.test_file_path):
- shutil.rmtree(cls.test_file_path)
-
+ cls.data_generator = setup_data(
+ modalities, cls.num_instances, cls.test_file_path
+ )
os.makedirs(f"{cls.test_file_path}/embeddings")
- cls.num_instances = 8
- cls.indizes = [str(i) for i in range(0, cls.num_instances)]
+ # TODO: adapt the representation so they return non aggregated values.
Apply windowing operation instead
- video_data_loader = VideoLoader(
- cls.test_file_path + "/" + ModalityType.VIDEO.name + "/",
cls.indizes
- )
- audio_data_loader = AudioLoader(
- cls.test_file_path + "/" + ModalityType.AUDIO.name + "/",
cls.indizes
+ cls.bert = cls.data_generator.modalities_by_type[
+ ModalityType.TEXT
+ ].apply_representation(Bert())
+ cls.mel_spe = (
+ cls.data_generator.modalities_by_type[ModalityType.AUDIO]
+ .apply_representation(MelSpectrogram())
+ .flatten()
)
- text_data_loader = TextLoader(
- cls.test_file_path + "/" + ModalityType.TEXT.name + "/",
cls.indizes
+ cls.resnet = (
+ cls.data_generator.modalities_by_type[ModalityType.VIDEO]
+ .apply_representation(ResNet())
+ .window(10, "avg")
+ .flatten()
)
- video = UnimodalModality(video_data_loader, ModalityType.VIDEO)
- audio = UnimodalModality(audio_data_loader, ModalityType.AUDIO)
- text = UnimodalModality(text_data_loader, ModalityType.TEXT)
- cls.data_generator = TestDataGenerator([video, audio, text],
cls.test_file_path)
- cls.data_generator.create_multimodal_data(cls.num_instances)
-
- cls.bert = text.apply_representation(Bert())
- cls.mel_spe = audio.apply_representation(MelSpectrogram())
- cls.resnet = video.apply_representation(ResNet())
-
cls.mods = [cls.bert, cls.mel_spe, cls.resnet]
split = train_test_split(
- cls.indizes, cls.data_generator.labels, test_size=0.2,
random_state=42
+ cls.data_generator.indices,
+ cls.data_generator.labels,
+ test_size=0.2,
+ random_state=42,
)
cls.train_indizes, cls.val_indizes = [int(i) for i in split[0]], [
int(i) for i in split[1]
]
for m in cls.mods:
- m.data = scale_data(m.data, [int(i) for i in cls.train_indizes])
+ m.data = scale_data(m.data, cls.train_indizes)
cls.representations = [
Concatenation(),
Average(),
- RowMax(),
+ RowMax(100),
Multiplication(),
Sum(),
LSTM(width=256, depth=3),
diff --git a/src/main/python/tests/scuro/test_multimodal_join.py
b/src/main/python/tests/scuro/test_multimodal_join.py
new file mode 100644
index 0000000000..c48f5f56b2
--- /dev/null
+++ b/src/main/python/tests/scuro/test_multimodal_join.py
@@ -0,0 +1,135 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# TODO: Test edge cases: unequal number of audio-video timestamps (should
still work and add the average over all audio/video samples)
+
+import shutil
+import unittest
+
+from systemds.scuro.modality.joined import JoinCondition
+from systemds.scuro.representations.window import WindowAggregation
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.resnet import ResNet
+from tests.scuro.data_generator import setup_data
+
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.modality.type import ModalityType
+
+
+class TestMultimodalJoin(unittest.TestCase):
+ test_file_path = None
+ mods = None
+ text = None
+ audio = None
+ video = None
+ data_generator = None
+ num_instances = 0
+ indizes = []
+
+ @classmethod
+ def setUpClass(cls):
+ cls.test_file_path = "join_test_data"
+ cls.num_instances = 4
+ cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO]
+
+ cls.data_generator = setup_data(cls.mods, cls.num_instances,
cls.test_file_path)
+
+ @classmethod
+ def tearDownClass(cls):
+ print("Cleaning up test data")
+ shutil.rmtree(cls.test_file_path)
+
+ def test_video_audio_join(self):
+ self._execute_va_join()
+
+ def test_chunked_video_audio_join(self):
+ self._execute_va_join(2)
+
+ def test_video_chunked_audio_join(self):
+ self._execute_va_join(None, 2)
+
+ def test_chunked_video_chunked_audio_join(self):
+ self._execute_va_join(2, 2)
+
+ def test_audio_video_join(self):
+ # Audio has a much higher frequency than video, hence we would need to
+ # duplicate or interpolate frames to match them to the audio frequency
+ self._execute_av_join()
+
+ # TODO
+ # def test_chunked_audio_video_join(self):
+ # self._execute_av_join(2)
+
+ # TODO
+ # def test_chunked_audio_chunked_video_join(self):
+ # self._execute_av_join(2, 2)
+
+ def _execute_va_join(self, l_chunk_size=None, r_chunk_size=None):
+ video, audio = self._prepare_data(l_chunk_size, r_chunk_size)
+ self._join(video, audio, 2)
+
+ def _execute_av_join(self, l_chunk_size=None, r_chunk_size=None):
+ video, audio = self._prepare_data(l_chunk_size, r_chunk_size)
+ self._join(audio, video, 2)
+
+ def _prepare_data(self, l_chunk_size=None, r_chunk_size=None):
+ video_data_loader = VideoLoader(
+ self.data_generator.get_modality_path(ModalityType.VIDEO),
+ self.data_generator.indices,
+ chunk_size=l_chunk_size,
+ )
+ video = UnimodalModality(video_data_loader, ModalityType.VIDEO)
+
+ audio_data_loader = AudioLoader(
+ self.data_generator.get_modality_path(ModalityType.AUDIO),
+ self.data_generator.indices,
+ r_chunk_size,
+ )
+ audio = UnimodalModality(audio_data_loader, ModalityType.AUDIO)
+
+ mel_audio = audio.apply_representation(MelSpectrogram())
+
+ return video, mel_audio
+
+ def _join(self, left_modality, right_modality, window_size):
+ resnet_modality = (
+ left_modality.join(
+ right_modality, JoinCondition("timestamp", "timestamp", "<")
+ )
+ .apply_representation(
+ ResNet(layer="layer1.0.conv2", model_name="ResNet50"),
+ WindowAggregation(window_size=window_size,
aggregation_function="mean"),
+ )
+ .combine("concat")
+ )
+
+ assert resnet_modality.left_modality is not None
+ assert resnet_modality.right_modality is not None
+ assert len(resnet_modality.left_modality.data) == self.num_instances
+ assert len(resnet_modality.right_modality.data) == self.num_instances
+ assert resnet_modality.data is not None
+
+ return resnet_modality
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/main/python/tests/scuro/test_unimodal_representations.py
b/src/main/python/tests/scuro/test_unimodal_representations.py
new file mode 100644
index 0000000000..d566830697
--- /dev/null
+++ b/src/main/python/tests/scuro/test_unimodal_representations.py
@@ -0,0 +1,120 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import os
+import shutil
+import unittest
+
+from systemds.scuro.representations.bow import BoW
+from systemds.scuro.representations.word2vec import W2V
+from systemds.scuro.representations.tfidf import TfIdf
+from systemds.scuro.modality.unimodal_modality import UnimodalModality
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.resnet import ResNet
+from tests.scuro.data_generator import setup_data
+
+from systemds.scuro.dataloader.audio_loader import AudioLoader
+from systemds.scuro.dataloader.video_loader import VideoLoader
+from systemds.scuro.dataloader.text_loader import TextLoader
+from systemds.scuro.modality.type import ModalityType
+
+
+class TestUnimodalRepresentations(unittest.TestCase):
+ test_file_path = None
+ mods = None
+ text = None
+ audio = None
+ video = None
+ data_generator = None
+ num_instances = 0
+ indizes = []
+
+ @classmethod
+ def setUpClass(cls):
+ cls.test_file_path = "unimodal_test_data"
+
+ cls.num_instances = 4
+ cls.mods = [ModalityType.VIDEO, ModalityType.AUDIO, ModalityType.TEXT]
+
+ cls.data_generator = setup_data(cls.mods, cls.num_instances,
cls.test_file_path)
+ os.makedirs(f"{cls.test_file_path}/embeddings")
+
+ @classmethod
+ def tearDownClass(cls):
+ print("Cleaning up test data")
+ shutil.rmtree(cls.test_file_path)
+
+ def test_audio_representations(self):
+ audio_representations = [MelSpectrogram()] # TODO: add FFT, TFN, 1DCNN
+ audio_data_loader = AudioLoader(
+ self.data_generator.get_modality_path(ModalityType.AUDIO),
+ self.data_generator.indices,
+ )
+ audio = UnimodalModality(audio_data_loader, ModalityType.AUDIO)
+
+ for representation in audio_representations:
+ r = audio.apply_representation(representation)
+ assert r.data is not None
+ assert len(r.data) == self.num_instances
+
+ def test_video_representations(self):
+ video_representations = [ResNet()] # Todo: add other video
representations
+ video_data_loader = VideoLoader(
+ self.data_generator.get_modality_path(ModalityType.VIDEO),
+ self.data_generator.indices,
+ )
+ video = UnimodalModality(video_data_loader, ModalityType.VIDEO)
+ for representation in video_representations:
+ r = video.apply_representation(representation)
+ assert r.data is not None
+ assert len(r.data) == self.num_instances
+
+ def test_text_representations(self):
+ # TODO: check params fro BOW, W2V, TfIdf
+ test_representations = [BoW(2, 2), W2V(5, 2, 2), TfIdf(2), Bert()]
+ text_data_loader = TextLoader(
+ self.data_generator.get_modality_path(ModalityType.TEXT),
+ self.data_generator.indices,
+ )
+ text = UnimodalModality(text_data_loader, ModalityType.TEXT)
+
+ for representation in test_representations:
+ r = text.apply_representation(representation)
+ assert r.data is not None
+ assert len(r.data) == self.num_instances
+
+ def test_chunked_video_representations(self):
+ video_representations = [ResNet()]
+ video_data_loader = VideoLoader(
+ self.data_generator.get_modality_path(ModalityType.VIDEO),
+ self.data_generator.indices,
+ chunk_size=2,
+ )
+ video = UnimodalModality(video_data_loader, ModalityType.VIDEO)
+ for representation in video_representations:
+ r = video.apply_representation(representation)
+ assert r.data is not None
+ assert len(r.data) == self.num_instances
+
+
+if __name__ == "__main__":
+ unittest.main()