This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 773d876a12 [SYSTEMDS-3928] New builtin function for Independent Subnet 
Training
773d876a12 is described below

commit 773d876a12b49de5d2e87bdb5674beaeab645586
Author: Arno Bock <[email protected]>
AuthorDate: Sat Mar 28 16:23:16 2026 +0100

    [SYSTEMDS-3928] New builtin function for Independent Subnet Training
    
    Closes #2427.
---
 scripts/builtin/independentSubnetTrain.dml         | 504 +++++++++++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../controlprogram/context/ExecutionContext.java   |   2 +-
 .../sysds/runtime/instructions/cp/ListObject.java  |   2 +-
 src/test/config/SystemDS-config.xml                |   2 +-
 .../builtin/part1/BuiltinIndSubnetTest.java        |  95 ++++
 .../builtin/indSubnetTest_mnist_lenet.dml          | 391 ++++++++++++++++
 7 files changed, 994 insertions(+), 3 deletions(-)

diff --git a/scripts/builtin/independentSubnetTrain.dml 
b/scripts/builtin/independentSubnetTrain.dml
new file mode 100644
index 0000000000..2009ab0515
--- /dev/null
+++ b/scripts/builtin/independentSubnetTrain.dml
@@ -0,0 +1,504 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+# Independent Subnet Training (IST)
+#
+# This builtin implements independent subnet training as a
+# second-order function. It orchestrates distributed / parallel
+# training over disjoint subnets using parfor, while delegating
+# architecture-specific logic to user-provided functions.
+# ------------------------------------------------------------
+# INPUT:
+#   model                : initial model parameters. A list of matrices for NN.
+#   features             : X
+#   labels               : Y
+#   val_features         : validation X
+#   val_labels           : validation Y
+#   upd                  : computes gradients and performs optimizer step
+#   agg                  : aggregation logic to combine updates of shared 
parameters (across subnets)
+#   epochs               : number of epochs
+#   batchsize            : batchsize for training
+#   j                    : number of gradient steps until aggregation -> 
determines length of the IST round (aggregation frequency)
+#   numSubnets           : number of independent subnets/workers
+#   hyperparams          : list of hyperparameters (e.g. lr, reg, mask params, 
etc.)
+#   verbose              : print progress (boolean)
+#   paramsPerLayer       : amount of parameters each layer consists of
+#   fullyConnectedLayers : list of all FC layer indices (starting at idx=1)
+#
+# OUTPUT:
+#   model_out     : trained model parameters (IST: W)
+#
+# ASSUMPTION:
+#   - the last layer is the output layer
+# ------------------------------------------------------------
+
+m_independentSubnetTrain = function(
+   list[unknown]        model,
+   matrix[double]       features,
+   matrix[double]       labels,
+   matrix[double]       val_features,
+   matrix[double]       val_labels,
+   string               upd,
+   string               agg,
+   int                  epochs,
+   int                  batchsize,
+   int                  j,
+   int                  numSubnets,
+   list[unknown]        hyperparams,
+   boolean              verbose,
+   int                  paramsPerLayer,
+   list[int]            fullyConnectedLayers
+)
+return (list[unknown] trained_model)
+{
+    # ------------------------------------------------------------
+    # Setup
+    # ------------------------------------------------------------
+    model_out = model
+
+    P = length(model)
+    N = nrow(features)
+    if (P %% paramsPerLayer != 0) stop("Model length not divisible by 
paramsPerLayer")
+    L = as.integer(P / paramsPerLayer)  # total layers
+
+    # I. determine shared parameters
+    isSharedParam = matrix(0, 1, P)
+
+    # - create mask for all FC layers
+    fcLayers = fullyConnectedLayers
+    isFC = matrix(0, rows=1, cols=L)
+    for (i in 1:length(fcLayers)) {
+       idx = as.integer(as.scalar(fcLayers[i]))
+       isFC[1, idx] = 1  # TODO vectorize
+    }
+
+    # - expand layer mask across all parameters
+    isFC_rep = isFC
+    for (r in 2:paramsPerLayer) {
+        isFC_rep = cbind(isFC_rep, isFC)
+    }
+    if (ncol(isFC_rep)!=P) stop("Dimension mismatch for FC layer mask.")
+
+    # - all non-FC layers are shared
+    isSharedParam = 1 - isFC_rep
+
+    # - edge case: FC bias parameters are shared in: output layer or at the 
end of a FC block
+    for (paramId in seq(2, paramsPerLayer, 2)) {   # iterate bias blocks only
+        for (l in 1:L) {
+            if (as.scalar(isFC[1,l])==1 & l==L) {
+                p_out_bias = (paramId - 1) * L + L  # output bias is shared 
across subnets
+                isSharedParam[1, p_out_bias] = 1  # TODO vectorize
+            }
+            else if (as.scalar(isFC[1,l])==1 & l<L & 
as.scalar(isFC[1,l+1])==0) {
+                p_out_bias = (paramId - 1) * L + l  # end of FC block's bias 
is shared across subnets
+                isSharedParam[1, p_out_bias] = 1  # TODO vectorize
+            }
+        }
+    }
+    if (ncol(isSharedParam) != P) stop("isSharedParam dimension mismatch!")
+
+    # II. calculate update-steps per epoch
+    if (batchsize<=0 | batchsize>N) {
+        stop("Batch size is out of bounds!")
+    } else {
+        stepsPerEpoch = ceil(N / batchsize)
+    }
+
+    # III. training loop
+    for (epoch in 1:epochs) {
+        if (verbose) print("Entered epoch: " + epoch)
+
+        # A.) reshuffle indices each epoch
+        allSampleIndicesRandom = order(target=rand(rows=N, cols=1), by=1, 
decreasing=FALSE, index.return=TRUE)
+        batchIndices = allSampleIndicesRandom[, 1]
+        b = nrow(batchIndices)
+        I = seq(1, b, 1)
+        V = matrix(1, rows=b, cols=1)
+        S = table(I, batchIndices, V, b, N)
+
+        features_shuffled = S %*% features
+        labels_shuffled = S %*% labels
+
+        # B.) iterate IST rounds
+        for (step in seq(1, stepsPerEpoch, j)) {
+            if (verbose) print("Starting new IST round at step: " + step)
+            round_model = model_out  # prevent accidental mutation of model_out
+
+            # 1.) create masks for all subnets
+            [masks, masks_meta_info] = ist_create_disjoint_masks(round_model, 
numSubnets, L, fcLayers, paramsPerLayer, isFC, verbose)
+
+            # 2.) preallocate list to store all subnets TODO move outside 
epoch loop?  to prevent constantly allocating...
+            updatedSubnets = list()
+            updatedSubnetsMasks = list()
+            for (s in 1:numSubnets) {
+                updatedSubnets      = append(updatedSubnets, list())
+                updatedSubnetsMasks = append(updatedSubnetsMasks, list())
+            }
+
+            # 3) create a template for each subnet based on input model 
(allows indexing in subsequent parfor-loop) TODO move outside epoch loop?  to 
prevent constantly allocating...
+            subnetModelTemplate = list()
+            subnetModelMaskTemplate = list()
+            for (pIdx in 1:P) {
+                subnetModelTemplate      = append(subnetModelTemplate, 
as.matrix(model[pIdx]))
+                subnetModelMaskTemplate  = append(subnetModelMaskTemplate, 
as.matrix(model[pIdx]))
+            }
+
+            # local optimization steps / IST round
+            localSteps = min(j, (stepsPerEpoch-step+1))
+
+            # 4.) obtain all minibatches for this IST round (doing it once 
prevents parfor confusion)
+            shuffled_features = list()
+            shuffled_labels = list()
+            for (localStep in 1:localSteps) {
+                mb = (step-1) + localStep
+                mb_local = mb-1
+                start = mb_local*batchsize + 1
+                end   = min(mb*batchsize, N)
+
+                Xb = features_shuffled[start:end, 1:ncol(features_shuffled)]
+                yb = labels_shuffled[start:end, 1:ncol(labels_shuffled)]
+                shuffled_features = append(shuffled_features, Xb)
+                shuffled_labels = append(shuffled_labels, yb)
+            }
+
+            # 5.) perform 'j' local gradient steps for each subnet
+            parfor (subnet in 1:numSubnets) {
+
+                # a.) obtain masked subnet
+                subnet_model = subnetModelTemplate
+                subnet_model_mask = subnetModelMaskTemplate
+                for (subnet_p in 1:length(round_model)) {
+                    param_start_idx = 
as.integer(as.scalar(masks_meta_info[subnet_p,1]))
+                    param_end_idx   = 
as.integer(as.scalar(masks_meta_info[subnet_p,2]))
+                    param_rows     = 
as.integer(as.scalar(masks_meta_info[subnet_p,3]))
+                    param_cols     = 
as.integer(as.scalar(masks_meta_info[subnet_p,4]))
+
+                    vec = masks[subnet, param_start_idx:param_end_idx]
+                    param_mask = matrix(vec, rows=param_rows, cols=param_cols, 
byrow=TRUE)
+                    param = as.matrix(round_model[subnet_p])
+
+                    subnet_model[subnet_p]      = list(param * param_mask)  # 
TODO sparse masking! dense masking will probably increase computational 
efficiency
+                    subnet_model_mask[subnet_p] = list(param_mask)
+                }
+
+                # b.) local optimization steps / IST round
+                for (localStep in 1:localSteps) {
+                    feat = as.matrix(shuffled_features[localStep])
+                    lab = as.matrix(shuffled_labels[localStep])
+
+                    # compute gradients for subnet s + apply update step on 
owned params (only)
+                    subnet_model = as.list(evalList(upd, 
list(model=subnet_model, mask=subnet_model_mask, features=feat, labels=lab, 
hyperparams=hyperparams)))
+                }
+
+                # c.) save updated subnet and mask
+                updatedSubnets[subnet] = list(subnet_model)
+                updatedSubnetsMasks[subnet] = list(subnet_model_mask)
+            }
+            if (verbose) print("All subnets have run successfully.")
+
+            # 6.) aggregate updates into global model (i.e. model_out)
+            for (p in 1:P) {
+                if (as.scalar(isSharedParam[1, p])==1) {
+                    # construct full model update by aggregating shared 
parameter updates from all subnets
+                    subnetParams = list()
+                    subnetMasks  = list()
+                    for (s in 1:numSubnets) {
+                        subnet = as.list(updatedSubnets[s])
+                        subnetMask = as.list(updatedSubnetsMasks[s])
+
+                        subnetParams = append(subnetParams, 
as.matrix(subnet[p]))
+                        subnetMasks = append(subnetMasks, 
as.matrix(subnetMask[p]))
+                    }
+
+                    # aggregate shared parameters based on provided function
+                    averagedUpdatedParam = eval(agg, 
list(initialParam=as.matrix(round_model[p]), allSubnetsParam=subnetParams, 
allSubnetsMasks=subnetMasks))
+                    round_model[p] = averagedUpdatedParam
+               }
+               else {
+                    # construct full model update by filling with disjointly 
partitioned parameter updates from all subnets
+                    initialParam = as.matrix(round_model[p])
+                    updatedParam = matrix(0, nrow(initialParam), 
ncol(initialParam))
+                    owned = matrix(0, nrow(initialParam), ncol(initialParam))
+
+                    for (s in 1:numSubnets) {
+                        subnet = as.list(updatedSubnets[s])
+                        subnetMask = as.list(updatedSubnetsMasks[s])
+
+                        owned = owned + as.matrix(subnetMask[p])
+                        updatedParam = updatedParam + as.matrix(subnet[p])
+                    }
+
+                    # SANITY CHECK:
+                    max_freq = max(owned)
+                    if (max_freq > 1) stop("Overlap detected")
+
+                    round_model[p] = updatedParam
+               }
+            }
+            if (verbose) print("Aggregation of subnets finished. IST round has 
been successfully executed!")
+
+            # 7.) update global model (end of the IST round)
+            model_out = round_model
+        }
+        # TODO (potentially): add validation for early stopping etc.
+    }
+    trained_model = model_out
+}
+
+
+# 
----------------------------------------------------------------------------------------------------------------------
+# Independent Subnet Masking
+#
+# This helper function creates two matrices: one contains all flattened masks, 
the other contains the info on
+# how to reconstruct the mask matrices. Each mask is a binary vector 
indicating which parameters belong to that subnet.
+# 
----------------------------------------------------------------------------------------------------------------------
+# INPUT:
+#   model                : list of parameter tensors grouped by parameter type 
i.e. blocks
+#   numSubnets           : number of independent subnets/workers (K)
+#   L                    : total number of layers INCLUDING the output layer 
(layer indices are assumed to be 1..L)
+#   fullyConnectedLayers : list of all FC layer indices (starting at idx=1)
+#   paramsPerFCLayer     : number of parameters / neurons to be partitioned 
per FC layer
+#   isFC                 : indicator matrix encoding which layers are FC => 
isFC[l] ∈ {0,1}
+#   verbose              : print progress (boolean)
+#
+# OUTPUT:
+#   masks_new            : mask matrix defining disjoint neuron ownership 
across subnets
+#   masks_new_meta       : metadata matrix describing the mask layout and 
ownership mapping
+#
+# ASSUMPTIONS:
+#   - neuron ownership is defined via bias vectors
+#   - model is a list of parameter tensors
+#   - trainable parameters are grouped by parameter type i.e. param blocks 
like (W_l1, W_l2, ..., b_l1, b_l2, ...)
+#   - assumes W and b are always the first two param blocks
+#   - the pattern of optional optimizer state tensors (e.g., vW_l, vb_l) 
follow the same grouping and always W followed by b
+#   - (output layer & end of FC block) biases are shared -> gradients collide; 
must be handled by aggregation logic
+# 
----------------------------------------------------------------------------------------------------------------------
+
+ist_create_disjoint_masks = function(
+    list[unknown]   model,
+    int             numSubnets,
+    int             L,
+    list[int]       fullyConnectedLayers,
+    int             paramsPerFCLayer,
+    Matrix[Double]  isFC,
+    boolean         verbose
+)
+  return (
+    Matrix[Double]  masks_new,
+    Matrix[Double]  masks_new_meta
+  )
+{
+    P = length(model)
+
+    # SANITY CHECKS: ensure provided model can be masked correctly
+    if (as.integer(P / paramsPerFCLayer) != L) {
+        stop("Layer/parameter mismatch. Please make sure each layer has the 
same amount of parameters.")
+    };
+    if (paramsPerFCLayer < 2 | paramsPerFCLayer %% 2 != 0) {
+        stop("At least 1 pair of W and b needs to be present, as well as 
parameters need to be W&b pairs.")
+    }
+
+    # I.) initialize and preallocate masks
+    masks_new_meta = matrix(0, rows=length(model), cols=4)  # 
columns=[start,end,rows,cols]
+    current_position = 1
+    for (p in 1:length(model)) {
+        M = as.matrix(model[p])
+        param_length = ncol(M) * nrow(M)  # as.scalar(ncol(M)) * 
as.scalar(nrow(M))
+
+        masks_new_meta[p,1] = current_position
+        masks_new_meta[p,2] = current_position + param_length -1
+        masks_new_meta[p,3] = nrow(M)
+        masks_new_meta[p,4] = ncol(M)
+
+        current_position = current_position + param_length
+    }
+    mask_size = current_position-1
+    masks_new = matrix(0, rows=numSubnets, cols=mask_size)  # all subnets in 
one matrix
+
+    # II.) iterate all layers
+    for (l in 1:L) {
+
+        # FC layer: create #{numSubnets} disjoint partitions for this layer 
across all parameters
+        if (as.scalar(isFC[1,l]) == 1) {
+            W = as.matrix(model[l])
+            b = as.matrix(model[l+L])
+            H = ncol(W);  # bias neurons in layer l
+
+            # SANITY CHECKS:
+            if (nrow(b) != 1 | ncol(b) != H) {
+                if (verbose) print("Bias shape mismatch!")
+                if (verbose) print("b:", nrow(b), "x", ncol(b))
+                if (verbose) print("expected: 1 x", H)
+                stop("Invalid bias shape")
+            }
+            if (l!=L & numSubnets>ncol(b)) {  # TODO change to next layer is 
non-FC logic
+                if (verbose) print("More subnets than available neurons in 
layer:")
+                if (verbose) print(l)
+                stop("Please use a wider model or decrease the amount of 
subnets.")
+            }
+
+            # A.) shuffle all neuron indices
+            allNeuronIndicesRandom = order(target=rand(rows=H, cols=1), by=1, 
decreasing=FALSE, index.return=TRUE)
+
+            # B.) determine neuron ownership
+            chunk_size = floor(H/numSubnets)
+            remaining_neurons = H - chunk_size * numSubnets
+            amount_active_neurons = matrix(chunk_size, rows=numSubnets, cols=1)
+            if (remaining_neurons > 0) {
+                randomSubnetIndices = order(target=rand(rows=numSubnets, 
cols=1, seed=-1), by=1, decreasing=FALSE, index.return=TRUE)  # TODO replace 
seed for experiments
+                for (i in 1:remaining_neurons) {
+                  sid = as.integer(as.scalar(randomSubnetIndices[i,1]))
+                  amount_active_neurons[sid,1] = 
as.scalar(amount_active_neurons[sid,1]) + 1  # TODO VECTORIZE
+                }
+            }
+            neuron_end_indices = cumsum(amount_active_neurons)
+            neuron_start_indices = neuron_end_indices - amount_active_neurons 
+ 1
+
+            # C.) obtain masks for all subnets
+            for(s in 1:numSubnets) {
+
+                # 1. obtain owned neurons for this layer
+                start = as.integer(as.scalar(neuron_start_indices[s,1]))
+                end = as.integer(as.scalar(neuron_end_indices[s,1]))
+                current_b_indices = allNeuronIndicesRandom[start:end, 1]
+
+                # 2. create masked bias
+                if(l==L) {  # output layer
+                    masked_b = matrix(1, rows=1, cols=ncol(b))
+                }
+                else if (l<L & as.scalar(isFC[1, l+1]) == 0) {  # next layer 
is not FC
+                    masked_b = matrix(1, rows=1, cols=ncol(b))
+                }
+                else {
+                    masked_b = matrix(0, rows=1, cols=ncol(b))
+                    for (i in 1:nrow(current_b_indices)) {  # TODO VECTORIZE
+                       idx = as.integer(as.scalar(current_b_indices[i,1]))
+                       masked_b[1, idx] = 1
+                    }
+                }
+
+                # 3. create masked weight
+                masked_W = matrix(0, rows=nrow(W), cols=ncol(W))
+                if(l==1) {
+                    for (i in 1:nrow(current_b_indices)) {  # TODO VECTORIZE
+                       idx = as.integer(as.scalar(current_b_indices[i,1]))
+                       masked_W[1:nrow(W), idx] = matrix(1, rows=nrow(W), 
cols=1)
+                    }
+                }
+                else if (l>1 & as.scalar(isFC[1, l-1])==0) {  # previous layer 
is not FC
+                    for (i in 1:nrow(current_b_indices)) {  # TODO VECTORIZE
+                       idx = as.integer(as.scalar(current_b_indices[i,1]))
+                       masked_W[1:nrow(W), idx] = matrix(1, rows=nrow(W), 
cols=1)
+                    }
+                }
+                else {
+                    # obtain active neurons of previous layer
+                    p = L + (l-1)
+                    start = as.integer(as.scalar(masks_new_meta[p,1]))
+                    end   = as.integer(as.scalar(masks_new_meta[p,2]))
+                    r     = as.integer(as.scalar(masks_new_meta[p,3]))
+                    c     = as.integer(as.scalar(masks_new_meta[p,4]))
+                    vec = masks_new[s, start:end]
+                    previous_masked_b = matrix(vec, rows=r, cols=c, byrow=TRUE)
+
+                    # SANITY CHECK: dimensions with layers of previous layer 
match
+                    if (l > 1 & ncol(previous_masked_b) != nrow(W)) {
+                        if (verbose) print("W/prev layer mismatch in layer 
l=", l)
+                        if (verbose) print("prev_b:", nrow(previous_masked_b), 
"x", ncol(previous_masked_b))
+                        if (verbose) print("W:", nrow(W), "x", ncol(W))
+                        stop("Invalid W shape wrt previous layer")
+                    }
+
+                    if (nrow(previous_masked_b)==1) previous_masked_b = 
t(previous_masked_b)
+                    if (ncol(masked_b) == 1) masked_b = t(masked_b)
+
+                    if(l==L) {  # output layer
+                        masked_W = previous_masked_b %*% matrix(1, 1, 
ncol(masked_W))
+                    }
+                    else if (l<L & as.scalar(isFC[1, l+1]) == 0) {  # next 
layer is not FC
+                        masked_W = previous_masked_b %*% matrix(1, 1, 
ncol(masked_W))
+                    } else {
+                        masked_W = previous_masked_b %*% masked_b
+                    }
+                }
+
+                # 4. forward these masks to all parameters in this layer
+                for (param in 1:paramsPerFCLayer) {
+                    k = (param-1)*L + l
+                    start = as.integer(as.scalar(masks_new_meta[k,1]))
+                    end   = as.integer(as.scalar(masks_new_meta[k,2]))
+                    len   = end - start + 1
+
+                    if (param %% 2 == 0) {
+                        flat = matrix(masked_b, rows=1, cols=len, byrow=TRUE)
+                    } else {
+                        flat = matrix(masked_W, rows=1, cols=len, byrow=TRUE)
+                    }
+                    masks_new[s, start:end] = flat
+                }
+            }
+
+            # SANITY CHECK: accumulating all subnets bias's result in vector 
of all 1s (in FC hidden layers only)
+            if (l < L) {
+                if (as.scalar(isFC[1, l+1]) == 1) {
+                    # bias parameter index for layer l is (L + l)
+                    p = L + l
+
+                    start = as.integer(as.scalar(masks_new_meta[p,1]));
+                    end   = as.integer(as.scalar(masks_new_meta[p,2]));
+                    r     = as.integer(as.scalar(masks_new_meta[p,3]));
+                    c     = as.integer(as.scalar(masks_new_meta[p,4]));
+
+                    # sum across subnets for this bias slice
+                    sumB_flat = colSums(masks_new[1:numSubnets, start:end])
+
+                    # reshape back to bias shape (usually Hx1 or 1xH depending 
on how you store it)
+                    sumB = matrix(sumB_flat, rows=r, cols=c, byrow=TRUE)
+
+                    if (min(sumB) != 1 | max(sumB) != 1) {
+                        if (verbose) print("Subnet bias masks not a partition 
in layer l=" + l)
+                        if (verbose) print("min(sumB)=" + min(sumB) + " 
max(sumB)=" + max(sumB))
+                        stop("Invalid subnet bias partition")
+                    }
+                }
+            }
+        }
+        else {
+            # Non-FC layer: independent subnet training will not create 
disjoint partitions for this layer //  shared across all subnets -> masks are 
all-ones
+            for (param in 1:paramsPerFCLayer) {
+                k = (param-1)*L + l
+                start = as.integer(as.scalar(masks_new_meta[k,1]))
+                end   = as.integer(as.scalar(masks_new_meta[k,2]))
+                r     = as.integer(as.scalar(masks_new_meta[k,3]))
+                c     = as.integer(as.scalar(masks_new_meta[k,4]))
+                len   = end - start + 1
+
+                if (param %% 2 == 0) {
+                    shared_b = matrix(1, rows=r, cols=c)
+                    flat = matrix(shared_b, rows=1, cols=len, byrow=TRUE)
+                } else {
+                    shared_W = matrix(1, rows=r, cols=c)
+                    flat = matrix(shared_W, rows=1, cols=len, byrow=TRUE)
+                }
+                masks_new[1:numSubnets, start:end] = matrix(1, numSubnets, 1) 
%*% flat
+            }
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index dc1f23b83f..797377432d 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -207,6 +207,7 @@ public enum Builtins {
        ISNA("is.na", "isNA", false),
        ISNAN("is.nan", "isNaN", false),
        ISINF("is.infinite", "isInf", false),
+    ISN_TRAIN("independentSubnetTrain", true),
        KM("km", true),
        KMEANS("kmeans", true),
        KMEANSPREDICT("kmeansPredict", true),
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index fa87d452d1..67cda352a7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -810,7 +810,7 @@ public class ExecutionContext {
                for (String varName : varList) {
                        Data dat = _variables.get(varName);
                        if (dat instanceof CacheableData<?>)
-                               
((CacheableData<?>)dat).enableCleanup(varsState.poll());
+                               
((CacheableData<?>)dat).enableCleanup(Boolean.TRUE.equals(varsState.poll()));
                        else if (dat instanceof ListObject)
                                ((ListObject)dat).enableCleanup(varsState);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 344f59535e..ae59eb54e9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -552,7 +552,7 @@ public class ListObject extends Data implements 
Externalizable {
        public void enableCleanup(Queue<Boolean> flags) {
                for (Data dat : this.getData()) {
                        if (dat instanceof CacheableData<?>)
-                               
((CacheableData<?>)dat).enableCleanup(flags.poll());
+                               
((CacheableData<?>)dat).enableCleanup(Boolean.TRUE.equals(flags.poll()));
                        else if (dat instanceof ListObject)
                                ((ListObject)dat).enableCleanup(flags);
                }
diff --git a/src/test/config/SystemDS-config.xml 
b/src/test/config/SystemDS-config.xml
index a899f5c71c..a051323af9 100644
--- a/src/test/config/SystemDS-config.xml
+++ b/src/test/config/SystemDS-config.xml
@@ -18,7 +18,7 @@
 -->
 
 <root>
-   <!-- The number of theads for the spark instance artificially selected-->
+   <!-- The number of threads for the spark instance artificially selected-->
    <sysds.local.spark.number.threads>2</sysds.local.spark.number.threads>
    <!-- The timeout of the federated tests to initialize the federated 
matrixes -->
    
<sysds.federated.initialization.timeout>2</sysds.federated.initialization.timeout>
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinIndSubnetTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinIndSubnetTest.java
new file mode 100644
index 0000000000..29ef6c3ca2
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinIndSubnetTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import static org.junit.Assert.assertTrue;
+
+@RunWith(value = Parameterized.class)
[email protected]
+@Ignore
+public class BuiltinIndSubnetTest extends AutomatedTestBase {
+
+       private static final Log LOG = 
LogFactory.getLog(BuiltinIndSubnetTest.class.getName());
+
+       protected final static String TEST_NAME = "indSubnetTest_mnist_lenet";
+       protected final static String TEST_DIR = "functions/builtin/";
+       protected String TEST_CLASS_DIR = TEST_DIR + 
BuiltinIndSubnetTest.class.getSimpleName() + "/";
+
+       private final String dataset_path;
+       private final double least_expected_acc;
+       private final String out_path;
+
+       public BuiltinIndSubnetTest(String dataset_path, double 
least_expected_acc, String out_path) {
+               this.dataset_path = dataset_path;
+               this.least_expected_acc = least_expected_acc;
+               this.out_path = out_path;
+       }
+
+       @Parameters
+       public static Collection<Object[]> data() {
+               String path = 
"src/test/resources/datasets/MNIST/mnist_test.csv";
+               double least_expected_acc = 0.5;
+               String out_path = "accuracy";
+               List<Object[]> tests = new ArrayList<>();
+               tests.add(new Object[]{path, least_expected_acc, out_path});
+
+               return tests;
+       }
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+       }
+
+       @Test
+       public void testClassificationFit() {
+
+               getAndLoadTestConfiguration(TEST_NAME);
+
+               List<String> proArgs = new ArrayList<>();
+               proArgs.add("-args");
+               proArgs.add(this.dataset_path);
+               proArgs.add(output(this.out_path));
+
+               programArgs = proArgs.toArray(new String[proArgs.size()]);
+
+               fullDMLScriptName = getScript();
+
+               LOG.error(runTest(null));
+
+               double[][] from_DML = 
TestUtils.convertHashMapToDoubleArray(readDMLScalarFromOutputDir(this.out_path));
+               double accuracy = from_DML[0][0];
+               assertTrue("Accuracy lower than expected", accuracy > 
this.least_expected_acc);
+       }
+}
diff --git a/src/test/scripts/functions/builtin/indSubnetTest_mnist_lenet.dml 
b/src/test/scripts/functions/builtin/indSubnetTest_mnist_lenet.dml
new file mode 100644
index 0000000000..eba5444026
--- /dev/null
+++ b/src/test/scripts/functions/builtin/indSubnetTest_mnist_lenet.dml
@@ -0,0 +1,391 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * MNIST LeNet Example
+ */
+
+#-------------------------------------------------------------
+# TRAINING
+#-------------------------------------------------------------
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers,
+                 int batchsize)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  print("Started training.")
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Parameters in each layer
+  paramsPerLayer = 4
+  fullyConnectedLayers = list(3,4)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K, -1)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model list
+  modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, 
vb2, vb3, vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3, 
fullyConnectedLayers=list(3,4))
+
+  # Length of an IST round
+  ist_round = 20
+
+  # Use independent subnet training function
+  modelList2 = independentSubnetTrain(features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, model=modelList, upd="computeGradients", 
agg="aggregateSharedParameters", epochs=epochs, batchsize=batchsize, 
j=ist_round, numSubnets=workers, hyperparams=params, verbose=TRUE, 
paramsPerLayer=paramsPerLayer, fullyConnectedLayers=fullyConnectedLayers)
+
+
+  W1 = as.matrix(modelList2[1])
+  W2 = as.matrix(modelList2[2])
+  W3 = as.matrix(modelList2[3])
+  W4 = as.matrix(modelList2[4])
+  b1 = as.matrix(modelList2[5])
+  b2 = as.matrix(modelList2[6])
+  b3 = as.matrix(modelList2[7])
+  b4 = as.matrix(modelList2[8])
+  print(toString(modelList2))
+  print("Training finished.")
+}
+
+#-------------------------------------------------------------
+# GRADIENTS
+#-------------------------------------------------------------
+computeGradients = function(
+    list[unknown] model,
+    list[unknown] mask,
+    matrix[double] features,
+    matrix[double] labels,
+    list[unknown] hyperparams
+
+) return (list[unknown] subnet_model) {
+
+    # 1) full gradients
+    grads = gradients(model=model, hyperparams=hyperparams, features=features, 
labels=labels)
+
+    # 2) mask gradients
+    grads_masked = list()
+    for (p in 1:length(grads)) {
+        grads_masked = append(grads_masked, as.matrix(grads[p]) * 
as.matrix(mask[p]))
+    }
+
+    # 3) apply optimizer step locally
+    subnet_model = aggregation(model=model, hyperparams=hyperparams, 
gradients=grads_masked)
+
+    # 4) mask velocities
+    for (p in (length(grads)+1):length(model)) {
+        subnet_model[p] = list(as.matrix(subnet_model[p]) * as.matrix(mask[p]))
+    }
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 
2, 2, 2, 0, 0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 
2, 2, 2, 0, 0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, 2, 2, 2, 2, 0, 0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, 2, 2, 2, 2, 0, 0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+   return (list[unknown] modelResult) {
+     W1 = as.matrix(model[1])
+     W2 = as.matrix(model[2])
+     W3 = as.matrix(model[3])
+     W4 = as.matrix(model[4])
+     b1 = as.matrix(model[5])
+     b2 = as.matrix(model[6])
+     b3 = as.matrix(model[7])
+     b4 = as.matrix(model[8])
+     dW1 = as.matrix(gradients[1])
+     dW2 = as.matrix(gradients[2])
+     dW3 = as.matrix(gradients[3])
+     dW4 = as.matrix(gradients[4])
+     db1 = as.matrix(gradients[5])
+     db2 = as.matrix(gradients[6])
+     db3 = as.matrix(gradients[7])
+     db4 = as.matrix(gradients[8])
+     vW1 = as.matrix(model[9])
+     vW2 = as.matrix(model[10])
+     vW3 = as.matrix(model[11])
+     vW4 = as.matrix(model[12])
+     vb1 = as.matrix(model[13])
+     vb2 = as.matrix(model[14])
+     vb3 = as.matrix(model[15])
+     vb4 = as.matrix(model[16])
+     lr = as.double(as.scalar(hyperparams["lr"]))
+     mu = as.double(as.scalar(hyperparams["mu"]))
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
+   }
+
+
+#-------------------------------------------------------------
+# AGGREGATION
+#-------------------------------------------------------------
+aggregateSharedParameters = function(
+    Matrix[Double] initialParam,
+    list[unknown] allSubnetsParam,   # list of all subnets updates for a 
certain shared parameter
+    list[unknown] allSubnetsMasks
+) return (Matrix[Double] averagedUpdatedParam) {
+
+    num = matrix(0, nrow(initialParam), ncol(initialParam))
+    den = matrix(0, nrow(initialParam), ncol(initialParam))
+
+    for (s in 1:length(allSubnetsParam)) {
+        num = num + (as.matrix(allSubnetsParam[s]) * 
as.matrix(allSubnetsMasks[s]))
+        den = den + as.matrix(allSubnetsMasks[s])
+    }
+
+    # avoid divide by zero: where den==0, keep base
+    denNZ = (den > 0)
+
+    averagedUpdatedParam = initialParam * (1 - denNZ) + (num / pmax(1, den)) * 
denNZ
+    #print("averagedUpdatedParam:")
+    #print(averagedUpdatedParam)
+}
+
+#-------------------------------------------------------------
+# MNIST
+#
+# Load CSV + preprocess + train/val split
+#
+# Returns:
+#       X, Y, X_val, Y_val, C, Hin, Win
+#-------------------------------------------------------------
+
+generate_mnist_datasplit = function(
+    boolean make_three_channels
+)
+return (matrix[double] X, matrix[double] Y,
+        matrix[double] X_val, matrix[double] Y_val,
+        int C, int Hin, int Win)
+{
+    # Read training dataset
+    train = read("./src/test/resources/datasets/MNIST/mnist_test.csv", 
format="csv")
+
+    # MNIST image properties
+    classes = 10
+    Hin = 28
+    Win = 28
+
+    # Extract images/labels
+    images = train[, 2:ncol(train)]
+    labels = train[, 1]
+
+    N = nrow(images)
+
+    # Scale to [-1, 1]
+    X_all = (images / 255.0) * 2 - 1
+
+    # Channels: LeNet wants C=1; ResNet wants C=3
+    if (make_three_channels) {
+        # duplicate along channels: (N, 784) -> (N, 2352)
+        X_all = cbind(X_all, X_all, X_all)
+        C = 3
+    } else {
+        C = 1
+    }
+
+    # One-hot encode
+    # labels in file are typically 0..9, table expects 1..K
+    Y_all = table(seq(1, N), labels + 1, N, classes)
+
+    # Train/val split
+    val_size = 5000  # Use first val_size rows for val, rest for train 
(deterministic)
+
+    X_val = X_all[1:val_size, ]
+    Y_val = Y_all[1:val_size, ]
+
+    X = X_all[(val_size+1):N, ]
+    Y = Y_all[(val_size+1):N, ]
+}
+
+#-------------------------------------------------------------
+# EXECUTOR
+#-------------------------------------------------------------
+
+# Training parameters
+epochs = 10
+batch_size = 128
+workers = 8
+
+# 1) Load train/val
+[X, Y, X_val, Y_val, C, Hin, Win] = generate_mnist_datasplit(FALSE)
+
+# 2) Train (IST happens inside train())
+[W1, b1, W2, b2, W3, b3, W4, b4] = train(X, Y, X_val, Y_val, C, Hin, Win, 
epochs, workers, batch_size)
\ No newline at end of file


Reply via email to