This is an automated email from the ASF dual-hosted git repository.

cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new b0002481cd [SYSTEMDS-3835] Add timeseries representations to Scuro
b0002481cd is described below

commit b0002481cd29f09bb1c010ad4d78fbe97a4191a2
Author: Christina Dionysio <[email protected]>
AuthorDate: Thu Nov 13 11:59:01 2025 +0100

    [SYSTEMDS-3835] Add timeseries representations to Scuro
    
    This patch adds new timeseries and represenations and a new mechanism to 
compute windowed timeseries representations in the unimodal optimizer.
---
 src/main/python/systemds/scuro/__init__.py         |  28 +++
 .../systemds/scuro/dataloader/timeseries_loader.py | 129 +++++++++++++
 .../scuro/drsearch/hyperparameter_tuner.py         |   4 +-
 .../scuro/drsearch/multimodal_optimizer.py         |   2 +-
 .../systemds/scuro/drsearch/unimodal_optimizer.py  |  79 ++++++--
 src/main/python/systemds/scuro/modality/type.py    |   2 +
 .../systemds/scuro/representations/context.py      |  17 +-
 .../representations/multimodal_attention_fusion.py |  23 ++-
 .../scuro/representations/representation.py        |   3 +
 .../representations/timeseries_representations.py  | 213 +++++++++++++++++++++
 .../scuro/representations/window_aggregation.py    |  56 ++++--
 src/main/python/tests/scuro/data_generator.py      |  17 ++
 .../python/tests/scuro/test_unimodal_optimizer.py  |  19 +-
 .../tests/scuro/test_unimodal_representations.py   |  51 +++++
 14 files changed, 601 insertions(+), 42 deletions(-)

diff --git a/src/main/python/systemds/scuro/__init__.py 
b/src/main/python/systemds/scuro/__init__.py
index b567b30024..c1db4c3d49 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -41,6 +41,21 @@ from systemds.scuro.representations.mel_spectrogram import 
MelSpectrogram
 from systemds.scuro.representations.multimodal_attention_fusion import (
     AttentionFusion,
 )
+from systemds.scuro.representations.timeseries_representations import (
+    Mean,
+    Max,
+    Min,
+    Kurtosis,
+    Skew,
+    Std,
+    RMS,
+    ACF,
+    FrequencyMagnitude,
+    SpectralCentroid,
+    Quantile,
+    ZeroCrossingRate,
+    BandpowerFFT,
+)
 from systemds.scuro.representations.mfcc import MFCC
 from systemds.scuro.representations.hadamard import Hadamard
 from systemds.scuro.representations.optical_flow import OpticalFlow
@@ -141,4 +156,17 @@ __all__ = [
     "AttentionFusion",
     "DynamicWindow",
     "StaticWindow",
+    "Min",
+    "Max",
+    "Mean",
+    "Std",
+    "Kurtosis",
+    "Skew",
+    "RMS",
+    "ACF",
+    "FrequencyMagnitude",
+    "SpectralCentroid",
+    "Quantile",
+    "BandpowerFFT",
+    "ZeroCrossingRate",
 ]
diff --git a/src/main/python/systemds/scuro/dataloader/timeseries_loader.py 
b/src/main/python/systemds/scuro/dataloader/timeseries_loader.py
new file mode 100644
index 0000000000..6887d6974f
--- /dev/null
+++ b/src/main/python/systemds/scuro/dataloader/timeseries_loader.py
@@ -0,0 +1,129 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import numpy as np
+from typing import List, Optional, Union
+import h5py
+
+
+from systemds.scuro.dataloader.base_loader import BaseLoader
+from systemds.scuro.modality.type import ModalityType
+
+
+class TimeseriesLoader(BaseLoader):
+    def __init__(
+        self,
+        source_path: str,
+        indices: List[str],
+        signal_names: List[str],
+        data_type: Union[np.dtype, str] = np.float32,
+        chunk_size: Optional[int] = None,
+        sampling_rate: Optional[int] = None,
+        normalize: bool = True,
+        file_format: str = "npy",
+    ):
+        super().__init__(
+            source_path, indices, data_type, chunk_size, 
ModalityType.TIMESERIES
+        )
+        self.signal_names = signal_names
+        self.sampling_rate = sampling_rate
+        self.normalize = normalize
+        self.file_format = file_format.lower()
+
+        if self.file_format not in ["npy", "mat", "hdf5", "txt"]:
+            raise ValueError(f"Unsupported file format: {self.file_format}")
+
+    def extract(self, file: str, index: Optional[Union[str, List[str]]] = 
None):
+        self.file_sanity_check(file)
+
+        if self.file_format == "npy":
+            data = self._load_npy(file)
+        elif self.file_format in ["txt", "csv"]:
+            with open(file, "r") as f:
+                first_line = f.readline()
+            if any(name in first_line for name in self.signal_names):
+                data = self._load_csv_with_header(file)
+            else:
+                data = self._load_txt(file)
+
+        if data.ndim > 1 and len(self.signal_names) == 1:
+            data = data.flatten()
+
+        if self.normalize:
+            data = self._normalize_signals(data)
+
+        if file:
+            self.metadata[index] = self.modality_type.create_ts_metadata(
+                self.signal_names, data, self.sampling_rate
+            )
+        else:
+            for i, index in enumerate(self.indices):
+                self.metadata[str(index)] = 
self.modality_type.create_ts_metadata(
+                    self.signal_names, data[i], self.sampling_rate
+                )
+        self.data.append(data)
+
+    def _normalize_signals(self, data: np.ndarray) -> np.ndarray:
+        if data.ndim == 1:
+            mean = np.mean(data)
+            std = np.std(data)
+            return (data - mean) / (std + 1e-8)
+        else:
+            for i in range(data.shape[1]):
+                mean = np.mean(data[:, i])
+                std = np.std(data[:, i])
+                data[:, i] = (data[:, i] - mean) / (std + 1e-8)
+            return data
+
+    def _load_npy(self, file: str) -> np.ndarray:
+        data = np.load(file).astype(self._data_type)
+        return data
+
+    def _load_txt(self, file: str) -> np.ndarray:
+        data = np.loadtxt(file).astype(self._data_type)
+        return data
+
+    def _load_txt_with_header(self, file: str) -> np.ndarray:
+        with open(file, "r") as f:
+            header = f.readline().strip().split()
+
+        col_indices = [
+            header.index(name) for name in self.signal_names if name in header
+        ]
+        data = np.loadtxt(file, dtype=self._data_type, skiprows=1, 
usecols=col_indices)
+        return data
+
+    def _load_csv_with_header(self, file: str, delimiter: str = None) -> 
np.ndarray:
+        import pandas as pd
+
+        if delimiter is None:
+            with open(file, "r") as f:
+                sample = f.read(1024)
+            if "," in sample:
+                delimiter = ","
+            elif "\t" in sample:
+                delimiter = "\t"
+            else:
+                delimiter = " "
+        df = pd.read_csv(file, delimiter=delimiter)
+
+        selected = [name for name in self.signal_names if name in df.columns]
+        data = df[selected].to_numpy(dtype=self._data_type)
+        return data
diff --git a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py 
b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
index 8902bb7d01..15136ac28f 100644
--- a/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
+++ b/src/main/python/systemds/scuro/drsearch/hyperparameter_tuner.py
@@ -146,8 +146,8 @@ class HyperparameterTuner:
                 visit_node(input_id)
             visited.add(node_id)
             if node.operation is not None:
-                if node.parameters:
-                    hyperparams.update(node.parameters)
+                if node.operation().parameters:
+                    hyperparams.update(node.operation().parameters)
                 reps.append(node.operation)
                 node_order.append(node_id)
             if node.modality_id is not None:
diff --git a/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
index 91d569bc59..fab3da1adc 100644
--- a/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/multimodal_optimizer.py
@@ -157,7 +157,7 @@ class MultimodalOptimizer:
                         fusion_id = new_builder.create_operation_node(
                             fusion_op.__class__,
                             [left_root, right_root],
-                            fusion_op.parameters,
+                            fusion_op.get_current_parameters(),
                         )
                         variants.append((new_builder, fusion_id))
 
diff --git a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py 
b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
index 10b127f5b6..f678700bdc 100644
--- a/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
+++ b/src/main/python/systemds/scuro/drsearch/unimodal_optimizer.py
@@ -95,6 +95,11 @@ class UnimodalOptimizer:
         with open(file_name, "wb") as f:
             pickle.dump(self.operator_performance.results, f)
 
+    def load_results(self, file_name):
+        with open(file_name, "rb") as f:
+            self.operator_performance.results = pickle.load(f)
+            self.operator_performance.cache = None
+
     def optimize_parallel(self, n_workers=None):
         if n_workers is None:
             n_workers = min(len(self.modalities), mp.cpu_count())
@@ -177,7 +182,9 @@ class UnimodalOptimizer:
                 builder = self.builders[modality.modality_id]
                 agg_operator = AggregatedRepresentation()
                 rep_node_id = builder.create_operation_node(
-                    agg_operator.__class__, [dag.root_node_id], 
agg_operator.parameters
+                    agg_operator.__class__,
+                    [dag.root_node_id],
+                    agg_operator.get_current_parameters(),
                 )
                 dag = builder.build(rep_node_id)
                 representations = dag.execute([modality])
@@ -207,7 +214,7 @@ class UnimodalOptimizer:
                     rep_node_id = builder.create_operation_node(
                         agg_operator.__class__,
                         [dag.root_node_id],
-                        agg_operator.parameters,
+                        agg_operator.get_current_parameters(),
                     )
                     dag = builder.build(rep_node_id)
                     representations = dag.execute([modality])
@@ -235,7 +242,7 @@ class UnimodalOptimizer:
         leaf_id = builder.create_leaf_node(modality.modality_id)
 
         rep_node_id = builder.create_operation_node(
-            operator.__class__, [leaf_id], operator.parameters
+            operator.__class__, [leaf_id], operator.get_current_parameters()
         )
         current_node_id = rep_node_id
         dags.append(builder.build(current_node_id))
@@ -258,31 +265,68 @@ class UnimodalOptimizer:
                     combine_id = builder.create_operation_node(
                         combination.__class__,
                         [current_node_id, other_rep_id],
-                        combination.parameters,
+                        combination.get_current_parameters(),
                     )
                     dags.append(builder.build(combine_id))
                     current_node_id = combine_id
+            if modality.modality_type in [
+                ModalityType.EMBEDDING,
+                ModalityType.IMAGE,
+                ModalityType.AUDIO,
+            ]:
+                dags.extend(
+                    self.default_context_operators(
+                        modality, builder, leaf_id, current_node_id
+                    )
+                )
+            elif modality.modality_type == ModalityType.TIMESERIES:
+                dags.extend(
+                    self.temporal_context_operators(
+                        modality, builder, leaf_id, current_node_id
+                    )
+                )
+        return dags
 
+    def default_context_operators(self, modality, builder, leaf_id, 
current_node_id):
+        dags = []
         context_operators = self._get_context_operators()
-
         for context_op in context_operators:
-            if modality.modality_type != ModalityType.TEXT:
+            if (
+                modality.modality_type != ModalityType.TEXT
+                and modality.modality_type != ModalityType.VIDEO
+            ):
                 context_node_id = builder.create_operation_node(
                     context_op,
                     [leaf_id],
-                    context_op().parameters,
+                    context_op().get_current_parameters(),
                 )
                 dags.append(builder.build(context_node_id))
 
             context_node_id = builder.create_operation_node(
                 context_op,
                 [current_node_id],
-                context_op().parameters,
+                context_op().get_current_parameters(),
             )
             dags.append(builder.build(context_node_id))
 
         return dags
 
+    def temporal_context_operators(self, modality, builder, leaf_id, 
current_node_id):
+        aggregators = 
self.operator_registry.get_representations(modality.modality_type)
+        context_operators = self._get_context_operators()
+
+        dags = []
+        for agg in aggregators:
+            for context_operator in context_operators:
+                context_node_id = builder.create_operation_node(
+                    context_operator,
+                    [leaf_id],
+                    context_operator(agg()).get_current_parameters(),
+                )
+                dags.append(builder.build(context_node_id))
+
+        return dags
+
 
 class UnimodalResults:
     def __init__(self, modalities, tasks, debug=False, run=None):
@@ -339,13 +383,20 @@ class UnimodalResults:
             key=lambda x: task_results[x].val_score,
             reverse=True,
         )[:k]
+        if not self.cache:
+            cache = [
+                list(task_results[i].dag.execute([modality]).values())[-1]
+                for i in sorted_indices
+            ]
+        else:
+            cache_items = (
+                list(self.cache[modality.modality_id][task.model.name].items())
+                if self.cache[modality.modality_id][task.model.name]
+                else []
+            )
+            cache = [cache_items[i][1] for i in sorted_indices if i < 
len(cache_items)]
 
-        cache_items = 
list(self.cache[modality.modality_id][task.model.name].items())
-        reordered_cache = [
-            cache_items[i][1] for i in sorted_indices if i < len(cache_items)
-        ]
-
-        return results, reordered_cache
+        return results, cache
 
 
 @dataclass(frozen=True)
diff --git a/src/main/python/systemds/scuro/modality/type.py 
b/src/main/python/systemds/scuro/modality/type.py
index 2853e8135d..ef1e0eeab2 100644
--- a/src/main/python/systemds/scuro/modality/type.py
+++ b/src/main/python/systemds/scuro/modality/type.py
@@ -193,6 +193,7 @@ class ModalityType(Flag):
     IMAGE = auto()
     TIMESERIES = auto()
     EMBEDDING = auto()
+    PHYSIOLOGICAL = auto()
 
     def get_schema(self):
         return ModalitySchemas.get(self.name)
@@ -239,6 +240,7 @@ class ModalityType(Flag):
         md["length"] = data.shape[0]
         md["signal_names"] = signal_names
         md["timestamp"] = create_timestamps(md["frequency"], md["length"])
+        md["is_multivariate"] = len(signal_names) > 1
         return md
 
     def create_video_metadata(self, frequency, length, width, height, 
num_channels):
diff --git a/src/main/python/systemds/scuro/representations/context.py 
b/src/main/python/systemds/scuro/representations/context.py
index 54f22633cc..27c118bcba 100644
--- a/src/main/python/systemds/scuro/representations/context.py
+++ b/src/main/python/systemds/scuro/representations/context.py
@@ -25,12 +25,13 @@ from systemds.scuro.representations.representation import 
Representation
 
 
 class Context(Representation):
-    def __init__(self, name, parameters=None):
+    def __init__(self, name, parameters=None, is_ts_rep=False):
         """
         Parent class for different context operations
         :param name: Name of the context operator
         """
         super().__init__(name, parameters)
+        self.is_ts_rep = is_ts_rep
 
     @abc.abstractmethod
     def execute(self, modality: Modality):
@@ -40,3 +41,17 @@ class Context(Representation):
         :return: contextualized data
         """
         raise f"Not implemented for Context Operator: {self.name}"
+
+    def get_current_parameters(self):
+        current_params = {}
+        if not self.parameters:
+            return current_params
+        for parameter in list(self.parameters.keys()):
+            if self.is_ts_rep:
+                if parameter == "agg_params":
+                    current_params[parameter] = (
+                        self.aggregation_function.get_current_parameters()
+                    )
+                    continue
+            current_params[parameter] = getattr(self, parameter)
+        return current_params
diff --git 
a/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py 
b/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
index d17451932e..6f5f527f31 100644
--- 
a/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
+++ 
b/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
@@ -43,7 +43,7 @@ class AttentionFusion(Fusion):
         num_epochs=50,
         learning_rate=0.001,
     ):
-        params = {
+        parameters = {
             "hidden_dim": [32, 128, 256, 384, 512, 768],
             "num_heads": [2, 4, 8, 12],
             "dropout": [0.0, 0.1, 0.2, 0.3, 0.4],
@@ -51,7 +51,7 @@ class AttentionFusion(Fusion):
             "num_epochs": [50, 100, 150, 200],
             "learning_rate": [1e-5, 1e-4, 1e-3, 1e-2],
         }
-        super().__init__("AttentionFusion", params)
+        super().__init__("AttentionFusion", parameters)
 
         self.hidden_dim = int(hidden_dim)
         self.num_heads = int(num_heads)
@@ -64,7 +64,7 @@ class AttentionFusion(Fusion):
         self.needs_alignment = True
         self.encoder = None
         self.classification_head = None
-        self.input_dimensions = None
+        self.input_dim = None
         self.max_sequence_length = None
         self.num_classes = None
         self.is_trained = False
@@ -81,7 +81,7 @@ class AttentionFusion(Fusion):
         torch.backends.cudnn.deterministic = True
         torch.backends.cudnn.benchmark = False
 
-    def _prepare_data(self, modalities: List[Modality]) -> Dict[str, 
torch.Tensor]:
+    def _prepare_data(self, modalities: List[Modality]):
         inputs = {}
         input_dimensions = {}
         max_sequence_length = 0
@@ -122,12 +122,12 @@ class AttentionFusion(Fusion):
         inputs, input_dimensions, max_sequence_length = 
self._prepare_data(modalities)
         y = np.array(labels)
 
-        self.input_dimensions = input_dimensions
+        self.input_dim = input_dimensions
         self.max_sequence_length = max_sequence_length
         self.num_classes = len(np.unique(y))
 
         self.encoder = MultiModalAttentionFusion(
-            self.input_dimensions,
+            self.input_dim,
             self.hidden_dim,
             self.num_heads,
             self.dropout,
@@ -206,7 +206,7 @@ class AttentionFusion(Fusion):
         self.model_state = {
             "encoder_state_dict": self.encoder.state_dict(),
             "classification_head_state_dict": 
self.classification_head.state_dict(),
-            "input_dimensions": self.input_dimensions,
+            "input_dimensions": self.input_dim,
             "max_sequence_length": self.max_sequence_length,
             "num_classes": self.num_classes,
             "hidden_dim": self.hidden_dim,
@@ -214,6 +214,11 @@ class AttentionFusion(Fusion):
             "dropout": self.dropout,
         }
 
+        with torch.no_grad():
+            encoder_output = self.encoder(inputs)
+
+        return encoder_output["fused"].cpu().numpy()
+
     def apply_representation(self, modalities: List[Modality]) -> np.ndarray:
         if not self.is_trained or self.encoder is None:
             raise ValueError("Model must be trained before applying 
representation")
@@ -237,12 +242,12 @@ class AttentionFusion(Fusion):
 
     def set_model_state(self, state: Dict[str, Any]):
         self.model_state = state
-        self.input_dimensions = state["input_dimensions"]
+        self.input_dim = state["input_dimensions"]
         self.max_sequence_length = state["max_sequence_length"]
         self.num_classes = state["num_classes"]
 
         self.encoder = MultiModalAttentionFusion(
-            self.input_dimensions,
+            self.input_dim,
             state["hidden_dim"],
             state["num_heads"],
             state["dropout"],
diff --git a/src/main/python/systemds/scuro/representations/representation.py 
b/src/main/python/systemds/scuro/representations/representation.py
index dac3bb2b98..9f1a91a0ea 100644
--- a/src/main/python/systemds/scuro/representations/representation.py
+++ b/src/main/python/systemds/scuro/representations/representation.py
@@ -34,6 +34,9 @@ class Representation:
 
     def get_current_parameters(self):
         current_params = {}
+        if not self.parameters:
+            return current_params
+
         for parameter in list(self.parameters.keys()):
             current_params[parameter] = getattr(self, parameter)
         return current_params
diff --git 
a/src/main/python/systemds/scuro/representations/timeseries_representations.py 
b/src/main/python/systemds/scuro/representations/timeseries_representations.py
new file mode 100644
index 0000000000..03464df7d4
--- /dev/null
+++ 
b/src/main/python/systemds/scuro/representations/timeseries_representations.py
@@ -0,0 +1,213 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import numpy as np
+from scipy import stats
+
+from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.modality.transformed import TransformedModality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.drsearch.operator_registry import register_representation
+
+
+class TimeSeriesRepresentation(UnimodalRepresentation):
+    def __init__(self, name, parameters=None):
+        if parameters is None:
+            parameters = {}
+        super().__init__(name, ModalityType.EMBEDDING, parameters, False)
+
+    def compute_feature(self, signal):
+        raise NotImplementedError("Subclasses should implement this method.")
+
+    def transform(self, modality):
+        transformed_modality = TransformedModality(
+            modality, self, self.output_modality_type
+        )
+        result = []
+
+        for signal in modality.data:
+            feature = self.compute_feature(signal)
+            result.append(feature)
+
+        transformed_modality.data = np.vstack(result)
+        return transformed_modality
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Mean(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("Mean")
+
+    def compute_feature(self, signal):
+        return np.array(np.mean(signal))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Min(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("Min")
+
+    def compute_feature(self, signal):
+        return np.array(np.min(signal))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Max(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("Max")
+
+    def compute_feature(self, signal):
+        return np.array(np.max(signal))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Sum(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("Sum")
+
+    def compute_feature(self, signal):
+        return np.array(np.sum(signal))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Std(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("Std")
+
+    def compute_feature(self, signal):
+        return np.array(np.std(signal))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Skew(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("Skew")
+
+    def compute_feature(self, signal):
+        return np.array(stats.skew(signal))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Quantile(TimeSeriesRepresentation):
+    def __init__(self, quantile=0.9):
+        super().__init__(
+            "Qunatile", {"quantile": [0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]}
+        )
+        self.quantile = quantile
+
+    def compute_feature(self, signal):
+        return np.array(np.quantile(signal, self.quantile))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class Kurtosis(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("Kurtosis")
+
+    def compute_feature(self, signal):
+        return np.array(stats.kurtosis(signal, fisher=True, bias=False))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class RMS(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("RMS")
+
+    def compute_feature(self, signal):
+        return np.array(np.sqrt(np.mean(np.square(signal))))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class ZeroCrossingRate(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("ZeroCrossingRate")
+
+    def compute_feature(self, signal):
+        return np.array(np.sum(np.diff(np.signbit(signal)) != 0))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class ACF(TimeSeriesRepresentation):
+    def __init__(self, k=1):
+        super().__init__("ACF", {"k": [1, 2, 5, 10, 20, 25, 50, 100, 200, 
500]})
+        self.k = k
+
+    def compute_feature(self, signal):
+        x = np.asarray(signal) - np.mean(signal)
+        k = int(self.k)
+        if k <= 0 or k >= len(x):
+            return np.array(0.0)
+        den = np.dot(x, x)
+        if not np.isfinite(den) or np.isclose(den, 0.0):
+            return np.array(0.0)
+        corr = np.correlate(x[:-k], x[k:])[0]
+        return np.array(corr / den)
+
+    def get_k_values(self, max_length, percent=0.2, num=10, log=False):
+        # TODO: Probably would be useful to invoke this function while tuning 
the hyperparameters depending on the max length of the singal
+        max_k = int(max_length * percent)
+        if log:
+            k_vals = np.unique(np.logspace(0, np.log10(max_k), num=num, 
dtype=int))
+        else:
+            k_vals = np.unique(np.linspace(1, max_k, num=num, dtype=int))
+        return k_vals.tolist()
+
+
+@register_representation([ModalityType.TIMESERIES])
+class FrequencyMagnitude(TimeSeriesRepresentation):
+    def __init__(self):
+        super().__init__("FrequencyMagnitude")
+
+    def compute_feature(self, signal):
+        return np.array(np.abs(np.fft.rfft(signal)))
+
+
+@register_representation([ModalityType.TIMESERIES])
+class SpectralCentroid(TimeSeriesRepresentation):
+    def __init__(self, fs=1.0):
+        super().__init__("SpectralCentroid", parameters={"fs": [1.0]})
+        self.fs = fs
+
+    def compute_feature(self, signal):
+        frequency_magnitude = FrequencyMagnitude().compute_feature(signal)
+        freqencies = np.fft.rfftfreq(len(signal), d=1.0 / self.fs)
+        num = np.sum(freqencies * frequency_magnitude)
+        den = np.sum(frequency_magnitude) + 1e-12
+        return np.array(num / den)
+
+
+@register_representation([ModalityType.TIMESERIES])
+class BandpowerFFT(TimeSeriesRepresentation):
+    def __init__(self, fs=1.0, f1=0.0, f2=0.5):
+        super().__init__(
+            "BandpowerFFT", parameters={"fs": [1.0], "f1": [0.0], "f2": [0.5]}
+        )
+        self.fs = fs
+        self.f1 = f1
+        self.f2 = f2
+
+    def compute_feature(
+        self,
+        signal,
+    ):
+        frequency_magnitude = FrequencyMagnitude().compute_feature(signal)
+        freqencies = np.fft.rfftfreq(len(signal), d=1.0 / self.fs)
+        m = (freqencies >= self.f1) & (freqencies < self.f2)
+        return np.array(np.sum(frequency_magnitude[m] ** 2))
diff --git 
a/src/main/python/systemds/scuro/representations/window_aggregation.py 
b/src/main/python/systemds/scuro/representations/window_aggregation.py
index c16f6d747f..adb92ceb53 100644
--- a/src/main/python/systemds/scuro/representations/window_aggregation.py
+++ b/src/main/python/systemds/scuro/representations/window_aggregation.py
@@ -32,10 +32,19 @@ from systemds.scuro.representations.context import Context
 
 class Window(Context):
     def __init__(self, name, aggregation_function):
-        parameters = {
-            "aggregation_function": 
list(Aggregation().get_aggregation_functions()),
-        }
-        super().__init__(name, parameters)
+        is_ts_rep = False
+        if isinstance(aggregation_function, str):
+            parameters = {
+                "aggregation_function": 
list(Aggregation().get_aggregation_functions()),
+            }
+        else:
+            is_ts_rep = True
+            parameters = {
+                "aggregation_function": aggregation_function.name,
+                "agg_params": aggregation_function.parameters,
+            }
+
+        super().__init__(name, parameters, is_ts_rep)
         self.aggregation_function = aggregation_function
 
     @property
@@ -44,7 +53,10 @@ class Window(Context):
 
     @aggregation_function.setter
     def aggregation_function(self, value):
-        self._aggregation_function = Aggregation(value)
+        if self.is_ts_rep:
+            self._aggregation_function = value
+        else:
+            self._aggregation_function = Aggregation(value)
 
 
 @register_context_operator()
@@ -118,13 +130,24 @@ class WindowAggregation(Window):
 
         result = []
         for i in range(0, new_length):
-            result.append(
-                self.aggregation_function.aggregate_instance(
-                    instance[
-                        i * self.window_size : i * self.window_size + 
self.window_size
-                    ]
+            if self.is_ts_rep:
+                result.append(
+                    self.aggregation_function.compute_feature(
+                        instance[
+                            i * self.window_size : i * self.window_size
+                            + self.window_size
+                        ]
+                    )
+                )
+            else:
+                result.append(
+                    self.aggregation_function.aggregate_instance(
+                        instance[
+                            i * self.window_size : i * self.window_size
+                            + self.window_size
+                        ]
+                    )
                 )
-            )
 
         return np.array(result)
 
@@ -132,9 +155,14 @@ class WindowAggregation(Window):
         result = [[] for _ in range(0, new_length)]
         data = np.stack(copy.deepcopy(instance))
         for i in range(0, new_length):
-            result[i] = self.aggregation_function.aggregate_instance(
-                data[i * self.window_size : i * self.window_size + 
self.window_size]
-            )
+            if self.is_ts_rep:
+                result[i] = self.aggregation_function.compute_feature(
+                    data[i * self.window_size : i * self.window_size + 
self.window_size]
+                )
+            else:
+                result[i] = self.aggregation_function.aggregate_instance(
+                    data[i * self.window_size : i * self.window_size + 
self.window_size]
+                )
 
         return np.array(result)
 
diff --git a/src/main/python/tests/scuro/data_generator.py 
b/src/main/python/tests/scuro/data_generator.py
index 4dcfa5a89c..11f034d9ce 100644
--- a/src/main/python/tests/scuro/data_generator.py
+++ b/src/main/python/tests/scuro/data_generator.py
@@ -85,6 +85,8 @@ class ModalityRandomDataGenerator:
                 self.metadata[i] = modality_type.create_video_metadata(
                     num_features / 30, 10, 0, 0, 1
                 )
+            elif modality_type == ModalityType.TIMESERIES:
+                self.metadata[i] = modality_type.create_ts_metadata(["test"], 
data[i])
             else:
                 raise NotImplementedError
 
@@ -112,6 +114,21 @@ class ModalityRandomDataGenerator:
 
         return data, metadata
 
+    def create_timeseries_data(self, num_instances, sequence_length, 
num_features=1):
+        data = [
+            np.random.rand(sequence_length, 
num_features).astype(self.data_type)
+            for _ in range(num_instances)
+        ]
+        if num_features == 1:
+            data = [d.squeeze(-1) for d in data]
+        metadata = {
+            i: ModalityType.TIMESERIES.create_ts_metadata(
+                [f"feature_{j}" for j in range(num_features)], data[i]
+            )
+            for i in range(num_instances)
+        }
+        return data, metadata
+
     def create_text_data(self, num_instances):
         subjects = [
             "The cat",
diff --git a/src/main/python/tests/scuro/test_unimodal_optimizer.py 
b/src/main/python/tests/scuro/test_unimodal_optimizer.py
index e2f0378d58..30ae725737 100644
--- a/src/main/python/tests/scuro/test_unimodal_optimizer.py
+++ b/src/main/python/tests/scuro/test_unimodal_optimizer.py
@@ -27,6 +27,12 @@ from sklearn import svm
 from sklearn.metrics import classification_report
 from sklearn.model_selection import train_test_split
 
+from systemds.scuro.representations.timeseries_representations import (
+    Mean,
+    Max,
+    Min,
+    ACF,
+)
 from systemds.scuro.drsearch.operator_registry import Registry
 from systemds.scuro.models.model import Model
 from systemds.scuro.drsearch.task import Task
@@ -163,6 +169,17 @@ class 
TestUnimodalRepresentationOptimizer(unittest.TestCase):
         )
         self.optimize_unimodal_representation_for_modality(text)
 
+    def test_unimodal_optimizer_for_ts_modality(self):
+        ts_data, ts_md = ModalityRandomDataGenerator().create_timeseries_data(
+            self.num_instances, 1000
+        )
+        ts = UnimodalModality(
+            TestDataLoader(
+                self.indices, None, ModalityType.TIMESERIES, ts_data, 
np.float32, ts_md
+            )
+        )
+        self.optimize_unimodal_representation_for_modality(ts)
+
     def test_unimodal_optimizer_for_video_modality(self):
         video_data, video_md = 
ModalityRandomDataGenerator().create_visual_modality(
             self.num_instances, 10, 10
@@ -181,7 +198,7 @@ class 
TestUnimodalRepresentationOptimizer(unittest.TestCase):
             {
                 ModalityType.TEXT: [W2V, BoW],
                 ModalityType.AUDIO: [Spectrogram, ZeroCrossing, Spectral, 
Pitch],
-                ModalityType.TIMESERIES: [ResNet],
+                ModalityType.TIMESERIES: [Mean, Max, Min, ACF],
                 ModalityType.VIDEO: [ResNet],
                 ModalityType.EMBEDDING: [],
             },
diff --git a/src/main/python/tests/scuro/test_unimodal_representations.py 
b/src/main/python/tests/scuro/test_unimodal_representations.py
index 8c8e9baa2d..6789786cfd 100644
--- a/src/main/python/tests/scuro/test_unimodal_representations.py
+++ b/src/main/python/tests/scuro/test_unimodal_representations.py
@@ -44,6 +44,21 @@ from tests.scuro.data_generator import (
     ModalityRandomDataGenerator,
 )
 from systemds.scuro.modality.type import ModalityType
+from systemds.scuro.representations.timeseries_representations import (
+    Mean,
+    Max,
+    Min,
+    Kurtosis,
+    Skew,
+    Std,
+    RMS,
+    ACF,
+    FrequencyMagnitude,
+    SpectralCentroid,
+    Quantile,
+    ZeroCrossingRate,
+    BandpowerFFT,
+)
 
 
 class TestUnimodalRepresentations(unittest.TestCase):
@@ -92,6 +107,42 @@ class TestUnimodalRepresentations(unittest.TestCase):
                 assert (audio.data[i] == original_data[i]).all()
             assert r.data[0].ndim == 2
 
+    def test_timeseries_representations(self):
+        ts_representations = [
+            Mean(),
+            Max(),
+            Min(),
+            Kurtosis(),
+            Skew(),
+            Std(),
+            RMS(),
+            ACF(),
+            FrequencyMagnitude(),
+            SpectralCentroid(),
+            Quantile(),
+            ZeroCrossingRate(),
+            BandpowerFFT(),
+        ]
+        ts_data, ts_md = ModalityRandomDataGenerator().create_timeseries_data(
+            self.num_instances, 1000
+        )
+
+        ts = UnimodalModality(
+            TestDataLoader(
+                self.indices, None, ModalityType.AUDIO, ts_data, np.float32, 
ts_md
+            )
+        )
+
+        ts.extract_raw_data()
+        original_data = copy.deepcopy(ts.data)
+
+        for representation in ts_representations:
+            r = ts.apply_representation(representation)
+            assert r.data is not None
+            assert len(r.data) == self.num_instances
+            for i in range(self.num_instances):
+                assert (ts.data[i] == original_data[i]).all()
+
     def test_video_representations(self):
         video_representations = [
             ResNet(),


Reply via email to