This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 758e060e33 [SYSTEMDS-3701] Additional scuro data representations
758e060e33 is described below
commit 758e060e33a637d649b12d0c5bdce2d5e5324b03
Author: Christina Dionysio <[email protected]>
AuthorDate: Tue Sep 24 09:38:56 2024 +0200
[SYSTEMDS-3701] Additional scuro data representations
Closes #2111.
---
src/main/python/systemds/scuro/__init__.py | 51 +++++++
.../python/systemds/scuro/aligner/dr_search.py | 97 +++++++-----
src/main/python/systemds/scuro/aligner/task.py | 2 +-
src/main/python/systemds/scuro/main.py | 24 +--
.../python/systemds/scuro/modality/__init__.py | 8 -
.../systemds/scuro/modality/aligned_modality.py | 13 +-
.../systemds/scuro/modality/audio_modality.py | 4 +-
.../python/systemds/scuro/modality/modality.py | 2 +-
.../systemds/scuro/modality/text_modality.py | 4 +-
.../systemds/scuro/modality/video_modality.py | 4 +-
src/main/python/systemds/scuro/models/__init__.py | 5 -
.../python/systemds/scuro/models/discrete_model.py | 2 +-
.../systemds/scuro/representations/__init__.py | 18 ---
.../systemds/scuro/representations/average.py | 15 +-
.../python/systemds/scuro/representations/bert.py | 96 ++++++++++++
.../scuro/representations/concatenation.py | 24 +--
.../systemds/scuro/representations/fusion.py | 24 +--
.../python/systemds/scuro/representations/lstm.py | 4 +-
.../python/systemds/scuro/representations/max.py | 72 +++++++++
.../scuro/representations/mel_spectrogram.py | 66 ++++++++
.../{average.py => multiplication.py} | 32 ++--
.../systemds/scuro/representations/resnet.py | 168 +++++++++++++++++++++
.../representations/{average.py => rowmax.py} | 51 +++++--
.../scuro/representations/{average.py => sum.py} | 28 ++--
.../systemds/scuro/representations/unimodal.py | 66 +-------
.../representations/{unimodal.py => utils.py} | 103 ++++++-------
.../RewriteMatrixMultChainOptSparseTest.java | 20 ++-
27 files changed, 702 insertions(+), 301 deletions(-)
diff --git a/src/main/python/systemds/scuro/__init__.py
b/src/main/python/systemds/scuro/__init__.py
index e66abb4646..1ef36539f0 100644
--- a/src/main/python/systemds/scuro/__init__.py
+++ b/src/main/python/systemds/scuro/__init__.py
@@ -18,3 +18,54 @@
# under the License.
#
# -------------------------------------------------------------
+from systemds.scuro.representations.representation import Representation
+from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.representations.fusion import Fusion
+from systemds.scuro.representations.sum import Sum
+from systemds.scuro.representations.max import RowMax
+from systemds.scuro.representations.multiplication import Multiplication
+from systemds.scuro.representations.mel_spectrogram import MelSpectrogram
+from systemds.scuro.representations.resnet import ResNet
+from systemds.scuro.representations.bert import Bert
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from systemds.scuro.representations.lstm import LSTM
+from systemds.scuro.representations.utils import NPY, Pickle, HDF5, JSON
+from systemds.scuro.models.model import Model
+from systemds.scuro.models.discrete_model import DiscreteModel
+from systemds.scuro.modality.aligned_modality import AlignedModality
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.modality.text_modality import TextModality
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.aligner.dr_search import DRSearch
+from systemds.scuro.aligner.task import Task
+
+
+__all__ = ["Representation",
+ "Average",
+ "Concatenation",
+ "Fusion",
+ "Sum",
+ "RowMax",
+ "Multiplication",
+ "MelSpectrogram",
+ "ResNet",
+ "Bert",
+ "UnimodalRepresentation",
+ "LSTM",
+ "NPY",
+ "Pickle",
+ "HDF5",
+ "JSON",
+ "Model",
+ "DiscreteModel",
+ "AlignedModality",
+ "AudioModality",
+ "VideoModality",
+ "TextModality",
+ "Modality",
+ "DRSearch",
+ "Task"
+ ]
+
diff --git a/src/main/python/systemds/scuro/aligner/dr_search.py
b/src/main/python/systemds/scuro/aligner/dr_search.py
index 4bdc7da4a2..b2a92ab75b 100644
--- a/src/main/python/systemds/scuro/aligner/dr_search.py
+++ b/src/main/python/systemds/scuro/aligner/dr_search.py
@@ -19,19 +19,24 @@
#
# -------------------------------------------------------------
import itertools
+import random
from typing import List
-from aligner.task import Task
-from modality.aligned_modality import AlignedModality
-from modality.modality import Modality
-from representations.representation import Representation
+from systemds.scuro.aligner.task import Task
+from systemds.scuro.modality.aligned_modality import AlignedModality
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.representation import Representation
+
+import warnings
+
+warnings.filterwarnings('ignore')
def get_modalities_by_name(modalities, name):
for modality in modalities:
if modality.name == name:
return modality
-
+
raise 'Modality ' + name + 'not in modalities'
@@ -51,9 +56,9 @@ class DRSearch:
self.best_modalities = None
self.best_representation = None
self.best_score = -1
-
+
def set_best_params(self, modality_name: str, representation:
Representation,
- score: float, modality_names: List[str]):
+ scores: List[float], modality_names: List[str]):
"""
Updates the best parameters for given modalities, representation, and
score
:param modality_name: The name of the aligned modality
@@ -62,43 +67,66 @@ class DRSearch:
:param modality_names: List of modality names used in this setting
:return:
"""
-
+
# check if modality name is already in dictionary
if modality_name not in self.scores.keys():
# if not add it to dictionary
self.scores[modality_name] = {}
-
+
# set score for representation
- self.scores[modality_name][representation] = score
-
+ self.scores[modality_name][representation] = scores
+
# compare current score with best score
- if score > self.best_score:
- self.best_score = score
+ if scores[1] > self.best_score:
+ self.best_score = scores[1]
self.best_representation = representation
self.best_modalities = modality_names
-
- def fit(self):
+
+ def reset_best_params(self):
+ self.best_score = -1
+ self.best_modalities = None
+ self.best_representation = None
+ self.scores = {}
+
+ def fit_random(self, seed=-1):
+ """
+ This method randomly selects a modality or combination of modalities
and representation
+ """
+ if seed != -1:
+ random.seed(seed)
+
+ modalities = []
+ for M in range(1, len(self.modalities) + 1):
+ for combination in itertools.combinations(self.modalities, M):
+ modalities.append(combination)
+
+ modality_combination = random.choice(modalities)
+ representation = random.choice(self.representations)
+
+ modality = AlignedModality(representation, list(modality_combination))
# noqa
+ modality.combine()
+
+ scores = self.task.run(modality.data)
+ self.set_best_params(modality.name, representation, scores,
modality.get_modality_names())
+
+ return self.best_representation, self.best_score, self.best_modalities
+
+ def fit_enumerate_all(self):
"""
This method finds the best representation out of a given List of
uni-modal modalities and
representations
:return: The best parameters found in the search procedure
"""
-
+
for M in range(1, len(self.modalities) + 1):
for combination in itertools.combinations(self.modalities, M):
- if len(combination) == 1:
- modality = combination[0]
- score =
self.task.run(modality.representation.scale_data(modality.data,
self.task.train_indices))
- self.set_best_params(modality.name,
modality.representation.name, score, [modality.name])
- self.scores[modality] = score
- else:
- for representation in self.representations:
- modality = AlignedModality(representation,
list(combination)) # noqa
- modality.combine(self.task.train_indices)
-
- score = self.task.run(modality.data)
- self.set_best_params(modality.name, representation,
score, modality.get_modality_names())
-
+ for representation in self.representations:
+ modality = AlignedModality(representation,
list(combination)) # noqa
+ modality.combine()
+
+ scores = self.task.run(modality.data)
+ self.set_best_params(modality.name, representation,
scores, modality.get_modality_names())
+
return self.best_representation, self.best_score, self.best_modalities
def transform(self, modalities: List[Modality]):
@@ -108,17 +136,16 @@ class DRSearch:
:param modalities: List of uni-modal modalities
:return: aligned data
"""
-
+
if self.best_score == -1:
raise 'Please fit representations first!'
-
+
used_modalities = []
-
+
for modality_name in self.best_modalities:
used_modalities.append(get_modalities_by_name(modalities,
modality_name))
-
+
modality = AlignedModality(self.best_representation, used_modalities)
# noqa
modality.combine(self.task.train_indices)
-
+
return modality.data
-
\ No newline at end of file
diff --git a/src/main/python/systemds/scuro/aligner/task.py
b/src/main/python/systemds/scuro/aligner/task.py
index 79f9690e65..efaafce32d 100644
--- a/src/main/python/systemds/scuro/aligner/task.py
+++ b/src/main/python/systemds/scuro/aligner/task.py
@@ -20,7 +20,7 @@
# -------------------------------------------------------------
from typing import List
-from models.model import Model
+from systemds.scuro.models.model import Model
class Task:
diff --git a/src/main/python/systemds/scuro/main.py
b/src/main/python/systemds/scuro/main.py
index 22477eb549..0648972fd8 100644
--- a/src/main/python/systemds/scuro/main.py
+++ b/src/main/python/systemds/scuro/main.py
@@ -22,16 +22,16 @@ import collections
import json
from datetime import datetime
-from representations.average import Averaging
-from representations.concatenation import Concatenation
-from modality.aligned_modality import AlignedModality
-from modality.text_modality import TextModality
-from modality.video_modality import VideoModality
-from modality.audio_modality import AudioModality
-from representations.unimodal import Pickle, JSON, HDF5, NPY
-from models.discrete_model import DiscreteModel
-from aligner.task import Task
-from aligner.dr_search import DRSearch
+from systemds.scuro.representations.average import Average
+from systemds.scuro.representations.concatenation import Concatenation
+from systemds.scuro.modality.aligned_modality import AlignedModality
+from systemds.scuro.modality.text_modality import TextModality
+from systemds.scuro.modality.video_modality import VideoModality
+from systemds.scuro.modality.audio_modality import AudioModality
+from systemds.scuro.representations.unimodal import Pickle, JSON, HDF5, NPY
+from systemds.scuro.models.discrete_model import DiscreteModel
+from systemds.scuro.aligner.task import Task
+from systemds.scuro.aligner.dr_search import DRSearch
class CustomTask(Task):
@@ -66,8 +66,8 @@ modalities = [text, audio, video]
model = DiscreteModel()
custom_task = CustomTask(model, labels, train_indices, val_indices)
-representations = [Concatenation(), Averaging()]
+representations = [Concatenation(), Average()]
dr_search = DRSearch(modalities, custom_task, representations)
-best_representation, best_score, best_modalities = dr_search.fit()
+best_representation, best_score, best_modalities = dr_search.fit_random()
aligned_representation = dr_search.transform(modalities)
diff --git a/src/main/python/systemds/scuro/modality/__init__.py
b/src/main/python/systemds/scuro/modality/__init__.py
index d09f468da2..e66abb4646 100644
--- a/src/main/python/systemds/scuro/modality/__init__.py
+++ b/src/main/python/systemds/scuro/modality/__init__.py
@@ -18,11 +18,3 @@
# under the License.
#
# -------------------------------------------------------------
-from systemds.scuro.modality.aligned_modality import AlignedModality
-from systemds.scuro.modality.audio_modality import AudioModality
-from systemds.scuro.modality.video_modality import VideoModality
-from systemds.scuro.modality.test_modality import TextModality
-from systemds.scuro.modality.modality import Modality
-
-
-__all__ = ["AlignedModality", "AudioModality", "VideoModality",
"TextModality", "Modality"]
\ No newline at end of file
diff --git a/src/main/python/systemds/scuro/modality/aligned_modality.py
b/src/main/python/systemds/scuro/modality/aligned_modality.py
index d4d20b962c..7950ec1919 100644
--- a/src/main/python/systemds/scuro/modality/aligned_modality.py
+++ b/src/main/python/systemds/scuro/modality/aligned_modality.py
@@ -20,8 +20,8 @@
# -------------------------------------------------------------
from typing import List
-from modality.modality import Modality
-from representations.fusion import Fusion
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.fusion import Fusion
class AlignedModality(Modality):
@@ -36,9 +36,16 @@ class AlignedModality(Modality):
name += modality.name
super().__init__(representation, modality_name=name)
self.modalities = modalities
-
+
def combine(self):
"""
Initiates the call to fuse the given modalities depending on the
Fusion type
"""
self.data = self.representation.fuse(self.modalities) # noqa
+
+ def get_modality_names(self):
+ names = []
+ for modality in self.modalities:
+ names.append(modality.name)
+
+ return names
\ No newline at end of file
diff --git a/src/main/python/systemds/scuro/modality/audio_modality.py
b/src/main/python/systemds/scuro/modality/audio_modality.py
index 01c71ad1e0..570faaad77 100644
--- a/src/main/python/systemds/scuro/modality/audio_modality.py
+++ b/src/main/python/systemds/scuro/modality/audio_modality.py
@@ -20,8 +20,8 @@
# -------------------------------------------------------------
import os
-from modality.modality import Modality
-from representations.unimodal import UnimodalRepresentation
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
class AudioModality(Modality):
diff --git a/src/main/python/systemds/scuro/modality/modality.py
b/src/main/python/systemds/scuro/modality/modality.py
index c7fe7cff8b..b15321be40 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -19,7 +19,7 @@
#
# -------------------------------------------------------------
-from representations.representation import Representation
+from systemds.scuro.representations.representation import Representation
class Modality:
diff --git a/src/main/python/systemds/scuro/modality/text_modality.py
b/src/main/python/systemds/scuro/modality/text_modality.py
index 71f384626d..ab6d7f0547 100644
--- a/src/main/python/systemds/scuro/modality/text_modality.py
+++ b/src/main/python/systemds/scuro/modality/text_modality.py
@@ -20,8 +20,8 @@
# -------------------------------------------------------------
import os
-from modality.modality import Modality
-from representations.unimodal import UnimodalRepresentation
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
class TextModality(Modality):
diff --git a/src/main/python/systemds/scuro/modality/video_modality.py
b/src/main/python/systemds/scuro/modality/video_modality.py
index 8062c26a89..110a13ffca 100644
--- a/src/main/python/systemds/scuro/modality/video_modality.py
+++ b/src/main/python/systemds/scuro/modality/video_modality.py
@@ -20,8 +20,8 @@
# -------------------------------------------------------------
import os
-from modality.modality import Modality
-from representations.unimodal import UnimodalRepresentation
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
class VideoModality(Modality):
diff --git a/src/main/python/systemds/scuro/models/__init__.py
b/src/main/python/systemds/scuro/models/__init__.py
index d7c003fb48..e66abb4646 100644
--- a/src/main/python/systemds/scuro/models/__init__.py
+++ b/src/main/python/systemds/scuro/models/__init__.py
@@ -18,8 +18,3 @@
# under the License.
#
# -------------------------------------------------------------
-from systemds.scuro.models import Model
-from systemds.scuro.discrete_model import DiscreteModel
-
-
-__all__ = ["Model", "DiscreteModel"]
\ No newline at end of file
diff --git a/src/main/python/systemds/scuro/models/discrete_model.py
b/src/main/python/systemds/scuro/models/discrete_model.py
index 994f0882e5..288643e5d8 100644
--- a/src/main/python/systemds/scuro/models/discrete_model.py
+++ b/src/main/python/systemds/scuro/models/discrete_model.py
@@ -18,7 +18,7 @@
# under the License.
#
# -------------------------------------------------------------
-from models.model import Model
+from systemds.scuro.models.model import Model
class DiscreteModel(Model):
diff --git a/src/main/python/systemds/scuro/representations/__init__.py
b/src/main/python/systemds/scuro/representations/__init__.py
index 9a2007319d..e66abb4646 100644
--- a/src/main/python/systemds/scuro/representations/__init__.py
+++ b/src/main/python/systemds/scuro/representations/__init__.py
@@ -18,21 +18,3 @@
# under the License.
#
# -------------------------------------------------------------
-from systemds.scuro.representations.representation import Representation
-from systemds.scuro.representations.average import Average
-from systemds.scuro.representations.concatenation import Concatenation
-from systemds.scuro.representations.fusion import Fusion
-from systemds.scuro.representations.unimodal import UnimodalRepresentation,
HDF5, NPY, Pickle, JSON
-from systemds.scuro.representations.lstm import LSTM
-
-
-__all__ = ["Representation",
- "Average",
- "Concatenation",
- "Fusion",
- "UnimodalRepresentation",
- "HDF5",
- "NPY",
- "Pickle",
- "JSON",
- "LSTM"]
diff --git a/src/main/python/systemds/scuro/representations/average.py
b/src/main/python/systemds/scuro/representations/average.py
index 77896b1914..11ce431566 100644
--- a/src/main/python/systemds/scuro/representations/average.py
+++ b/src/main/python/systemds/scuro/representations/average.py
@@ -23,18 +23,18 @@ from typing import List
import numpy as np
-from modality.modality import Modality
+from systemds.scuro.modality.modality import Modality
from keras.api.preprocessing.sequence import pad_sequences
-from representations.fusion import Fusion
+from systemds.scuro.representations.fusion import Fusion
-class Averaging(Fusion):
+class Average(Fusion):
def __init__(self):
"""
Combines modalities using averaging
"""
- super().__init__('Averaging')
+ super().__init__('Average')
def fuse(self, modalities: List[Modality]):
max_emb_size = self.get_max_embedding_size(modalities)
@@ -43,12 +43,11 @@ class Averaging(Fusion):
for modality in modalities:
d = pad_sequences(modality.data, maxlen=max_emb_size,
dtype='float32', padding='post')
padded_modalities.append(d)
-
+
data = padded_modalities[0]
for i in range(1, len(modalities)):
data += padded_modalities[i]
-
- data = self.scale_data(data, modalities[0].train_indices)
+
data /= len(modalities)
-
+
return np.array(data)
diff --git a/src/main/python/systemds/scuro/representations/bert.py
b/src/main/python/systemds/scuro/representations/bert.py
new file mode 100644
index 0000000000..365b39c322
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/bert.py
@@ -0,0 +1,96 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import pickle
+
+import numpy as np
+
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+import torch
+from transformers import BertTokenizer, BertModel
+import os
+
+
+def read_text_file(file_path):
+ with open(file_path, 'r', encoding='utf-8') as file:
+ text = file.read()
+ return text
+
+
+class Bert(UnimodalRepresentation):
+ def __init__(self, avg_layers=None, output_file=None):
+ super().__init__('Bert')
+
+ self.avg_layers = avg_layers
+ self.output_file = output_file
+
+ def parse_all(self, filepath, indices, get_sequences=False):
+ # Assumes text is stored in .txt files
+ data = []
+ if os.path.isdir(filepath):
+ for filename in os.listdir(filepath):
+ f = os.path.join(filepath, filename)
+ if os.path.isfile(f):
+ with open(f, 'r') as file:
+ data.append(file.readlines()[0])
+ else:
+ with open(filepath, 'r') as file:
+ data = file.readlines()
+
+ model_name = 'bert-base-uncased'
+ tokenizer = BertTokenizer.from_pretrained(model_name)
+
+ if self.avg_layers is not None:
+ model = BertModel.from_pretrained(model_name,
output_hidden_states=True)
+ else:
+ model = BertModel.from_pretrained(model_name)
+
+ embeddings = self.create_embeddings(data, model, tokenizer)
+
+ if self.output_file is not None:
+ data = {}
+ for i in range(0, embeddings.shape[0]):
+ data[indices[i]] = embeddings[i]
+ self.save_embeddings(data)
+
+ return embeddings
+
+ def create_embeddings(self, data, model, tokenizer):
+ embeddings = []
+ for d in data:
+ inputs = tokenizer(d, return_tensors="pt", padding=True,
truncation=True)
+
+ with torch.no_grad():
+ outputs = model(**inputs)
+
+ if self.avg_layers is not None:
+ cls_embedding = [outputs.hidden_states[i][:, 0, :] for i in
range(-self.avg_layers, 0)]
+ cls_embedding = torch.mean(torch.stack(cls_embedding), dim=0)
+ else:
+ cls_embedding = outputs.last_hidden_state[:, 0,
:].squeeze().numpy()
+ embeddings.append(cls_embedding)
+
+ embeddings = np.array(embeddings)
+ return embeddings.reshape((embeddings.shape[0], embeddings.shape[-1]))
+
+ def save_embeddings(self, data):
+ with open(self.output_file, 'wb') as file:
+ pickle.dump(data, file)
diff --git a/src/main/python/systemds/scuro/representations/concatenation.py
b/src/main/python/systemds/scuro/representations/concatenation.py
index 149e1f8801..81b6fe33fc 100644
--- a/src/main/python/systemds/scuro/representations/concatenation.py
+++ b/src/main/python/systemds/scuro/representations/concatenation.py
@@ -23,10 +23,10 @@ from typing import List
import numpy as np
-from modality.modality import Modality
+from systemds.scuro.modality.modality import Modality
from keras.api.preprocessing.sequence import pad_sequences
-from representations.fusion import Fusion
+from systemds.scuro.representations.fusion import Fusion
class Concatenation(Fusion):
@@ -38,15 +38,21 @@ class Concatenation(Fusion):
self.padding = padding
def fuse(self, modalities: List[Modality]):
+ if len(modalities) == 1:
+ return np.array(modalities[0].data)
+
max_emb_size = self.get_max_embedding_size(modalities)
-
size = len(modalities[0].data)
- data = np.zeros((size, 0))
-
+
+ if modalities[0].data.ndim > 2:
+ data = np.zeros((size, max_emb_size, 0))
+ else:
+ data = np.zeros((size, 0))
+
for modality in modalities:
if self.padding:
- data = np.concatenate(pad_sequences(modality.data,
maxlen=max_emb_size, dtype='float32', padding='post'), axis=1)
+ data = np.concatenate([data, pad_sequences(modality.data,
maxlen=max_emb_size, dtype='float32', padding='post')], axis=-1)
else:
- data = np.concatenate([data, modality.data], axis=1)
-
- return self.scale_data(data, modalities[0].train_indices)
+ data = np.concatenate([data, modality.data], axis=-1)
+
+ return np.array(data)
\ No newline at end of file
diff --git a/src/main/python/systemds/scuro/representations/fusion.py
b/src/main/python/systemds/scuro/representations/fusion.py
index 4e242137f1..04e9ebbb64 100644
--- a/src/main/python/systemds/scuro/representations/fusion.py
+++ b/src/main/python/systemds/scuro/representations/fusion.py
@@ -22,8 +22,8 @@ from typing import List
from sklearn.preprocessing import StandardScaler
-from modality.modality import Modality
-from representations.representation import Representation
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.representation import Representation
class Fusion(Representation):
@@ -33,7 +33,7 @@ class Fusion(Representation):
:param name: Name of the fusion type
"""
super().__init__(name)
-
+
def fuse(self, modalities: List[Modality]):
"""
Implemented for every child class and creates a fused representation
out of
@@ -42,7 +42,7 @@ class Fusion(Representation):
:return: fused data
"""
raise f'Not implemented for Fusion: {self.name}'
-
+
def get_max_embedding_size(self, modalities: List[Modality]):
"""
Computes the maximum embedding size from a given list of modalities
@@ -56,17 +56,5 @@ class Fusion(Representation):
raise f'Modality sizes don\'t match!'
elif curr_shape[1] > max_size:
max_size = curr_shape[1]
-
- return max_size
-
- def scale_data(self, data, train_indices):
- """
- Scales the data using the StandardScaler.
- The scaler is fit on the training data before performing the scaling
on the whole data array
- :param data: data to be scaled
- :param train_indices:
- :return: scaled data
- """
- scaler = StandardScaler()
- scaler.fit(data[train_indices])
- return scaler.transform(data)
+
+ return max_size
\ No newline at end of file
diff --git a/src/main/python/systemds/scuro/representations/lstm.py
b/src/main/python/systemds/scuro/representations/lstm.py
index a38ca1e577..dcdd9b65c1 100644
--- a/src/main/python/systemds/scuro/representations/lstm.py
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -25,8 +25,8 @@ from typing import List
import numpy as np
-from modality.modality import Modality
-from representations.fusion import Fusion
+from systemds.scuro.modality.modality import Modality
+from systemds.scuro.representations.fusion import Fusion
class LSTM(Fusion):
diff --git a/src/main/python/systemds/scuro/representations/max.py
b/src/main/python/systemds/scuro/representations/max.py
new file mode 100644
index 0000000000..2f58581cb8
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/max.py
@@ -0,0 +1,72 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+import itertools
+from typing import List
+
+import numpy as np
+
+from systemds.scuro.modality.modality import Modality
+from keras.preprocessing.sequence import pad_sequences
+
+from systemds.scuro.representations.fusion import Fusion
+
+
+class RowMax(Fusion):
+ def __init__(self, split=4):
+ """
+ Combines modalities by computing the outer product of a modality
combination and
+ taking the row max
+ """
+ super().__init__('RowMax')
+ self.split = split
+
+ def fuse(self, modalities: List[Modality],):
+ if len(modalities) < 2:
+ return np.array(modalities[0].data)
+
+ max_emb_size = self.get_max_embedding_size(modalities)
+
+ padded_modalities = []
+ for modality in modalities:
+ d = pad_sequences(modality.data, maxlen=max_emb_size,
dtype='float32', padding='post')
+ padded_modalities.append(d)
+
+ split_rows = int(len(modalities[0].data) / self.split)
+
+ data = []
+
+ for combination in itertools.combinations(padded_modalities, 2):
+ combined = None
+ for i in range(0, self.split):
+ start = split_rows * i
+ end = split_rows * (i + 1) if i < (self.split - 1) else
len(modalities[0].data)
+ m = np.einsum('bi,bo->bio', combination[0][start:end],
combination[1][start:end])
+ m = m.max(axis=2)
+ if combined is None:
+ combined = m
+ else:
+ combined = np.concatenate((combined, m), axis=0)
+ data.append(combined)
+
+ data = np.stack(data)
+ data = data.max(axis=0)
+
+ return np.array(data)
diff --git a/src/main/python/systemds/scuro/representations/mel_spectrogram.py
b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
new file mode 100644
index 0000000000..395b2977a0
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/mel_spectrogram.py
@@ -0,0 +1,66 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import os
+import pickle
+
+import librosa
+import numpy as np
+from keras.src.utils import pad_sequences
+
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+
+
+class MelSpectrogram(UnimodalRepresentation):
+ def __init__(self, avg=True, output_file=None):
+ super().__init__('MelSpectrogram')
+ self.avg = avg
+ self.output_file = output_file
+
+ def parse_all(self, file_path, indices, get_sequences=False):
+ result = []
+ max_length = 0
+ if os.path.isdir(file_path):
+ for filename in os.listdir(file_path):
+ f = os.path.join(file_path, filename)
+ if os.path.isfile(f):
+ y, sr = librosa.load(f)
+ S = librosa.feature.melspectrogram(y=y, sr=sr)
+ S_dB = librosa.power_to_db(S, ref=np.max)
+ if S_dB.shape[-1] > max_length:
+ max_length = S_dB.shape[-1]
+ result.append(S_dB)
+
+ r = []
+ for elem in result:
+ d = pad_sequences(elem, maxlen=max_length, dtype='float32',
padding='post')
+ r.append(d)
+
+ np_array_r = np.array(r) if not self.avg else np.mean(np.array(r),
axis=1)
+
+ if self.output_file is not None:
+ data = {}
+ for i in range(0, np_array_r.shape[0]):
+ data[indices[i]] = np_array_r[i]
+ with open(self.output_file, 'wb') as file:
+ pickle.dump(data, file)
+
+ return np_array_r
diff --git a/src/main/python/systemds/scuro/representations/average.py
b/src/main/python/systemds/scuro/representations/multiplication.py
similarity index 59%
copy from src/main/python/systemds/scuro/representations/average.py
copy to src/main/python/systemds/scuro/representations/multiplication.py
index 77896b1914..2b3ae64eac 100644
--- a/src/main/python/systemds/scuro/representations/average.py
+++ b/src/main/python/systemds/scuro/representations/multiplication.py
@@ -23,32 +23,26 @@ from typing import List
import numpy as np
-from modality.modality import Modality
-from keras.api.preprocessing.sequence import pad_sequences
+from systemds.scuro.modality.modality import Modality
+from keras.preprocessing.sequence import pad_sequences
-from representations.fusion import Fusion
+from systemds.scuro.representations.fusion import Fusion
-class Averaging(Fusion):
+class Multiplication(Fusion):
def __init__(self):
"""
- Combines modalities using averaging
+ Combines modalities using elementwise multiply
"""
- super().__init__('Averaging')
+ super().__init__('Multiplication')
- def fuse(self, modalities: List[Modality]):
+ def fuse(self, modalities: List[Modality], train_indices=None):
max_emb_size = self.get_max_embedding_size(modalities)
-
- padded_modalities = []
- for modality in modalities:
- d = pad_sequences(modality.data, maxlen=max_emb_size,
dtype='float32', padding='post')
- padded_modalities.append(d)
-
- data = padded_modalities[0]
- for i in range(1, len(modalities)):
- data += padded_modalities[i]
- data = self.scale_data(data, modalities[0].train_indices)
- data /= len(modalities)
+ data = pad_sequences(modalities[0].data, maxlen=max_emb_size,
dtype='float32', padding='post')
+
+ for m in range(1, len(modalities)):
+ # scaled = self.scale_data(modalities[m].data, train_indices)
+ data = np.multiply(data, pad_sequences(modalities[m].data,
maxlen=max_emb_size, dtype='float32', padding='post'))
- return np.array(data)
+ return data
diff --git a/src/main/python/systemds/scuro/representations/resnet.py
b/src/main/python/systemds/scuro/representations/resnet.py
new file mode 100644
index 0000000000..52802288de
--- /dev/null
+++ b/src/main/python/systemds/scuro/representations/resnet.py
@@ -0,0 +1,168 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+
+import h5py
+
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
+from typing import Callable, Dict, Tuple, Any
+import torch.utils.data
+import os
+import cv2
+import torch
+import torchvision.models as models
+import torchvision.transforms as transforms
+import numpy as np
+
+DEVICE = 'cpu'
+
+
+class ResNet(UnimodalRepresentation):
+ def __init__(self, output_file=None):
+ super().__init__('ResNet')
+
+ self.output_file = output_file
+
+ def parse_all(self, file_path, indices, get_sequences=False):
+ resnet = models.resnet152(weights=models.ResNet152_Weights.DEFAULT)
+ resnet.eval()
+
+ for param in resnet.parameters():
+ param.requires_grad = False
+
+ transform = transforms.Compose([
+ transforms.ToPILImage(),
+ transforms.Resize((224, 224)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229,
0.224, 0.225])
+ ])
+
+ dataset = ResNetDataset(transform=transform,
video_folder_path=file_path)
+ embeddings = {}
+
+ class Identity(torch.nn.Module):
+ def forward(self, input_: torch.Tensor) -> torch.Tensor:
+ return input_
+
+ resnet.fc = Identity()
+
+ res5c_output = None
+
+ def avg_pool_hook(_module: torch.nn.Module, input_:
Tuple[torch.Tensor], _output: Any) -> None:
+ nonlocal res5c_output
+ res5c_output = input_[0]
+
+ resnet.avgpool.register_forward_hook(avg_pool_hook)
+
+ for instance in torch.utils.data.DataLoader(dataset):
+ video_id = instance["id"][0]
+ frames = instance["frames"][0].to(DEVICE)
+ embeddings[video_id] = torch.empty((len(frames), 2048))
+ batch_size = 32
+ for start_index in range(0, len(frames), batch_size):
+ end_index = min(start_index + batch_size, len(frames))
+ frame_ids_range = range(start_index, end_index)
+ frame_batch = frames[frame_ids_range]
+
+ avg_pool_value = resnet(frame_batch)
+
+ embeddings[video_id][frame_ids_range] =
avg_pool_value.to(DEVICE)
+
+ if self.output_file is not None:
+ with h5py.File(self.output_file, 'w') as hdf:
+ for key, value in embeddings.items():
+ hdf.create_dataset(key, data=value)
+
+ emb = np.zeros((len(indices), 2048), dtype='float32')
+ if indices is not None:
+ for i in indices:
+ emb[i] = embeddings.get(str(i)).mean(dim=0).numpy()
+ else:
+ for i, key in enumerate(embeddings.keys()):
+ emb[i] = embeddings.get(key).mean(dim=0).numpy()
+
+ return emb
+
+ @staticmethod
+ def extract_features_from_video(video_path, model, transform):
+ cap = cv2.VideoCapture(video_path)
+ features = []
+ count = 0
+ success, frame = cap.read()
+
+ while success:
+ success, frame = cap.read()
+ transformed_frame = transform(frame).unsqueeze(0)
+
+ with torch.no_grad():
+ feature_vector = model(transformed_frame)
+ feature_vector = feature_vector.view(-1).numpy()
+
+ features.append(feature_vector)
+
+ count += 1
+
+ cap.release()
+ return features, count
+
+
+class ResNetDataset(torch.utils.data.Dataset):
+ def __init__(self, video_folder_path: str, transform: Callable = None):
+ self.video_folder_path = video_folder_path
+ self.transform = transform
+ self.video_ids = []
+ video_files = [f for f in os.listdir(self.video_folder_path) if
+ f.lower().endswith(('.mp4', '.avi', '.mov', '.mkv'))]
+ self.file_extension = video_files[0].split('.')[-1]
+
+ for video in video_files:
+ video_id, _ = video.split('/')[-1].split('.')
+ self.video_ids.append(video_id)
+
+ self.frame_count_by_video_id = {video_id: 0 for video_id in
self.video_ids}
+
+ def __getitem__(self, index) -> Dict[str, object]:
+ video_id = self.video_ids[index]
+ video_path = self.video_folder_path + '/' + video_id + '.' +
self.file_extension
+
+ frames = None
+ count = 0
+
+ cap = cv2.VideoCapture(video_path)
+
+ success, frame = cap.read()
+
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ self.frame_count_by_video_id[video_id] = num_frames
+ if frames is None and success:
+ frames = torch.empty((num_frames, 3, 224, 224))
+
+ while success:
+ frame = self.transform(frame)
+ frames[count] = frame # noqa
+ success, frame = cap.read()
+ count += 1
+
+ cap.release()
+ return {"id": video_id, "frames": frames}
+
+ def __len__(self) -> int:
+ return len(self.video_ids)
diff --git a/src/main/python/systemds/scuro/representations/average.py
b/src/main/python/systemds/scuro/representations/rowmax.py
similarity index 50%
copy from src/main/python/systemds/scuro/representations/average.py
copy to src/main/python/systemds/scuro/representations/rowmax.py
index 77896b1914..c4184687a1 100644
--- a/src/main/python/systemds/scuro/representations/average.py
+++ b/src/main/python/systemds/scuro/representations/rowmax.py
@@ -18,7 +18,7 @@
# under the License.
#
# -------------------------------------------------------------
-
+import itertools
from typing import List
import numpy as np
@@ -29,26 +29,45 @@ from keras.api.preprocessing.sequence import pad_sequences
from representations.fusion import Fusion
-class Averaging(Fusion):
- def __init__(self):
+class RowMax(Fusion):
+ def __init__(self, split=1):
"""
- Combines modalities using averaging
+ Combines modalities by computing the outer product of a modality
combination and
+ taking the row max
"""
- super().__init__('Averaging')
-
- def fuse(self, modalities: List[Modality]):
+ super().__init__('RowMax')
+ self.split = split
+
+ def fuse(self, modalities: List[Modality], train_indices):
+ if len(modalities) < 2:
+ return np.array(modalities)
+
max_emb_size = self.get_max_embedding_size(modalities)
padded_modalities = []
for modality in modalities:
- d = pad_sequences(modality.data, maxlen=max_emb_size,
dtype='float32', padding='post')
+ scaled = self.scale_data(modality.data, train_indices)
+ d = pad_sequences(scaled, maxlen=max_emb_size, dtype='float32',
padding='post')
padded_modalities.append(d)
-
- data = padded_modalities[0]
- for i in range(1, len(modalities)):
- data += padded_modalities[i]
-
- data = self.scale_data(data, modalities[0].train_indices)
- data /= len(modalities)
-
+
+ split_rows = int(len(modalities[0].data) / self.split)
+
+ data = []
+
+ for combination in itertools.combinations(padded_modalities, 2):
+ combined = None
+ for i in range(0, self.split):
+ start = split_rows * i
+ end = split_rows * (i + 1) if i < (self.split - 1) else
len(modalities[0].data)
+ m = np.einsum('bi,bo->bio', combination[0][start:end],
combination[1][start:end])
+ m = m.max(axis=2)
+ if combined is None:
+ combined = m
+ else:
+ combined = np.concatenate((combined, m), axis=0)
+ data.append(combined)
+
+ data = np.stack(data)
+ data = data.max(axis=0)
+
return np.array(data)
diff --git a/src/main/python/systemds/scuro/representations/average.py
b/src/main/python/systemds/scuro/representations/sum.py
similarity index 60%
copy from src/main/python/systemds/scuro/representations/average.py
copy to src/main/python/systemds/scuro/representations/sum.py
index 77896b1914..9c75606627 100644
--- a/src/main/python/systemds/scuro/representations/average.py
+++ b/src/main/python/systemds/scuro/representations/sum.py
@@ -21,34 +21,26 @@
from typing import List
-import numpy as np
-from modality.modality import Modality
-from keras.api.preprocessing.sequence import pad_sequences
+from systemds.scuro.modality.modality import Modality
+from keras.preprocessing.sequence import pad_sequences
-from representations.fusion import Fusion
+from systemds.scuro.representations.fusion import Fusion
-class Averaging(Fusion):
+class Sum(Fusion):
def __init__(self):
"""
- Combines modalities using averaging
+ Combines modalities using colum-wise sum
"""
- super().__init__('Averaging')
+ super().__init__('Sum')
def fuse(self, modalities: List[Modality]):
max_emb_size = self.get_max_embedding_size(modalities)
- padded_modalities = []
- for modality in modalities:
- d = pad_sequences(modality.data, maxlen=max_emb_size,
dtype='float32', padding='post')
- padded_modalities.append(d)
-
- data = padded_modalities[0]
- for i in range(1, len(modalities)):
- data += padded_modalities[i]
+ data = pad_sequences(modalities[0].data, maxlen=max_emb_size,
dtype='float32', padding='post')
- data = self.scale_data(data, modalities[0].train_indices)
- data /= len(modalities)
+ for m in range(1, len(modalities)):
+ data += pad_sequences(modalities[m].data, maxlen=max_emb_size,
dtype='float32', padding='post')
- return np.array(data)
+ return data
diff --git a/src/main/python/systemds/scuro/representations/unimodal.py
b/src/main/python/systemds/scuro/representations/unimodal.py
index 659ad32468..da0e721a57 100644
--- a/src/main/python/systemds/scuro/representations/unimodal.py
+++ b/src/main/python/systemds/scuro/representations/unimodal.py
@@ -18,13 +18,7 @@
# under the License.
#
# -------------------------------------------------------------
-import json
-import pickle
-
-import h5py
-import numpy as np
-
-from representations.representation import Representation
+from systemds.scuro.representations.representation import Representation
class UnimodalRepresentation(Representation):
@@ -42,61 +36,3 @@ class UnimodalRepresentation(Representation):
class PixelRepresentation(UnimodalRepresentation):
def __init__(self):
super().__init__('Pixel')
-
-
-class ResNet(UnimodalRepresentation):
- def __init__(self):
- super().__init__('ResNet')
-
-
-class Pickle(UnimodalRepresentation):
- def __init__(self):
- super().__init__('Pickle')
-
- def parse_all(self, filepath, indices):
- with open(filepath, "rb") as file:
- data = pickle.load(file, encoding='latin1')
-
- if indices is not None:
- for n, idx in enumerate(indices):
- result = np.empty((len(data), np.mean(data[idx][()],
axis=1).shape[0]))
- break
- for n, idx in enumerate(indices):
- result[n] = np.mean(data[idx], axis=1)
- return result
- else:
- return np.array([np.mean(data[index], axis=1) for index in data])
-
-
-class JSON(UnimodalRepresentation):
- def __init__(self):
- super().__init__('JSON')
-
- def parse_all(self, filepath, indices):
- with open(filepath) as file:
- return json.load(file)
-
-
-class NPY(UnimodalRepresentation):
- def __init__(self):
- super().__init__('NPY')
-
- def parse_all(self, filepath, indices):
- data = np.load(filepath)
-
- if indices is not None:
- return np.array([data[n, 0] for n, index in enumerate(indices)])
- else:
- return np.array([data[index, 0] for index in data])
-
-
-class HDF5(UnimodalRepresentation):
- def __init__(self):
- super().__init__('HDF5')
-
- def parse_all(self, filepath, indices=None):
- data = h5py.File(filepath)
- if indices is not None:
- return np.array([np.mean(data[index][()], axis=0) for index in
indices])
- else:
- return np.array([np.mean(data[index][()], axis=0) for index in
data])
diff --git a/src/main/python/systemds/scuro/representations/unimodal.py
b/src/main/python/systemds/scuro/representations/utils.py
similarity index 50%
copy from src/main/python/systemds/scuro/representations/unimodal.py
copy to src/main/python/systemds/scuro/representations/utils.py
index 659ad32468..d611cd9c71 100644
--- a/src/main/python/systemds/scuro/representations/unimodal.py
+++ b/src/main/python/systemds/scuro/representations/utils.py
@@ -18,85 +18,78 @@
# under the License.
#
# -------------------------------------------------------------
+
+
import json
import pickle
import h5py
import numpy as np
-from representations.representation import Representation
-
-
-class UnimodalRepresentation(Representation):
- def __init__(self, name):
- """
- Parent class for all unimodal representation types
- :param name: name of the representation
- """
- super().__init__(name)
-
- def parse_all(self, file_path, indices):
- raise f'Not implemented for {self.name}'
-
-
-class PixelRepresentation(UnimodalRepresentation):
- def __init__(self):
- super().__init__('Pixel')
-
-
-class ResNet(UnimodalRepresentation):
- def __init__(self):
- super().__init__('ResNet')
+from systemds.scuro.representations.unimodal import UnimodalRepresentation
-class Pickle(UnimodalRepresentation):
+class NPY(UnimodalRepresentation):
def __init__(self):
- super().__init__('Pickle')
+ super().__init__('NPY')
- def parse_all(self, filepath, indices):
- with open(filepath, "rb") as file:
- data = pickle.load(file, encoding='latin1')
+ def parse_all(self, filepath, indices, get_sequences=False):
+ data = np.load(filepath, allow_pickle=True)
if indices is not None:
- for n, idx in enumerate(indices):
- result = np.empty((len(data), np.mean(data[idx][()],
axis=1).shape[0]))
- break
- for n, idx in enumerate(indices):
- result[n] = np.mean(data[idx], axis=1)
- return result
+ return np.array([data[index] for index in indices])
else:
- return np.array([np.mean(data[index], axis=1) for index in data])
+ return np.array([data[index] for index in data])
-class JSON(UnimodalRepresentation):
+class Pickle(UnimodalRepresentation):
def __init__(self):
- super().__init__('JSON')
+ super().__init__('Pickle')
- def parse_all(self, filepath, indices):
- with open(filepath) as file:
- return json.load(file)
+ def parse_all(self, file_path, indices, get_sequences=False):
+ with open(file_path, 'rb') as f:
+ data = pickle.load(f)
+
+ embeddings = []
+ for n, idx in enumerate(indices):
+ embeddings.append(data[idx])
+
+ return np.array(embeddings)
-class NPY(UnimodalRepresentation):
+class HDF5(UnimodalRepresentation):
def __init__(self):
- super().__init__('NPY')
+ super().__init__('HDF5')
- def parse_all(self, filepath, indices):
- data = np.load(filepath)
+ def parse_all(self, filepath, indices=None, get_sequences=False):
+ data = h5py.File(filepath)
- if indices is not None:
- return np.array([data[n, 0] for n, index in enumerate(indices)])
+ if get_sequences:
+ max_emb = 0
+ for index in indices:
+ if max_emb < len(data[index][()]):
+ max_emb = len(data[index][()])
+
+ emb = []
+ if indices is not None:
+ for index in indices:
+ emb_i = data[index].tolist()
+ for i in range(len(emb_i), max_emb):
+ emb_i.append([0 for x in range(0, len(emb_i[0]))])
+ emb.append(emb_i)
+
+ return np.array(emb)
else:
- return np.array([data[index, 0] for index in data])
+ if indices is not None:
+ return np.array([np.mean(data[index], axis=0) for index in
indices])
+ else:
+ return np.array([np.mean(data[index][()], axis=0) for index in
data])
-class HDF5(UnimodalRepresentation):
+class JSON(UnimodalRepresentation):
def __init__(self):
- super().__init__('HDF5')
+ super().__init__('JSON')
- def parse_all(self, filepath, indices=None):
- data = h5py.File(filepath)
- if indices is not None:
- return np.array([np.mean(data[index][()], axis=0) for index in
indices])
- else:
- return np.array([np.mean(data[index][()], axis=0) for index in
data])
+ def parse_all(self, filepath, indices):
+ with open(filepath) as file:
+ return json.load(file)
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java
index a3bafd2ebb..f5c9bc09b7 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptSparseTest.java
@@ -1,3 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
package org.apache.sysds.test.functions.rewrite;
import org.apache.sysds.hops.OptimizerUtils;
@@ -76,7 +95,6 @@ public class RewriteMatrixMultChainOptSparseTest extends
AutomatedTestBase {
Assert.assertFalse(heavyHittersContainsSubString("mmchain") ||
heavyHittersContainsSubString("sp_mapmmchain"));
}
-
}
finally {
OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES =
oldFlag1;