This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 428016c  [SYSTEMDS-2550] Federated Parameter Server
428016c is described below

commit 428016c38fb55a3b6094334c7c876b71bf4bf3f7
Author: Tobias Rieger <[email protected]>
AuthorDate: Tue Aug 25 10:23:00 2020 +0200

    [SYSTEMDS-2550] Federated Parameter Server
    
    This commit adds federated Parameter server to the system.
    this allows for federated training of neural networks.
    Particularly verified are two architectures, a feed forward NN and
    a Convocational NN.
    
    Closes #1075
---
 .../apache/sysds/hops/recompile/Recompiler.java    |   2 +-
 .../java/org/apache/sysds/parser/Statement.java    |  16 +-
 .../controlprogram/caching/MatrixObject.java       |  21 +-
 .../federated/FederatedWorkerHandler.java          |   6 +-
 .../paramserv/FederatedPSControlThread.java        | 560 +++++++++++++++++++++
 .../controlprogram/paramserv/ParamServer.java      |   5 +
 .../controlprogram/paramserv/ParamservUtils.java   |  47 +-
 .../paramserv/dp/DataPartitionFederatedScheme.java |  88 ++++
 .../paramserv/dp/FederatedDataPartitioner.java     |  46 ++
 .../dp/KeepDataOnWorkerFederatedScheme.java        |  32 ++
 .../paramserv/dp/ShuffleFederatedScheme.java       |  33 ++
 .../sysds/runtime/instructions/cp/ListObject.java  | 145 +++++-
 .../cp/ParamservBuiltinCPInstruction.java          |  98 +++-
 .../sysds/runtime/util/ProgramConverter.java       |  40 +-
 .../component/paramserv/SerializationTest.java     |  86 +++-
 .../paramserv/FederatedParamservTest.java          | 195 +++++++
 .../scripts/functions/federated/paramserv/CNN.dml  | 474 +++++++++++++++++
 .../federated/paramserv/FederatedParamservTest.dml |  57 +++
 .../functions/federated/paramserv/TwoNN.dml        | 299 +++++++++++
 19 files changed, 2185 insertions(+), 65 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index d048863..c785cfc 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -1039,7 +1039,7 @@ public class Recompiler
                }
        }
        
-       private static void rRecompileProgramBlock2Forced( ProgramBlock pb, 
long tid, HashSet<String> fnStack, ExecType et ) {
+       public static void rRecompileProgramBlock2Forced( ProgramBlock pb, long 
tid, HashSet<String> fnStack, ExecType et ) {
                if (pb instanceof WhileProgramBlock)
                {
                        WhileProgramBlock pbTmp = (WhileProgramBlock)pb;
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java 
b/src/main/java/org/apache/sysds/parser/Statement.java
index f0bdd66..b61b0d6 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -71,7 +71,7 @@ public abstract class Statement implements ParseInfo
        public static final String PS_MODE = "mode";
        public static final String PS_GRADIENTS = "gradients";
        public enum PSModeType {
-               LOCAL, REMOTE_SPARK
+               FEDERATED, LOCAL, REMOTE_SPARK
        }
        public static final String PS_UPDATE_TYPE = "utype";
        public enum PSUpdateType {
@@ -94,12 +94,26 @@ public abstract class Statement implements ParseInfo
        public enum PSScheme {
                DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, 
OVERLAP_RESHUFFLE
        }
+       public enum FederatedPSScheme {
+               KEEP_DATA_ON_WORKER, SHUFFLE
+       }
        public static final String PS_HYPER_PARAMS = "hyperparams";
        public static final String PS_CHECKPOINTING = "checkpointing";
        public enum PSCheckpointing {
                NONE, EPOCH, EPOCH10
        }
 
+       // String constants related to federated parameter server functionality
+       // prefixed with code: "1701-NCC-" to not overwrite anything
+       public static final String PS_FED_BATCH_SIZE = "1701-NCC-batch_size";
+       public static final String PS_FED_DATA_SIZE = "1701-NCC-data_size";
+       public static final String PS_FED_NUM_BATCHES = "1701-NCC-num_batches";
+       public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
+       public static final String PS_FED_GRADIENTS_FNAME = 
"1701-NCC-gradients_fname";
+       public static final String PS_FED_AGGREGATION_FNAME = 
"1701-NCC-aggregation_fname";
+       public static final String PS_FED_BATCHCOUNTER_VARID = 
"1701-NCC-batchcounter_varid";
+       public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid";
+
 
        public abstract boolean controlStatement();
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 85d8a8f..4fd2b06 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -66,7 +66,7 @@ import org.apache.sysds.runtime.util.IndexRange;
 public class MatrixObject extends CacheableData<MatrixBlock>
 {
        private static final long serialVersionUID = 6374712373206495637L;
-       
+
        public enum UpdateType {
                COPY,
                INPLACE,
@@ -87,7 +87,7 @@ public class MatrixObject extends CacheableData<MatrixBlock>
        private int _partitionSize = -1; //indicates n for BLOCKWISE_N
        private String _partitionCacheName = null; //name of cache block
        private MatrixBlock _partitionInMemory = null;
-       
+
        /**
         * Constructor that takes the value type and the HDFS filename.
         * 
@@ -112,6 +112,23 @@ public class MatrixObject extends 
CacheableData<MatrixBlock>
                _cache = null;
                _data = null;
        }
+
+       /**
+        * Constructor that takes the value type, HDFS filename and associated 
metadata and a MatrixBlock
+        * used for creation after serialization
+        *
+        * @param vt value type
+        * @param file file name
+        * @param mtd metadata
+        * @param data matrix block data
+        */
+       public MatrixObject( ValueType vt, String file, MetaData mtd, 
MatrixBlock data) {
+               super (DataType.MATRIX, vt);
+               _metaData = mtd;
+               _hdfsFileName = file;
+               _cache = null;
+               _data = data;
+       }
        
        /**
         * Copy constructor that copies meta data but NO data.
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 6764f12..e932785 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -244,11 +244,15 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                }
                
                //wrap transferred cache block into cacheable data
-               Data data = null;
+               Data data;
                if( request.getParam(0) instanceof CacheBlock )
                        data = 
ExecutionContext.createCacheableData((CacheBlock) request.getParam(0));
                else if( request.getParam(0) instanceof ScalarObject )
                        data = (ScalarObject) request.getParam(0);
+               else if( request.getParam(0) instanceof ListObject )
+                       data = (ListObject) request.getParam(0);
+               else
+                       throw new DMLRuntimeException("FederatedWorkerHandler: 
Unsupported object type, has to be of type CacheBlock or ScalarObject");
                
                //set variable and construct empty response
                ec.setVariable(varname, data);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
new file mode 100644
index 0000000..8fa0698
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -0,0 +1,560 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv;
+
+import org.apache.sysds.parser.DataIdentifier;
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.ProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.IntObject;
+import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.util.ProgramConverter;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+
+import static org.apache.sysds.runtime.util.ProgramConverter.*;
+
+public class FederatedPSControlThread extends PSWorker implements 
Callable<Void> {
+       FederatedData _featuresData;
+       FederatedData _labelsData;
+       final long _batchCounterVarID;
+       final long _modelVarID;
+       int _totalNumBatches;
+
+       public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, 
ParamServer ps) {
+               super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+               
+               // generate the IDs for model and batch counter. These get 
overwritten on the federated worker each time
+               _batchCounterVarID = FederationUtils.getNextFedDataID();
+               _modelVarID = FederationUtils.getNextFedDataID();
+       }
+
+       /**
+        * Sets up the federated worker and control thread
+        */
+       public void setup() {
+               // prepare features and labels
+               _features.getFedMapping().forEachParallel((range, data) -> {
+                       _featuresData = data;
+                       return null;
+               });
+               _labels.getFedMapping().forEachParallel((range, data) -> {
+                       _labelsData = data;
+                       return null;
+               });
+
+               // calculate number of batches and get data size
+               long dataSize = _features.getNumRows();
+               _totalNumBatches = (int) Math.ceil((double) dataSize / 
_batchSize);
+
+               // serialize program
+               // create program blocks for the instruction filtering
+               String programSerialized;
+               ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
+
+               BasicProgramBlock gradientProgramBlock = new 
BasicProgramBlock(_ec.getProgram());
+               gradientProgramBlock.setInstructions(new 
ArrayList<>(Arrays.asList(_inst)));
+               programBlocks.add(gradientProgramBlock);
+
+               if(_freq == Statement.PSFrequency.EPOCH) {
+                       BasicProgramBlock aggProgramBlock = new 
BasicProgramBlock(_ec.getProgram());
+                       aggProgramBlock.setInstructions(new 
ArrayList<>(Arrays.asList(_ps.getAggInst())));
+                       programBlocks.add(aggProgramBlock);
+               }
+
+               StringBuilder sb = new StringBuilder();
+               sb.append(PROG_BEGIN);
+               sb.append( NEWLINE );
+               sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
+                               programBlocks,
+                               new HashMap<>(),
+                               false
+               ));
+               sb.append(PROG_END);
+               programSerialized = sb.toString();
+
+               // write program and meta data to worker
+               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                               _featuresData.getVarID(),
+                               new setupFederatedWorker(_batchSize,
+                                               dataSize,
+                                               _totalNumBatches,
+                                               programSerialized,
+                                               _inst.getNamespace(),
+                                               _inst.getFunctionName(),
+                                               
_ps.getAggInst().getFunctionName(),
+                                               
_ec.getListObject("hyperparams"),
+                                               _batchCounterVarID,
+                                               _modelVarID
+                               )
+               ));
+
+               try {
+                       FederatedResponse response = udfResponse.get();
+                       if(!response.isSuccessful())
+                               throw new 
DMLRuntimeException("FederatedLocalPSThread: Setup UDF failed");
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute Setup UDF" + e.getMessage());
+               }
+       }
+
+       /**
+        * Setup UDF executed on the federated worker
+        */
+       private static class setupFederatedWorker extends FederatedUDF {
+               long _batchSize;
+               long _dataSize;
+               long _numBatches;
+               String _programString;
+               String _namespace;
+               String _gradientsFunctionName;
+               String _aggregationFunctionName;
+               ListObject _hyperParams;
+               long _batchCounterVarID;
+               long _modelVarID;
+
+               protected setupFederatedWorker(long batchSize, long dataSize, 
long numBatches, String programString, String namespace, String 
gradientsFunctionName, String aggregationFunctionName, ListObject hyperParams, 
long batchCounterVarID, long modelVarID) {
+                       super(new long[]{});
+                       _batchSize = batchSize;
+                       _dataSize = dataSize;
+                       _numBatches = numBatches;
+                       _programString = programString;
+                       _namespace = namespace;
+                       _gradientsFunctionName = gradientsFunctionName;
+                       _aggregationFunctionName = aggregationFunctionName;
+                       _hyperParams = hyperParams;
+                       _batchCounterVarID = batchCounterVarID;
+                       _modelVarID = modelVarID;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       // parse and set program
+                       
ec.setProgram(ProgramConverter.parseProgram(_programString, 0, false));
+
+                       // set variables to ec
+                       ec.setVariable(Statement.PS_FED_BATCH_SIZE, new 
IntObject(_batchSize));
+                       ec.setVariable(Statement.PS_FED_DATA_SIZE, new 
IntObject(_dataSize));
+                       ec.setVariable(Statement.PS_FED_NUM_BATCHES, new 
IntObject(_numBatches));
+                       ec.setVariable(Statement.PS_FED_NAMESPACE, new 
StringObject(_namespace));
+                       ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new 
StringObject(_gradientsFunctionName));
+                       ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new 
StringObject(_aggregationFunctionName));
+                       ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
+                       ec.setVariable(Statement.PS_FED_BATCHCOUNTER_VARID, new 
IntObject(_batchCounterVarID));
+                       ec.setVariable(Statement.PS_FED_MODEL_VARID, new 
IntObject(_modelVarID));
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+               }
+       }
+
+       /**
+        * cleans up the execution context of the federated worker
+        */
+       public void teardown() {
+               // write program and meta data to worker
+               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                               _featuresData.getVarID(),
+                               new teardownFederatedWorker()
+               ));
+
+               try {
+                       FederatedResponse response = udfResponse.get();
+                       if(!response.isSuccessful())
+                               throw new 
DMLRuntimeException("FederatedLocalPSThread: Teardown UDF failed");
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute Teardown UDF" + e.getMessage());
+               }
+       }
+
+       /**
+        * Teardown UDF executed on the federated worker
+        */
+       private static class teardownFederatedWorker extends FederatedUDF {
+               protected teardownFederatedWorker() {
+                       super(new long[]{});
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       // remove variables from ec
+                       ec.removeVariable(Statement.PS_FED_BATCH_SIZE);
+                       ec.removeVariable(Statement.PS_FED_DATA_SIZE);
+                       ec.removeVariable(Statement.PS_FED_NUM_BATCHES);
+                       ec.removeVariable(Statement.PS_FED_NAMESPACE);
+                       ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
+                       ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
+                       ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
+                       ec.removeVariable(Statement.PS_FED_MODEL_VARID);
+                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_HYPER_PARAMS);
+                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_GRADIENTS);
+                       
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
+               }
+       }
+
+       /**
+        * Entry point of the functionality
+        *
+        * @return void
+        * @throws Exception incase the execution fails
+        */
+       @Override
+       public Void call() throws Exception {
+               try {
+                       switch (_freq) {
+                               case BATCH:
+                                       computeBatch(_totalNumBatches);
+                                       break;
+                               case EPOCH:
+                                       computeEpoch();
+                                       break;
+                               default:
+                                       throw new 
DMLRuntimeException(String.format("%s not support update frequency %s", 
getWorkerName(), _freq));
+                       }
+               } catch (Exception e) {
+                       throw new DMLRuntimeException(String.format("%s 
failed", getWorkerName()), e);
+               }
+               teardown();
+               return null;
+       }
+
+       protected ListObject pullModel() {
+               // Pull the global parameters from ps
+               return _ps.pull(_workerID);
+       }
+
+       protected void pushGradients(ListObject gradients) {
+               // Push the gradients to ps
+               _ps.push(_workerID, gradients);
+       }
+
+       /**
+        * Computes all epochs and synchronizes after each batch
+        *
+        * @param numBatches the number of batches per epoch
+        */
+       protected void computeBatch(int numBatches) {
+               for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
+                       for (int batchCounter = 0; batchCounter < numBatches; 
batchCounter++) {
+                               ListObject model = pullModel();
+                               ListObject gradients = 
computeBatchGradients(model, batchCounter);
+                               pushGradients(gradients);
+                               ParamservUtils.cleanupListObject(model);
+                               ParamservUtils.cleanupListObject(gradients);
+                       }
+                       System.out.println("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
+               }
+       }
+
+       /**
+        * Computes a single specified batch on the federated worker
+        *
+        * @param model the current model from the parameter server
+        * @param batchCounter the current batch number needed for slicing the 
features and labels
+        * @return the gradient vector
+        */
+       protected ListObject computeBatchGradients(ListObject model, int 
batchCounter) {
+               // put batch counter on federated worker
+               Future<FederatedResponse> putBatchCounterResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _batchCounterVarID, new 
IntObject(batchCounter)));
+
+               // put current model on federated worker
+               Future<FederatedResponse> putParamsResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+
+               try {
+                       if(!putParamsResponse.get().isSuccessful() || 
!putBatchCounterResponse.get().isSuccessful())
+                               throw new 
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute put" + e.getMessage());
+               }
+
+               // create and execute the udf on the remote worker
+               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                               _featuresData.getVarID(),
+                               new federatedComputeBatchGradients(new 
long[]{_featuresData.getVarID(), _labelsData.getVarID(), _batchCounterVarID, 
_modelVarID})
+               ));
+
+               try {
+                       Object[] responseData = udfResponse.get().getData();
+                       return (ListObject) responseData[0];
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute UDF" + e.getMessage());
+               }
+       }
+
+       /**
+        * This is the code that will be executed on the federated Worker when 
computing a single batch
+        */
+       private static class federatedComputeBatchGradients extends 
FederatedUDF {
+               protected federatedComputeBatchGradients(long[] inIDs) {
+                       super(inIDs);
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       // read in data by varid
+                       MatrixObject features = (MatrixObject) data[0];
+                       MatrixObject labels = (MatrixObject) data[1];
+                       long batchCounter = ((IntObject) 
data[2]).getLongValue();
+                       ListObject model = (ListObject) data[3];
+
+                       // get data from execution context
+                       long batchSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
+                       long dataSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
+                       String namespace = ((StringObject) 
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
+                       String gradientsFunctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
+
+                       // slice batch from feature and label matrix
+                       long begin = batchCounter * batchSize + 1;
+                       long end = Math.min((batchCounter + 1) * batchSize, 
dataSize);
+                       MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(features, begin, end);
+                       MatrixObject bLabels = 
ParamservUtils.sliceMatrix(labels, begin, end);
+
+                       // prepare execution context
+                       ec.setVariable(Statement.PS_MODEL, model);
+                       ec.setVariable(Statement.PS_FEATURES, bFeatures);
+                       ec.setVariable(Statement.PS_LABELS, bLabels);
+
+                       // recreate gradient instruction and output
+                       FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, 
false);
+                       ArrayList<DataIdentifier> inputs = 
func.getInputParams();
+                       ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
+                       CPOperand[] boundInputs = inputs.stream()
+                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+                                       .toArray(CPOperand[]::new);
+                       ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
+                                       
.collect(Collectors.toCollection(ArrayList::new));
+                       Instruction gradientsInstruction = new 
FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
+                                       func.getInputParamNames(), outputNames, 
"gradient function");
+                       DataIdentifier gradientsOutput = outputs.get(0);
+
+                       // calculate and gradients
+                       gradientsInstruction.processInstruction(ec);
+                       ListObject gradients = 
ec.getListObject(gradientsOutput.getName());
+
+                       // clean up sliced batch
+                       
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
+                       ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
+                       ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
+
+                       // model clean up - doing this twice is not an issue
+                       ParamservUtils.cleanupListObject(ec, 
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
+                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
+
+                       // return
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, gradients);
+               }
+       }
+
+       /**
+        * Computes all epochs and synchronizes after each one
+        */
+       protected void computeEpoch() {
+               for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
+                       // Pull the global parameters from ps
+                       ListObject model = pullModel();
+                       ListObject gradients = computeEpochGradients(model);
+                       pushGradients(gradients);
+                       System.out.println("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
+                       ParamservUtils.cleanupListObject(model);
+                       ParamservUtils.cleanupListObject(gradients);
+               }
+       }
+
+       /**
+        * Computes one epoch on the federated worker and updates the model 
local
+        *
+        * @param model the current model from the parameter server
+        * @return the gradient vector
+        */
+       protected ListObject computeEpochGradients(ListObject model) {
+               // put current model on federated worker
+               Future<FederatedResponse> putParamsResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, _modelVarID, model));
+
+               try {
+                       if(!putParamsResponse.get().isSuccessful())
+                               throw new 
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute put" + e.getMessage());
+               }
+
+               // create and execute the udf on the remote worker
+               Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                               _featuresData.getVarID(),
+                               new federatedComputeEpochGradients(new 
long[]{_featuresData.getVarID(), _labelsData.getVarID(), _modelVarID})
+               ));
+
+               try {
+                       Object[] responseData = udfResponse.get().getData();
+                       return (ListObject) responseData[0];
+               }
+               catch(Exception e) {
+                       throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute UDF" + e.getMessage());
+               }
+       }
+
+       /**
+        * This is the code that will be executed on the federated Worker when 
computing one epoch
+        */
+       private static class federatedComputeEpochGradients extends 
FederatedUDF {
+               protected federatedComputeEpochGradients(long[] inIDs) {
+                       super(inIDs);
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       // read in data by varid
+                       MatrixObject features = (MatrixObject) data[0];
+                       MatrixObject labels = (MatrixObject) data[1];
+                       ListObject model = (ListObject) data[2];
+
+                       // get data from execution context
+                       long batchSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
+                       long dataSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_DATA_SIZE)).getLongValue();
+                       long numBatches = ((IntObject) 
ec.getVariable(Statement.PS_FED_NUM_BATCHES)).getLongValue();
+                       String namespace = ((StringObject) 
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
+                       String gradientsFunctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
+                       String aggregationFuctionName = ((StringObject) 
ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
+
+                       // recreate gradient instruction and output
+                       FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(namespace, gradientsFunctionName, 
false);
+                       ArrayList<DataIdentifier> inputs = 
func.getInputParams();
+                       ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
+                       CPOperand[] boundInputs = inputs.stream()
+                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+                                       .toArray(CPOperand[]::new);
+                       ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
+                                       
.collect(Collectors.toCollection(ArrayList::new));
+                       Instruction gradientsInstruction = new 
FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
+                                       func.getInputParamNames(), outputNames, 
"gradient function");
+                       DataIdentifier gradientsOutput = outputs.get(0);
+
+                       // recreate aggregation instruction and output
+                       func = 
ec.getProgram().getFunctionProgramBlock(namespace, aggregationFuctionName, 
false);
+                       inputs = func.getInputParams();
+                       outputs = func.getOutputParams();
+                       boundInputs = inputs.stream()
+                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+                                       .toArray(CPOperand[]::new);
+                       outputNames = 
outputs.stream().map(DataIdentifier::getName)
+                                       
.collect(Collectors.toCollection(ArrayList::new));
+                       Instruction aggregationInstruction = new 
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
+                                       func.getInputParamNames(), outputNames, 
"aggregation function");
+                       DataIdentifier aggregationOutput = outputs.get(0);
+
+
+                       ListObject accGradients = null;
+                       // prepare execution context
+                       ec.setVariable(Statement.PS_MODEL, model);
+                       for (int batchCounter = 0; batchCounter < numBatches; 
batchCounter++) {
+                               // slice batch from feature and label matrix
+                               long begin = batchCounter * batchSize + 1;
+                               long end = Math.min((batchCounter + 1) * 
batchSize, dataSize);
+                               MatrixObject bFeatures = 
ParamservUtils.sliceMatrix(features, begin, end);
+                               MatrixObject bLabels = 
ParamservUtils.sliceMatrix(labels, begin, end);
+
+                               // prepare execution context
+                               ec.setVariable(Statement.PS_FEATURES, 
bFeatures);
+                               ec.setVariable(Statement.PS_LABELS, bLabels);
+                               boolean localUpdate = batchCounter < numBatches 
- 1;
+
+                               // calculate intermediate gradients
+                               gradientsInstruction.processInstruction(ec);
+                               ListObject gradients = 
ec.getListObject(gradientsOutput.getName());
+
+                               // TODO: is this equivalent for momentum based 
and AMS prob?
+                               accGradients = 
ParamservUtils.accrueGradients(accGradients, gradients, false);
+
+                               // Update the local model with gradients
+                               if(localUpdate) {
+                                       // Invoke the aggregate function
+                                       
aggregationInstruction.processInstruction(ec);
+                                       // Get the new model
+                                       model = 
ec.getListObject(aggregationOutput.getName());
+                                       // Set new model in execution context
+                                       ec.setVariable(Statement.PS_MODEL, 
model);
+                                       // clean up gradients and result
+                                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_GRADIENTS);
+                                       ParamservUtils.cleanupListObject(ec, 
aggregationOutput.getName());
+                               }
+
+                               // clean up sliced batch
+                               ParamservUtils.cleanupData(ec, 
Statement.PS_FEATURES);
+                               ParamservUtils.cleanupData(ec, 
Statement.PS_LABELS);
+                       }
+
+                       // model clean up - doing this twice is not an issue
+                       ParamservUtils.cleanupListObject(ec, 
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
+                       ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
+               }
+       }
+
+       // Statistics methods
+       @Override
+       public String getWorkerName() {
+               return String.format("Federated worker_%d", _workerID);
+       }
+
+       @Override
+       protected void incWorkerNumber() {
+
+       }
+
+       @Override
+       protected void accLocalModelUpdateTime(Timing time) {
+
+       }
+
+       @Override
+       protected void accBatchIndexingTime(Timing time) {
+
+       }
+
+       @Override
+       protected void accGradientComputeTime(Timing time) {
+
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 276f56c..e420ed8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -57,6 +57,7 @@ public abstract class ParamServer
        //aggregation service
        protected ExecutionContext _ec;
        private Statement.PSUpdateType _updateType;
+
        private FunctionCallCPInstruction _inst;
        private String _outputName;
        private boolean[] _finishedStates;  // Workers' finished states
@@ -232,4 +233,8 @@ public abstract class ParamServer
                if (DMLScript.STATISTICS)
                        Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
        }
+
+       public FunctionCallCPInstruction getAggInst() {
+               return _inst;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index 968cb1d..e63fb14 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -31,6 +31,7 @@ import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.MultiThreadedHop;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.lops.LopProperties;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DMLTranslator;
 import org.apache.sysds.parser.Statement;
@@ -214,15 +215,21 @@ public class ParamservUtils {
        }
 
        public static ExecutionContext createExecutionContext(ExecutionContext 
ec,
-               LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+               LocalVariableMap varsMap, String updFunc, String aggFunc, int k)
+       {
+               return createExecutionContext(ec, varsMap, updFunc, aggFunc, k, 
false);
+       }
+
+       public static ExecutionContext createExecutionContext(ExecutionContext 
ec,
+               LocalVariableMap varsMap, String updFunc, String aggFunc, int 
k, boolean forceExecTypeCP)
        {
                Program prog = ec.getProgram();
 
                // 1. Recompile the internal program blocks 
-               recompileProgramBlocks(k, prog.getProgramBlocks());
+               recompileProgramBlocks(k, prog.getProgramBlocks(), 
forceExecTypeCP);
                // 2. Recompile the imported function blocks
                prog.getFunctionProgramBlocks(false)
-                       .forEach((fname, fvalue) -> recompileProgramBlocks(k, 
fvalue.getChildBlocks()));
+                       .forEach((fname, fvalue) -> recompileProgramBlocks(k, 
fvalue.getChildBlocks(), forceExecTypeCP));
 
                // 3. Copy all functions 
                return ExecutionContextFactory.createContext(
@@ -249,6 +256,10 @@ public class ParamservUtils {
        }
 
        public static void recompileProgramBlocks(int k, List<ProgramBlock> 
pbs) {
+               recompileProgramBlocks(k, pbs, false);
+       }
+
+       public static void recompileProgramBlocks(int k, List<ProgramBlock> 
pbs, boolean forceExecTypeCP) {
                // Reset the visit status from root
                for (ProgramBlock pb : pbs)
                        
DMLTranslator.resetHopsDAGVisitStatus(pb.getStatementBlock());
@@ -256,43 +267,49 @@ public class ParamservUtils {
                // Should recursively assign the level of parallelism
                // and recompile the program block
                try {
-                       rAssignParallelism(pbs, k, false);
+                       if(forceExecTypeCP)
+                               rAssignParallelismAndRecompile(pbs, k, true, 
forceExecTypeCP);
+                       else
+                               rAssignParallelismAndRecompile(pbs, k, false, 
forceExecTypeCP);
                } catch (IOException e) {
                        throw new DMLRuntimeException(e);
                }
        }
 
-       private static boolean rAssignParallelism(List<ProgramBlock> pbs, int 
k, boolean recompiled) throws IOException {
+       private static boolean 
rAssignParallelismAndRecompile(List<ProgramBlock> pbs, int k, boolean 
recompiled, boolean forceExecTypeCP) throws IOException {
                for (ProgramBlock pb : pbs) {
                        if (pb instanceof ParForProgramBlock) {
                                ParForProgramBlock pfpb = (ParForProgramBlock) 
pb;
                                pfpb.setDegreeOfParallelism(k);
-                               recompiled |= 
rAssignParallelism(pfpb.getChildBlocks(), 1, recompiled);
+                               recompiled |= 
rAssignParallelismAndRecompile(pfpb.getChildBlocks(), 1, recompiled, 
forceExecTypeCP);
                        } else if (pb instanceof ForProgramBlock) {
-                               recompiled |= 
rAssignParallelism(((ForProgramBlock) pb).getChildBlocks(), k, recompiled);
+                               recompiled |= 
rAssignParallelismAndRecompile(((ForProgramBlock) pb).getChildBlocks(), k, 
recompiled, forceExecTypeCP);
                        } else if (pb instanceof WhileProgramBlock) {
-                               recompiled |= 
rAssignParallelism(((WhileProgramBlock) pb).getChildBlocks(), k, recompiled);
+                               recompiled |= 
rAssignParallelismAndRecompile(((WhileProgramBlock) pb).getChildBlocks(), k, 
recompiled, forceExecTypeCP);
                        } else if (pb instanceof FunctionProgramBlock) {
-                               recompiled |= 
rAssignParallelism(((FunctionProgramBlock) pb).getChildBlocks(), k, recompiled);
+                               recompiled |= 
rAssignParallelismAndRecompile(((FunctionProgramBlock) pb).getChildBlocks(), k, 
recompiled, forceExecTypeCP);
                        } else if (pb instanceof IfProgramBlock) {
                                IfProgramBlock ipb = (IfProgramBlock) pb;
-                               recompiled |= 
rAssignParallelism(ipb.getChildBlocksIfBody(), k, recompiled);
+                               recompiled |= 
rAssignParallelismAndRecompile(ipb.getChildBlocksIfBody(), k, recompiled, 
forceExecTypeCP);
                                if (ipb.getChildBlocksElseBody() != null)
-                                       recompiled |= 
rAssignParallelism(ipb.getChildBlocksElseBody(), k, recompiled);
+                                       recompiled |= 
rAssignParallelismAndRecompile(ipb.getChildBlocksElseBody(), k, recompiled, 
forceExecTypeCP);
                        } else {
                                StatementBlock sb = pb.getStatementBlock();
                                for (Hop hop : sb.getHops())
-                                       recompiled |= rAssignParallelism(hop, 
k, recompiled);
+                                       recompiled |= 
rAssignParallelismAndRecompile(hop, k, recompiled);
                        }
                        // Recompile the program block
                        if (recompiled) {
-                               
Recompiler.recompileProgramBlockInstructions(pb);
+                               if(forceExecTypeCP)
+                                       
Recompiler.rRecompileProgramBlock2Forced(pb, pb.getThreadID(), new HashSet<>(), 
LopProperties.ExecType.CP);
+                               else
+                                       
Recompiler.recompileProgramBlockInstructions(pb);
                        }
                }
                return recompiled;
        }
 
-       private static boolean rAssignParallelism(Hop hop, int k, boolean 
recompiled) {
+       private static boolean rAssignParallelismAndRecompile(Hop hop, int k, 
boolean recompiled) {
                if (hop.isVisited()) {
                        return recompiled;
                }
@@ -304,7 +321,7 @@ public class ParamservUtils {
                }
                ArrayList<Hop> inputs = hop.getInput();
                for (Hop h : inputs) {
-                       recompiled |= rAssignParallelism(h, k, recompiled);
+                       recompiled |= rAssignParallelismAndRecompile(h, k, 
recompiled);
                }
                hop.setVisited();
                return recompiled;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
new file mode 100644
index 0000000..4183372
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+
+public abstract class DataPartitionFederatedScheme {
+
+       public static final class Result {
+               public final List<MatrixObject> pFeatures;
+               public final List<MatrixObject> pLabels;
+               public final int workerNum;
+
+               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels, int workerNum) {
+                       this.pFeatures = pFeatures;
+                       this.pLabels = pLabels;
+                       this.workerNum = workerNum;
+               }
+       }
+
+       public abstract Result doPartitioning(MatrixObject features, 
MatrixObject labels);
+
+       /**
+        * Takes a row federated Matrix and slices it into a matrix for each 
worker
+        *
+        * @param fedMatrix the federated input matrix
+        */
+       static List<MatrixObject> sliceFederatedMatrix(MatrixObject fedMatrix) {
+               if (fedMatrix.isFederated(FederationMap.FType.ROW)) {
+
+                       List<MatrixObject> slices = 
Collections.synchronizedList(new ArrayList<>());
+                       fedMatrix.getFedMapping().forEachParallel((range, data) 
-> {
+                               // Create sliced matrix object
+                               MatrixObject slice = new 
MatrixObject(fedMatrix.getValueType(), 
Dag.getNextUniqueVarname(Types.DataType.MATRIX));
+                               // Warning needs MetaDataFormat instead of 
MetaData
+                               slice.setMetaData(new MetaDataFormat(
+                                               new 
MatrixCharacteristics(range.getSize(0), range.getSize(1)),
+                                               Types.FileFormat.BINARY)
+                               );
+
+                               // Create new federation map
+                               HashMap<FederatedRange, FederatedData> 
newFedHashMap = new HashMap<>();
+                               newFedHashMap.put(range, data);
+                               slice.setFedMapping(new 
FederationMap(fedMatrix.getFedMapping().getID(), newFedHashMap));
+                               
slice.getFedMapping().setType(FederationMap.FType.ROW);
+
+                               slices.add(slice);
+                               return null;
+                       });
+
+                       return slices;
+               }
+               else {
+                       throw new DMLRuntimeException("Federated data 
partitioner: " +
+                                       "currently only supports row federated 
data");
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
new file mode 100644
index 0000000..4cdfb95
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+
+public class FederatedDataPartitioner {
+
+       private final DataPartitionFederatedScheme _scheme;
+
+       public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
+               switch (scheme) {
+                       case KEEP_DATA_ON_WORKER:
+                               _scheme = new KeepDataOnWorkerFederatedScheme();
+                               break;
+                       case SHUFFLE:
+                               _scheme = new ShuffleFederatedScheme();
+                               break;
+                       default:
+                               throw new 
DMLRuntimeException(String.format("FederatedDataPartitioner: not support data 
partition scheme '%s'", scheme));
+               }
+       }
+
+       public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject 
features, MatrixObject labels) {
+               return _scheme.doPartitioning(features, labels);
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
new file mode 100644
index 0000000..06feded
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import java.util.List;
+
+public class KeepDataOnWorkerFederatedScheme extends 
DataPartitionFederatedScheme {
+       @Override
+       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+               List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+               List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+               return new Result(pFeatures, pLabels, pFeatures.size());
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
new file mode 100644
index 0000000..d6d8cfc
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+
+import java.util.List;
+
+public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
+       @Override
+       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+               List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
+               List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+               return new Result(pFeatures, pLabels, pFeatures.size());
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
index 4f726ee..a8397cb 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ListObject.java
@@ -19,17 +19,29 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
 
-public class ListObject extends Data {
+public class ListObject extends Data implements Externalizable {
        private static final long serialVersionUID = 3652422061598967358L;
 
        private final List<Data> _data;
@@ -37,6 +49,14 @@ public class ListObject extends Data {
        private List<String> _names = null;
        private int _nCacheable;
        private List<LineageItem> _lineage = null;
+
+       /*
+        * No op constructor for Externalizable interface
+        */
+       public ListObject() {
+               super(DataType.LIST, ValueType.UNKNOWN);
+               _data = new ArrayList<>();
+       }
        
        public ListObject(List<Data> data) {
                this(data, null, null);
@@ -286,4 +306,127 @@ public class ListObject extends Data {
                sb.append(")");
                return sb.toString();
        }
+
+       /**
+        * Redirects the default java serialization via externalizable to our 
default
+        * hadoop writable serialization for efficient broadcast/rdd 
serialization.
+        *
+        * @param out object output
+        * @throws IOException if IOException occurs
+        */
+       @Override
+       public void writeExternal(ObjectOutput out) throws IOException {
+               // write out length
+               out.writeInt(getLength());
+               // write out num cacheable
+               out.writeInt(_nCacheable);
+
+               // write out names for named list
+               out.writeBoolean(getNames() != null);
+               if(getNames() != null) {
+                       for (int i = 0; i < getLength(); i++) {
+                               out.writeObject(_names.get(i));
+                       }
+               }
+
+               // write out data
+               for(int i = 0; i < getLength(); i++) {
+                       Data d = getData(i);
+                       out.writeObject(d.getDataType());
+                       out.writeObject(d.getValueType());
+                       out.writeObject(d.getPrivacyConstraint());
+                       switch(d.getDataType()) {
+                               case LIST:
+                                       ListObject lo = (ListObject) d;
+                                       out.writeObject(lo);
+                                       break;
+                               case MATRIX:
+                                       MatrixObject mo = (MatrixObject) d;
+                                       MetaDataFormat md = (MetaDataFormat) 
mo.getMetaData();
+                                       DataCharacteristics dc = 
md.getDataCharacteristics();
+
+                                       out.writeObject(dc.getRows());
+                                       out.writeObject(dc.getCols());
+                                       out.writeObject(dc.getBlocksize());
+                                       out.writeObject(dc.getNonZeros());
+                                       out.writeObject(md.getFileFormat());
+                                       
out.writeObject(mo.acquireReadAndRelease());
+                                       break;
+                               case SCALAR:
+                                       ScalarObject so = (ScalarObject) d;
+                                       out.writeObject(so.getStringValue());
+                                       break;
+                               default:
+                                       throw new DMLRuntimeException("Unable 
to serialize datatype " + dataType);
+                       }
+               }
+       }
+
+       /**
+        * Redirects the default java serialization via externalizable to our 
default
+        * hadoop writable serialization for efficient broadcast/rdd 
deserialization.
+        *
+        * @param in object input
+        * @throws IOException if IOException occurs
+        */
+       @Override
+       public void readExternal(ObjectInput in) throws IOException, 
ClassNotFoundException {
+               // read in length
+               int length = in.readInt();
+               // read in num cacheable
+               _nCacheable = in.readInt();
+
+               // read in names
+               Boolean names = in.readBoolean();
+               if(names) {
+                       _names = new ArrayList<>();
+                       for (int i = 0; i < length; i++) {
+                               _names.add((String) in.readObject());
+                       }
+               }
+
+               // read in data
+               for(int i = 0; i < length; i++) {
+                       DataType dataType = (DataType) in.readObject();
+                       ValueType valueType = (ValueType) in.readObject();
+                       PrivacyConstraint privacyConstraint = 
(PrivacyConstraint) in.readObject();
+                       Data d;
+                       switch(dataType) {
+                               case LIST:
+                                       d = (ListObject) in.readObject();
+                                       break;
+                               case MATRIX:
+                                       long rows = (long) in.readObject();
+                                       long cols = (long) in.readObject();
+                                       int blockSize = (int) in.readObject();
+                                       long nonZeros = (long) in.readObject();
+                                       Types.FileFormat fileFormat = 
(Types.FileFormat) in.readObject();
+
+                                       // construct objects and set meta data
+                                       MatrixCharacteristics 
matrixCharacteristics = new MatrixCharacteristics(rows, cols, blockSize, 
nonZeros);
+                                       MetaDataFormat metaDataFormat = new 
MetaDataFormat(matrixCharacteristics, fileFormat);
+                                       MatrixBlock matrixBlock = (MatrixBlock) 
in.readObject();
+
+                                       d = new MatrixObject(valueType, 
Dag.getNextUniqueVarname(Types.DataType.MATRIX), metaDataFormat, matrixBlock);
+                                       break;
+                               case SCALAR:
+                                       String value = (String) in.readObject();
+                                       ScalarObject so;
+                                       switch (valueType) {
+                                               case INT64:     so = new 
IntObject(Long.parseLong(value)); break;
+                                               case FP64:      so = new 
DoubleObject(Double.parseDouble(value)); break;
+                                               case BOOLEAN:   so = new 
BooleanObject(Boolean.parseBoolean(value)); break;
+                                               case STRING:    so = new 
StringObject(value); break;
+                                               default:
+                                                       throw new 
DMLRuntimeException("Unable to parse valuetype " + valueType);
+                                       }
+                                       d = so;
+                                       break;
+                               default:
+                                       throw new DMLRuntimeException("Unable 
to deserialize datatype " + dataType);
+                       }
+                       d.setPrivacyConstraints(privacyConstraint);
+                       _data.add(d);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index c42ec91..5e8ad32 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -52,6 +52,7 @@ import org.apache.spark.util.LongAccumulator;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.hops.recompile.Recompiler;
 import org.apache.sysds.lops.LopProperties;
+import org.apache.sysds.parser.Statement;
 import org.apache.sysds.parser.Statement.PSFrequency;
 import org.apache.sysds.parser.Statement.PSModeType;
 import org.apache.sysds.parser.Statement.PSScheme;
@@ -61,13 +62,16 @@ import 
org.apache.sysds.runtime.controlprogram.LocalVariableMap;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import 
org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread;
 import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
 import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
 import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
 import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
 import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
+import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
 import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner;
 import 
org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
 import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
@@ -91,16 +95,87 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               PSModeType mode = getPSMode();
-               switch (mode) {
-                       case LOCAL:
-                               runLocally(ec, mode);
-                               break;
-                       case REMOTE_SPARK:
-                               runOnSpark((SparkExecutionContext) ec, mode);
-                               break;
-                       default:
-                               throw new 
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+               // check if the input is federated
+               if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
+                               
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+                       runFederated(ec);
+               }
+               // if not federated check mode
+               else {
+                       PSModeType mode = getPSMode();
+                       switch (mode) {
+                               case LOCAL:
+                                       runLocally(ec, mode);
+                                       break;
+                               case REMOTE_SPARK:
+                                       runOnSpark((SparkExecutionContext) ec, 
mode);
+                                       break;
+                               default:
+                                       throw new 
DMLRuntimeException(String.format("Paramserv func: not support mode %s", mode));
+                       }
+               }
+       }
+
+       private void runFederated(ExecutionContext ec) {
+               System.out.println("PARAMETER SERVER");
+               System.out.println("[+] Running in federated mode");
+
+               // get inputs
+               PSFrequency freq = getFrequency();
+               PSUpdateType updateType = getUpdateType();
+               String updFunc = getParam(PS_UPDATE_FUN);
+               String aggFunc = getParam(PS_AGGREGATION_FUN);
+
+               // partition federated data
+               DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER)
+                               
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), 
ec.getMatrixObject(getParam(PS_LABELS)));
+               List<MatrixObject> pFeatures = result.pFeatures;
+               List<MatrixObject> pLabels = result.pLabels;
+               int workerNum = result.workerNum;
+
+               // setup threading
+               BasicThreadFactory factory = new BasicThreadFactory.Builder()
+                               
.namingPattern("workers-pool-thread-%d").build();
+               ExecutorService es = Executors.newFixedThreadPool(workerNum, 
factory);
+
+               // Get the compiled execution context
+               LocalVariableMap newVarsMap = createVarsMap(ec);
+               // Level of par is 1 because one worker will be launched per 
task
+               // TODO: Fix recompilation
+               ExecutionContext newEC = 
ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, 1, 
true);
+               // Create workers' execution context
+               List<ExecutionContext> federatedWorkerECs = 
ParamservUtils.copyExecutionContext(newEC, workerNum);
+               // Create the agg service's execution context
+               ExecutionContext aggServiceEC = 
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
+               // Create the parameter server
+               ListObject model = ec.getListObject(getParam(PS_MODEL));
+               ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, workerNum, model, aggServiceEC);
+               // Create the local workers
+               List<FederatedPSControlThread> threads = IntStream.range(0, 
workerNum)
+                               .mapToObj(i -> new FederatedPSControlThread(i, 
updFunc, freq, getEpochs(), getBatchSize(), federatedWorkerECs.get(i), ps))
+                               .collect(Collectors.toList());
+
+               if(workerNum != threads.size()) {
+                       throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning 
does not match threads!");
+               }
+
+               // Set features and lables for the control threads and write 
the program and instructions and hyperparams to the federated workers
+               for (int i = 0; i < threads.size(); i++) {
+                       threads.get(i).setFeatures(pFeatures.get(i));
+                       threads.get(i).setLabels(pLabels.get(i));
+                       threads.get(i).setup();
+               }
+
+               try {
+                       // Launch the worker threads and wait for completion
+                       for (Future<Void> ret : es.invokeAll(threads))
+                               ret.get(); //error handling
+                       // Fetch the final model from ps
+                       ec.setVariable(output.getName(), ps.getResult());
+               } catch (InterruptedException | ExecutionException e) {
+                       throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
+               } finally {
+                       es.shutdownNow();
                }
        }
 
@@ -150,7 +225,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                LongAccumulator aEpoch = 
sec.getSparkContext().sc().longAccumulator("numEpochs");
                
                // Create remote workers
-               SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN), 
+               SparkPSWorker worker = new 
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
                        getFrequency(), getEpochs(), getBatchSize(), program, 
clsMap, sec.getSparkContext().getConf(),
                        server.getPort(), aSetup, aWorker, aUpdate, aIndex, 
aGrad, aRPC, aBatch, aEpoch);
 
@@ -333,6 +408,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         */
        private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
                switch (mode) {
+                       case FEDERATED:
                        case LOCAL:
                        case REMOTE_SPARK:
                                return LocalParamServer.create(model, aggFunc, 
updateType, ec, workerNum);
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java 
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 7dad319..16472c7 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -817,7 +817,7 @@ public class ProgramConverter
                //handle program
                sb.append(PROG_BEGIN);
                sb.append( NEWLINE );
-               sb.append( serializeProgram(prog, pbs, clsMap) );
+               sb.append( serializeProgram(prog, pbs, clsMap, true) );
                sb.append(PROG_END);
                sb.append( NEWLINE );
                sb.append( COMPONENTS_DELIM );
@@ -849,32 +849,32 @@ public class ProgramConverter
                return sb.toString();
        }
 
-       private static String serializeProgram( Program prog, 
ArrayList<ProgramBlock> pbs, HashMap<String, byte[]> clsMap ) {
-               //note program contains variables, programblocks and function 
program blocks 
+       public static String serializeProgram( Program prog, 
ArrayList<ProgramBlock> pbs, HashMap<String, byte[]> clsMap, boolean opt) {
+               //note program contains variables, programblocks and function 
program blocks
                //but in order to avoid redundancy, we only serialize function 
program blocks
-               HashMap<String, FunctionProgramBlock> fpb = 
prog.getFunctionProgramBlocks();
+               HashMap<String, FunctionProgramBlock> fpb = 
prog.getFunctionProgramBlocks(opt);
                HashSet<String> cand = new HashSet<>();
-               rFindSerializationCandidates(pbs, cand);
-               return rSerializeFunctionProgramBlocks( fpb, cand, clsMap );
+               rFindSerializationCandidates(pbs, cand, opt);
+               return rSerializeFunctionProgramBlocks(fpb, cand, clsMap);
        }
 
-       private static void rFindSerializationCandidates( 
ArrayList<ProgramBlock> pbs, HashSet<String> cand ) 
+       private static void rFindSerializationCandidates( 
ArrayList<ProgramBlock> pbs, HashSet<String> cand, boolean opt)
        {
                for( ProgramBlock pb : pbs )
                {
                        if( pb instanceof WhileProgramBlock ) {
                                WhileProgramBlock wpb = (WhileProgramBlock) pb;
-                               
rFindSerializationCandidates(wpb.getChildBlocks(), cand );
+                               
rFindSerializationCandidates(wpb.getChildBlocks(), cand, opt);
                        }
                        else if ( pb instanceof ForProgramBlock || pb 
instanceof ParForProgramBlock ) {
                                ForProgramBlock fpb = (ForProgramBlock) pb; 
-                               
rFindSerializationCandidates(fpb.getChildBlocks(), cand);
+                               
rFindSerializationCandidates(fpb.getChildBlocks(), cand, opt);
                        }
                        else if ( pb instanceof IfProgramBlock ) {
                                IfProgramBlock ipb = (IfProgramBlock) pb;
-                               
rFindSerializationCandidates(ipb.getChildBlocksIfBody(), cand);
+                               
rFindSerializationCandidates(ipb.getChildBlocksIfBody(), cand, opt);
                                if( ipb.getChildBlocksElseBody() != null )
-                                       
rFindSerializationCandidates(ipb.getChildBlocksElseBody(), cand);
+                                       
rFindSerializationCandidates(ipb.getChildBlocksElseBody(), cand, opt);
                        }
                        else if( pb instanceof BasicProgramBlock ) { 
                                BasicProgramBlock bpb = (BasicProgramBlock) pb;
@@ -885,8 +885,8 @@ public class ProgramConverter
                                                if( !cand.contains(fkey) ) { 
//memoization for multiple calls, recursion
                                                        cand.add( fkey ); //add 
to candidates
                                                        //investigate chains of 
function calls
-                                                       FunctionProgramBlock 
fpb = pb.getProgram().getFunctionProgramBlock(fci.getNamespace(), 
fci.getFunctionName());
-                                                       
rFindSerializationCandidates(fpb.getChildBlocks(), cand);
+                                                       FunctionProgramBlock 
fpb = pb.getProgram().getFunctionProgramBlock(fci.getNamespace(), 
fci.getFunctionName(), opt);
+                                                       
rFindSerializationCandidates(fpb.getChildBlocks(), cand, opt);
                                                }
                                        }
                        }
@@ -985,12 +985,12 @@ public class ProgramConverter
        }
 
        @SuppressWarnings("all")
-       private static String serializeInstructions( ArrayList<Instruction> 
inst, HashMap<String, byte[]> clsMap ) 
+       private static String serializeInstructions( ArrayList<Instruction> 
inst, HashMap<String, byte[]> clsMap )
        {
                StringBuilder sb = new StringBuilder();
                int count = 0;
                for( Instruction linst : inst ) {
-                       //check that only cp instruction are transmitted 
+                       //check that only cp instruction are transmitted
                        if( !( linst instanceof CPInstruction) )
                                throw new DMLRuntimeException( 
NOT_SUPPORTED_SPARK_INSTRUCTION + " " +linst.getClass().getName()+"\n"+linst );
                        
@@ -1098,7 +1098,6 @@ public class ProgramConverter
                                continue;
                        if( count>0 ) {
                                sb.append( ELEMENT_DELIM );
-                               sb.append( NEWLINE );
                        }
                        sb.append( pb.getKey() );
                        sb.append( KEY_VALUE_DELIM );
@@ -1115,7 +1114,6 @@ public class ProgramConverter
                for( ProgramBlock pb : pbs ) {
                        if( count>0 ) {
                                sb.append( ELEMENT_DELIM );
-                               sb.append(NEWLINE);
                        }
                        sb.append( rSerializeProgramBlock(pb, clsMap) );
                        count++;
@@ -1339,6 +1337,10 @@ public class ProgramConverter
        }
 
        public static Program parseProgram( String in, int id ) {
+               return parseProgram(in, id, true);
+       }
+
+       public static Program parseProgram( String in, int id, boolean opt ) {
                String lin = in.substring( PROG_BEGIN.length(),in.length()- 
PROG_END.length()).trim();
                Program prog = new Program();
                HashMap<String,FunctionProgramBlock> fc = 
parseFunctionProgramBlocks(lin, prog, id);
@@ -1346,7 +1348,7 @@ public class ProgramConverter
                        String[] keypart = e.getKey().split( Program.KEY_DELIM 
);
                        String namespace = keypart[0];
                        String name      = keypart[1];
-                       prog.addFunctionProgramBlock(namespace, name, 
e.getValue());
+                       prog.addFunctionProgramBlock(namespace, name, 
e.getValue(), opt);
                }
                return prog;
        }
@@ -1354,7 +1356,7 @@ public class ProgramConverter
        private static LocalVariableMap parseVariables(String in) {
                LocalVariableMap ret = null;
                if( in.length()> VARS_BEGIN.length() + VARS_END.length()) {
-                       String varStr = in.substring( 
VARS_BEGIN.length(),in.length()- VARS_END.length()).trim();
+                       String varStr = in.substring( 
VARS_BEGIN.length(),in.length() - VARS_END.length()).trim();
                        ret = LocalVariableMap.deserialize(varStr);
                }
                else { //empty input symbol table
diff --git 
a/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
 
b/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
index 665997a..bf47f19 100644
--- 
a/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/paramserv/SerializationTest.java
@@ -19,8 +19,15 @@
 
 package org.apache.sysds.test.component.paramserv;
 
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.ObjectInput;
+import java.io.ObjectOutputStream;
+import java.io.ObjectInputStream;
 import java.util.Arrays;
+import java.util.Collection;
 
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.junit.Assert;
 import org.junit.Test;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -29,32 +36,80 @@ import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.ProgramConverter;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
+@RunWith(Parameterized.class)
 public class SerializationTest {
+       private int _named;
+
+       @Parameterized.Parameters
+       public static Collection named() {
+               return Arrays.asList(new Object[][] {
+                               { 0 },
+                               { 1 }
+               });
+       }
+
+       public SerializationTest(Integer named) {
+               this._named = named;
+       }
 
        @Test
-       public void serializeUnnamedListObject() {
+       public void serializeListObject() {
                MatrixObject mo1 = generateDummyMatrix(10);
                MatrixObject mo2 = generateDummyMatrix(20);
                IntObject io = new IntObject(30);
-               ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io));
-               String serial = ProgramConverter.serializeDataObject("key", lo);
-               Object[] obj = ProgramConverter.parseDataObject(serial);
-               ListObject actualLO = (ListObject) obj[1];
-               MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
-               MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
-               IntObject actualIO = (IntObject) actualLO.slice(2);
-               
Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), 
actualMO1.acquireRead().getDenseBlockValues(), 0);
-               
Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), 
actualMO2.acquireRead().getDenseBlockValues(), 0);
-               Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
+               ListObject lot = new ListObject(Arrays.asList(mo2));
+               ListObject lo;
+
+               if (_named == 1)
+                        lo = new ListObject(Arrays.asList(mo1, lot, io), 
Arrays.asList("e1", "e2", "e3"));
+               else
+                       lo = new ListObject(Arrays.asList(mo1, lot, io));
+
+               ListObject loDeserialized = null;
+
+               // serialize and back
+               try {
+                       ByteArrayOutputStream bos = new ByteArrayOutputStream();
+                       ObjectOutputStream out = new ObjectOutputStream(bos);
+                       out.writeObject(lo);
+                       out.flush();
+                       byte[] loBytes = bos.toByteArray();
+
+                       ByteArrayInputStream bis = new 
ByteArrayInputStream(loBytes);
+                       ObjectInput in = new ObjectInputStream(bis);
+                       loDeserialized = (ListObject) in.readObject();
+               }
+               catch(Exception e){
+                       System.out.println("Error while serializing and 
deserializing to bytes: " + e);
+                       assert(false);
+               }
+
+               MatrixObject mo1Deserialized = (MatrixObject) 
loDeserialized.getData(0);
+               ListObject lotDeserialized = (ListObject) 
loDeserialized.getData(1);
+               MatrixObject mo2Deserialized = (MatrixObject) 
lotDeserialized.getData(0);
+               IntObject ioDeserialized = (IntObject) 
loDeserialized.getData(2);
+
+               if (_named == 1)
+                       Assert.assertEquals(lo.getNames(), 
loDeserialized.getNames());
+
+               
Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), 
mo1Deserialized.acquireRead().getDenseBlockValues(), 0);
+               
Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), 
mo2Deserialized.acquireRead().getDenseBlockValues(), 0);
+               Assert.assertEquals(io.getLongValue(), 
ioDeserialized.getLongValue());
        }
 
        @Test
-       public void serializeNamedListObject() {
+       public void serializeListObjectProgramConverter() {
                MatrixObject mo1 = generateDummyMatrix(10);
                MatrixObject mo2 = generateDummyMatrix(20);
                IntObject io = new IntObject(30);
-               ListObject lo = new ListObject(Arrays.asList(mo1, mo2, io), 
Arrays.asList("e1", "e2", "e3"));
+               ListObject lo;
+               if (_named == 1)
+                       lo = new ListObject(Arrays.asList(mo1, mo2, io), 
Arrays.asList("e1", "e2", "e3"));
+               else
+                       lo = new ListObject(Arrays.asList(mo1, mo2, io));
 
                String serial = ProgramConverter.serializeDataObject("key", lo);
                Object[] obj = ProgramConverter.parseDataObject(serial);
@@ -62,7 +117,10 @@ public class SerializationTest {
                MatrixObject actualMO1 = (MatrixObject) actualLO.slice(0);
                MatrixObject actualMO2 = (MatrixObject) actualLO.slice(1);
                IntObject actualIO = (IntObject) actualLO.slice(2);
-               Assert.assertEquals(lo.getNames(), actualLO.getNames());
+
+               if (_named == 1)
+                       Assert.assertEquals(lo.getNames(), actualLO.getNames());
+
                
Assert.assertArrayEquals(mo1.acquireRead().getDenseBlockValues(), 
actualMO1.acquireRead().getDenseBlockValues(), 0);
                
Assert.assertArrayEquals(mo2.acquireRead().getDenseBlockValues(), 
actualMO2.acquireRead().getDenseBlockValues(), 0);
                Assert.assertEquals(io.getLongValue(), actualIO.getLongValue());
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
new file mode 100644
index 0000000..194df09
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedParamservTest extends AutomatedTestBase {
+    private final static String TEST_DIR = "functions/federated/paramserv/";
+    private final static String TEST_NAME = "FederatedParamservTest";
+    private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedParamservTest.class.getSimpleName() + "/";
+    private final static int _blocksize = 1024;
+
+    private final String _networkType;
+    private final int _numFederatedWorkers;
+    private final int _examplesPerWorker;
+    private final int _epochs;
+    private final int _batch_size;
+    private final double _eta;
+    private final String _utype;
+    private final String _freq;
+
+    private Types.ExecMode _platformOld;
+
+    // parameters
+    @Parameterized.Parameters
+    public static Collection<Object[]> parameters() {
+        return Arrays.asList(new Object[][] {
+                //Network type, number of federated workers, examples per 
worker, batch size, epochs, learning rate, update type, update frequency
+                {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+                {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+                {"TwoNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+                {"TwoNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "BSP", "BATCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "ASP", "BATCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "BSP", "EPOCH"},
+                {"CNN", 2, 2, 1, 5, 0.01, "ASP", "EPOCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+                {"TwoNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "BATCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "BATCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "BSP", "EPOCH"},
+                {"CNN", 5, 1000, 32, 2, 0.01, "ASP", "EPOCH"}
+        });
+    }
+
+    public FederatedParamservTest(String networkType, int numFederatedWorkers, 
int examplesPerWorker, int batch_size, int epochs, double eta, String utype, 
String freq) {
+        _networkType = networkType;
+        _numFederatedWorkers = numFederatedWorkers;
+        _examplesPerWorker = examplesPerWorker;
+        _batch_size = batch_size;
+        _epochs = epochs;
+        _eta = eta;
+        _utype = utype;
+        _freq = freq;
+    }
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME));
+
+        _platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
+    }
+
+    @Override
+    public void tearDown() {
+
+        rtplatform = _platformOld;
+    }
+
+    @Test
+    public void federatedParamserv() {
+        // config
+        getAndLoadTestConfiguration(TEST_NAME);
+        String HOME = SCRIPT_DIR + TEST_DIR;
+        setOutputBuffering(true);
+
+        int C = 1, Hin = 28, Win = 28;
+        int numFeatures = C*Hin*Win;
+        int numLabels = 10;
+
+        // dml name
+        fullDMLScriptName = HOME + TEST_NAME + ".dml";
+        // generate program args
+        List<String> programArgsList = new ArrayList<>(Arrays.asList(
+                "-stats",
+                "-nvargs",
+                "examples_per_worker=" + _examplesPerWorker,
+                "num_features=" + numFeatures,
+                "num_labels=" + numLabels,
+                "epochs=" + _epochs,
+                "batch_size=" + _batch_size,
+                "eta=" + _eta,
+                "utype=" + _utype,
+                "freq=" + _freq,
+                "network_type=" + _networkType,
+                "channels=" + C,
+                "hin=" + Hin,
+                "win=" + Win
+        ));
+
+        // for each worker
+        List<Integer> ports = new ArrayList<>();
+        List<Thread> threads = new ArrayList<>();
+        for(int i = 0; i < _numFederatedWorkers; i++) {
+            // write row partitioned features to disk
+            writeInputMatrixWithMTD("X" + i, 
generateDummyMNISTFeatures(_examplesPerWorker, C, Hin, Win), false,
+                    new MatrixCharacteristics(_examplesPerWorker, numFeatures, 
_blocksize, _examplesPerWorker * numFeatures));
+            // write row partitioned labels to disk
+            writeInputMatrixWithMTD("y" + i, 
generateDummyMNISTLabels(_examplesPerWorker, numLabels), false,
+                    new MatrixCharacteristics(_examplesPerWorker, numLabels, 
_blocksize, _examplesPerWorker * numLabels));
+
+            // start worker
+            ports.add(getRandomAvailablePort());
+            threads.add(startLocalFedWorkerThread(ports.get(i)));
+
+            // add worker to program args
+            programArgsList.add("X" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("X" + i)));
+            programArgsList.add("y" + i + "=" + 
TestUtils.federatedAddress(ports.get(i), input("y" + i)));
+        }
+
+        programArgs = programArgsList.toArray(new String[0]);
+        // ByteArrayOutputStream stdout =
+        runTest(null);
+        // System.out.print(stdout.toString());
+
+        // cleanup
+        for(int i = 0; i < _numFederatedWorkers; i++) {
+            TestUtils.shutdownThreads(threads.get(i));
+        }
+    }
+
+    /**
+     * Generates an feature matrix that has the same format as the MNIST 
dataset,
+     * but is completely random and normalized
+     *
+     *  @param numExamples Number of examples to generate
+     *  @param C Channels in the input data
+     *  @param Hin Height in Pixels of the input data
+     *  @param Win Width in Pixels of the input data
+     *  @return a dummy MNIST feature matrix
+     */
+    private double[][] generateDummyMNISTFeatures(int numExamples, int C, int 
Hin, int Win) {
+        // Seed -1 takes the time in milliseconds as a seed
+        // Sparsity 1 means no sparsity
+        return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+    }
+
+    /**
+     * Generates an label matrix that has the same format as the MNIST 
dataset, but is completely random and consists
+     * of one hot encoded vectors as rows
+     *
+     *  @param numExamples Number of examples to generate
+     *  @param numLabels Number of labels to generate
+     *  @return a dummy MNIST lable matrix
+     */
+    private double[][] generateDummyMNISTLabels(int numExamples, int 
numLabels) {
+        // Seed -1 takes the time in milliseconds as a seed
+        // Sparsity 1 means no sparsity
+        return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+    }
+}
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml 
b/src/test/scripts/functions/federated/paramserv/CNN.dml
new file mode 100644
index 0000000..55d05dc
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -0,0 +1,474 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a convolutional 
neural network of the "LeNet" architecture
+ * on different execution schemes and with different inputs, for example a 
federated input matrix.
+ */
+
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * Trains a convolutional net using the "LeNet" architectur single threaded 
the conventional way.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int C, int Hin, int Win, int epochs, int batch_size, double 
learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> 
relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  learning_rate = learning_rate  # learning rate
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+  # Regularization
+  lambda = 5e-04
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, 
Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, 
F2=F2, N3=N3)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+  print_interval = floor(iters / 25)
+
+  print("[+] Starting optimization")
+  print("[+]  Learning rate: " + learning_rate)
+  print("[+]  Batch size: " + batch_size)
+  print("[+]  Iterations per epoch: " + iters + "\n")
+
+  for (e in 1:epochs) {
+    print("[+] Starting epoch: " + e)
+    print("|")
+    for(i in 1:iters) {
+      # Create the model list
+      model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
+
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
+      model_updated = aggregation(model_list, hyperparams, gradients_list)
+
+      W1 = as.matrix(model_updated[1])
+      W2 = as.matrix(model_updated[2])
+      W3 = as.matrix(model_updated[3])
+      W4 = as.matrix(model_updated[4])
+      b1 = as.matrix(model_updated[5])
+      b2 = as.matrix(model_updated[6])
+      b3 = as.matrix(model_updated[7])
+      b4 = as.matrix(model_updated[8])
+      vW1 = as.matrix(model_updated[9])
+      vW2 = as.matrix(model_updated[10])
+      vW3 = as.matrix(model_updated[11])
+      vW4 = as.matrix(model_updated[12])
+      vb1 = as.matrix(model_updated[13])
+      vb2 = as.matrix(model_updated[14])
+      vb3 = as.matrix(model_updated[15])
+      vb4 = as.matrix(model_updated[16])
+      if((i %% print_interval) == 0) {
+        print("█")
+      }
+    }
+    print("|")
+  }
+
+  model_trained = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
+}
+
+/*
+ * Trains a convolutional net using the "LeNet" architecture using a parameter 
server with specified properties.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.  The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - Y: Target matrix, of shape (N, K)
+ *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ *  - Y_val: Target validation matrix, of shape (N, K)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int C, int Hin, int Win, int epochs, int workers,
+                 string utype, string freq, int batch_size, string scheme, 
string mode, double learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> 
relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  learning_rate = learning_rate  # learning rate
+  mu = 0.9  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+  # Regularization
+  lambda = 5e-04
+  # Create the model list
+  model_list = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, 
vb2, vb3, vb4)
+  # Create the hyper parameter list
+  params = list(learning_rate=learning_rate, mu=mu, decay=decay, C=C, Hin=Hin, 
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, 
N3=N3)
+
+  # Use paramserv function
+  model_trained = paramserv(model=model_list, features=X, labels=y, 
val_features=X_val, val_labels=y_val, 
upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients", 
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation", 
mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, 
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+}
+
+/*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ *  - X: Input data matrix, of shape (N, C*Hin*Win)
+ *  - C: Number of input channels (dimensionality of input depth)
+ *  - Hin: Input height
+ *  - Win: Input width
+ *  - batch_size: Batch size
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ *       - b1: 1st layer biases vector, of shape (F1, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ *       - b2: 2nd layer biases vector, of shape (F2, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3)
+ *       - b3: 3rd layer biases vector, of shape (1, N3)
+ *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ *       - b4: 4th layer biases vector, of shape (1, K)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size, 
list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+  N = nrow(X)
+
+  # Network:
+  ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> 
relu3 -> affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  iters = ceil(N / batch_size)
+  parfor(i in 1:iters, check=0) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
2, 2, 2, 2, 0, 0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
2, 2, 2, 2, 0, 0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+/*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ *  - y: Target matrix, of shape (N, K)
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1)
+ *  - accuracy: Scalar accuracy, of shape (1)
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+          return (list[unknown] gradients) {
+
+  C = as.integer(as.scalar(hyperparams["C"]))
+  Hin = as.integer(as.scalar(hyperparams["Hin"]))
+  Win = as.integer(as.scalar(hyperparams["Win"]))
+  Hf = as.integer(as.scalar(hyperparams["Hf"]))
+  Wf = as.integer(as.scalar(hyperparams["Wf"]))
+  stride = as.integer(as.scalar(hyperparams["stride"]))
+  pad = as.integer(as.scalar(hyperparams["pad"]))
+  lambda = as.double(as.scalar(hyperparams["lambda"]))
+  F1 = as.integer(as.scalar(hyperparams["F1"]))
+  F2 = as.integer(as.scalar(hyperparams["F2"]))
+  N3 = as.integer(as.scalar(hyperparams["N3"]))
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  W4 = as.matrix(model[4])
+  b1 = as.matrix(model[5])
+  b2 = as.matrix(model[6])
+  b3 = as.matrix(model[7])
+  b4 = as.matrix(model[8])
+
+  # Compute forward pass
+  ## layer 1: conv1 -> relu1 -> pool1
+  [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf, 
Wf,
+                                              stride, stride, pad, pad)
+  outr1 = relu::forward(outc1)
+  [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2, 
2, 2, 2, 0, 0)
+  ## layer 2: conv2 -> relu2 -> pool2
+  [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1, 
Hf, Wf,
+                                            stride, stride, pad, pad)
+  outr2 = relu::forward(outc2)
+  [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2, 
2, 2, 2, 0, 0)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  outa3 = affine::forward(outp2, W3, b3)
+  outr3 = relu::forward(outa3)
+  [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+  ## layer 4:  affine4 -> softmax
+  outa4 = affine::forward(outd3, W4, b4)
+  probs = softmax::forward(outa4)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  print("[+] Completed forward pass on batch: train loss: " + loss + ", train 
accuracy: " + accuracy)
+
+  # Compute data backward pass
+  ## loss
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  ## layer 4:  affine4 -> softmax
+  douta4 = softmax::backward(dprobs, outa4)
+  [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+  ## layer 3:  affine3 -> relu3 -> dropout
+  doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+  douta3 = relu::backward(doutr3, outa3)
+  [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+  ## layer 2: conv2 -> relu2 -> pool2
+  doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, 2, 2, 2, 2, 0, 0)
+  doutc2 = relu::backward(doutr2, outc2)
+  [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2, 
F1,
+                                        Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+  ## layer 1: conv1 -> relu1 -> pool1
+  doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, 2, 2, 2, 2, 0, 0)
+  doutc1 = relu::backward(doutr1, outc1)
+  [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features, 
W1, b1, C, Hin, Win,
+                                          Hf, Wf, stride, stride, pad, pad)
+
+  # Compute regularization backward pass
+  dW1_reg = l2_reg::backward(W1, lambda)
+  dW2_reg = l2_reg::backward(W2, lambda)
+  dW3_reg = l2_reg::backward(W3, lambda)
+  dW4_reg = l2_reg::backward(W4, lambda)
+  dW1 = dW1 + dW1_reg
+  dW2 = dW2 + dW2_reg
+  dW3 = dW3 + dW3_reg
+  dW4 = dW4 + dW4_reg
+
+  gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+   W1 = as.matrix(model[1])
+   W2 = as.matrix(model[2])
+   W3 = as.matrix(model[3])
+   W4 = as.matrix(model[4])
+   b1 = as.matrix(model[5])
+   b2 = as.matrix(model[6])
+   b3 = as.matrix(model[7])
+   b4 = as.matrix(model[8])
+   dW1 = as.matrix(gradients[1])
+   dW2 = as.matrix(gradients[2])
+   dW3 = as.matrix(gradients[3])
+   dW4 = as.matrix(gradients[4])
+   db1 = as.matrix(gradients[5])
+   db2 = as.matrix(gradients[6])
+   db3 = as.matrix(gradients[7])
+   db4 = as.matrix(gradients[8])
+   vW1 = as.matrix(model[9])
+   vW2 = as.matrix(model[10])
+   vW3 = as.matrix(model[11])
+   vW4 = as.matrix(model[12])
+   vb1 = as.matrix(model[13])
+   vb2 = as.matrix(model[14])
+   vb3 = as.matrix(model[15])
+   vb4 = as.matrix(model[16])
+   learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+   mu = as.double(as.scalar(hyperparams["mu"]))
+
+   # Optimize with SGD w/ Nesterov momentum
+   [W1, vW1] = sgd_nesterov::update(W1, dW1, learning_rate, mu, vW1)
+   [b1, vb1] = sgd_nesterov::update(b1, db1, learning_rate, mu, vb1)
+   [W2, vW2] = sgd_nesterov::update(W2, dW2, learning_rate, mu, vW2)
+   [b2, vb2] = sgd_nesterov::update(b2, db2, learning_rate, mu, vb2)
+   [W3, vW3] = sgd_nesterov::update(W3, dW3, learning_rate, mu, vW3)
+   [b3, vb3] = sgd_nesterov::update(b3, db3, learning_rate, mu, vb3)
+   [W4, vW4] = sgd_nesterov::update(W4, dW4, learning_rate, mu, vW4)
+   [b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
+
+   model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
+}
\ No newline at end of file
diff --git 
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml 
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
new file mode 100644
index 0000000..16c72c4
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
+source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
+
+# create federated input matrices
+features = federated(addresses=list($X0, $X1),
+    ranges=list(list(0, 0), list($examples_per_worker, $num_features),
+                list($examples_per_worker, 0), list($examples_per_worker * 2, 
$num_features)))
+
+labels = federated(addresses=list($y0, $y1),
+    ranges=list(list(0, 0), list($examples_per_worker, $num_labels),
+                list($examples_per_worker, 0), list($examples_per_worker * 2, 
$num_labels)))
+
+epochs = $epochs
+batch_size = $batch_size
+learning_rate = $eta
+utype = $utype
+freq = $freq
+network_type = $network_type
+
+# currently ignored parameters
+workers = 1
+scheme = "DISJOINT_CONTIGUOUS"
+paramserv_mode = "LOCAL"
+
+# config for the cnn
+channels = $channels
+hin = $hin
+win = $win
+
+if(network_type == "TwoNN") {
+  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), epochs, workers, utype, freq, batch_size, scheme, 
paramserv_mode, learning_rate)
+}
+else {
+  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), channels, hin, win, epochs, workers, utype, freq, 
batch_size, scheme, paramserv_mode, learning_rate)
+}
+print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml 
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
new file mode 100644
index 0000000..3bcfe84
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -0,0 +1,299 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a simple feed forward 
neural network
+ * on different execution schemes and with different inputs, for example a 
federated input matrix.
+ */
+
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd.dml") as sgd
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers single 
threaded the conventional way.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int epochs, int batch_size, double learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200)
+  [W2, b2] = affine::init(200, 200)
+  [W3, b3] = affine::init(200, K)
+  W3 = W3 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Create the hyper parameter list
+  hyperparams = list(learning_rate=learning_rate)
+  # Calculate iterations
+  iters = ceil(N / batch_size)
+  print_interval = floor(iters / 25)
+
+  print("[+] Starting optimization")
+  print("[+]  Learning rate: " + learning_rate)
+  print("[+]  Batch size: " + batch_size)
+  print("[+]  Iterations per epoch: " + iters + "\n")
+
+  for (e in 1:epochs) {
+    print("[+] Starting epoch: " + e)
+    print("|")
+    for(i in 1:iters) {
+      # Create the model list
+      model_list = list(W1, W2, W3, b1, b2, b3)
+
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      gradients_list = gradients(model_list, hyperparams, X_batch, y_batch)
+      model_updated = aggregation(model_list, hyperparams, gradients_list)
+
+      W1 = as.matrix(model_updated[1])
+      W2 = as.matrix(model_updated[2])
+      W3 = as.matrix(model_updated[3])
+      b1 = as.matrix(model_updated[4])
+      b2 = as.matrix(model_updated[5])
+      b3 = as.matrix(model_updated[6])
+
+      if((i %% print_interval) == 0) {
+        print("█")
+      }
+    }
+    print("|")
+  }
+
+  model_trained = list(W1, W2, W3, b1, b2, b3)
+}
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers
+ * using a parameter server with specified properties.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ *  - X: Input data matrix of shape (N, D)
+ *  - y: Target matrix of shape (N, K)
+ *  - X_val: Input validation data matrix of shape (N_val, D)
+ *  - y_val: Targed validation matrix of shape (N_val, K)
+ *  - epochs: Total number of full training loops over the full data set
+ *  - batch_size: Batch size
+ *  - learning_rate: The learning rate for the SGD
+ *  - workers: Number of workers to create
+ *  - utype: parameter server framework to use
+ *  - scheme: update schema
+ *  - mode: local or distributed
+ *
+ * Outputs:
+ *  - model_trained: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int epochs, int workers,
+                 string utype, string freq, int batch_size, string scheme, 
string mode, double learning_rate)
+    return (list[unknown] model_trained) {
+
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create the network:
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  [W1, b1] = affine::init(D, 200)
+  [W2, b2] = affine::init(200, 200)
+  [W3, b3] = affine::init(200, K)
+
+  # Create the model list
+  model_list = list(W1, W2, W3, b1, b2, b3)
+  # Create the hyper parameter list
+  params = list(learning_rate=learning_rate)
+  # Use paramserv function
+  model_trained = paramserv(model=model_list, features=X, labels=y, 
val_features=X_val, val_labels=y_val, 
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients", 
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation", 
mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size, 
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
+}
+
+/*
+ * Computes the class probability predictions of a simple feed forward neural 
network.
+ *
+ * Inputs:
+ *  - X: The input data matrix of shape (N, D)
+ *  - model: List containing
+ *       - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ *       - b1: 1st layer biases vector, of shape (200, 1)
+ *       - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ *       - b2: 2nd layer biases vector, of shape (200, 1)
+ *       - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ *       - b3: 3rd layer biases vector, of shape (K, 1)
+ *
+ * Outputs:
+ *  - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X,
+                   list[unknown] model)
+    return (matrix[double] probs) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  out1relu = relu::forward(affine::forward(X, W1, b1))
+  out2relu = relu::forward(affine::forward(out1relu, W2, b2))
+  probs = softmax::forward(affine::forward(out2relu, W3, b3))
+}
+
+/*
+ * Evaluates a simple feed forward neural network.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples.  The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ *  - probs: Class probabilities, of shape (N, K).
+ *  - y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+                     list[unknown] hyperparams,
+                     matrix[double] features,
+                     matrix[double] labels)
+    return (list[unknown] gradients) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+
+  # Compute forward pass
+  ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+  out1 = affine::forward(features, W1, b1)
+  out1relu = relu::forward(out1)
+  out2 = affine::forward(out1relu, W2, b2)
+  out2relu = relu::forward(out2)
+  out3 = affine::forward(out2relu, W3, b3)
+  probs = softmax::forward(out3)
+
+  # Compute loss & accuracy for training data
+  loss = cross_entropy_loss::forward(probs, labels)
+  accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+  print("[+] Completed forward pass on batch: train loss: " + loss + ", train 
accuracy: " + accuracy)
+
+  # Compute data backward pass
+  dprobs = cross_entropy_loss::backward(probs, labels)
+  dout3 = softmax::backward(dprobs, out3)
+  [dout2relu, dW3, db3] = affine::backward(dout3, out2relu, W3, b3)
+  dout2 = relu::backward(dout2relu, out2)
+  [dout1relu, dW2, db2] = affine::backward(dout2, out1relu, W2, b2)
+  dout1 = relu::backward(dout1relu, out1)
+  [dfeatures, dW1, db1] = affine::backward(dout1, features, W1, b1)
+
+  gradients = list(dW1, dW2, dW3, db1, db2, db3)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+                       list[unknown] hyperparams,
+                       list[unknown] gradients)
+    return (list[unknown] model_result) {
+
+  W1 = as.matrix(model[1])
+  W2 = as.matrix(model[2])
+  W3 = as.matrix(model[3])
+  b1 = as.matrix(model[4])
+  b2 = as.matrix(model[5])
+  b3 = as.matrix(model[6])
+  dW1 = as.matrix(gradients[1])
+  dW2 = as.matrix(gradients[2])
+  dW3 = as.matrix(gradients[3])
+  db1 = as.matrix(gradients[4])
+  db2 = as.matrix(gradients[5])
+  db3 = as.matrix(gradients[6])
+  learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+
+  # Optimize with SGD
+  W3 = sgd::update(W3, dW3, learning_rate)
+  b3 = sgd::update(b3, db3, learning_rate)
+  W2 = sgd::update(W2, dW2, learning_rate)
+  b2 = sgd::update(b2, db2, learning_rate)
+  W1 = sgd::update(W1, dW1, learning_rate)
+  b1 = sgd::update(b1, db1, learning_rate)
+
+  model_result = list(W1, W2, W3, b1, b2, b3)
+}
\ No newline at end of file

Reply via email to