This is an automated email from the ASF dual-hosted git repository.

cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cd4c828f2a [SYSTEMDS-3835] Add additional visual representations
cd4c828f2a is described below

commit cd4c828f2af696fab3659f96377f168e97a381f2
Author: Christina Dionysio <[email protected]>
AuthorDate: Thu Nov 13 18:03:25 2025 +0100

    [SYSTEMDS-3835] Add additional visual representations
    
    This patch adds new visual (image, video) representations, and a test 
utility for the image modality.
---
 .github/workflows/python.yml                       |   5 +-
 src/main/python/systemds/scuro/__init__.py         |   8 +-
 src/main/python/systemds/scuro/modality/type.py    |   9 ++
 .../systemds/scuro/modality/unimodal_modality.py   |   5 +-
 .../python/systemds/scuro/representations/clip.py  | 133 +++++++++++++++++++++
 .../scuro/representations/color_histogram.py       | 111 +++++++++++++++++
 .../systemds/scuro/representations/resnet.py       |  10 +-
 .../representations/swin_video_transformer.py      |   5 +-
 .../scuro/representations/{resnet.py => vgg.py}    | 107 +++++------------
 .../python/systemds/scuro/representations/x3d.py   | 117 ++++++++++--------
 .../python/systemds/scuro/utils/torch_dataset.py   |  28 +++--
 src/main/python/tests/scuro/data_generator.py      |  38 ++++--
 .../python/tests/scuro/test_operator_registry.py   |  49 +++++++-
 .../tests/scuro/test_unimodal_representations.py   |  33 ++++-
 14 files changed, 494 insertions(+), 164 deletions(-)

diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml
index bf20101e6c..fcd8bf8c84 100644
--- a/.github/workflows/python.yml
+++ b/.github/workflows/python.yml
@@ -171,7 +171,8 @@ jobs:
           h5py \
           gensim \
           opt-einsum \
-          nltk
+          nltk \
+          fvcore
         kill $KA 
         cd src/main/python
-        python -m unittest discover -s tests/scuro -p 'test_*.py' -v
\ No newline at end of file
+        python -m unittest discover -s tests/scuro -p 'test_*.py' -v
diff --git a/src/main/python/systemds/scuro/__init__.py 
b/src/main/python/systemds/scuro/__init__.py
index c1db4c3d49..e74ae53f36 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -77,6 +77,7 @@ from systemds.scuro.representations.window_aggregation import 
(
 )
 from systemds.scuro.representations.word2vec import W2V
 from systemds.scuro.representations.x3d import X3D
+from systemds.scuro.representations.color_histogram import ColorHistogram
 from systemds.scuro.models.model import Model
 from systemds.scuro.models.discrete_model import DiscreteModel
 from systemds.scuro.modality.joined import JoinedModality
@@ -97,7 +98,8 @@ from systemds.scuro.representations.covarep_audio_features 
import (
 )
 from systemds.scuro.drsearch.multimodal_optimizer import MultimodalOptimizer
 from systemds.scuro.drsearch.unimodal_optimizer import UnimodalOptimizer
-
+from systemds.scuro.representations.vgg import VGG19
+from systemds.scuro.representations.clip import CLIPText, CLIPVisual
 
 __all__ = [
     "BaseLoader",
@@ -120,6 +122,7 @@ __all__ = [
     "MFCC",
     "Hadamard",
     "OpticalFlow",
+    "ColorHistogram",
     "Representation",
     "NPY",
     "JSON",
@@ -169,4 +172,7 @@ __all__ = [
     "Quantile",
     "BandpowerFFT",
     "ZeroCrossingRate",
+    "VGG19",
+    "CLIPVisual",
+    "CLIPText",
 ]
diff --git a/src/main/python/systemds/scuro/modality/type.py 
b/src/main/python/systemds/scuro/modality/type.py
index ef1e0eeab2..382e2631ad 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/modality/type.py
@@ -254,7 +254,16 @@ class ModalityType(Flag):
         md["data_layout"]["representation"] = DataLayout.NESTED_LEVEL
         md["data_layout"]["type"] = float
         md["data_layout"]["shape"] = (width, height, num_channels)
+        return md
 
+    def create_image_metadata(self, width, height, num_channels):
+        md = deepcopy(self.get_schema())
+        md["width"] = width
+        md["height"] = height
+        md["num_channels"] = num_channels
+        md["data_layout"]["representation"] = DataLayout.SINGLE_LEVEL
+        md["data_layout"]["type"] = float
+        md["data_layout"]["shape"] = (width, height, num_channels)
         return md
 
 
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py 
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 373921e95c..4ae1067c62 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -165,8 +165,9 @@ class UnimodalModality(Modality):
                         padded = np.pad(
                             embeddings,
                             pad_width=(
-                                (0, padding_needed),
-                                (0, 0),
+                                (0, padding_needed)
+                                if len(embeddings.shape) == 1
+                                else ((0, padding_needed), (0, 0))
                             ),
                             mode="constant",
                             constant_values=0,
diff --git a/src/main/python/systemds/scuro/representations/clip.py 
b/src/main/python/systemds/scuro/representations/clip.py
new file mode 100644
index 0000000000..044d0f795a
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/clip.py
@@ -0,0 +1,133 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import numpy as np
+from torchvision import transforms
+
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+import torch
+from systemds.scuro.representations.utils import save_embeddings
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.drsearch.operator_registry import register_representation
+from transformers import CLIPProcessor, CLIPModel
+
+from systemds.scuro.utils.converter import numpy_dtype_to_torch_dtype
+from systemds.scuro.utils.static_variables import get_device
+from systemds.scuro.utils.torch_dataset import CustomDataset
+
+
+@register_representation(ModalityType.VIDEO)
+class CLIPVisual(UnimodalRepresentation):
+    def __init__(self, output_file=None):
+        parameters = {}
+        super().__init__("CLIPVisual", ModalityType.EMBEDDING, parameters)
+        self.model = 
CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(
+            get_device()
+        )
+        self.processor = 
CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+        self.output_file = output_file
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(modality, self)
+        self.data_type = numpy_dtype_to_torch_dtype(modality.data_type)
+        if next(self.model.parameters()).dtype != self.data_type:
+            self.model = self.model.to(self.data_type)
+
+        embeddings = self.create_visual_embeddings(modality)
+
+        if self.output_file is not None:
+            save_embeddings(embeddings, self.output_file)
+
+        transformed_modality.data = list(embeddings.values())
+        return transformed_modality
+
+    def create_visual_embeddings(self, modality):
+        tf = transforms.Compose([transforms.ToPILImage(), 
transforms.ToTensor()])
+        dataset = CustomDataset(
+            modality.data,
+            self.data_type,
+            get_device(),
+            (modality.metadata[0]["width"], modality.metadata[0]["height"]),
+            tf=tf,
+        )
+        embeddings = {}
+        for instance in torch.utils.data.DataLoader(dataset):
+            id = int(instance["id"][0])
+            frames = instance["data"][0]
+            embeddings[id] = []
+            batch_size = 64
+
+            for start_index in range(0, len(frames), batch_size):
+                end_index = min(start_index + batch_size, len(frames))
+                frame_ids_range = range(start_index, end_index)
+                frame_batch = frames[frame_ids_range]
+
+                inputs = self.processor(images=frame_batch, 
return_tensors="pt")
+                with torch.no_grad():
+                    output = self.model.get_image_features(**inputs)
+
+                if len(output.shape) > 2:
+                    output = torch.nn.functional.adaptive_avg_pool2d(output, 
(1, 1))
+
+                embeddings[id].extend(
+                    torch.flatten(output, 1)
+                    .detach()
+                    .cpu()
+                    .float()
+                    .numpy()
+                    .astype(modality.data_type)
+                )
+
+            embeddings[id] = np.array(embeddings[id])
+        return embeddings
+
+
+@register_representation(ModalityType.TEXT)
+class CLIPText(UnimodalRepresentation):
+    def __init__(self, output_file=None):
+        parameters = {}
+        super().__init__("CLIPText", ModalityType.EMBEDDING, parameters)
+        self.model = 
CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(
+            get_device()
+        )
+        self.processor = 
CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
+        self.output_file = output_file
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(modality, self)
+
+        embeddings = self.create_text_embeddings(modality.data, self.model)
+
+        if self.output_file is not None:
+            save_embeddings(embeddings, self.output_file)
+
+        transformed_modality.data = embeddings
+        return transformed_modality
+
+    def create_text_embeddings(self, data, model):
+        embeddings = []
+        for d in data:
+            inputs = self.processor(text=d, return_tensors="pt", padding=True)
+            with torch.no_grad():
+                text_embedding = model.get_text_features(**inputs)
+                embeddings.append(text_embedding.squeeze().numpy().reshape(1, 
-1))
+
+        return embeddings
diff --git a/src/main/python/systemds/scuro/representations/color_histogram.py 
b/src/main/python/systemds/scuro/representations/color_histogram.py
new file mode 100644
index 0000000000..6412b1979d
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/color_histogram.py
@@ -0,0 +1,111 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import numpy as np
+import cv2
+
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.modality.transformed import TransformedModality
+
+
+class ColorHistogram(UnimodalRepresentation):
+    def __init__(
+        self,
+        color_space="RGB",
+        bins=32,
+        normalize=True,
+        aggregation="mean",
+        output_file=None,
+    ):
+        super().__init__(
+            "ColorHistogram", ModalityType.EMBEDDING, self._get_parameters()
+        )
+        self.color_space = color_space
+        self.bins = bins
+        self.normalize = normalize
+        self.aggregation = aggregation
+        self.output_file = output_file
+
+    def _get_parameters(self):
+        return {
+            "color_space": ["RGB", "HSV", "GRAY"],
+            "bins": [8, 16, 32, 64, 128, 256, (8, 8, 8), (16, 16, 16)],
+            "normalize": [True, False],
+            "aggregation": ["mean", "max", "concat"],
+        }
+
+    def compute_histogram(self, image):
+        if self.color_space == "HSV":
+            img = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
+            channels = [0, 1, 2]
+        elif self.color_space == "GRAY":
+            img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
+            channels = [0]
+        else:
+            img = image
+            channels = [0, 1, 2]
+
+        hist = self._region_histogram(img, channels)
+        return hist
+
+    def _region_histogram(self, img, channels):
+        if isinstance(self.bins, tuple):
+            bins = self.bins
+        elif len(channels) > 1:
+            bins = [self.bins] * len(channels)
+        else:
+            bins = [self.bins]
+        hist = cv2.calcHist([img], channels, None, bins, [0, 256] * 
len(channels))
+        hist = hist.flatten()
+        if self.normalize:
+            hist_sum = np.sum(hist)
+            if hist_sum > 0:
+                hist /= hist_sum
+        return hist.astype(np.float32)
+
+    def transform(self, modality):
+        if modality.modality_type == ModalityType.IMAGE:
+            images = modality.data
+            hist_list = [self.compute_histogram(img) for img in images]
+            transformed_modality = TransformedModality(
+                modality, self, ModalityType.EMBEDDING
+            )
+            transformed_modality.data = hist_list
+            return transformed_modality
+        elif modality.modality_type == ModalityType.VIDEO:
+            embeddings = []
+            for vid in modality.data:
+                frame_hists = [self.compute_histogram(frame) for frame in vid]
+                if self.aggregation == "mean":
+                    hist = np.mean(frame_hists, axis=0)
+                elif self.aggregation == "max":
+                    hist = np.max(frame_hists, axis=0)
+                elif self.aggregation == "concat":
+                    hist = np.concatenate(frame_hists)
+                embeddings.append(hist)
+            transformed_modality = TransformedModality(
+                modality, self, ModalityType.EMBEDDING
+            )
+            transformed_modality.data = embeddings
+            return transformed_modality
+        else:
+            raise ValueError("Unsupported data format for 
HistogramRepresentation")
diff --git a/src/main/python/systemds/scuro/representations/resnet.py 
b/src/main/python/systemds/scuro/representations/resnet.py
index 7bb94d8bfd..f544e6a46f 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/resnet.py
@@ -144,17 +144,21 @@ class ResNet(UnimodalRepresentation):
             embeddings[video_id] = []
             batch_size = 64
 
+            if modality.modality_type == ModalityType.IMAGE:
+                frames = frames.unsqueeze(0)
+
             for start_index in range(0, len(frames), batch_size):
                 end_index = min(start_index + batch_size, len(frames))
                 frame_ids_range = range(start_index, end_index)
                 frame_batch = frames[frame_ids_range]
 
                 _ = self.model(frame_batch)
-                values = res5c_output
-                pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 
1))
+                output = res5c_output
+                if len(output.shape) > 2:
+                    output = torch.nn.functional.adaptive_avg_pool2d(output, 
(1, 1))
 
                 embeddings[video_id].extend(
-                    torch.flatten(pooled, 1)
+                    torch.flatten(output, 1)
                     .detach()
                     .cpu()
                     .float()
diff --git 
a/src/main/python/systemds/scuro/representations/swin_video_transformer.py 
b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
index c0b7ab38ab..e8511dd0cf 100644
--- a/src/main/python/systemds/scuro/representations/swin_video_transformer.py
+++ b/src/main/python/systemds/scuro/representations/swin_video_transformer.py
@@ -34,7 +34,7 @@ from systemds.scuro.utils.torch_dataset import CustomDataset
 from systemds.scuro.utils.static_variables import get_device
 
 
-# @register_representation([ModalityType.VIDEO])
+@register_representation([ModalityType.VIDEO])
 class SwinVideoTransformer(UnimodalRepresentation):
     def __init__(self, layer_name="avgpool"):
         parameters = {
@@ -50,7 +50,7 @@ class SwinVideoTransformer(UnimodalRepresentation):
             ],
         }
         self.data_type = torch.float
-        super().__init__("SwinVideoTransformer", ModalityType.TIMESERIES, 
parameters)
+        super().__init__("SwinVideoTransformer", ModalityType.EMBEDDING, 
parameters)
         self.layer_name = layer_name
         self.model = 
swin3d_t(weights=models.video.Swin3D_T_Weights.KINETICS400_V1).to(
             get_device()
@@ -95,6 +95,7 @@ class SwinVideoTransformer(UnimodalRepresentation):
                 .detach()
                 .cpu()
                 .numpy()
+                .flatten()
                 .astype(modality.data_type)
             )
 
diff --git a/src/main/python/systemds/scuro/representations/resnet.py 
b/src/main/python/systemds/scuro/representations/vgg.py
similarity index 57%
copy from src/main/python/systemds/scuro/representations/resnet.py
copy to src/main/python/systemds/scuro/representations/vgg.py
index 7bb94d8bfd..374586f2b9 100644
--- a/src/main/python/systemds/scuro/representations/resnet.py
+++ b/src/main/python/systemds/scuro/representations/vgg.py
@@ -26,6 +26,7 @@ from typing import Tuple, Any
 from systemds.scuro.drsearch.operator_registry import register_representation
 import torch.utils.data
 import torch
+import re
 import torchvision.models as models
 import numpy as np
 from systemds.scuro.modality.type import ModalityType
@@ -33,13 +34,12 @@ from systemds.scuro.utils.static_variables import get_device
 
 
 @register_representation([ModalityType.IMAGE, ModalityType.VIDEO])
-class ResNet(UnimodalRepresentation):
-    def __init__(self, model_name="ResNet18", layer="avgpool", 
output_file=None):
+class VGG19(UnimodalRepresentation):
+    def __init__(self, layer="classifier.0", output_file=None):
         self.data_type = torch.bfloat16
-        self.model_name = model_name
+        self.model = 
models.vgg19(weights=models.VGG19_Weights.DEFAULT).to(get_device())
         parameters = self._get_parameters()
-        super().__init__("ResNet", ModalityType.EMBEDDING, parameters)
-
+        super().__init__("VGG19", ModalityType.EMBEDDING, parameters)
         self.output_file = output_file
         self.layer_name = layer
         self.model.eval()
@@ -52,65 +52,16 @@ class ResNet(UnimodalRepresentation):
 
         self.model.fc = Identity()
 
-    @property
-    def model_name(self):
-        return self._model_name
-
-    @model_name.setter
-    def model_name(self, model_name):
-        self._model_name = model_name
-        if model_name == "ResNet18":
-            self.model = (
-                models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
-                .to(get_device())
-                .to(self.data_type)
-            )
-
-        elif model_name == "ResNet34":
-            self.model = 
models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(
-                get_device()
-            )
-            self.model = self.model.to(self.data_type)
-        elif model_name == "ResNet50":
-            self.model = (
-                models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
-                .to(get_device())
-                .to(self.data_type)
-            )
+    def _get_parameters(self):
+        parameters = {"layer_name": []}
 
-        elif model_name == "ResNet101":
-            self.model = (
-                models.resnet101(weights=models.ResNet101_Weights.DEFAULT)
-                .to(get_device())
-                .to(self.data_type)
-            )
+        parameters["layer_name"] = [
+            "features.35",
+            "classifier.0",
+            "classifier.3",
+            "classifier.6",
+        ]
 
-        elif model_name == "ResNet152":
-            self.model = (
-                models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
-                .to(get_device())
-                .to(self.data_type)
-            )
-        else:
-            raise NotImplementedError
-
-    def _get_parameters(self, high_level=True):
-        parameters = {"model_name": [], "layer_name": []}
-        for m in ["ResNet18", "ResNet34", "ResNet50", "ResNet101", 
"ResNet152"]:
-            parameters["model_name"].append(m)
-
-        if high_level:
-            parameters["layer_name"] = [
-                "conv1",
-                "layer1",
-                "layer2",
-                "layer3",
-                "layer4",
-                "avgpool",
-            ]
-        else:
-            for name, layer in self.model.named_modules():
-                parameters["layer_name"].append(name)
         return parameters
 
     def transform(self, modality):
@@ -121,28 +72,32 @@ class ResNet(UnimodalRepresentation):
         dataset = CustomDataset(modality.data, self.data_type, get_device())
         embeddings = {}
 
-        res5c_output = None
+        activations = {}
 
-        def get_features(name_):
+        def get_activation(name_):
             def hook(
                 _module: torch.nn.Module, input_: Tuple[torch.Tensor], output: 
Any
             ):
-                nonlocal res5c_output
-                res5c_output = output
+                activations[name_] = output
 
             return hook
 
-        if self.layer_name:
-            for name, layer in self.model.named_modules():
-                if name == self.layer_name:
-                    layer.register_forward_hook(get_features(name))
-                    break
+        digit = re.findall(r"\d+", self.layer_name)[0]
+        if "feature" in self.layer_name:
+            self.model.features[int(digit)].register_forward_hook(
+                get_activation(self.layer_name)
+            )
+        else:
+
+            self.model.classifier[int(digit)].register_forward_hook(
+                get_activation(self.layer_name)
+            )
 
         for instance in torch.utils.data.DataLoader(dataset):
             video_id = instance["id"][0]
             frames = instance["data"][0]
             embeddings[video_id] = []
-            batch_size = 64
+            batch_size = 32
 
             for start_index in range(0, len(frames), batch_size):
                 end_index = min(start_index + batch_size, len(frames))
@@ -150,11 +105,11 @@ class ResNet(UnimodalRepresentation):
                 frame_batch = frames[frame_ids_range]
 
                 _ = self.model(frame_batch)
-                values = res5c_output
-                pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 
1))
-
+                output = activations[self.layer_name]
+                if len(output.shape) == 4:
+                    output = torch.nn.functional.adaptive_avg_pool2d(output, 
(1, 1))
                 embeddings[video_id].extend(
-                    torch.flatten(pooled, 1)
+                    torch.flatten(output, 1)
                     .detach()
                     .cpu()
                     .float()
diff --git a/src/main/python/systemds/scuro/representations/x3d.py 
b/src/main/python/systemds/scuro/representations/x3d.py
index 1629ac6f30..7701865f82 100644
--- a/src/main/python/systemds/scuro/representations/x3d.py
+++ b/src/main/python/systemds/scuro/representations/x3d.py
@@ -18,34 +18,27 @@
 # under the License.
 #
 # -------------------------------------------------------------
+from systemds.scuro.utils.static_variables import get_device
 from systemds.scuro.utils.torch_dataset import CustomDataset
 from systemds.scuro.modality.transformed import TransformedModality
 from systemds.scuro.representations.unimodal import UnimodalRepresentation
-from typing import Callable, Dict, Tuple, Any
+from typing import Tuple, Any
 import torch.utils.data
 import torch
 from torchvision.models.video import r3d_18, s3d
 import torchvision.models as models
-import torchvision.transforms as transforms
 import numpy as np
 from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.drsearch.operator_registry import register_representation
-import math
 
-if torch.backends.mps.is_available():
-    DEVICE = torch.device("mps")
-elif torch.cuda.is_available():
-    DEVICE = torch.device("cuda")
-else:
-    DEVICE = torch.device("cpu")
 
-
-# @register_representation([ModalityType.VIDEO])
+@register_representation([ModalityType.VIDEO])
 class X3D(UnimodalRepresentation):
-    def __init__(self, layer="avgpool", model_name="r3d", output_file=None):
+    def __init__(self, layer="classifier.1", model_name="s3d", 
output_file=None):
+        self.data_type = torch.float32
         self.model_name = model_name
         parameters = self._get_parameters()
-        super().__init__("X3D", ModalityType.TIMESERIES, parameters)
+        super().__init__("X3D", ModalityType.EMBEDDING, parameters)
 
         self.output_file = output_file
         self.layer_name = layer
@@ -67,25 +60,37 @@ class X3D(UnimodalRepresentation):
     def model_name(self, model_name):
         self._model_name = model_name
         if model_name == "r3d":
-            self.model = r3d_18(pretrained=True).to(DEVICE)
+            self.model = r3d_18(pretrained=True).to(get_device())
         elif model_name == "s3d":
-            self.model = 
s3d(weights=models.video.S3D_Weights.DEFAULT).to(DEVICE)
+            self.model = 
s3d(weights=models.video.S3D_Weights.DEFAULT).to(get_device())
         else:
             raise NotImplementedError
 
     def _get_parameters(self, high_level=True):
         parameters = {"model_name": [], "layer_name": []}
-        for m in ["r3d", "s3d"]:
+        for m in ["c3d", "s3d"]:
             parameters["model_name"].append(m)
 
         if high_level:
             parameters["layer_name"] = [
-                "conv1",
-                "layer1",
-                "layer2",
-                "layer3",
-                "layer4",
+                "features.1",
+                "features.2",
+                "features.3",
+                "features.4",
+                "features.5",
+                "features.6",
+                "features.7",
+                "features.8",
+                "features.9",
+                "features.10",
+                "features.11",
+                "features.12",
+                "features.13",
+                "features.14",
+                "features.15",
                 "avgpool",
+                "classifier.0",
+                "classifier.1",
             ]
         else:
             for name, layer in self.model.named_modules():
@@ -93,17 +98,18 @@ class X3D(UnimodalRepresentation):
         return parameters
 
     def transform(self, modality):
-        dataset = CustomDataset(modality.data)
+        dataset = CustomDataset(modality.data, self.data_type, get_device())
+
         embeddings = {}
 
-        res5c_output = None
+        activation = None
 
         def get_features(name_):
             def hook(
                 _module: torch.nn.Module, input_: Tuple[torch.Tensor], output: 
Any
             ):
-                nonlocal res5c_output
-                res5c_output = output
+                nonlocal activation
+                activation = output
 
             return hook
 
@@ -115,15 +121,20 @@ class X3D(UnimodalRepresentation):
 
         for instance in dataset:
             video_id = instance["id"]
-            frames = instance["data"].to(DEVICE)
+            frames = instance["data"].to(get_device())
             embeddings[video_id] = []
 
             frames = frames.unsqueeze(0).permute(0, 2, 1, 3, 4)
+            if frames.shape[2] < 14:
+                pad_width = (0, 0, 0, 0, 0, 14 - frames.shape[2], 0, 0, 0, 0)
+                frames = torch.nn.functional.pad(frames, pad_width, 
mode="constant")
             _ = self.model(frames)
-            values = res5c_output
+            values = activation
             pooled = torch.nn.functional.adaptive_avg_pool2d(values, (1, 1))
 
-            embeddings[video_id].extend(torch.flatten(pooled, 
1).detach().cpu().numpy())
+            embeddings[video_id].extend(
+                torch.flatten(pooled, 1).detach().cpu().numpy().flatten()
+            )
 
             embeddings[video_id] = np.array(embeddings[video_id])
 
@@ -137,13 +148,13 @@ class X3D(UnimodalRepresentation):
 
 
 class I3D(UnimodalRepresentation):
-    def __init__(self, layer="avgpool", model_name="i3d", output_file=None):
+    def __init__(self, layer="blocks.6", model_name="i3d", output_file=None):
         self.model_name = model_name
-        parameters = self._get_parameters()
         self.model = torch.hub.load(
             "facebookresearch/pytorchvideo", "i3d_r50", pretrained=True
-        ).to(DEVICE)
-        super().__init__("I3D", ModalityType.TIMESERIES, parameters)
+        ).to(get_device())
+        parameters = self._get_parameters()
+        super().__init__("I3D", ModalityType.EMBEDDING, parameters)
 
         self.output_file = output_file
         self.layer_name = layer
@@ -152,18 +163,17 @@ class I3D(UnimodalRepresentation):
             param.requires_grad = False
 
     def _get_parameters(self, high_level=True):
-        parameters = {"model_name": [], "layer_name": []}
-        for m in ["r3d", "s3d"]:
-            parameters["model_name"].append(m)
+        parameters = {"layer_name": []}
 
         if high_level:
             parameters["layer_name"] = [
-                "conv1",
-                "layer1",
-                "layer2",
-                "layer3",
-                "layer4",
-                "avgpool",
+                "blocks.0",
+                "blocks.1",
+                "blocks.2",
+                "blocks.3",
+                "blocks.4",
+                "blocks.5",
+                "blocks.6",
             ]
         else:
             for name, layer in self.model.named_modules():
@@ -171,28 +181,37 @@ class I3D(UnimodalRepresentation):
         return parameters
 
     def transform(self, modality):
-        dataset = CustomDataset(modality.data, torch.float32, DEVICE)
+        dataset = CustomDataset(modality.data, torch.float32, get_device())
         embeddings = {}
 
         features = None
 
-        def hook(module, input, output):
-            pooled = torch.nn.functional.adaptive_avg_pool3d(output, 
1).squeeze()
-            nonlocal features
-            features = pooled.detach().cpu().numpy()
+        def get_features(name_):
+            def hook(
+                _module: torch.nn.Module, input_: Tuple[torch.Tensor], output: 
Any
+            ):
+                # pooled = torch.nn.functional.adaptive_avg_pool3d(output, 
1).squeeze()
+                nonlocal features
+                features = output.detach().cpu().numpy()
+
+            return hook
 
-        handle = self.model.blocks[6].dropout.register_forward_hook(hook)
+        if self.layer_name:
+            for name, layer in self.model.named_modules():
+                if name == self.layer_name:
+                    layer.register_forward_hook(get_features(name))
+                    break
 
         for instance in dataset:
             video_id = instance["id"]
-            frames = instance["data"].to(DEVICE)
+            frames = instance["data"].to(get_device())
             embeddings[video_id] = []
 
             batch = torch.transpose(frames, 1, 0)
             batch = batch.unsqueeze(0)
             _ = self.model(batch)
 
-            embeddings[video_id] = features
+            embeddings[video_id] = features.flatten()
 
         transformed_modality = TransformedModality(
             modality, self, self.output_modality_type
diff --git a/src/main/python/systemds/scuro/utils/torch_dataset.py 
b/src/main/python/systemds/scuro/utils/torch_dataset.py
index c04be0ec7b..2a7ec1f963 100644
--- a/src/main/python/systemds/scuro/utils/torch_dataset.py
+++ b/src/main/python/systemds/scuro/utils/torch_dataset.py
@@ -20,12 +20,13 @@
 # -------------------------------------------------------------
 from typing import Dict
 
+import numpy as np
 import torch
 import torchvision.transforms as transforms
 
 
 class CustomDataset(torch.utils.data.Dataset):
-    def __init__(self, data, data_type, device, size=None):
+    def __init__(self, data, data_type, device, size=None, tf=None):
         self.data = data
         self.data_type = data_type
         self.device = device
@@ -33,7 +34,7 @@ class CustomDataset(torch.utils.data.Dataset):
         if size is None:
             self.size = (256, 224)
 
-        self.tf = transforms.Compose(
+        tf_default = transforms.Compose(
             [
                 transforms.ToPILImage(),
                 transforms.Resize(self.size[0]),
@@ -46,6 +47,11 @@ class CustomDataset(torch.utils.data.Dataset):
             ]
         )
 
+        if tf is None:
+            self.tf = tf_default
+        else:
+            self.tf = tf
+
     def __getitem__(self, index) -> Dict[str, object]:
         data = self.data[index]
         output = torch.empty(
@@ -54,12 +60,20 @@ class CustomDataset(torch.utils.data.Dataset):
             device=self.device,
         )
 
-        for i, d in enumerate(data):
-            if data[0].ndim < 3:
-                d = torch.tensor(d)
-                d = d.repeat(3, 1, 1)
+        if isinstance(data, np.ndarray) and data.ndim == 3:
+            # image
+            data = torch.tensor(data).permute(2, 0, 1)
+            output = self.tf(data).to(self.device)
+        else:
+            for i, d in enumerate(data):
+                if data[0].ndim < 3:
+                    d = torch.tensor(d)
+                    d = d.repeat(3, 1, 1)
 
-            output[i] = self.tf(d)
+                tf = self.tf(d)
+                if tf.shape[0] != 3:
+                    tf = tf[:3, :, :]
+                output[i] = tf
 
         return {"id": index, "data": output}
 
diff --git a/src/main/python/tests/scuro/data_generator.py 
b/src/main/python/tests/scuro/data_generator.py
index 11f034d9ce..7676906505 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -195,26 +195,38 @@ class ModalityRandomDataGenerator:
     def create_visual_modality(
         self, num_instances, max_num_frames=1, height=28, width=28
     ):
-        data = [
-            np.random.randint(
-                0,
-                256,
-                (np.random.randint(5, max_num_frames + 1), height, width, 3),
-                dtype=np.uint8,
-            )
-            for _ in range(num_instances)
-        ]
-        if max_num_frames == 1:
-            print(f"TODO: create image metadata")
-        else:
+        if max_num_frames > 1:
+            data = [
+                np.random.randint(
+                    0,
+                    256,
+                    (np.random.randint(1, max_num_frames + 1), height, width, 
3),
+                    dtype=np.uint8,
+                )
+                for _ in range(num_instances)
+            ]
             metadata = {
                 i: ModalityType.VIDEO.create_video_metadata(
                     30, data[i].shape[0], width, height, 3
                 )
                 for i in range(num_instances)
             }
+        else:
+            data = [
+                np.random.randint(
+                    0,
+                    256,
+                    (height, width, 3),
+                    dtype=np.uint8,
+                )
+                for _ in range(num_instances)
+            ]
+            metadata = {
+                i: ModalityType.IMAGE.create_image_metadata(width, height, 3)
+                for i in range(num_instances)
+            }
 
-        return (data, metadata)
+        return data, metadata
 
     def create_balanced_labels(self, num_instances, num_classes=2):
         if num_instances % num_classes != 0:
diff --git a/src/main/python/tests/scuro/test_operator_registry.py 
b/src/main/python/tests/scuro/test_operator_registry.py
index b5fa4b01b4..c33eb5fcc2 100644
--- a/src/main/python/tests/scuro/test_operator_registry.py
+++ b/src/main/python/tests/scuro/test_operator_registry.py
@@ -21,6 +21,7 @@
 
 import unittest
 
+from systemds.scuro import FrequencyMagnitude
 from systemds.scuro.representations.covarep_audio_features import (
     ZeroCrossing,
     Spectral,
@@ -29,6 +30,9 @@ from systemds.scuro.representations.covarep_audio_features 
import (
 )
 from systemds.scuro.representations.mfcc import MFCC
 from systemds.scuro.representations.swin_video_transformer import 
SwinVideoTransformer
+from systemds.scuro.representations.clip import CLIPText, CLIPVisual
+from systemds.scuro.representations.vgg import VGG19
+from systemds.scuro.representations.x3d import X3D, I3D
 from systemds.scuro.representations.wav2vec import Wav2Vec
 from systemds.scuro.representations.window_aggregation import (
     WindowAggregation,
@@ -39,6 +43,22 @@ from systemds.scuro.representations.bow import BoW
 from systemds.scuro.representations.word2vec import W2V
 from systemds.scuro.representations.tfidf import TfIdf
 from systemds.scuro.drsearch.operator_registry import Registry
+from systemds.scuro.representations.timeseries_representations import (
+    Max,
+    Mean,
+    Min,
+    RMS,
+    Sum,
+    Std,
+    Skew,
+    Kurtosis,
+    SpectralCentroid,
+    BandpowerFFT,
+    ACF,
+    Quantile,
+    ZeroCrossingRate,
+    FrequencyMagnitude,
+)
 from systemds.scuro.modality.type import ModalityType
 from systemds.scuro.representations.average import Average
 from systemds.scuro.representations.bert import Bert
@@ -49,7 +69,6 @@ from systemds.scuro.representations.mel_spectrogram import 
MelSpectrogram
 from systemds.scuro.representations.spectrogram import Spectrogram
 from systemds.scuro.representations.hadamard import Hadamard
 from systemds.scuro.representations.resnet import ResNet
-from systemds.scuro.representations.sum import Sum
 from systemds.scuro.representations.multimodal_attention_fusion import 
AttentionFusion
 
 
@@ -71,16 +90,34 @@ class TestOperatorRegistry(unittest.TestCase):
         registry = Registry()
         assert registry.get_representations(ModalityType.VIDEO) == [
             ResNet,
-            # SwinVideoTransformer,
+            SwinVideoTransformer,
+            X3D,
+            VGG19,
+            CLIPVisual,
         ]
 
-    # def test_timeseries_representations_in_registry(self):
-    #     registry = Registry()
-    #     assert registry.get_representations(ModalityType.TIMESERIES) == 
[ResNet]
+    def test_timeseries_representations_in_registry(self):
+        registry = Registry()
+        assert registry.get_representations(ModalityType.TIMESERIES) == [
+            Mean,
+            Min,
+            Max,
+            Sum,
+            Std,
+            Skew,
+            Quantile,
+            Kurtosis,
+            RMS,
+            ZeroCrossingRate,
+            ACF,
+            FrequencyMagnitude,
+            SpectralCentroid,
+            BandpowerFFT,
+        ]
 
     def test_text_representations_in_registry(self):
         registry = Registry()
-        for representation in [BoW, TfIdf, W2V, Bert]:
+        for representation in [CLIPText, BoW, TfIdf, W2V, Bert]:
             assert representation in registry.get_representations(
                 ModalityType.TEXT
             ), f"{representation} not in registry"
diff --git a/src/main/python/tests/scuro/test_unimodal_representations.py 
b/src/main/python/tests/scuro/test_unimodal_representations.py
index 6789786cfd..3bc28ee23c 100644
--- a/src/main/python/tests/scuro/test_unimodal_representations.py
+++ b/src/main/python/tests/scuro/test_unimodal_representations.py
@@ -23,6 +23,7 @@ import unittest
 import copy
 import numpy as np
 
+from systemds.scuro.representations.clip import CLIPVisual, CLIPText
 from systemds.scuro.representations.bow import BoW
 from systemds.scuro.representations.covarep_audio_features import (
     Spectral,
@@ -34,6 +35,9 @@ from systemds.scuro.representations.wav2vec import Wav2Vec
 from systemds.scuro.representations.spectrogram import Spectrogram
 from systemds.scuro.representations.word2vec import W2V
 from systemds.scuro.representations.tfidf import TfIdf
+from systemds.scuro.representations.x3d import X3D
+from systemds.scuro.representations.x3d import I3D
+from systemds.scuro.representations.color_histogram import ColorHistogram
 from systemds.scuro.modality.unimodal_modality import UnimodalModality
 from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
 from systemds.scuro.representations.mfcc import MFCC
@@ -59,6 +63,7 @@ from 
systemds.scuro.representations.timeseries_representations import (
     ZeroCrossingRate,
     BandpowerFFT,
 )
+from systemds.scuro.representations.vgg import VGG19
 
 
 class TestUnimodalRepresentations(unittest.TestCase):
@@ -143,11 +148,34 @@ class TestUnimodalRepresentations(unittest.TestCase):
             for i in range(self.num_instances):
                 assert (ts.data[i] == original_data[i]).all()
 
+    def test_image_representations(self):
+        image_representations = [ColorHistogram(), CLIPVisual(), ResNet()]
+
+        image_data, image_md = 
ModalityRandomDataGenerator().create_visual_modality(
+            self.num_instances, 1
+        )
+
+        image = UnimodalModality(
+            TestDataLoader(
+                self.indices, None, ModalityType.IMAGE, image_data, 
np.float32, image_md
+            )
+        )
+
+        for representation in image_representations:
+            r = image.apply_representation(representation)
+            assert r.data is not None
+            assert len(r.data) == self.num_instances
+
     def test_video_representations(self):
         video_representations = [
+            CLIPVisual(),
+            ColorHistogram(),
+            I3D(),
+            X3D(),
+            VGG19(),
             ResNet(),
             SwinVideoTransformer(),
-        ]  # Todo: add other video representations
+        ]
         video_data, video_md = 
ModalityRandomDataGenerator().create_visual_modality(
             self.num_instances, 60
         )
@@ -160,10 +188,9 @@ class TestUnimodalRepresentations(unittest.TestCase):
             r = video.apply_representation(representation)
             assert r.data is not None
             assert len(r.data) == self.num_instances
-            assert r.data[0].ndim == 2
 
     def test_text_representations(self):
-        test_representations = [BoW(2, 2), TfIdf(), W2V()]
+        test_representations = [CLIPText(), BoW(2, 2), TfIdf(), W2V()]
         text_data, text_md = ModalityRandomDataGenerator().create_text_data(
             self.num_instances
         )


Reply via email to