This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 773d876a12 [SYSTEMDS-3928] New builtin function for Independent Subnet
Training
773d876a12 is described below
commit 773d876a12b49de5d2e87bdb5674beaeab645586
Author: Arno Bock <[email protected]>
AuthorDate: Sat Mar 28 16:23:16 2026 +0100
[SYSTEMDS-3928] New builtin function for Independent Subnet Training
Closes #2427.
---
scripts/builtin/independentSubnetTrain.dml | 504 +++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../controlprogram/context/ExecutionContext.java | 2 +-
.../sysds/runtime/instructions/cp/ListObject.java | 2 +-
src/test/config/SystemDS-config.xml | 2 +-
.../builtin/part1/BuiltinIndSubnetTest.java | 95 ++++
.../builtin/indSubnetTest_mnist_lenet.dml | 391 ++++++++++++++++
7 files changed, 994 insertions(+), 3 deletions(-)
diff --git a/scripts/builtin/independentSubnetTrain.dml
b/scripts/builtin/independentSubnetTrain.dml
new file mode 100644
index 0000000000..2009ab0515
--- /dev/null
+++ b/scripts/builtin/independentSubnetTrain.dml
@@ -0,0 +1,504 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# Independent Subnet Training (IST)
+#
+# This builtin implements independent subnet training as a
+# second-order function. It orchestrates distributed / parallel
+# training over disjoint subnets using parfor, while delegating
+# architecture-specific logic to user-provided functions.
+# ------------------------------------------------------------
+# INPUT:
+# model : initial model parameters. A list of matrices for NN.
+# features : X
+# labels : Y
+# val_features : validation X
+# val_labels : validation Y
+# upd : computes gradients and performs optimizer step
+# agg : aggregation logic to combine updates of shared
parameters (across subnets)
+# epochs : number of epochs
+# batchsize : batchsize for training
+# j : number of gradient steps until aggregation ->
determines length of the IST round (aggregation frequency)
+# numSubnets : number of independent subnets/workers
+# hyperparams : list of hyperparameters (e.g. lr, reg, mask params,
etc.)
+# verbose : print progress (boolean)
+# paramsPerLayer : amount of parameters each layer consists of
+# fullyConnectedLayers : list of all FC layer indices (starting at idx=1)
+#
+# OUTPUT:
+# model_out : trained model parameters (IST: W)
+#
+# ASSUMPTION:
+# - the last layer is the output layer
+# ------------------------------------------------------------
+
+m_independentSubnetTrain = function(
+ list[unknown] model,
+ matrix[double] features,
+ matrix[double] labels,
+ matrix[double] val_features,
+ matrix[double] val_labels,
+ string upd,
+ string agg,
+ int epochs,
+ int batchsize,
+ int j,
+ int numSubnets,
+ list[unknown] hyperparams,
+ boolean verbose,
+ int paramsPerLayer,
+ list[int] fullyConnectedLayers
+)
+return (list[unknown] trained_model)
+{
+ # ------------------------------------------------------------
+ # Setup
+ # ------------------------------------------------------------
+ model_out = model
+
+ P = length(model)
+ N = nrow(features)
+ if (P %% paramsPerLayer != 0) stop("Model length not divisible by
paramsPerLayer")
+ L = as.integer(P / paramsPerLayer) # total layers
+
+ # I. determine shared parameters
+ isSharedParam = matrix(0, 1, P)
+
+ # - create mask for all FC layers
+ fcLayers = fullyConnectedLayers
+ isFC = matrix(0, rows=1, cols=L)
+ for (i in 1:length(fcLayers)) {
+ idx = as.integer(as.scalar(fcLayers[i]))
+ isFC[1, idx] = 1 # TODO vectorize
+ }
+
+ # - expand layer mask across all parameters
+ isFC_rep = isFC
+ for (r in 2:paramsPerLayer) {
+ isFC_rep = cbind(isFC_rep, isFC)
+ }
+ if (ncol(isFC_rep)!=P) stop("Dimension mismatch for FC layer mask.")
+
+ # - all non-FC layers are shared
+ isSharedParam = 1 - isFC_rep
+
+ # - edge case: FC bias parameters are shared in: output layer or at the
end of a FC block
+ for (paramId in seq(2, paramsPerLayer, 2)) { # iterate bias blocks only
+ for (l in 1:L) {
+ if (as.scalar(isFC[1,l])==1 & l==L) {
+ p_out_bias = (paramId - 1) * L + L # output bias is shared
across subnets
+ isSharedParam[1, p_out_bias] = 1 # TODO vectorize
+ }
+ else if (as.scalar(isFC[1,l])==1 & l<L &
as.scalar(isFC[1,l+1])==0) {
+ p_out_bias = (paramId - 1) * L + l # end of FC block's bias
is shared across subnets
+ isSharedParam[1, p_out_bias] = 1 # TODO vectorize
+ }
+ }
+ }
+ if (ncol(isSharedParam) != P) stop("isSharedParam dimension mismatch!")
+
+ # II. calculate update-steps per epoch
+ if (batchsize<=0 | batchsize>N) {
+ stop("Batch size is out of bounds!")
+ } else {
+ stepsPerEpoch = ceil(N / batchsize)
+ }
+
+ # III. training loop
+ for (epoch in 1:epochs) {
+ if (verbose) print("Entered epoch: " + epoch)
+
+ # A.) reshuffle indices each epoch
+ allSampleIndicesRandom = order(target=rand(rows=N, cols=1), by=1,
decreasing=FALSE, index.return=TRUE)
+ batchIndices = allSampleIndicesRandom[, 1]
+ b = nrow(batchIndices)
+ I = seq(1, b, 1)
+ V = matrix(1, rows=b, cols=1)
+ S = table(I, batchIndices, V, b, N)
+
+ features_shuffled = S %*% features
+ labels_shuffled = S %*% labels
+
+ # B.) iterate IST rounds
+ for (step in seq(1, stepsPerEpoch, j)) {
+ if (verbose) print("Starting new IST round at step: " + step)
+ round_model = model_out # prevent accidental mutation of model_out
+
+ # 1.) create masks for all subnets
+ [masks, masks_meta_info] = ist_create_disjoint_masks(round_model,
numSubnets, L, fcLayers, paramsPerLayer, isFC, verbose)
+
+ # 2.) preallocate list to store all subnets TODO move outside
epoch loop? to prevent constantly allocating...
+ updatedSubnets = list()
+ updatedSubnetsMasks = list()
+ for (s in 1:numSubnets) {
+ updatedSubnets = append(updatedSubnets, list())
+ updatedSubnetsMasks = append(updatedSubnetsMasks, list())
+ }
+
+ # 3) create a template for each subnet based on input model
(allows indexing in subsequent parfor-loop) TODO move outside epoch loop? to
prevent constantly allocating...
+ subnetModelTemplate = list()
+ subnetModelMaskTemplate = list()
+ for (pIdx in 1:P) {
+ subnetModelTemplate = append(subnetModelTemplate,
as.matrix(model[pIdx]))
+ subnetModelMaskTemplate = append(subnetModelMaskTemplate,
as.matrix(model[pIdx]))
+ }
+
+ # local optimization steps / IST round
+ localSteps = min(j, (stepsPerEpoch-step+1))
+
+ # 4.) obtain all minibatches for this IST round (doing it once
prevents parfor confusion)
+ shuffled_features = list()
+ shuffled_labels = list()
+ for (localStep in 1:localSteps) {
+ mb = (step-1) + localStep
+ mb_local = mb-1
+ start = mb_local*batchsize + 1
+ end = min(mb*batchsize, N)
+
+ Xb = features_shuffled[start:end, 1:ncol(features_shuffled)]
+ yb = labels_shuffled[start:end, 1:ncol(labels_shuffled)]
+ shuffled_features = append(shuffled_features, Xb)
+ shuffled_labels = append(shuffled_labels, yb)
+ }
+
+ # 5.) perform 'j' local gradient steps for each subnet
+ parfor (subnet in 1:numSubnets) {
+
+ # a.) obtain masked subnet
+ subnet_model = subnetModelTemplate
+ subnet_model_mask = subnetModelMaskTemplate
+ for (subnet_p in 1:length(round_model)) {
+ param_start_idx =
as.integer(as.scalar(masks_meta_info[subnet_p,1]))
+ param_end_idx =
as.integer(as.scalar(masks_meta_info[subnet_p,2]))
+ param_rows =
as.integer(as.scalar(masks_meta_info[subnet_p,3]))
+ param_cols =
as.integer(as.scalar(masks_meta_info[subnet_p,4]))
+
+ vec = masks[subnet, param_start_idx:param_end_idx]
+ param_mask = matrix(vec, rows=param_rows, cols=param_cols,
byrow=TRUE)
+ param = as.matrix(round_model[subnet_p])
+
+ subnet_model[subnet_p] = list(param * param_mask) #
TODO sparse masking! dense masking will probably increase computational
efficiency
+ subnet_model_mask[subnet_p] = list(param_mask)
+ }
+
+ # b.) local optimization steps / IST round
+ for (localStep in 1:localSteps) {
+ feat = as.matrix(shuffled_features[localStep])
+ lab = as.matrix(shuffled_labels[localStep])
+
+ # compute gradients for subnet s + apply update step on
owned params (only)
+ subnet_model = as.list(evalList(upd,
list(model=subnet_model, mask=subnet_model_mask, features=feat, labels=lab,
hyperparams=hyperparams)))
+ }
+
+ # c.) save updated subnet and mask
+ updatedSubnets[subnet] = list(subnet_model)
+ updatedSubnetsMasks[subnet] = list(subnet_model_mask)
+ }
+ if (verbose) print("All subnets have run successfully.")
+
+ # 6.) aggregate updates into global model (i.e. model_out)
+ for (p in 1:P) {
+ if (as.scalar(isSharedParam[1, p])==1) {
+ # construct full model update by aggregating shared
parameter updates from all subnets
+ subnetParams = list()
+ subnetMasks = list()
+ for (s in 1:numSubnets) {
+ subnet = as.list(updatedSubnets[s])
+ subnetMask = as.list(updatedSubnetsMasks[s])
+
+ subnetParams = append(subnetParams,
as.matrix(subnet[p]))
+ subnetMasks = append(subnetMasks,
as.matrix(subnetMask[p]))
+ }
+
+ # aggregate shared parameters based on provided function
+ averagedUpdatedParam = eval(agg,
list(initialParam=as.matrix(round_model[p]), allSubnetsParam=subnetParams,
allSubnetsMasks=subnetMasks))
+ round_model[p] = averagedUpdatedParam
+ }
+ else {
+ # construct full model update by filling with disjointly
partitioned parameter updates from all subnets
+ initialParam = as.matrix(round_model[p])
+ updatedParam = matrix(0, nrow(initialParam),
ncol(initialParam))
+ owned = matrix(0, nrow(initialParam), ncol(initialParam))
+
+ for (s in 1:numSubnets) {
+ subnet = as.list(updatedSubnets[s])
+ subnetMask = as.list(updatedSubnetsMasks[s])
+
+ owned = owned + as.matrix(subnetMask[p])
+ updatedParam = updatedParam + as.matrix(subnet[p])
+ }
+
+ # SANITY CHECK:
+ max_freq = max(owned)
+ if (max_freq > 1) stop("Overlap detected")
+
+ round_model[p] = updatedParam
+ }
+ }
+ if (verbose) print("Aggregation of subnets finished. IST round has
been successfully executed!")
+
+ # 7.) update global model (end of the IST round)
+ model_out = round_model
+ }
+ # TODO (potentially): add validation for early stopping etc.
+ }
+ trained_model = model_out
+}
+
+
+#
----------------------------------------------------------------------------------------------------------------------
+# Independent Subnet Masking
+#
+# This helper function creates two matrices: one contains all flattened masks,
the other contains the info on
+# how to reconstruct the mask matrices. Each mask is a binary vector
indicating which parameters belong to that subnet.
+#
----------------------------------------------------------------------------------------------------------------------
+# INPUT:
+# model : list of parameter tensors grouped by parameter type
i.e. blocks
+# numSubnets : number of independent subnets/workers (K)
+# L : total number of layers INCLUDING the output layer
(layer indices are assumed to be 1..L)
+# fullyConnectedLayers : list of all FC layer indices (starting at idx=1)
+# paramsPerFCLayer : number of parameters / neurons to be partitioned
per FC layer
+# isFC : indicator matrix encoding which layers are FC =>
isFC[l] ∈ {0,1}
+# verbose : print progress (boolean)
+#
+# OUTPUT:
+# masks_new : mask matrix defining disjoint neuron ownership
across subnets
+# masks_new_meta : metadata matrix describing the mask layout and
ownership mapping
+#
+# ASSUMPTIONS:
+# - neuron ownership is defined via bias vectors
+# - model is a list of parameter tensors
+# - trainable parameters are grouped by parameter type i.e. param blocks
like (W_l1, W_l2, ..., b_l1, b_l2, ...)
+# - assumes W and b are always the first two param blocks
+# - the pattern of optional optimizer state tensors (e.g., vW_l, vb_l)
follow the same grouping and always W followed by b
+# - (output layer & end of FC block) biases are shared -> gradients collide;
must be handled by aggregation logic
+#
----------------------------------------------------------------------------------------------------------------------
+
+ist_create_disjoint_masks = function(
+ list[unknown] model,
+ int numSubnets,
+ int L,
+ list[int] fullyConnectedLayers,
+ int paramsPerFCLayer,
+ Matrix[Double] isFC,
+ boolean verbose
+)
+ return (
+ Matrix[Double] masks_new,
+ Matrix[Double] masks_new_meta
+ )
+{
+ P = length(model)
+
+ # SANITY CHECKS: ensure provided model can be masked correctly
+ if (as.integer(P / paramsPerFCLayer) != L) {
+ stop("Layer/parameter mismatch. Please make sure each layer has the
same amount of parameters.")
+ };
+ if (paramsPerFCLayer < 2 | paramsPerFCLayer %% 2 != 0) {
+ stop("At least 1 pair of W and b needs to be present, as well as
parameters need to be W&b pairs.")
+ }
+
+ # I.) initialize and preallocate masks
+ masks_new_meta = matrix(0, rows=length(model), cols=4) #
columns=[start,end,rows,cols]
+ current_position = 1
+ for (p in 1:length(model)) {
+ M = as.matrix(model[p])
+ param_length = ncol(M) * nrow(M) # as.scalar(ncol(M)) *
as.scalar(nrow(M))
+
+ masks_new_meta[p,1] = current_position
+ masks_new_meta[p,2] = current_position + param_length -1
+ masks_new_meta[p,3] = nrow(M)
+ masks_new_meta[p,4] = ncol(M)
+
+ current_position = current_position + param_length
+ }
+ mask_size = current_position-1
+ masks_new = matrix(0, rows=numSubnets, cols=mask_size) # all subnets in
one matrix
+
+ # II.) iterate all layers
+ for (l in 1:L) {
+
+ # FC layer: create #{numSubnets} disjoint partitions for this layer
across all parameters
+ if (as.scalar(isFC[1,l]) == 1) {
+ W = as.matrix(model[l])
+ b = as.matrix(model[l+L])
+ H = ncol(W); # bias neurons in layer l
+
+ # SANITY CHECKS:
+ if (nrow(b) != 1 | ncol(b) != H) {
+ if (verbose) print("Bias shape mismatch!")
+ if (verbose) print("b:", nrow(b), "x", ncol(b))
+ if (verbose) print("expected: 1 x", H)
+ stop("Invalid bias shape")
+ }
+ if (l!=L & numSubnets>ncol(b)) { # TODO change to next layer is
non-FC logic
+ if (verbose) print("More subnets than available neurons in
layer:")
+ if (verbose) print(l)
+ stop("Please use a wider model or decrease the amount of
subnets.")
+ }
+
+ # A.) shuffle all neuron indices
+ allNeuronIndicesRandom = order(target=rand(rows=H, cols=1), by=1,
decreasing=FALSE, index.return=TRUE)
+
+ # B.) determine neuron ownership
+ chunk_size = floor(H/numSubnets)
+ remaining_neurons = H - chunk_size * numSubnets
+ amount_active_neurons = matrix(chunk_size, rows=numSubnets, cols=1)
+ if (remaining_neurons > 0) {
+ randomSubnetIndices = order(target=rand(rows=numSubnets,
cols=1, seed=-1), by=1, decreasing=FALSE, index.return=TRUE) # TODO replace
seed for experiments
+ for (i in 1:remaining_neurons) {
+ sid = as.integer(as.scalar(randomSubnetIndices[i,1]))
+ amount_active_neurons[sid,1] =
as.scalar(amount_active_neurons[sid,1]) + 1 # TODO VECTORIZE
+ }
+ }
+ neuron_end_indices = cumsum(amount_active_neurons)
+ neuron_start_indices = neuron_end_indices - amount_active_neurons
+ 1
+
+ # C.) obtain masks for all subnets
+ for(s in 1:numSubnets) {
+
+ # 1. obtain owned neurons for this layer
+ start = as.integer(as.scalar(neuron_start_indices[s,1]))
+ end = as.integer(as.scalar(neuron_end_indices[s,1]))
+ current_b_indices = allNeuronIndicesRandom[start:end, 1]
+
+ # 2. create masked bias
+ if(l==L) { # output layer
+ masked_b = matrix(1, rows=1, cols=ncol(b))
+ }
+ else if (l<L & as.scalar(isFC[1, l+1]) == 0) { # next layer
is not FC
+ masked_b = matrix(1, rows=1, cols=ncol(b))
+ }
+ else {
+ masked_b = matrix(0, rows=1, cols=ncol(b))
+ for (i in 1:nrow(current_b_indices)) { # TODO VECTORIZE
+ idx = as.integer(as.scalar(current_b_indices[i,1]))
+ masked_b[1, idx] = 1
+ }
+ }
+
+ # 3. create masked weight
+ masked_W = matrix(0, rows=nrow(W), cols=ncol(W))
+ if(l==1) {
+ for (i in 1:nrow(current_b_indices)) { # TODO VECTORIZE
+ idx = as.integer(as.scalar(current_b_indices[i,1]))
+ masked_W[1:nrow(W), idx] = matrix(1, rows=nrow(W),
cols=1)
+ }
+ }
+ else if (l>1 & as.scalar(isFC[1, l-1])==0) { # previous layer
is not FC
+ for (i in 1:nrow(current_b_indices)) { # TODO VECTORIZE
+ idx = as.integer(as.scalar(current_b_indices[i,1]))
+ masked_W[1:nrow(W), idx] = matrix(1, rows=nrow(W),
cols=1)
+ }
+ }
+ else {
+ # obtain active neurons of previous layer
+ p = L + (l-1)
+ start = as.integer(as.scalar(masks_new_meta[p,1]))
+ end = as.integer(as.scalar(masks_new_meta[p,2]))
+ r = as.integer(as.scalar(masks_new_meta[p,3]))
+ c = as.integer(as.scalar(masks_new_meta[p,4]))
+ vec = masks_new[s, start:end]
+ previous_masked_b = matrix(vec, rows=r, cols=c, byrow=TRUE)
+
+ # SANITY CHECK: dimensions with layers of previous layer
match
+ if (l > 1 & ncol(previous_masked_b) != nrow(W)) {
+ if (verbose) print("W/prev layer mismatch in layer
l=", l)
+ if (verbose) print("prev_b:", nrow(previous_masked_b),
"x", ncol(previous_masked_b))
+ if (verbose) print("W:", nrow(W), "x", ncol(W))
+ stop("Invalid W shape wrt previous layer")
+ }
+
+ if (nrow(previous_masked_b)==1) previous_masked_b =
t(previous_masked_b)
+ if (ncol(masked_b) == 1) masked_b = t(masked_b)
+
+ if(l==L) { # output layer
+ masked_W = previous_masked_b %*% matrix(1, 1,
ncol(masked_W))
+ }
+ else if (l<L & as.scalar(isFC[1, l+1]) == 0) { # next
layer is not FC
+ masked_W = previous_masked_b %*% matrix(1, 1,
ncol(masked_W))
+ } else {
+ masked_W = previous_masked_b %*% masked_b
+ }
+ }
+
+ # 4. forward these masks to all parameters in this layer
+ for (param in 1:paramsPerFCLayer) {
+ k = (param-1)*L + l
+ start = as.integer(as.scalar(masks_new_meta[k,1]))
+ end = as.integer(as.scalar(masks_new_meta[k,2]))
+ len = end - start + 1
+
+ if (param %% 2 == 0) {
+ flat = matrix(masked_b, rows=1, cols=len, byrow=TRUE)
+ } else {
+ flat = matrix(masked_W, rows=1, cols=len, byrow=TRUE)
+ }
+ masks_new[s, start:end] = flat
+ }
+ }
+
+ # SANITY CHECK: accumulating all subnets bias's result in vector
of all 1s (in FC hidden layers only)
+ if (l < L) {
+ if (as.scalar(isFC[1, l+1]) == 1) {
+ # bias parameter index for layer l is (L + l)
+ p = L + l
+
+ start = as.integer(as.scalar(masks_new_meta[p,1]));
+ end = as.integer(as.scalar(masks_new_meta[p,2]));
+ r = as.integer(as.scalar(masks_new_meta[p,3]));
+ c = as.integer(as.scalar(masks_new_meta[p,4]));
+
+ # sum across subnets for this bias slice
+ sumB_flat = colSums(masks_new[1:numSubnets, start:end])
+
+ # reshape back to bias shape (usually Hx1 or 1xH depending
on how you store it)
+ sumB = matrix(sumB_flat, rows=r, cols=c, byrow=TRUE)
+
+ if (min(sumB) != 1 | max(sumB) != 1) {
+ if (verbose) print("Subnet bias masks not a partition
in layer l=" + l)
+ if (verbose) print("min(sumB)=" + min(sumB) + "
max(sumB)=" + max(sumB))
+ stop("Invalid subnet bias partition")
+ }
+ }
+ }
+ }
+ else {
+ # Non-FC layer: independent subnet training will not create
disjoint partitions for this layer // shared across all subnets -> masks are
all-ones
+ for (param in 1:paramsPerFCLayer) {
+ k = (param-1)*L + l
+ start = as.integer(as.scalar(masks_new_meta[k,1]))
+ end = as.integer(as.scalar(masks_new_meta[k,2]))
+ r = as.integer(as.scalar(masks_new_meta[k,3]))
+ c = as.integer(as.scalar(masks_new_meta[k,4]))
+ len = end - start + 1
+
+ if (param %% 2 == 0) {
+ shared_b = matrix(1, rows=r, cols=c)
+ flat = matrix(shared_b, rows=1, cols=len, byrow=TRUE)
+ } else {
+ shared_W = matrix(1, rows=r, cols=c)
+ flat = matrix(shared_W, rows=1, cols=len, byrow=TRUE)
+ }
+ masks_new[1:numSubnets, start:end] = matrix(1, numSubnets, 1)
%*% flat
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index dc1f23b83f..797377432d 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -207,6 +207,7 @@ public enum Builtins {
ISNA("is.na", "isNA", false),
ISNAN("is.nan", "isNaN", false),
ISINF("is.infinite", "isInf", false),
+ ISN_TRAIN("independentSubnetTrain", true),
KM("km", true),
KMEANS("kmeans", true),
KMEANSPREDICT("kmeansPredict", true),
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index fa87d452d1..67cda352a7 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -810,7 +810,7 @@ public class ExecutionContext {
for (String varName : varList) {
Data dat = _variables.get(varName);
if (dat instanceof CacheableData<?>)
-
((CacheableData<?>)dat).enableCleanup(varsState.poll());
+
((CacheableData<?>)dat).enableCleanup(Boolean.TRUE.equals(varsState.poll()));
else if (dat instanceof ListObject)
((ListObject)dat).enableCleanup(varsState);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 344f59535e..ae59eb54e9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -552,7 +552,7 @@ public class ListObject extends Data implements
Externalizable {
public void enableCleanup(Queue<Boolean> flags) {
for (Data dat : this.getData()) {
if (dat instanceof CacheableData<?>)
-
((CacheableData<?>)dat).enableCleanup(flags.poll());
+
((CacheableData<?>)dat).enableCleanup(Boolean.TRUE.equals(flags.poll()));
else if (dat instanceof ListObject)
((ListObject)dat).enableCleanup(flags);
}
diff --git a/src/test/config/SystemDS-config.xml
b/src/test/config/SystemDS-config.xml
index a899f5c71c..a051323af9 100644
--- a/src/test/config/SystemDS-config.xml
+++ b/src/test/config/SystemDS-config.xml
@@ -18,7 +18,7 @@
-->
<root>
- <!-- The number of theads for the spark instance artificially selected-->
+ <!-- The number of threads for the spark instance artificially selected-->
<sysds.local.spark.number.threads>2</sysds.local.spark.number.threads>
<!-- The timeout of the federated tests to initialize the federated
matrixes -->
<sysds.federated.initialization.timeout>2</sysds.federated.initialization.timeout>
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinIndSubnetTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinIndSubnetTest.java
new file mode 100644
index 0000000000..29ef6c3ca2
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinIndSubnetTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import static org.junit.Assert.assertTrue;
+
+@RunWith(value = Parameterized.class)
[email protected]
+@Ignore
+public class BuiltinIndSubnetTest extends AutomatedTestBase {
+
+ private static final Log LOG =
LogFactory.getLog(BuiltinIndSubnetTest.class.getName());
+
+ protected final static String TEST_NAME = "indSubnetTest_mnist_lenet";
+ protected final static String TEST_DIR = "functions/builtin/";
+ protected String TEST_CLASS_DIR = TEST_DIR +
BuiltinIndSubnetTest.class.getSimpleName() + "/";
+
+ private final String dataset_path;
+ private final double least_expected_acc;
+ private final String out_path;
+
+ public BuiltinIndSubnetTest(String dataset_path, double
least_expected_acc, String out_path) {
+ this.dataset_path = dataset_path;
+ this.least_expected_acc = least_expected_acc;
+ this.out_path = out_path;
+ }
+
+ @Parameters
+ public static Collection<Object[]> data() {
+ String path =
"src/test/resources/datasets/MNIST/mnist_test.csv";
+ double least_expected_acc = 0.5;
+ String out_path = "accuracy";
+ List<Object[]> tests = new ArrayList<>();
+ tests.add(new Object[]{path, least_expected_acc, out_path});
+
+ return tests;
+ }
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+ }
+
+ @Test
+ public void testClassificationFit() {
+
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ List<String> proArgs = new ArrayList<>();
+ proArgs.add("-args");
+ proArgs.add(this.dataset_path);
+ proArgs.add(output(this.out_path));
+
+ programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+ fullDMLScriptName = getScript();
+
+ LOG.error(runTest(null));
+
+ double[][] from_DML =
TestUtils.convertHashMapToDoubleArray(readDMLScalarFromOutputDir(this.out_path));
+ double accuracy = from_DML[0][0];
+ assertTrue("Accuracy lower than expected", accuracy >
this.least_expected_acc);
+ }
+}
diff --git a/src/test/scripts/functions/builtin/indSubnetTest_mnist_lenet.dml
b/src/test/scripts/functions/builtin/indSubnetTest_mnist_lenet.dml
new file mode 100644
index 0000000000..eba5444026
--- /dev/null
+++ b/src/test/scripts/functions/builtin/indSubnetTest_mnist_lenet.dml
@@ -0,0 +1,391 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * MNIST LeNet Example
+ */
+
+#-------------------------------------------------------------
+# TRAINING
+#-------------------------------------------------------------
+
+train = function(matrix[double] X, matrix[double] Y,
+ matrix[double] X_val, matrix[double] Y_val,
+ int C, int Hin, int Win, int epochs, int workers,
+ int batchsize)
+ return (matrix[double] W1, matrix[double] b1,
+ matrix[double] W2, matrix[double] b2,
+ matrix[double] W3, matrix[double] b3,
+ matrix[double] W4, matrix[double] b4) {
+ /*
+ * Trains a convolutional net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win).
+ * - Y: Target matrix, of shape (N, K).
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+ * - Y_val: Target validation matrix, of shape (N, K).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - epochs: Total number of full training loops over the full data set.
+ *
+ * Outputs:
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+ * - b1: 1st layer biases vector, of shape (F1, 1).
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+ * - b2: 2nd layer biases vector, of shape (F2, 1).
+ * - W3: 3rd layer weights (parameters) matrix, of shape
(F2*(Hin/4)*(Win/4), N3).
+ * - b3: 3rd layer biases vector, of shape (1, N3).
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+ * - b4: 4th layer biases vector, of shape (1, K).
+ */
+ print("Started training.")
+ N = nrow(X)
+ K = ncol(Y)
+
+ # Parameters in each layer
+ paramsPerLayer = 4
+ fullyConnectedLayers = list(3,4)
+
+ # Create network:
+ # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 ->
affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target
dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1) # inputs: (N, C*Hin*Win)
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1) # inputs: (N,
F1*(Hin/2)*(Win/2))
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1) # inputs: (N,
F2*(Hin/2/2)*(Win/2/2))
+ [W4, b4] = affine::init(N3, K, -1) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ lr = 0.01 # learning rate
+ mu = 0.9 #0.5 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+ # Regularization
+ lambda = 5e-04
+
+ # Create the model list
+ modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1,
vb2, vb3, vb4)
+
+ # Create the hyper parameter list
+ params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf,
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3,
fullyConnectedLayers=list(3,4))
+
+ # Length of an IST round
+ ist_round = 20
+
+ # Use independent subnet training function
+ modelList2 = independentSubnetTrain(features=X, labels=Y,
val_features=X_val, val_labels=Y_val, model=modelList, upd="computeGradients",
agg="aggregateSharedParameters", epochs=epochs, batchsize=batchsize,
j=ist_round, numSubnets=workers, hyperparams=params, verbose=TRUE,
paramsPerLayer=paramsPerLayer, fullyConnectedLayers=fullyConnectedLayers)
+
+
+ W1 = as.matrix(modelList2[1])
+ W2 = as.matrix(modelList2[2])
+ W3 = as.matrix(modelList2[3])
+ W4 = as.matrix(modelList2[4])
+ b1 = as.matrix(modelList2[5])
+ b2 = as.matrix(modelList2[6])
+ b3 = as.matrix(modelList2[7])
+ b4 = as.matrix(modelList2[8])
+ print(toString(modelList2))
+ print("Training finished.")
+}
+
+#-------------------------------------------------------------
+# GRADIENTS
+#-------------------------------------------------------------
+computeGradients = function(
+ list[unknown] model,
+ list[unknown] mask,
+ matrix[double] features,
+ matrix[double] labels,
+ list[unknown] hyperparams
+
+) return (list[unknown] subnet_model) {
+
+ # 1) full gradients
+ grads = gradients(model=model, hyperparams=hyperparams, features=features,
labels=labels)
+
+ # 2) mask gradients
+ grads_masked = list()
+ for (p in 1:length(grads)) {
+ grads_masked = append(grads_masked, as.matrix(grads[p]) *
as.matrix(mask[p]))
+ }
+
+ # 3) apply optimizer step locally
+ subnet_model = aggregation(model=model, hyperparams=hyperparams,
gradients=grads_masked)
+
+ # 4) mask velocities
+ for (p in (length(grads)+1):length(model)) {
+ subnet_model[p] = list(as.matrix(subnet_model[p]) * as.matrix(mask[p]))
+ }
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+ list[unknown] hyperparams,
+ matrix[double] features,
+ matrix[double] labels)
+ return (list[unknown] gradients) {
+
+ C = as.integer(as.scalar(hyperparams["C"]))
+ Hin = as.integer(as.scalar(hyperparams["Hin"]))
+ Win = as.integer(as.scalar(hyperparams["Win"]))
+ Hf = as.integer(as.scalar(hyperparams["Hf"]))
+ Wf = as.integer(as.scalar(hyperparams["Wf"]))
+ stride = as.integer(as.scalar(hyperparams["stride"]))
+ pad = as.integer(as.scalar(hyperparams["pad"]))
+ lambda = as.double(as.scalar(hyperparams["lambda"]))
+ F1 = as.integer(as.scalar(hyperparams["F1"]))
+ F2 = as.integer(as.scalar(hyperparams["F2"]))
+ N3 = as.integer(as.scalar(hyperparams["N3"]))
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf,
Wf,
+ stride, stride, pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2,
2, 2, 2, 0, 0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1,
Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2,
2, 2, 2, 0, 0)
+ ## layer 3: affine3 -> relu3 -> dropout
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outd3, W4, b4)
+ probs = softmax::forward(outa4)
+
+ # Compute data backward pass
+ ## loss:
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ ## layer 4: affine4 -> softmax
+ douta4 = softmax::backward(dprobs, outa4)
+ [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+ ## layer 3: affine3 -> relu3 -> dropout
+ doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+ douta3 = relu::backward(doutr3, outa3)
+ [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+ ## layer 2: conv2 -> relu2 -> pool2
+ doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2,
Woutc2, 2, 2, 2, 2, 0, 0)
+ doutc2 = relu::backward(doutr2, outc2)
+ [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2,
F1,
+ Houtp1, Woutp1, Hf, Wf, stride,
stride, pad, pad)
+ ## layer 1: conv1 -> relu1 -> pool1
+ doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1,
Woutc1, 2, 2, 2, 2, 0, 0)
+ doutc1 = relu::backward(doutr1, outc1)
+ [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features,
W1, b1, C, Hin, Win,
+ Hf, Wf, stride, stride, pad, pad)
+
+ # Compute regularization backward pass
+ dW1_reg = l2_reg::backward(W1, lambda)
+ dW2_reg = l2_reg::backward(W2, lambda)
+ dW3_reg = l2_reg::backward(W3, lambda)
+ dW4_reg = l2_reg::backward(W4, lambda)
+ dW1 = dW1 + dW1_reg
+ dW2 = dW2 + dW2_reg
+ dW3 = dW3 + dW3_reg
+ dW4 = dW4 + dW4_reg
+
+ gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+ list[unknown] hyperparams,
+ list[unknown] gradients)
+ return (list[unknown] modelResult) {
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ dW4 = as.matrix(gradients[4])
+ db1 = as.matrix(gradients[5])
+ db2 = as.matrix(gradients[6])
+ db3 = as.matrix(gradients[7])
+ db4 = as.matrix(gradients[8])
+ vW1 = as.matrix(model[9])
+ vW2 = as.matrix(model[10])
+ vW3 = as.matrix(model[11])
+ vW4 = as.matrix(model[12])
+ vb1 = as.matrix(model[13])
+ vb2 = as.matrix(model[14])
+ vb3 = as.matrix(model[15])
+ vb4 = as.matrix(model[16])
+ lr = as.double(as.scalar(hyperparams["lr"]))
+ mu = as.double(as.scalar(hyperparams["mu"]))
+
+ # Optimize with SGD w/ Nesterov momentum
+ [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+ [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+ [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+ [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+ [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+ [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+ [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+ [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+ modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4,
vb1, vb2, vb3, vb4)
+ }
+
+
+#-------------------------------------------------------------
+# AGGREGATION
+#-------------------------------------------------------------
+aggregateSharedParameters = function(
+ Matrix[Double] initialParam,
+ list[unknown] allSubnetsParam, # list of all subnets updates for a
certain shared parameter
+ list[unknown] allSubnetsMasks
+) return (Matrix[Double] averagedUpdatedParam) {
+
+ num = matrix(0, nrow(initialParam), ncol(initialParam))
+ den = matrix(0, nrow(initialParam), ncol(initialParam))
+
+ for (s in 1:length(allSubnetsParam)) {
+ num = num + (as.matrix(allSubnetsParam[s]) *
as.matrix(allSubnetsMasks[s]))
+ den = den + as.matrix(allSubnetsMasks[s])
+ }
+
+ # avoid divide by zero: where den==0, keep base
+ denNZ = (den > 0)
+
+ averagedUpdatedParam = initialParam * (1 - denNZ) + (num / pmax(1, den)) *
denNZ
+ #print("averagedUpdatedParam:")
+ #print(averagedUpdatedParam)
+}
+
+#-------------------------------------------------------------
+# MNIST
+#
+# Load CSV + preprocess + train/val split
+#
+# Returns:
+# X, Y, X_val, Y_val, C, Hin, Win
+#-------------------------------------------------------------
+
+generate_mnist_datasplit = function(
+ boolean make_three_channels
+)
+return (matrix[double] X, matrix[double] Y,
+ matrix[double] X_val, matrix[double] Y_val,
+ int C, int Hin, int Win)
+{
+ # Read training dataset
+ train = read("./src/test/resources/datasets/MNIST/mnist_test.csv",
format="csv")
+
+ # MNIST image properties
+ classes = 10
+ Hin = 28
+ Win = 28
+
+ # Extract images/labels
+ images = train[, 2:ncol(train)]
+ labels = train[, 1]
+
+ N = nrow(images)
+
+ # Scale to [-1, 1]
+ X_all = (images / 255.0) * 2 - 1
+
+ # Channels: LeNet wants C=1; ResNet wants C=3
+ if (make_three_channels) {
+ # duplicate along channels: (N, 784) -> (N, 2352)
+ X_all = cbind(X_all, X_all, X_all)
+ C = 3
+ } else {
+ C = 1
+ }
+
+ # One-hot encode
+ # labels in file are typically 0..9, table expects 1..K
+ Y_all = table(seq(1, N), labels + 1, N, classes)
+
+ # Train/val split
+ val_size = 5000 # Use first val_size rows for val, rest for train
(deterministic)
+
+ X_val = X_all[1:val_size, ]
+ Y_val = Y_all[1:val_size, ]
+
+ X = X_all[(val_size+1):N, ]
+ Y = Y_all[(val_size+1):N, ]
+}
+
+#-------------------------------------------------------------
+# EXECUTOR
+#-------------------------------------------------------------
+
+# Training parameters
+epochs = 10
+batch_size = 128
+workers = 8
+
+# 1) Load train/val
+[X, Y, X_val, Y_val, C, Hin, Win] = generate_mnist_datasplit(FALSE)
+
+# 2) Train (IST happens inside train())
+[W1, b1, W2, b2, W3, b3, W4, b4] = train(X, Y, X_val, Y_val, C, Hin, Win,
epochs, workers, batch_size)
\ No newline at end of file