This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 192e81c  [SYSTEMDS-3067] Optional model averaging in local/federated 
paramserv
192e81c is described below

commit 192e81c95ff46c6467171184370f3a2efed26848
Author: Atefeh Asayesh <[email protected]>
AuthorDate: Mon Aug 23 19:02:25 2021 +0200

    [SYSTEMDS-3067] Optional model averaging in local/federated paramserv
    
    Closes #1358.
---
 .../ParameterizedBuiltinFunctionExpression.java    |   4 +-
 .../java/org/apache/sysds/parser/Statement.java    |   7 +-
 .../runtime/compress/CompressedMatrixBlock.java    |   4 +-
 .../paramserv/FederatedPSControlThread.java        |  70 ++-
 .../controlprogram/paramserv/LocalPSWorker.java    |  75 ++--
 .../controlprogram/paramserv/LocalParamServer.java |   8 +-
 .../runtime/controlprogram/paramserv/PSWorker.java |  10 +-
 .../controlprogram/paramserv/ParamServer.java      | 125 +++++-
 .../controlprogram/paramserv/ParamservUtils.java   |  35 ++
 .../controlprogram/paramserv/SparkPSWorker.java    |  17 +-
 .../cp/ParamservBuiltinCPInstruction.java          |  49 ++-
 .../paramserv/AvgModelFederatedParamservTest.java  | 245 +++++++++++
 .../paramserv/ParamservLocalNNAveragingTest.java   |  75 ++++
 .../paramserv/AvgModelFederatedParamservTest.dml   |  61 +++
 .../functions/federated/paramserv/CNNModelAvg.dml  | 474 +++++++++++++++++++++
 .../federated/paramserv/TwoNNModelAvg.dml          | 307 +++++++++++++
 .../paramserv/mnist_lenet_paramserv_avg.dml        | 372 ++++++++++++++++
 .../paramserv/paramserv-averaging-test.dml         |  49 +++
 18 files changed, 1861 insertions(+), 126 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 26a54ac..1f5fd16 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -102,7 +102,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                        new ParameterizedBuiltinFunctionExpression(ctx, 
pbifop,varParams, fileName);
                return retVal;
        }
-       
+
                        
        public ParameterizedBuiltinFunctionExpression(ParserRuleContext ctx, 
Builtins op, LinkedHashMap<String,Expression> varParams,
                        String filename) {
@@ -315,7 +315,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                        Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, 
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
                        Statement.PS_VAL_FUN, Statement.PS_MODE, 
Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
                        Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, 
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
-                       Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, 
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
+                       Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, 
Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_MODELAVG);
                checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
                // check existence and correctness of parameters
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java 
b/src/main/java/org/apache/sysds/parser/Statement.java
index d15fd44..f9f5911 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -33,10 +33,10 @@ public abstract class Statement implements ParseInfo
        // parameter names for seq()
        public static final String SEQ_FROM = "from"; 
        public static final String SEQ_TO   = "to";
-       public static final String SEQ_INCR     = "incr";
+       public static final String SEQ_INCR = "incr";
        
-       public static final String SOURCE       = "source";
-       public static final String SETWD        = "setwd";
+       public static final String SOURCE   = "source";
+       public static final String SETWD    = "setwd";
 
        public static final String MATRIX_DATA_TYPE = "matrix";
        public static final String FRAME_DATA_TYPE = "frame";
@@ -72,6 +72,7 @@ public abstract class Statement implements ParseInfo
        public static final String PS_MODE = "mode";
        public static final String PS_GRADIENTS = "gradients";
        public static final String PS_SEED = "seed";
+       public static final String PS_MODELAVG = "modelAvg";
        public enum PSModeType {
                FEDERATED, LOCAL, REMOTE_SPARK
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index bae7bae..033368e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -1224,11 +1224,11 @@ public class CompressedMatrixBlock extends MatrixBlock {
                return getUncompressed();
        }
 
-       private void printDecompressWarning(String operation) {
+       private static void printDecompressWarning(String operation) {
                LOG.warn("Operation '" + operation + "' not supported yet - 
decompressing for ULA operations.");
        }
 
-       private void printDecompressWarning(String operation, MatrixBlock m2) {
+       private static void printDecompressWarning(String operation, 
MatrixBlock m2) {
                if(isCompressed(m2))
                        LOG.warn("Operation '" + operation + "' not supported 
yet - decompressing for ULA operations.");
                else
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index c77ddf4..de99e20 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -44,6 +44,7 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.BooleanObject;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
@@ -84,15 +85,16 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
 
        public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq,
                PSRuntimeBalancing runtimeBalancing, boolean weighting, int 
epochs, long batchSize,
-               int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer 
ps)
+               int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer 
ps, boolean modelAvg)
        {
-               super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+               super(workerID, updFunc, freq, epochs, batchSize, ec, ps, 
modelAvg);
 
                _numBatchesPerEpoch = numBatchesPerGlobalEpoch;
                _runtimeBalancing = runtimeBalancing;
                _weighting = weighting;
                // generate the ID for the model
                _modelVarID = FederationUtils.getNextFedDataID();
+               _modelAvg = modelAvg;
        }
 
        /**
@@ -150,7 +152,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        aggProgramBlock.setInstructions(new 
ArrayList<>(Collections.singletonList(_ps.getAggInst())));
                        pbs.add(aggProgramBlock);
                }
-               
+
                programSerialized = InstructionUtils.concatStrings(
                        PROG_BEGIN, NEWLINE,
                        ProgramConverter.serializeProgram(_ec.getProgram(), 
pbs, new HashMap<>()),
@@ -159,17 +161,10 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                // write program and meta data to worker
                Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(
                        new FederatedRequest(RequestType.EXEC_UDF, 
_featuresData.getVarID(),
-                               new SetupFederatedWorker(_batchSize,
-                                       dataSize,
-                                       _possibleBatchesPerLocalEpoch,
-                                       programSerialized,
-                                       _inst.getNamespace(),
-                                       _inst.getFunctionName(),
-                                       _ps.getAggInst().getFunctionName(),
-                                       _ec.getListObject("hyperparams"),
-                                       _modelVarID
-                               )
-               ));
+                               new SetupFederatedWorker(_batchSize, dataSize, 
_possibleBatchesPerLocalEpoch,
+                                       programSerialized, 
_inst.getNamespace(), _inst.getFunctionName(),
+                                       _ps.getAggInst().getFunctionName(), 
_ec.getListObject("hyperparams"),
+                                       _modelVarID, _modelAvg)));
 
                try {
                        FederatedResponse response = udfResponse.get();
@@ -195,10 +190,11 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                private final String _aggregationFunctionName;
                private final ListObject _hyperParams;
                private final long _modelVarID;
+               private final boolean _modelAvg;
 
                protected SetupFederatedWorker(long batchSize, long dataSize, 
int possibleBatchesPerLocalEpoch,
                        String programString, String namespace, String 
gradientsFunctionName, String aggregationFunctionName,
-                       ListObject hyperParams, long modelVarID)
+                       ListObject hyperParams, long modelVarID, boolean 
modelAvg)
                {
                        super(new long[]{});
                        _batchSize = batchSize;
@@ -210,6 +206,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        _aggregationFunctionName = aggregationFunctionName;
                        _hyperParams = hyperParams;
                        _modelVarID = modelVarID;
+                       _modelAvg = modelAvg;
                }
 
                @Override
@@ -226,6 +223,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new 
StringObject(_aggregationFunctionName));
                        ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
                        ec.setVariable(Statement.PS_FED_MODEL_VARID, new 
IntObject(_modelVarID));
+                       ec.setVariable(Statement.PS_MODELAVG, new 
BooleanObject(_modelAvg));
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
@@ -277,7 +275,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
                        ec.removeVariable(Statement.PS_FED_MODEL_VARID);
                        ParamservUtils.cleanupListObject(ec, 
Statement.PS_HYPER_PARAMS);
-                       
+
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
                }
 
@@ -334,7 +332,6 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        });
                        accFedPSGradientWeightingTime(tWeighting);
                }
-
                // Push the gradients to ps
                _ps.push(_workerID, gradients);
        }
@@ -344,31 +341,26 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        }
 
        /**
-        * Computes all epochs and updates after each batch 
+        * Computes all epochs and updates after each batch
         */
        protected void computeWithBatchUpdates() {
                for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
                        int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : 
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
-
                        for (int batchCounter = 0; batchCounter < 
_numBatchesPerEpoch; batchCounter++) {
                                int localStartBatchNum = 
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
                                ListObject model = pullModel();
                                ListObject gradients = 
computeGradientsForNBatches(model, 1, localStartBatchNum);
-                               weightAndPushGradients(gradients);
-                               ParamservUtils.cleanupListObject(model);
+                               if (_modelAvg)
+                                       model = _ps.updateLocalModel(_ec, 
gradients, model);
+                               else
+                                       ParamservUtils.cleanupListObject(model);
+                               weightAndPushGradients(_modelAvg ? model : 
gradients);
                                ParamservUtils.cleanupListObject(gradients);
                        }
                }
        }
 
        /**
-        * Computes all epochs and updates after N batches
-        */
-       protected void computeWithNBatchUpdates() {
-               throw new NotImplementedException();
-       }
-
-       /**
         * Computes all epochs and updates after each epoch
         */
        protected void computeWithEpochUpdates() {
@@ -376,6 +368,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        int localStartBatchNum = (_cycleStartAt0) ? 0 : 
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
 
                        // Pull the global parameters from ps
+                       // TODO double check if model averaging is handled 
correctly (internally?)
                        ListObject model = pullModel();
                        ListObject gradients = 
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, 
true);
                        weightAndPushGradients(gradients);
@@ -469,6 +462,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        String namespace = ((StringObject) 
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
                        String gradientsFunc = ((StringObject) 
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
                        String aggFunc = ((StringObject) 
ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
+                       boolean modelAvg = ((BooleanObject) 
ec.getVariable(Statement.PS_MODELAVG)).getBooleanValue();
 
                        // recreate gradient instruction and output
                        boolean opt = 
!ec.getProgram().containsFunctionProgramBlock(namespace, gradientsFunc, false);
@@ -481,7 +475,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
                                
.collect(Collectors.toCollection(ArrayList::new));
                        Instruction gradientsInstruction = new 
FunctionCallCPInstruction(namespace, gradientsFunc,
-                               opt, boundInputs,func.getInputParamNames(), 
outputNames, "gradient function");
+                               opt, boundInputs, func.getInputParamNames(), 
outputNames, "gradient function");
                        DataIdentifier gradientsOutput = outputs.get(0);
 
                        // recreate aggregation instruction and output if needed
@@ -505,7 +499,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        int currentLocalBatchNumber = _localStartBatchNum;
                        // prepare execution context
                        ec.setVariable(Statement.PS_MODEL, model);
-                       for (int batchCounter = 0; batchCounter < 
_numBatchesToCompute; batchCounter++) {
+                       for(int batchCounter = 0; batchCounter < 
_numBatchesToCompute; batchCounter++) {
                                int localBatchNum = 
getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch);
 
                                // slice batch from feature and label matrix
@@ -521,13 +515,14 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                // calculate gradients for batch
                                gradientsInstruction.processInstruction(ec);
                                ListObject gradients = 
ec.getListObject(gradientsOutput.getName());
-
+                               
                                // accrue the computed gradients - In the 
single batch case this is just a list copy
                                // is this equivalent for momentum based and 
AMS prob?
-                               accGradients = 
ParamservUtils.accrueGradients(accGradients, gradients, false);
-
+                               accGradients = modelAvg ? null :
+                                       
ParamservUtils.accrueGradients(accGradients, gradients, false);
+                               
                                // update the local model with gradients if 
needed
-                               if(_localUpdate && batchCounter < 
_numBatchesToCompute - 1) {
+                               if((_localUpdate && batchCounter < 
_numBatchesToCompute - 1) | modelAvg) {
                                        // Invoke the aggregate function
                                        assert aggregationInstruction != null;
                                        
aggregationInstruction.processInstruction(ec);
@@ -540,17 +535,18 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                }
 
                                // clean up
-                               ParamservUtils.cleanupListObject(ec, 
gradientsOutput.getName());
                                ParamservUtils.cleanupData(ec, 
Statement.PS_FEATURES);
                                ParamservUtils.cleanupData(ec, 
Statement.PS_LABELS);
                        }
 
                        // model clean up
                        ParamservUtils.cleanupListObject(ec, 
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
-                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
+                       // TODO double check cleanup gradients and models
+                       
                        // stop timing
                        DoubleObject gradientsTime = new 
DoubleObject(tGradients.stop());
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new 
Object[]{accGradients, gradientsTime});
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+                               new Object[]{modelAvg ? model : accGradients, 
gradientsTime});
                }
 
                @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index 8ba81f6..fef617b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -43,16 +43,16 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        protected LocalPSWorker() {}
 
        public LocalPSWorker(int workerID, String updFunc, 
Statement.PSFrequency freq,
-               int epochs, long batchSize, ExecutionContext ec, ParamServer ps)
+               int epochs, long batchSize, ExecutionContext ec, ParamServer 
ps, boolean modelAvg)
        {
-               super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+               super(workerID, updFunc, freq, epochs, batchSize, ec, ps, 
modelAvg);
        }
 
        @Override
        public String getWorkerName() {
                return String.format("Local worker_%d", _workerID);
        }
-       
+
        @Override
        public Void call() throws Exception {
                incWorkerNumber();
@@ -81,7 +81,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        }
 
        private void computeEpoch(long dataSize, int batchIter) {
-               for (int i = 0; i < _epochs; i++) {
+               for(int i = 0; i < _epochs; i++) {
                        // Pull the global parameters from ps
                        ListObject params = pullModel();
                        Future<ListObject> accGradients = 
ConcurrentUtils.constantFuture(null);
@@ -89,32 +89,32 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                        try {
                                for (int j = 0; j < batchIter; j++) {
                                        ListObject gradients = 
computeGradients(params, dataSize, batchIter, i, j);
-       
+
                                        boolean localUpdate = j < batchIter - 1;
-                                       
-                                       // Accumulate the intermediate 
gradients (async for overlap w/ model updates 
+
+                                       // Accumulate the intermediate 
gradients (async for overlap w/ model updates
                                        // and gradient computation, sequential 
over gradient matrices to avoid deadlocks)
                                        ListObject accGradientsPrev = 
accGradients.get();
-                                       accGradients = _tpool.submit(() -> 
ParamservUtils.accrueGradients(
-                                               accGradientsPrev, gradients, 
false, !localUpdate));
-       
+                                       accGradients = _modelAvg ? 
ConcurrentUtils.constantFuture(null) :
+                                               _tpool.submit(() -> 
ParamservUtils.accrueGradients(
+                                                       accGradientsPrev, 
gradients, false, !localUpdate));
+                                       
                                        // Update the local model with gradients
-                                       if(localUpdate)
+                                       if(localUpdate | _modelAvg)
                                                params = updateModel(params, 
gradients, i, j, batchIter);
-       
+
                                        accNumBatches(1);
                                }
-
-                               // Push the gradients to ps
-                               pushGradients(accGradients.get());
-                               ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
+                               pushGradients(_modelAvg ? params : 
accGradients.get());
+                               if (!_modelAvg)
+                                       ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
                        }
                        catch(ExecutionException | InterruptedException ex) {
                                throw new DMLRuntimeException(ex);
                        }
-                       
+
                        accNumEpochs(1);
-                       if (LOG.isDebugEnabled()) {
+                       if(LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
                        }
                }
@@ -126,9 +126,9 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                globalParams = _ps.updateLocalModel(_ec, gradients, 
globalParams);
 
                accLocalModelUpdateTime(tUpd);
-               
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. "
+
+               if(LOG.isDebugEnabled()) {
+                       LOG.debug(String.format("%s: local global parameter 
[size:%d kb] updated. " 
                                + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]",
                                getWorkerName(), globalParams.getDataSize(), i 
+ 1, _epochs, j + 1, batchIter));
                }
@@ -136,19 +136,24 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        }
 
        private void computeBatch(long dataSize, int totalIter) {
-               for (int i = 0; i < _epochs; i++) {
-                       for (int j = 0; j < totalIter; j++) {
+               for(int i = 0; i < _epochs; i++) {
+                       for(int j = 0; j < totalIter; j++) {
                                ListObject globalParams = pullModel();
-
                                ListObject gradients = 
computeGradients(globalParams, dataSize, totalIter, i, j);
-
-                               // Push the gradients to ps
-                               pushGradients(gradients);
-                               ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
                                
+                               if(_modelAvg) {
+                                       // Update locally  & Push the local 
update model to ps
+                                       ListObject model = 
updateModel(globalParams, gradients, i, j, totalIter);
+                                       pushGradients(model);
+                               }
+                               else {
+                                       // Push the gradients to ps
+                                       pushGradients(gradients);
+                                       ParamservUtils.cleanupListObject(_ec, 
Statement.PS_MODEL);
+                               }
                                accNumBatches(1);
                        }
-                       
+
                        accNumEpochs(1);
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("%s: finished %d 
epoch.", getWorkerName(), i + 1));
@@ -159,7 +164,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        private ListObject pullModel() {
                // Pull the global parameters from ps
                ListObject globalParams = _ps.pull(_workerID);
-               if (LOG.isDebugEnabled()) {
+               if(LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: successfully pull the 
global parameters "
                                + "[size:%d kb] from ps.", getWorkerName(), 
globalParams.getDataSize() / 1024));
                }
@@ -169,7 +174,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
        private void pushGradients(ListObject gradients) {
                // Push the gradients to ps
                _ps.push(_workerID, gradients);
-               if (LOG.isDebugEnabled()) {
+               if(LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: successfully push the 
gradients "
                                + "[size:%d kb] to ps.", getWorkerName(), 
gradients.getDataSize() / 1024));
                }
@@ -189,11 +194,11 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                _ec.setVariable(Statement.PS_FEATURES, bFeatures);
                _ec.setVariable(Statement.PS_LABELS, bLabels);
 
-               if (LOG.isDebugEnabled()) {
+               if(LOG.isDebugEnabled()) {
                        LOG.debug(String.format("%s: got batch data [size:%d 
kb] of index from %d to %d [last index: %d]. "
                                + "[Epoch:%d  Total epoch:%d  Iteration:%d  
Total iteration:%d]", getWorkerName(),
-                               bFeatures.getDataSize() / 1024 + 
bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
-                               j + 1, batchIter));
+                               bFeatures.getDataSize() / 1024 + 
bLabels.getDataSize() / 1024, begin, end, dataSize,
+                               i + 1, _epochs, j + 1, batchIter));
                }
 
                // Invoke the update function
@@ -208,7 +213,7 @@ public class LocalPSWorker extends PSWorker implements 
Callable<Void> {
                ParamservUtils.cleanupData(_ec, Statement.PS_LABELS);
                return gradients;
        }
-       
+
        @Override
        protected void incWorkerNumber() {
                if (DMLScript.STATISTICS)
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
index 7bd96f2..ebf6698 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -33,17 +33,17 @@ public class LocalParamServer extends ParamServer {
 
        public static LocalParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
                Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
-               MatrixObject valFeatures, MatrixObject valLabels)
+               MatrixObject valFeatures, MatrixObject valLabels, boolean 
modelAvg)
        {
                return new LocalParamServer(model, aggFunc, updateType, freq, 
ec,
-                       workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels);
+                       workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels, modelAvg);
        }
 
        private LocalParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
                Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
-               MatrixObject valFeatures, MatrixObject valLabels)
+               MatrixObject valFeatures, MatrixObject valLabels, boolean 
modelAvg)
        {
-               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels);
+               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels, modelAvg);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 99ec9e2..a1e55f3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -37,7 +37,7 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 
-public abstract class PSWorker implements Serializable 
+public abstract class PSWorker implements Serializable
 {
        private static final long serialVersionUID = -3510485051178200118L;
 
@@ -45,7 +45,7 @@ public abstract class PSWorker implements Serializable
        // Note: we use a non-static variable to obtain the live maintenance 
thread pool
        // which is important in scenarios w/ multiple scripts in a single JVM 
(e.g., tests)
        protected ExecutorService _tpool = LazyWriteBuffer.getUtilThreadPool();
-       
+
        protected int _workerID;
        protected int _epochs;
        protected long _batchSize;
@@ -57,10 +57,11 @@ public abstract class PSWorker implements Serializable
        protected MatrixObject _labels;
        protected String _updFunc;
        protected Statement.PSFrequency _freq;
+       protected boolean _modelAvg;
 
        protected PSWorker() {}
 
-       protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) {
+       protected PSWorker(int workerID, String updFunc, Statement.PSFrequency 
freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps, boolean 
modelAvg) {
                _workerID = workerID;
                _updFunc = updFunc;
                _freq = freq;
@@ -68,6 +69,7 @@ public abstract class PSWorker implements Serializable
                _batchSize = batchSize;
                _ec = ec;
                _ps = ps;
+               _modelAvg = modelAvg;
                setupUpdateFunction(updFunc, ec);
        }
 
@@ -148,7 +150,7 @@ public abstract class PSWorker implements Serializable
        protected void accNumEpochs(int n) {
                //do nothing
        }
-       
+
        protected void accNumBatches(int n) {
                //do nothing
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 9f5b126..dc5b85f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -28,6 +28,7 @@ import java.util.concurrent.BlockingQueue;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -41,17 +42,20 @@ import 
org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.utils.Statistics;
 
-public abstract class ParamServer 
+public abstract class ParamServer
 {
        protected static final Log LOG = 
LogFactory.getLog(ParamServer.class.getName());
        protected static final boolean ACCRUE_BSP_GRADIENTS = true;
-       
+
        // worker input queues and global model
        protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
        private ListObject _model;
@@ -72,16 +76,17 @@ public abstract class ParamServer
        private String _accuracyOutput;
 
        private int _syncCounter = 0;
-       private int _epochCounter = 0 ;
+       private int _epochCounter = 0;
        private int _numBatchesPerEpoch;
 
        private int _numWorkers;
+       private boolean _modelAvg;
+       private ListObject _accModels = null;
 
        protected ParamServer() {}
 
        protected ParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
-               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc,
-               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels)
+               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels, boolean modelAvg)
        {
                // init worker queues and global model
                _modelMap = new HashMap<>(workerNum);
@@ -90,7 +95,7 @@ public abstract class ParamServer
                        _modelMap.put(i, new ArrayBlockingQueue<>(1));
                });
                _model = model;
-               
+
                // init aggregation service
                _ec = ec;
                _updateType = updateType;
@@ -103,7 +108,8 @@ public abstract class ParamServer
                }
                _numBatchesPerEpoch = numBatchesPerEpoch;
                _numWorkers = workerNum;
-               
+               _modelAvg = modelAvg;
+
                // broadcast initial model
                broadcastModel(true);
        }
@@ -179,9 +185,17 @@ public abstract class ParamServer
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updateGlobalModel(int workerID, ListObject 
params) {
+               if(_modelAvg) {
+                       updateAverageModel(workerID, params);
+               }
+               else
+                       updateGlobalGradients(workerID, params);
+       }
+
+       protected synchronized void updateGlobalGradients(int workerID, 
ListObject gradients) {
                try {
-                       if (LOG.isDebugEnabled()) {
+                       if(LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
                                        gradients.getDataSize() / 1024, 
workerID));
                        }
@@ -221,7 +235,7 @@ public abstract class ParamServer
                                                        _epochCounter++;
                                                        _syncCounter = 0;
                                                }
-                                               
+
                                                // Broadcast the updated model
                                                resetFinishedStates();
                                                broadcastModel(true);
@@ -256,7 +270,7 @@ public abstract class ParamServer
                                default:
                                        throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name());
                        }
-               } 
+               }
                catch (Exception e) {
                        throw new DMLRuntimeException("Aggregation or 
validation service failed: ", e);
                }
@@ -293,8 +307,93 @@ public abstract class ParamServer
                ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
                return newModel;
        }
-       
-       private boolean allFinished() {
+
+       protected synchronized void updateAverageModel(int workerID, ListObject 
model) {
+               try {
+                       if(LOG.isDebugEnabled()) {
+                               LOG.debug(String.format("Successfully pulled 
the models [size:%d kb] of worker_%d.",
+                                       model.getDataSize() / 1024, workerID));
+                       }
+                       Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : 
null;
+
+                       //first weight the models based on number of workers
+                       ListObject weightParams = weightModels(model, 
_numWorkers);
+                       switch(_updateType) {
+                               case BSP: {
+                                       setFinishedState(workerID);
+                                       // second Accumulate the given 
weightModels into the accrued models
+                                       _accModels = 
ParamservUtils.accrueGradients(_accModels, weightParams, true);
+
+                                       if(allFinished()) {
+                                               _model = setParams(_ec, 
_accModels, _model);
+                                               if (DMLScript.STATISTICS && 
tAgg != null)
+                                                       
Statistics.accPSAggregationTime((long) tAgg.stop());
+                                               _accModels = null; //reset for 
next accumulation
+
+                                               // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
+                                               // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
+                                               // BSP epoch case every time
+                                               if(_numBatchesPerEpoch != -1 && 
(_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH 
&& ++_syncCounter % _numBatchesPerEpoch == 0))) {
+
+                                                       if(LOG.isInfoEnabled())
+                                                               LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
+                                                       time_epoch();
+                                                       if(_validationPossible) 
{
+                                                               validate();
+                                                       }
+                                                       _epochCounter++;
+                                                       _syncCounter = 0;
+                                               }
+                                               // Broadcast the updated model
+                                               resetFinishedStates();
+                                               broadcastModel(true);
+                                               if(LOG.isDebugEnabled())
+                                                       LOG.debug("Global 
parameter is broadcasted successfully ");
+                                       }
+                                       break;
+                               }
+                               case ASP:
+                                       throw new NotImplementedException();
+
+                               default:
+                                       throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name());
+                       }
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException("Aggregation or 
validation service failed: ", e);
+               }
+       }
+
+       protected  ListObject weightModels(ListObject params, int numWorkers) {
+               double _averagingFactor = 1d / numWorkers;
+
+               if( _averagingFactor != 1) {
+                       double final_averagingFactor = _averagingFactor;
+                       params.getData().parallelStream().forEach((matrix) -> {
+                               MatrixObject matrixObject = (MatrixObject) 
matrix;
+                               MatrixBlock input = 
matrixObject.acquireReadAndRelease().scalarOperations(
+                                       new 
RightScalarOperator(Multiply.getMultiplyFnObject(), final_averagingFactor), new 
MatrixBlock());
+                               matrixObject.acquireModify(input);
+                               matrixObject.release();
+                       });
+               }
+               return  params;
+       }
+
+       /* A service method for averaging model with models
+        *
+        * @param ec execution context
+        * @param accModels list of models
+        * @param model old model
+        * @return new model (accModels)
+        */
+       protected  ListObject setParams(ExecutionContext ec, ListObject 
accModels, ListObject model) {
+               ec.setVariable(Statement.PS_MODEL, model);
+               ec.setVariable(Statement.PS_GRADIENTS, accModels);
+               return accModels;
+       }
+
+               private boolean allFinished() {
                return !ArrayUtils.contains(_finishedStates, false);
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index cb4271d..da1e9f7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -479,4 +479,39 @@ public class ParamservUtils {
                        ParamservUtils.cleanupListObject(gradients);
                return accGradients;
        }
+
+       /**
+        * Accumulate the given models into the accrued accrueModels
+        *
+        * @param accModels accrued models list object
+        * @param model given models list object
+        * @param cleanup clean up the given models list object
+        * @return new accrued models list object
+        */
+       public static ListObject accrueModels(ListObject accModels, ListObject 
model, boolean cleanup) {
+               return accrueModels(accModels, model, false, cleanup);
+       }
+
+       /**
+        * Accumulate the given models into the accrued models
+        *
+        * @param accModels accrued models list object
+        * @param model given models list object
+        * @param par parallel execution
+        * @param cleanup clean up the given models list object
+        * @return new accrued models list object
+        */
+       public static ListObject accrueModels(ListObject accModels, ListObject 
model, boolean par, boolean cleanup) {
+               if (accModels == null)
+                       return ParamservUtils.copyList(model, cleanup);
+               IntStream range = IntStream.range(0, accModels.getLength());
+               (par ? range.parallel() : range).forEach(i -> {
+                       MatrixBlock mb1 = ((MatrixObject) 
accModels.getData().get(i)).acquireReadAndRelease();
+                       MatrixBlock mb2 = ((MatrixObject) 
model.getData().get(i)).acquireReadAndRelease();
+                       mb1.binaryOperationsInPlace(new 
BinaryOperator(Plus.getPlusFnObject()), mb2);
+               });
+               if (cleanup)
+                       ParamservUtils.cleanupListObject(model);
+               return accModels;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
index 870fd9c..1f3cd1a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
@@ -53,8 +53,8 @@ public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<
        private final LongAccumulator _aRPC; // accumulator for rpc request
        private final LongAccumulator _nBatches; //number of executed batches
        private final LongAccumulator _nEpochs; //number of executed epoches
-       
-       public SparkPSWorker(String updFunc, String aggFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, String program, 
HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator 
aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator 
aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, 
LongAccumulator aEpochs) {
+
+       public SparkPSWorker(String updFunc, String aggFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, String program, 
HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator 
aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator 
aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, 
LongAccumulator aEpochs, boolean modelAvg) {
                _updFunc = updFunc;
                _aggFunc = aggFunc;
                _freq = freq;
@@ -72,13 +72,14 @@ public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<
                _aRPC = aRPC;
                _nBatches = aBatches;
                _nEpochs = aEpochs;
+               _modelAvg = modelAvg;
        }
 
        @Override
        public String getWorkerName() {
                return String.format("Spark worker_%d", _workerID);
        }
-       
+
        @Override
        public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> 
input) throws Exception {
                Timing tSetup = new Timing(true);
@@ -116,13 +117,13 @@ public class SparkPSWorker extends LocalPSWorker 
implements VoidFunction<Tuple2<
                setFeatures(ParamservUtils.newMatrixObject(input._2._1, false));
                setLabels(ParamservUtils.newMatrixObject(input._2._2, false));
        }
-       
+
 
        @Override
        protected void incWorkerNumber() {
                _aWorker.add(1);
        }
-       
+
        @Override
        protected void accLocalModelUpdateTime(Timing time) {
                if( time != null )
@@ -140,17 +141,17 @@ public class SparkPSWorker extends LocalPSWorker 
implements VoidFunction<Tuple2<
                if( time != null )
                        _aGrad.add((long) time.stop());
        }
-       
+
        @Override
        protected void accNumEpochs(int n) {
                _nEpochs.add(n);
        }
-       
+
        @Override
        protected void accNumBatches(int n) {
                _nBatches.add(n);
        }
-       
+
        private void accSetupTime(Timing time) {
                if( time != null )
                        _aSetup.add((long) time.stop());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index bc6ee67..53987a0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -39,6 +39,7 @@ import static 
org.apache.sysds.parser.Statement.PS_HYPER_PARAMS;
 import static org.apache.sysds.parser.Statement.PS_LABELS;
 import static org.apache.sysds.parser.Statement.PS_MODE;
 import static org.apache.sysds.parser.Statement.PS_MODEL;
+import static org.apache.sysds.parser.Statement.PS_MODELAVG;
 import static org.apache.sysds.parser.Statement.PS_PARALLELISM;
 import static org.apache.sysds.parser.Statement.PS_SCHEME;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
@@ -89,7 +90,7 @@ import org.apache.sysds.utils.Statistics;
 
 public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruction {
        private static final Log LOG = 
LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());
-       
+
        public static final int DEFAULT_BATCH_SIZE = 64;
        private static final PSFrequency DEFAULT_UPDATE_FREQUENCY = 
PSFrequency.EPOCH;
        private static final PSScheme DEFAULT_SCHEME = 
PSScheme.DISJOINT_CONTIGUOUS;
@@ -97,6 +98,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        private static final FederatedPSScheme DEFAULT_FEDERATED_SCHEME = 
FederatedPSScheme.KEEP_DATA_ON_WORKER;
        private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL;
        private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
+       private static final Boolean DEFAULT_MODELAVG = false;
 
        public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, 
String> paramsMap, CPOperand out, String opcode, String istr) {
                super(op, paramsMap, out, opcode, istr);
@@ -106,7 +108,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        public void processInstruction(ExecutionContext ec) {
                // check if the input is federated
                if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
-                               
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+                       ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
                        runFederated(ec);
                }
                // if not federated check mode
@@ -181,13 +183,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                ListObject model = ec.getListObject(getParam(PS_MODEL));
                MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) 
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
                MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? 
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
+               boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
                ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
-                               getNumBatchesPerEpoch(runtimeBalancing, 
result._balanceMetrics), val_features, val_labels);
+                       getNumBatchesPerEpoch(runtimeBalancing, 
result._balanceMetrics), val_features, val_labels, modelAvg);
                // Create the local workers
                int finalNumBatchesPerEpoch = 
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
                List<FederatedPSControlThread> threads = IntStream.range(0, 
workerNum)
                        .mapToObj(i -> new FederatedPSControlThread(i, updFunc, 
freq, runtimeBalancing, weighting,
-                               getEpochs(), getBatchSize(), 
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
+                               getEpochs(), getBatchSize(), 
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, modelAvg))
                        .collect(Collectors.toList());
                if(workerNum != threads.size()) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning 
does not match threads!");
@@ -223,6 +226,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                int workerNum = getWorkerNum(mode);
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
+               boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
 
                // Get the compiled execution context
                LocalVariableMap newVarsMap = createVarsMap(sec);
@@ -234,7 +238,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Create the parameter server
                ListObject model = sec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(mode, aggFunc, getUpdateType(), 
getFrequency(), workerNum, model, aggServiceEC);
+               ParamServer ps = createPS(mode, aggFunc, getUpdateType(), 
getFrequency(), workerNum, model, aggServiceEC, modelAvg);
 
                // Get driver host
                String host = 
sec.getSparkContext().getConf().get("spark.driver.host");
@@ -260,11 +264,11 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                LongAccumulator aRPC = 
sec.getSparkContext().sc().longAccumulator("rpcRequest");
                LongAccumulator aBatch = 
sec.getSparkContext().sc().longAccumulator("numBatches");
                LongAccumulator aEpoch = 
sec.getSparkContext().sc().longAccumulator("numEpochs");
-               
+
                // Create remote workers
                SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
                        getFrequency(), getEpochs(), getBatchSize(), program, 
clsMap, sec.getSparkContext().getConf(),
-                       server.getPort(), aSetup, aWorker, aUpdate, aIndex, 
aGrad, aRPC, aBatch, aEpoch);
+                       server.getPort(), aSetup, aWorker, aUpdate, aIndex, 
aGrad, aRPC, aBatch, aEpoch, modelAvg);
 
                if (DMLScript.STATISTICS)
                        Statistics.accPSSetupTime((long) tSetup.stop());
@@ -326,13 +330,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                ListObject model = ec.getListObject(getParam(PS_MODEL));
                MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) 
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
                MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? 
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
-               ParamServer ps = createPS(mode, aggFunc, updateType, freq, 
workerNum, model, aggServiceEC, getValFunction(),
-                               num_batches_per_epoch, val_features, 
val_labels);
+               boolean modelAvg = getModelAvg();
+               ParamServer ps = createPS(mode, aggFunc, updateType, freq, 
workerNum, model, aggServiceEC,
+                       getValFunction(), num_batches_per_epoch, val_features, 
val_labels, modelAvg);
 
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
                        .mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
-                               getEpochs(), getBatchSize(), workerECs.get(i), 
ps))
+                               getEpochs(), getBatchSize(), workerECs.get(i), 
ps, modelAvg))
                        .collect(Collectors.toList());
 
                // Do data partition
@@ -468,21 +473,21 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         * @return parameter server
         */
        private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
-               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec)
+               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec, boolean modelAvg)
        {
-               return createPS(mode, aggFunc, updateType, freq, workerNum, 
model, ec, null, -1, null, null);
+               return createPS(mode, aggFunc, updateType, freq, workerNum, 
model, ec, null, -1, null, null, modelAvg);
        }
 
        // When this creation is used the parameter server is able to validate 
after each epoch
        private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
                PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec, String valFunc,
-               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels)
+               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels, boolean modelAvg)
        {
                switch (mode) {
                        case FEDERATED:
                        case LOCAL:
                        case REMOTE_SPARK:
-                               return LocalParamServer.create(model, aggFunc, 
updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels);
+                               return LocalParamServer.create(model, aggFunc, 
updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels, modelAvg);
                        default:
                                throw new DMLRuntimeException("Unsupported 
parameter server: " + mode.name());
                }
@@ -575,17 +580,25 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private boolean getWeighting() {
-               return getParameterMap().containsKey(PS_FED_WEIGHTING) && 
Boolean.parseBoolean(getParam(PS_FED_WEIGHTING));
+               return getParameterMap().containsKey(PS_FED_WEIGHTING)
+                       && Boolean.parseBoolean(getParam(PS_FED_WEIGHTING));
        }
 
        private String getValFunction() {
-               if (getParameterMap().containsKey(PS_VAL_FUN)) {
+               if (getParameterMap().containsKey(PS_VAL_FUN))
                        return getParam(PS_VAL_FUN);
-               }
                return null;
        }
 
        private int getSeed() {
-               return (getParameterMap().containsKey(PS_SEED)) ? 
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
+               return getParameterMap().containsKey(PS_SEED) ?
+                       Integer.parseInt(getParam(PS_SEED)) :
+                       (int) System.currentTimeMillis();
+       }
+       
+       private boolean getModelAvg() {
+               if(!getParameterMap().containsKey(PS_MODELAVG))
+                       return DEFAULT_MODELAVG;
+               return Boolean.parseBoolean(getParam(PS_MODELAVG));
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
new file mode 100644
index 0000000..3583cee
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class AvgModelFederatedParamservTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
+       private final static String TEST_DIR = "functions/federated/paramserv/";
+       private final static String TEST_NAME = 
"AvgModelFederatedParamservTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
AvgModelFederatedParamservTest.class.getSimpleName() + "/";
+
+       private final String _networkType;
+       private final int _numFederatedWorkers;
+       private final int _dataSetSize;
+       private final int _epochs;
+       private final int _batch_size;
+       private final double _eta;
+       private final String _utype;
+       private final String _freq;
+       private final String _scheme;
+       private final String _runtime_balancing;
+       private final String _weighting;
+       private final String _data_distribution;
+       private final int _seed;
+
+       // parameters
+       @Parameterized.Parameters
+       public static Collection<Object[]> parameters() {
+               return Arrays.asList(new Object[][] {
+                       // Network type, number of federated workers, data set 
size, batch size, epochs, learning rate, update type, update frequency
+                       // basic functionality
+                       //{"TwoNN",     4, 60000, 32, 4, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "false","BALANCED",             
200},
+
+                       // One important point is that we do the model 
averaging in the case of BSP
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "true", "IMBALANCED",  
 200},
+                       {"CNN",         2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "SHUFFLE",                              "BASELINE",             
"true", "IMBALANCED",   200},
+                       {"TwoNN",       5, 1000, 100, 2, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE",                 "true", "BALANCED",             
200},
+
+                       /*
+                               // runtime balancing
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MIN",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MIN",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX",    "true", "IMBALANCED",   
200},
+
+                               // data partitioning
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SHUFFLE",                              "CYCLE_AVG",    "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "REPLICATE_TO_MAX",             "NONE",                 "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SUBSAMPLE_TO_MIN",             "NONE",                 "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "BALANCE_TO_AVG",               "NONE",                 "true", 
"IMBALANCED",   200},
+
+                               // balanced tests
+                               {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "NONE",                 "true", "BALANCED",    
         200}
+                       */
+               });
+       }
+
+       public AvgModelFederatedParamservTest(String networkType, int 
numFederatedWorkers, int dataSetSize, int batch_size,
+               int epochs, double eta, String utype, String freq, String 
scheme, String runtime_balancing, String weighting, String data_distribution, 
int seed) {
+
+               _networkType = networkType;
+               _numFederatedWorkers = numFederatedWorkers;
+               _dataSetSize = dataSetSize;
+               _batch_size = batch_size;
+               _epochs = epochs;
+               _eta = eta;
+               _utype = utype;
+               _freq = freq;
+               _scheme = scheme;
+               _runtime_balancing = runtime_balancing;
+               _weighting = weighting;
+               _data_distribution = data_distribution;
+               _seed = seed;
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+       }
+
+       @Test
+       public void AvgmodelfederatedParamservSingleNode() {
+               AvgmodelfederatedParamserv(ExecMode.SINGLE_NODE, true);
+       }
+
+       @Test
+       public void AvgmodelfederatedParamservHybrid() {
+               AvgmodelfederatedParamserv(ExecMode.HYBRID, true);
+       }
+
+       private void AvgmodelfederatedParamserv(ExecMode mode, boolean 
modelAvg) {
+               // Warning Statistics accumulate in unit test
+               // config
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               setOutputBuffering(false);
+
+               int C = 1, Hin = 28, Win = 28;
+               int numLabels = 10;
+
+               ExecMode platformOld = setExecMode(mode);
+
+               try {
+                       // start threads
+                       List<Integer> ports = new ArrayList<>();
+                       List<Thread> threads = new ArrayList<>();
+                       for(int i = 0; i < _numFederatedWorkers; i++) {
+                               ports.add(getRandomAvailablePort());
+                               
threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+                       }
+
+                       // generate test data
+                       double[][] features = 
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
+                       double[][] labels = 
generateDummyMNISTLabels(_dataSetSize, numLabels);
+                       String featuresName = "";
+                       String labelsName = "";
+
+                       // federate test data balanced or imbalanced
+                       if(_data_distribution.equals("IMBALANCED")) {
+                               featuresName = "X_IMBALANCED_" + 
_numFederatedWorkers;
+                               labelsName = "y_IMBALANCED_" + 
_numFederatedWorkers;
+                               double[][] ranges = {{0,1}, {1,4}};
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, 
_numFederatedWorkers, ports, ranges);
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, 
_numFederatedWorkers, ports, ranges);
+                       }
+                       else {
+                               featuresName = "X_BALANCED_" + 
_numFederatedWorkers;
+                               labelsName = "y_BALANCED_" + 
_numFederatedWorkers;
+                               double[][] ranges = 
generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length);
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features, 
_numFederatedWorkers, ports, ranges);
+                               
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels, 
_numFederatedWorkers, ports, ranges);
+                       }
+
+                       try {
+                               //wait for all workers to be setup
+                               Thread.sleep(FED_WORKER_WAIT);
+                       }
+                       catch(InterruptedException e) {
+                               e.printStackTrace();
+                       }
+
+                       // dml name
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       // generate program args
+                       List<String> programArgsList = new 
ArrayList<>(Arrays.asList("-stats",
+                               "-nvargs",
+                               "features=" + input(featuresName),
+                               "labels=" + input(labelsName),
+                               "epochs=" + _epochs,
+                               "batch_size=" + _batch_size,
+                               "eta=" + _eta,
+                               "utype=" + _utype,
+                               "freq=" + _freq,
+                               "scheme=" + _scheme,
+                               "runtime_balancing=" + _runtime_balancing,
+                               "weighting=" + _weighting,
+                               "network_type=" + _networkType,
+                               "channels=" + C,
+                               "hin=" + Hin,
+                               "win=" + Win,
+                               "seed=" + _seed,
+                               "modelAvg=" +  
Boolean.toString(modelAvg).toUpperCase()));
+
+                       programArgs = programArgsList.toArray(new String[0]);
+                       LOG.debug(runTest(null));
+                       Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
+
+                       // shut down threads
+                       for(int i = 0; i < _numFederatedWorkers; i++) {
+                               TestUtils.shutdownThreads(threads.get(i));
+                       }
+               }
+               finally {
+                       resetExecMode(platformOld);
+               }
+       }
+
+       /**
+        * Generates an feature matrix that has the same format as the MNIST 
dataset,
+        * but is completely random and normalized
+        *
+        *  @param numExamples Number of examples to generate
+        *  @param C Channels in the input data
+        *  @param Hin Height in Pixels of the input data
+        *  @param Win Width in Pixels of the input data
+        *  @return a dummy MNIST feature matrix
+        */
+       private double[][] generateDummyMNISTFeatures(int numExamples, int C, 
int Hin, int Win) {
+               // Seed -1 takes the time in milliseconds as a seed
+               // Sparsity 1 means no sparsity
+               return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+       }
+
+       /**
+        * Generates an label matrix that has the same format as the MNIST 
dataset, but is completely random and consists
+        * of one hot encoded vectors as rows
+        *
+        *  @param numExamples Number of examples to generate
+        *  @param numLabels Number of labels to generate
+        *  @return a dummy MNIST lable matrix
+        */
+       private double[][] generateDummyMNISTLabels(int numExamples, int 
numLabels) {
+               // Seed -1 takes the time in milliseconds as a seed
+               // Sparsity 1 means no sparsity
+               return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
new file mode 100644
index 0000000..396f24f
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.paramserv;
+
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Test;
+
[email protected]
+public class ParamservLocalNNAveragingTest extends AutomatedTestBase {
+
+       private static final String TEST_NAME = "paramserv-averaging-test";
+
+       private static final String TEST_DIR = "functions/paramserv/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
ParamservLocalNNAveragingTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {}));
+       }
+
+       @Test
+       public void testParamservBSPBatchDisjointContiguous() {
+               runDMLTest(10, 2, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, true);
+       }
+
+       @Test
+       public void testParamservBSPEpoch() {
+               runDMLTest(10, 2, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, true);
+       }
+
+       @Test
+       public void testParamservBSPBatchDisjointRoundRobin() {
+               runDMLTest(10, 2, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN, true);
+       }
+
+       @Test
+       public void testParamservBSPBatchDisjointRandom() {
+               runDMLTest(10, 2, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM, true);
+       }
+
+       @Test
+       public void testParamservBSPBatchOverlapReshuffle() {
+               runDMLTest(10, 2, Statement.PSUpdateType.BSP, 
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE, true);
+       }
+
+       private void runDMLTest(int epochs, int workers, Statement.PSUpdateType 
utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme, 
boolean modelAvg) {
+               TestConfiguration config = 
getTestConfiguration(ParamservLocalNNAveragingTest.TEST_NAME);
+               loadTestConfiguration(config);
+               programArgs = new String[] { "-stats", "-nvargs", "mode=LOCAL", 
"epochs=" + epochs,
+                       "workers=" + workers, "utype=" + utype, "freq=" + freq, 
"batchsize=" + batchsize,
+                       "scheme=" + scheme, "modelAvg=" +modelAvg };
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               fullDMLScriptName = HOME + 
ParamservLocalNNAveragingTest.TEST_NAME + ".dml";
+               runTest(true, false, null, null, -1);
+       }
+}
diff --git 
a/src/test/scripts/functions/federated/paramserv/AvgModelFederatedParamservTest.dml
 
b/src/test/scripts/functions/federated/paramserv/AvgModelFederatedParamservTest.dml
new file mode 100644
index 0000000..b802186
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/paramserv/AvgModelFederatedParamservTest.dml
@@ -0,0 +1,61 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
+source("src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml") as 
TwoNNModelAvg
+source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
+source("src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml") as 
CNNModelAvg
+
+
+# create federated input matrices
+features = read($features)
+labels = read($labels)
+
+if($network_type == "TwoNN") {
+  if(!as.logical($modelAvg)) {
+    model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed)
+    print("Test results:")
+    [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), model, list())
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test 
+ "\n")
+  }
+  else if (as.logical($modelAvg)){
+    model = TwoNNModelAvg::train_paramserv(features, labels, matrix(0, 
rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed, $modelAvg)
+    print("Test results:")
+    [loss_test, accuracy_test] = TwoNNModelAvg::validate(matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), model, list())
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test 
+ "\n")
+  }
+}
+else if($network_type == "CNN") {
+  if(!as.logical($modelAvg)) {
+    model = CNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, 
$win, $seed)
+    print("Test results:")
+    hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+    [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784), 
matrix(0, rows=100, cols=10), model, hyperparams)
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test 
+ "\n")
+  }
+  else if (as.logical($modelAvg)){
+    model = CNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, 
$win, $seed, $modelAvg)
+    print("Test results:")
+    hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+    [loss_test, accuracy_test] = CNNModelAvg::validate(matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), model, hyperparams)
+    print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test 
+ "\n")
+  }
+}
diff --git a/src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml 
b/src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml
new file mode 100644
index 0000000..d14cc18
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml
@@ -0,0 +1,474 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a convolutional 
neural network of the "LeNet" architecture
+ * on different execution schemes and with different inputs, for example a 
federated input matrix.
+ */
+
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * Trains a convolutional net using the "LeNet" architectur single threaded 
the conventional way.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train = function(matrix[double] X, matrix[double] y, matrix[double] X_val,
+  matrix[double] y_val, int epochs, int batch_size, double eta, int C, int Hin,
+       int Win, int seed = -1, boolean modelAvg) return (list[unknown] model)
+{
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> 
relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed)  # inputs: (N, C*Hin*Win)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed)  # inputs: 
(N, F2*(Hin/2/2)*(Win/2/2))
+  lseed = ifelse(seed==-1, -1, seed + 3);
+  [W4, b4] = affine::init(N3, K, seed = lseed)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, 
vb3, vb4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, 
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, 
N3=N3)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+
+  for (e in 1:epochs) {
+    for(i in 1:iters) {
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+      model = aggregation(model, hyperparams, gradients_list)
+    }
+  }
+}
+
+/*
+ * Trains a convolutional net using the "LeNet" architecture using a parameter 
server with specified properties.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - Y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - Y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *  - modelAvg: Optional boolean parameter to select between updating or 
averaging the model in paramserver side.
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+  matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
+  string utype, string freq, int batch_size, string scheme, string 
runtime_balancing,
+  string weighting, double eta, int C, int Hin, int Win, int seed = -1, 
boolean modelAvg)
+  return (list[unknown] model)
+{
+
+
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> 
relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed)  # inputs: (N, C*Hin*Win)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed)  # inputs: 
(N, F2*(Hin/2/2)*(Win/2/2))
+  lseed = ifelse(seed==-1, -1, seed + 3);
+  [W4, b4] = affine::init(N3, K, seed = lseed)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  learning_rate = eta  # learning rate
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+  # Regularization
+  lambda = 5e-04
+  # Create the model list
+  model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, 
vb3, vb4)
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin, 
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, 
N3=N3)
+
+  # Use paramserv function
+  model = paramserv(model=model, features=X, labels=y, val_features=X_val, 
val_labels=y_val,
+    
upd="./src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml::gradients",
+    
agg="./src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml::aggregation",
+    
val="./src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml::validate",
+    k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, 
hyperparams=hyperparams, seed=seed, modelAvg=modelAvg)
+}
+
+/*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - batch_size: Batch size
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, 
list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+  N = nrow(X)
+
+  # Network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> 
relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  iters = ceil(N / batch_size)
+  for(i in 1:iters, check=0) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
2, 2, 2, 2, 0, 0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
2, 2, 2, 2, 0, 0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+/*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ *  - y: Target matrix, of shape (N, K)
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1)
+ *  - accuracy: Scalar accuracy, of shape (1)
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for 
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels, 
+  list[unknown] model, list[unknown] hyperparams) 
+       return (double loss, double accuracy)
+{
+  [loss, accuracy] = eval(predict(val_features, 
as.integer(as.scalar(hyperparams["C"])),
+    as.integer(as.scalar(hyperparams["Hin"])), 
as.integer(as.scalar(hyperparams["Win"])), 
+    32, model), val_labels)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 
2, 2, 2, 0, 0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 
2, 2, 2, 0, 0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  # print("[+] Completed forward pass on batch: train loss: " + loss + ", 
train accuracy: " + accuracy)
+
+  # Compute data backward pass
+  ## loss
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, 2, 2, 2, 2, 0, 0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, 2, 2, 2, 2, 0, 0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+   W1 = as.matrix(model[1])
+   W2 = as.matrix(model[2])
+   W3 = as.matrix(model[3])
+   W4 = as.matrix(model[4])
+   b1 = as.matrix(model[5])
+   b2 = as.matrix(model[6])
+   b3 = as.matrix(model[7])
+   b4 = as.matrix(model[8])
+   dW1 = as.matrix(gradients[1])
+   dW2 = as.matrix(gradients[2])
+   dW3 = as.matrix(gradients[3])
+   dW4 = as.matrix(gradients[4])
+   db1 = as.matrix(gradients[5])
+   db2 = as.matrix(gradients[6])
+   db3 = as.matrix(gradients[7])
+   db4 = as.matrix(gradients[8])
+   vW1 = as.matrix(model[9])
+   vW2 = as.matrix(model[10])
+   vW3 = as.matrix(model[11])
+   vW4 = as.matrix(model[12])
+   vb1 = as.matrix(model[13])
+   vb2 = as.matrix(model[14])
+   vb3 = as.matrix(model[15])
+   vb4 = as.matrix(model[16])
+   learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+   mu = as.double(as.scalar(hyperparams["mu"]))
+
+   # Optimize with SGD w/ Nesterov momentum
+   [W1, vW1] = sgd_nesterov::update(W1, dW1, learning_rate, mu, vW1)
+   [b1, vb1] = sgd_nesterov::update(b1, db1, learning_rate, mu, vb1)
+   [W2, vW2] = sgd_nesterov::update(W2, dW2, learning_rate, mu, vW2)
+   [b2, vb2] = sgd_nesterov::update(b2, db2, learning_rate, mu, vb2)
+   [W3, vW3] = sgd_nesterov::update(W3, dW3, learning_rate, mu, vW3)
+   [b3, vb3] = sgd_nesterov::update(b3, db3, learning_rate, mu, vb3)
+   [W4, vW4] = sgd_nesterov::update(W4, dW4, learning_rate, mu, vW4)
+   [b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
+
+   model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
+}
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml 
b/src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml
new file mode 100644
index 0000000..049bea8
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml
@@ -0,0 +1,307 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a simple feed forward 
neural network
+ * on different execution schemes and with different inputs, for example a 
federated input matrix.
+ */
+
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd.dml") as sgd
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers single 
threaded the conventional way.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - Optional boolean parameter to select between updating or averaging the 
model in paramserver side.
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int epochs, int batch_size, double eta,
+                 int seed = -1 , boolean modelAvg )
+    return (list[unknown] model) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200, seed = seed)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = affine::init(200, 200,  seed = lseed)
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(200, K, seed = lseed)
+  W3 = W3 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+  model = list(W1, W2, W3, b1, b2, b3)
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+
+  for (e in 1:epochs) {
+    for(i in 1:iters) {
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+      model = aggregation(model, hyperparams, gradients_list)
+    }
+  }
+}
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers
+ * using a parameter server with specified properties.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing, string weighting,
+                 double eta, int seed = -1, boolean modelAvg)
+    return (list[unknown] model) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200, seed = seed)
+  lseed = ifelse(seed==-1, -1, seed + 1);
+  [W2, b2] = affine::init(200, 200,  seed = lseed)
+  lseed = ifelse(seed==-1, -1, seed + 2);
+  [W3, b3] = affine::init(200, K, seed = lseed)
+  # W3 = W3 / sqrt(2) # different initialization, since being fed into 
softmax, instead of relu
+
+  # [W1, b1] = affine::init(D, 200)
+  # [W2, b2] = affine::init(200, 200)
+  # [W3, b3] = affine::init(200, K)
+
+  # Create the model list
+  model = list(W1, W2, W3, b1, b2, b3)
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=eta)
+
+  # Use paramserv function
+  model = paramserv(model=model, features=X, labels=y, val_features=X_val, 
val_labels=y_val,
+    
upd="./src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml::gradients",
+    
agg="./src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml::aggregation",
+    
val="./src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml::validate",
+    k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, 
hyperparams=hyperparams, seed=seed, modelAvg=modelAvg)
+}
+
+/*
+ * Computes the class probability predictions of a simple feed forward neural 
network.
+ *
+ * Inputs:
+ *  - X: The input data matrix of shape (N, D)
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X,
+                   list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  out1relu = relu::forward(affine::forward(X, W1, b1))
+  out2relu = relu::forward(affine::forward(out1relu, W2, b2))
+  probs = softmax::forward(affine::forward(out2relu, W3, b3))
+}
+
+/*
+ * Evaluates a simple feed forward neural network.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K).
+ *  - y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for 
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels, 
list[unknown] model, list[unknown] hyperparams)
+    return (double loss, double accuracy) {
+  [loss, accuracy] = eval(predict(val_features, model), val_labels)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+    return (list[unknown] gradients) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  # Compute forward pass
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  out1 = affine::forward(features, W1, b1)
+  out1relu = relu::forward(out1)
+  out2 = affine::forward(out1relu, W2, b2)
+  out2relu = relu::forward(out2)
+  out3 = affine::forward(out2relu, W3, b3)
+  probs = softmax::forward(out3)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  # print("[+] Completed forward pass on batch: train loss: " + loss + ", 
train accuracy: " + accuracy)
+
+  # Compute data backward pass
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  dout3 = softmax::backward(dprobs, out3)
+  [dout2relu, dW3, db3] = affine::backward(dout3, out2relu, W3, b3)
+  dout2 = relu::backward(dout2relu, out2)
+  [dout1relu, dW2, db2] = affine::backward(dout2, out1relu, W2, b2)
+  dout1 = relu::backward(dout1relu, out1)
+  [dfeatures, dW1, db1] = affine::backward(dout1, features, W1, b1)
+
+  gradients = list(dW1, dW2, dW3, db1, db2, db3)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+  dW1 = as.matrix(gradients[1])
+  dW2 = as.matrix(gradients[2])
+  dW3 = as.matrix(gradients[3])
+  db1 = as.matrix(gradients[4])
+  db2 = as.matrix(gradients[5])
+  db3 = as.matrix(gradients[6])
+  learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+
+  # Optimize with SGD
+  W3 = sgd::update(W3, dW3, learning_rate)
+  b3 = sgd::update(b3, db3, learning_rate)
+  W2 = sgd::update(W2, dW2, learning_rate)
+  b2 = sgd::update(b2, db2, learning_rate)
+  W1 = sgd::update(W1, dW1, learning_rate)
+  b1 = sgd::update(b1, db1, learning_rate)
+
+  model_result = list(W1, W2, W3, b1, b2, b3)
+}
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml 
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml
new file mode 100644
index 0000000..bd5fd7d
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml
@@ -0,0 +1,372 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+                 matrix[double] X_val, matrix[double] Y_val,
+                 int C, int Hin, int Win, int epochs, int workers,
+                 string utype, string freq, int batchsize, string scheme, 
string mode, boolean modelAvg)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, Y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - Y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *  - modelAv: Optional boolean parameter to select between updating or 
averaging the model in paramserver side.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(Y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1)  # inputs: (N, 
F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K, -1)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Create the model list
+  modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, 
vb2, vb3, vb4)
+
+  # Create the hyper parameter list
+  params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, 
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+  # Use paramserv function
+  modelList2 = paramserv(model=modelList, features=X, labels=Y, 
val_features=X_val, val_labels=Y_val, 
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml::gradients",
 
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml::aggregation",
 mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batchsize, 
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE", 
modelAvg=modelAvg)
+
+  W1 = as.matrix(modelList2[1])
+  W2 = as.matrix(modelList2[2])
+  W3 = as.matrix(modelList2[3])
+  W4 = as.matrix(modelList2[4])
+  b1 = as.matrix(modelList2[5])
+  b2 = as.matrix(modelList2[6])
+  b3 = as.matrix(modelList2[7])
+  b4 = as.matrix(modelList2[8])
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 
2, 2, 2, 0, 0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 
2, 2, 2, 0, 0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute data backward pass
+  ## loss:
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, 2, 2, 2, 2, 0, 0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, 2, 2, 2, 2, 0, 0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+   return (list[unknown] modelResult) {
+     W1 = as.matrix(model[1])
+     W2 = as.matrix(model[2])
+     W3 = as.matrix(model[3])
+     W4 = as.matrix(model[4])
+     b1 = as.matrix(model[5])
+     b2 = as.matrix(model[6])
+     b3 = as.matrix(model[7])
+     b4 = as.matrix(model[8])
+     dW1 = as.matrix(gradients[1])
+     dW2 = as.matrix(gradients[2])
+     dW3 = as.matrix(gradients[3])
+     dW4 = as.matrix(gradients[4])
+     db1 = as.matrix(gradients[5])
+     db2 = as.matrix(gradients[6])
+     db3 = as.matrix(gradients[7])
+     db4 = as.matrix(gradients[8])
+     vW1 = as.matrix(model[9])
+     vW2 = as.matrix(model[10])
+     vW3 = as.matrix(model[11])
+     vW4 = as.matrix(model[12])
+     vb1 = as.matrix(model[13])
+     vb2 = as.matrix(model[14])
+     vb3 = as.matrix(model[15])
+     vb4 = as.matrix(model[16])
+     lr = as.double(as.scalar(hyperparams["lr"]))
+     mu = as.double(as.scalar(hyperparams["mu"]))
+
+     # Optimize with SGD w/ Nesterov momentum
+     [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+     [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+     [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+     [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+     [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+     [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+     [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+     [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+     modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
+   }
+
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  iters = ceil(N / batch_size)
+  parfor(i in 1:iters, check=0) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
2, 2, 2, 2, 0, 0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
2, 2, 2, 2, 0, 0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, Y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - Y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, Y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - Y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  Y = table(seq(1, N), classes)  # one-hot encoding
+}
diff --git a/src/test/scripts/functions/paramserv/paramserv-averaging-test.dml 
b/src/test/scripts/functions/paramserv/paramserv-averaging-test.dml
new file mode 100644
index 0000000..073e216
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-averaging-test.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml") 
as mnist_lenet_avg
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as 
mnist_lenet
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet_avg::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet_avg::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet_avg::train(X, Y, X_val, Y_val, 
C, Hin, Win, $epochs, $workers, $utype, $freq, $batchsize, $scheme, $mode, 
$modelAvg)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet_avg::predict(X_val, C, Hin, Win, $batchsize, W1, b1, 
W2, b2, W3, b3, W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)

Reply via email to