This is an automated email from the ASF dual-hosted git repository.
cdionysio pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f8858408aa [SYSTEMDS-3913] Make learnable fusion representations work
for multi-label tasks
f8858408aa is described below
commit f8858408aa5fe26414025bf1ee0904832e72ff2f
Author: Christina Dionysio <[email protected]>
AuthorDate: Mon Dec 8 15:38:26 2025 +0100
[SYSTEMDS-3913] Make learnable fusion representations work for multi-label
tasks
This patch adds functionality to learnable fusion methods to make it work
with multi-label tasks as well.
---
.../systemds/scuro/dataloader/video_loader.py | 16 ++++-
.../python/systemds/scuro/modality/modality.py | 22 ++++++-
.../systemds/scuro/modality/unimodal_modality.py | 19 ------
.../systemds/scuro/representations/fusion.py | 12 +++-
.../python/systemds/scuro/representations/lstm.py | 49 +++++++++++---
.../representations/multimodal_attention_fusion.py | 76 ++++++++++++++++++----
6 files changed, 144 insertions(+), 50 deletions(-)
diff --git a/src/main/python/systemds/scuro/dataloader/video_loader.py
b/src/main/python/systemds/scuro/dataloader/video_loader.py
index 2c154ecbaf..8471cc7c35 100644
--- a/src/main/python/systemds/scuro/dataloader/video_loader.py
+++ b/src/main/python/systemds/scuro/dataloader/video_loader.py
@@ -71,7 +71,13 @@ class VideoLoader(BaseLoader):
self.fps, length, width, height, num_channels
)
- frames = []
+ num_frames = (length + frame_interval - 1) // frame_interval
+
+ stacked_frames = np.zeros(
+ (num_frames, height, width, num_channels), dtype=self._data_type
+ )
+
+ frame_idx = 0
idx = 0
while cap.isOpened():
ret, frame = cap.read()
@@ -81,7 +87,11 @@ class VideoLoader(BaseLoader):
if idx % frame_interval == 0:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = frame.astype(self._data_type) / 255.0
- frames.append(frame)
+ stacked_frames[frame_idx] = frame
+ frame_idx += 1
idx += 1
- self.data.append(np.stack(frames))
+ if frame_idx < num_frames:
+ stacked_frames = stacked_frames[:frame_idx]
+
+ self.data.append(stacked_frames)
diff --git a/src/main/python/systemds/scuro/modality/modality.py
b/src/main/python/systemds/scuro/modality/modality.py
index 07f80cbd9f..f6e0320469 100644
--- a/src/main/python/systemds/scuro/modality/modality.py
+++ b/src/main/python/systemds/scuro/modality/modality.py
@@ -88,9 +88,8 @@ class Modality:
):
return
- md_copy = deepcopy(self.metadata)
- self.metadata = {}
- for i, (md_k, md_v) in enumerate(md_copy.items()):
+ for i, (md_k, md_v) in enumerate(self.metadata.items()):
+ md_v = selective_copy_metadata(md_v)
updated_md = self.modality_type.update_metadata(md_v, self.data[i])
self.metadata[md_k] = updated_md
if i == 0:
@@ -183,3 +182,20 @@ class Modality:
break
return aligned
+
+
+def selective_copy_metadata(metadata):
+ if isinstance(metadata, dict):
+ new_md = {}
+ for k, v in metadata.items():
+ if k == "data_layout":
+ new_md[k] = v.copy() if isinstance(v, dict) else v
+ elif isinstance(v, np.ndarray):
+ new_md[k] = v
+ else:
+ new_md[k] = selective_copy_metadata(v)
+ return new_md
+ elif isinstance(metadata, (list, tuple)):
+ return type(metadata)(selective_copy_metadata(item) for item in
metadata)
+ else:
+ return metadata
diff --git a/src/main/python/systemds/scuro/modality/unimodal_modality.py
b/src/main/python/systemds/scuro/modality/unimodal_modality.py
index 5898ea98c1..f7b7394e0f 100644
--- a/src/main/python/systemds/scuro/modality/unimodal_modality.py
+++ b/src/main/python/systemds/scuro/modality/unimodal_modality.py
@@ -146,8 +146,6 @@ class UnimodalModality(Modality):
else:
original_lengths.append(d.shape[0])
- new_modality.data = self.l2_normalize_features(new_modality.data)
-
if len(original_lengths) > 0 and min(original_lengths) <
max(original_lengths):
target_length = max(original_lengths)
padded_embeddings = []
@@ -194,20 +192,3 @@ class UnimodalModality(Modality):
new_modality.transform_time = time.time() - start
new_modality.self_contained = representation.self_contained
return new_modality
-
- def l2_normalize_features(self, feature_list):
- normalized_features = []
- for feature in feature_list:
- original_shape = feature.shape
- flattened = feature.flatten()
-
- norm = np.linalg.norm(flattened)
- if norm > 0:
- normalized_flat = flattened / norm
- normalized_feature = normalized_flat.reshape(original_shape)
- else:
- normalized_feature = feature
-
- normalized_features.append(normalized_feature)
-
- return normalized_features
diff --git a/src/main/python/systemds/scuro/representations/fusion.py
b/src/main/python/systemds/scuro/representations/fusion.py
index 693689bf92..addccadade 100644
--- a/src/main/python/systemds/scuro/representations/fusion.py
+++ b/src/main/python/systemds/scuro/representations/fusion.py
@@ -68,19 +68,25 @@ class Fusion(Representation):
return self.execute(mods)
def transform_with_training(self, modalities: List[Modality], task):
+ fusion_train_indices = task.fusion_train_indices
+
train_modalities = []
for modality in modalities:
train_data = [
- d for i, d in enumerate(modality.data) if i in
task.train_indices
+ d for i, d in enumerate(modality.data) if i in
fusion_train_indices
]
train_modality = TransformedModality(modality, self)
train_modality.data = copy.deepcopy(train_data)
train_modalities.append(train_modality)
transformed_train = self.execute(
- train_modalities, task.labels[task.train_indices]
+ train_modalities, task.labels[fusion_train_indices]
)
- transformed_val = self.transform_data(modalities, task.val_indices)
+
+ all_other_indices = [
+ i for i in range(len(modalities[0].data)) if i not in
fusion_train_indices
+ ]
+ transformed_other = self.transform_data(modalities, all_other_indices)
transformed_data = np.zeros(
(len(modalities[0].data), transformed_train.shape[1])
diff --git a/src/main/python/systemds/scuro/representations/lstm.py
b/src/main/python/systemds/scuro/representations/lstm.py
index c8e9644881..efc3127274 100644
--- a/src/main/python/systemds/scuro/representations/lstm.py
+++ b/src/main/python/systemds/scuro/representations/lstm.py
@@ -42,7 +42,7 @@ class LSTM(Fusion):
depth=1,
dropout_rate=0.1,
learning_rate=0.001,
- epochs=50,
+ epochs=20,
batch_size=32,
):
parameters = {
@@ -50,7 +50,7 @@ class LSTM(Fusion):
"depth": [1, 2, 3],
"dropout_rate": [0.1, 0.2, 0.3, 0.4, 0.5],
"learning_rate": [0.001, 0.0001, 0.01, 0.1],
- "epochs": [50, 100, 200],
+ "epochs": [10, 2050, 100, 200],
"batch_size": [8, 16, 32, 64, 128],
}
@@ -70,6 +70,7 @@ class LSTM(Fusion):
self.num_classes = None
self.is_trained = False
self.model_state = None
+ self.is_multilabel = False
self._set_random_seeds()
@@ -166,18 +167,32 @@ class LSTM(Fusion):
X = self._prepare_data(modalities)
y = np.array(labels)
+ if y.ndim == 2 and y.shape[1] > 1:
+ self.is_multilabel = True
+ self.num_classes = y.shape[1]
+ else:
+ self.is_multilabel = False
+ if y.ndim == 2:
+ y = y.ravel()
+ self.num_classes = len(np.unique(y))
+
self.input_dim = X.shape[2]
- self.num_classes = len(np.unique(y))
self.model = self._build_model(self.input_dim, self.num_classes)
device = get_device()
self.model.to(device)
- criterion = nn.CrossEntropyLoss()
+ if self.is_multilabel:
+ criterion = nn.BCEWithLogitsLoss()
+ else:
+ criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.learning_rate)
X_tensor = torch.FloatTensor(X).to(device)
- y_tensor = torch.LongTensor(y).to(device)
+ if self.is_multilabel:
+ y_tensor = torch.FloatTensor(y).to(device)
+ else:
+ y_tensor = torch.LongTensor(y).to(device)
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=self.batch_size,
shuffle=True)
@@ -202,15 +217,23 @@ class LSTM(Fusion):
"state_dict": self.model.state_dict(),
"input_dim": self.input_dim,
"num_classes": self.num_classes,
+ "is_multilabel": self.is_multilabel,
"width": self.width,
"depth": self.depth,
"dropout_rate": self.dropout_rate,
}
self.model.eval()
+ all_features = []
with torch.no_grad():
- features, _ = self.model(X_tensor)
- return features.cpu().numpy()
+ inference_dataloader = DataLoader(
+ TensorDataset(X_tensor), batch_size=self.batch_size,
shuffle=False
+ )
+ for (batch_X,) in inference_dataloader:
+ features, _ = self.model(batch_X)
+ all_features.append(features.cpu())
+
+ return torch.cat(all_features, dim=0).numpy()
def apply_representation(self, modalities: List[Modality]) -> np.ndarray:
if not self.is_trained or self.model is None:
@@ -222,12 +245,17 @@ class LSTM(Fusion):
self.model.to(device)
X_tensor = torch.FloatTensor(X).to(device)
-
+ all_features = []
self.model.eval()
with torch.no_grad():
- features, _ = self.model(X_tensor)
+ inference_dataloader = DataLoader(
+ TensorDataset(X_tensor), batch_size=self.batch_size,
shuffle=False
+ )
+ for (batch_X,) in inference_dataloader:
+ features, _ = self.model(batch_X)
+ all_features.append(features.cpu())
- return features.cpu().numpy()
+ return torch.cat(all_features, dim=0).numpy()
def get_model_state(self) -> Dict[str, Any]:
return self.model_state
@@ -236,6 +264,7 @@ class LSTM(Fusion):
self.model_state = state
self.input_dim = state["input_dim"]
self.num_classes = state["num_classes"]
+ self.is_multilabel = state.get("is_multilabel", False)
self.model = self._build_model(self.input_dim, self.num_classes)
self.model.load_state_dict(state["state_dict"])
diff --git
a/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
b/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
index 6f5f527f31..3f86610550 100644
---
a/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
+++
b/src/main/python/systemds/scuro/representations/multimodal_attention_fusion.py
@@ -40,7 +40,7 @@ class AttentionFusion(Fusion):
num_heads=8,
dropout=0.1,
batch_size=32,
- num_epochs=50,
+ num_epochs=20,
learning_rate=0.001,
):
parameters = {
@@ -48,7 +48,7 @@ class AttentionFusion(Fusion):
"num_heads": [2, 4, 8, 12],
"dropout": [0.0, 0.1, 0.2, 0.3, 0.4],
"batch_size": [8, 16, 32, 64, 128],
- "num_epochs": [50, 100, 150, 200],
+ "num_epochs": [10, 20, 50, 100, 150, 200],
"learning_rate": [1e-5, 1e-4, 1e-3, 1e-2],
}
super().__init__("AttentionFusion", parameters)
@@ -69,6 +69,7 @@ class AttentionFusion(Fusion):
self.num_classes = None
self.is_trained = False
self.model_state = None
+ self.is_multilabel = False
self._set_random_seeds()
@@ -122,9 +123,17 @@ class AttentionFusion(Fusion):
inputs, input_dimensions, max_sequence_length =
self._prepare_data(modalities)
y = np.array(labels)
+ if y.ndim == 2 and y.shape[1] > 1:
+ self.is_multilabel = True
+ self.num_classes = y.shape[1]
+ else:
+ self.is_multilabel = False
+ if y.ndim == 2:
+ y = y.ravel()
+ self.num_classes = len(np.unique(y))
+
self.input_dim = input_dimensions
self.max_sequence_length = max_sequence_length
- self.num_classes = len(np.unique(y))
self.encoder = MultiModalAttentionFusion(
self.input_dim,
@@ -142,7 +151,10 @@ class AttentionFusion(Fusion):
self.encoder.to(device)
self.classification_head.to(device)
- criterion = nn.CrossEntropyLoss()
+ if self.is_multilabel:
+ criterion = nn.BCEWithLogitsLoss()
+ else:
+ criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
list(self.encoder.parameters())
+ list(self.classification_head.parameters()),
@@ -151,7 +163,11 @@ class AttentionFusion(Fusion):
for modality_name in inputs:
inputs[modality_name] = inputs[modality_name].to(device)
- labels_tensor = torch.from_numpy(y).long().to(device)
+
+ if self.is_multilabel:
+ labels_tensor = torch.from_numpy(y).float().to(device)
+ else:
+ labels_tensor = torch.from_numpy(y).long().to(device)
dataset_inputs = []
for i in range(len(y)):
@@ -197,9 +213,17 @@ class AttentionFusion(Fusion):
optimizer.step()
total_loss += loss.item()
- _, predicted = torch.max(logits.data, 1)
- total_correct += (predicted == batch_labels).sum().item()
- total_samples += batch_labels.size(0)
+
+ if self.is_multilabel:
+ predicted = (torch.sigmoid(logits) > 0.5).float()
+ correct = (predicted == batch_labels).float()
+ hamming_acc = correct.mean()
+ total_correct += hamming_acc.item() * batch_labels.size(0)
+ total_samples += batch_labels.size(0)
+ else:
+ _, predicted = torch.max(logits.data, 1)
+ total_correct += (predicted == batch_labels).sum().item()
+ total_samples += batch_labels.size(0)
self.is_trained = True
@@ -214,10 +238,24 @@ class AttentionFusion(Fusion):
"dropout": self.dropout,
}
+ all_features = []
+
with torch.no_grad():
- encoder_output = self.encoder(inputs)
+ for batch_start in range(
+ 0, len(inputs[list(inputs.keys())[0]]), self.batch_size
+ ):
+ batch_end = min(
+ batch_start + self.batch_size,
len(inputs[list(inputs.keys())[0]])
+ )
+
+ batch_inputs = {}
+ for modality_name, tensor in inputs.items():
+ batch_inputs[modality_name] = tensor[batch_start:batch_end]
+
+ encoder_output = self.encoder(batch_inputs)
+ all_features.append(encoder_output["fused"].cpu())
- return encoder_output["fused"].cpu().numpy()
+ return torch.cat(all_features, dim=0).numpy()
def apply_representation(self, modalities: List[Modality]) -> np.ndarray:
if not self.is_trained or self.encoder is None:
@@ -232,10 +270,23 @@ class AttentionFusion(Fusion):
inputs[modality_name] = inputs[modality_name].to(device)
self.encoder.eval()
+ all_features = []
+
with torch.no_grad():
- encoder_output = self.encoder(inputs)
+ batch_size = self.batch_size
+ n_samples = len(inputs[list(inputs.keys())[0]])
+
+ for batch_start in range(0, n_samples, batch_size):
+ batch_end = min(batch_start + batch_size, n_samples)
+
+ batch_inputs = {}
+ for modality_name, tensor in inputs.items():
+ batch_inputs[modality_name] = tensor[batch_start:batch_end]
+
+ encoder_output = self.encoder(batch_inputs)
+ all_features.append(encoder_output["fused"].cpu())
- return encoder_output["fused"].cpu().numpy()
+ return torch.cat(all_features, dim=0).numpy()
def get_model_state(self) -> Dict[str, Any]:
return self.model_state
@@ -245,6 +296,7 @@ class AttentionFusion(Fusion):
self.input_dim = state["input_dimensions"]
self.max_sequence_length = state["max_sequence_length"]
self.num_classes = state["num_classes"]
+ self.is_multilabel = state.get("is_multilabel", False)
self.encoder = MultiModalAttentionFusion(
self.input_dim,