This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new b6640d9  [SYSTEMDS-2550] Extended parameter server (validation 
function, stats)
b6640d9 is described below

commit b6640d93011e9bd1fa986c3e92ca7b2a9d8a276b
Author: Tobias Rieger <[email protected]>
AuthorDate: Sat Jan 30 22:08:50 2021 +0100

    [SYSTEMDS-2550] Extended parameter server (validation function, stats)
    
    Closes #1154.
---
 .../ParameterizedBuiltinFunctionExpression.java    |   3 +-
 .../java/org/apache/sysds/parser/Statement.java    |   2 +-
 .../paramserv/FederatedPSControlThread.java        |  94 +++++++++-------
 .../controlprogram/paramserv/LocalParamServer.java |  16 ++-
 .../runtime/controlprogram/paramserv/PSWorker.java |   1 -
 .../controlprogram/paramserv/ParamServer.java      | 125 +++++++++++++++++++--
 .../runtime/controlprogram/parfor/stat/Timing.java |  23 ++--
 .../cp/ParamservBuiltinCPInstruction.java          |  57 ++++++++--
 .../java/org/apache/sysds/utils/Statistics.java    |  66 ++++++++---
 .../paramserv/FederatedParamservTest.java          |  40 ++++---
 .../scripts/functions/federated/paramserv/CNN.dml  |  48 +++++---
 .../federated/paramserv/FederatedParamservTest.dml |  14 ++-
 .../functions/federated/paramserv/TwoNN.dml        |  18 ++-
 13 files changed, 372 insertions(+), 135 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 05bfc48..583c643 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -288,7 +288,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                //check for invalid parameters
                Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, 
Statement.PS_FEATURES, Statement.PS_LABELS,
                        Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, 
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
-                       Statement.PS_MODE, Statement.PS_UPDATE_TYPE, 
Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
+                       Statement.PS_VAL_FUN, Statement.PS_MODE, 
Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
                        Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, 
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
                        Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, 
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
                checkInvalidParameters(getOpCode(), getVarParams(), valid);
@@ -301,6 +301,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                checkDataValueType(true, fname, Statement.PS_VAL_LABELS, 
DataType.MATRIX, ValueType.FP64, conditional);
                checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
                checkDataValueType(false, fname, Statement.PS_AGGREGATION_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
+               checkDataValueType(true, fname, Statement.PS_VAL_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
                checkStringParam(true, fname, Statement.PS_MODE, conditional);
                checkStringParam(true, fname, Statement.PS_UPDATE_TYPE, 
conditional);
                checkStringParam(true, fname, Statement.PS_FREQUENCY, 
conditional);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java 
b/src/main/java/org/apache/sysds/parser/Statement.java
index 9104246..38d16cd 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -66,6 +66,7 @@ public abstract class Statement implements ParseInfo
        public static final String PS_LABELS = "labels";
        public static final String PS_VAL_FEATURES = "val_features";
        public static final String PS_VAL_LABELS = "val_labels";
+       public static final String PS_VAL_FUN = "val";
        public static final String PS_UPDATE_FUN = "upd";
        public static final String PS_AGGREGATION_FUN = "agg";
        public static final String PS_MODE = "mode";
@@ -117,7 +118,6 @@ public abstract class Statement implements ParseInfo
        public static final String PS_FED_NAMESPACE = "1701-NCC-namespace";
        public static final String PS_FED_GRADIENTS_FNAME = 
"1701-NCC-gradients_fname";
        public static final String PS_FED_AGGREGATION_FNAME = 
"1701-NCC-aggregation_fname";
-       public static final String PS_FED_BATCHCOUNTER_VARID = 
"1701-NCC-batchcounter_varid";
        public static final String PS_FED_MODEL_VARID = "1701-NCC-model_varid";
 
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 98bc91a..13e029c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -23,6 +23,7 @@ import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.parser.Statement.PSFrequency;
@@ -45,6 +46,7 @@ import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
@@ -53,9 +55,10 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.util.ProgramConverter;
+import org.apache.sysds.utils.Statistics;
 
 import java.util.ArrayList;
-import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.concurrent.Callable;
 import java.util.concurrent.Future;
@@ -69,16 +72,15 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
 
        private FederatedData _featuresData;
        private FederatedData _labelsData;
-       private final long _localStartBatchNumVarID;
        private final long _modelVarID;
 
        // runtime balancing
-       private PSRuntimeBalancing _runtimeBalancing;
+       private final PSRuntimeBalancing _runtimeBalancing;
        private int _numBatchesPerEpoch;
        private int _possibleBatchesPerLocalEpoch;
-       private boolean _weighing;
+       private final boolean _weighing;
        private double _weighingFactor = 1;
-       private boolean _cycleStartAt0 = false;
+       private final boolean _cycleStartAt0 = false;
 
        public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq,
                PSRuntimeBalancing runtimeBalancing, boolean weighing, int 
epochs, long batchSize,
@@ -89,8 +91,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                _numBatchesPerEpoch = numBatchesPerGlobalEpoch;
                _runtimeBalancing = runtimeBalancing;
                _weighing = weighing;
-               // generate the IDs for model and batch counter. These get 
overwritten on the federated worker each time
-               _localStartBatchNumVarID = FederationUtils.getNextFedDataID();
+               // generate the ID for the model
                _modelVarID = FederationUtils.getNextFedDataID();
        }
 
@@ -100,6 +101,8 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
         * @param weighingFactor Gradients from this worker will be multiplied 
by this factor if weighing is enabled
         */
        public void setup(double weighingFactor) {
+               incWorkerNumber();
+
                // prepare features and labels
                _featuresData = (FederatedData) 
_features.getFedMapping().getMap().values().toArray()[0];
                _labelsData = (FederatedData) 
_labels.getFedMapping().getMap().values().toArray()[0];
@@ -125,9 +128,11 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        _numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
                }
 
-               LOG.info("Setup config for worker " + this.getWorkerName());
-               LOG.info("Batch size: " + _batchSize + " possible batches: " + 
_possibleBatchesPerLocalEpoch
-                               + " batches to run: " + _numBatchesPerEpoch + " 
weighing factor: " + _weighingFactor);
+               if( LOG.isInfoEnabled() ) {
+                       LOG.info("Setup config for worker " + 
this.getWorkerName());
+                       LOG.info("Batch size: " + _batchSize + " possible 
batches: " + _possibleBatchesPerLocalEpoch
+                                       + " batches to run: " + 
_numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
+               }
 
                // serialize program
                // create program blocks for the instruction filtering
@@ -135,12 +140,12 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                ArrayList<ProgramBlock> pbs = new ArrayList<>();
 
                BasicProgramBlock gradientProgramBlock = new 
BasicProgramBlock(_ec.getProgram());
-               gradientProgramBlock.setInstructions(new 
ArrayList<>(Arrays.asList(_inst)));
+               gradientProgramBlock.setInstructions(new 
ArrayList<>(Collections.singletonList(_inst)));
                pbs.add(gradientProgramBlock);
 
                if(_freq == PSFrequency.EPOCH) {
                        BasicProgramBlock aggProgramBlock = new 
BasicProgramBlock(_ec.getProgram());
-                       aggProgramBlock.setInstructions(new 
ArrayList<>(Arrays.asList(_ps.getAggInst())));
+                       aggProgramBlock.setInstructions(new 
ArrayList<>(Collections.singletonList(_ps.getAggInst())));
                        pbs.add(aggProgramBlock);
                }
 
@@ -160,7 +165,6 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                        _inst.getFunctionName(),
                                        _ps.getAggInst().getFunctionName(),
                                        _ec.getListObject("hyperparams"),
-                                       _localStartBatchNumVarID,
                                        _modelVarID
                                )
                ));
@@ -188,12 +192,11 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                private final String _gradientsFunctionName;
                private final String _aggregationFunctionName;
                private final ListObject _hyperParams;
-               private final long _batchCounterVarID;
                private final long _modelVarID;
 
                protected SetupFederatedWorker(long batchSize, long dataSize, 
int possibleBatchesPerLocalEpoch,
                        String programString, String namespace, String 
gradientsFunctionName, String aggregationFunctionName,
-                       ListObject hyperParams, long batchCounterVarID, long 
modelVarID)
+                       ListObject hyperParams, long modelVarID)
                {
                        super(new long[]{});
                        _batchSize = batchSize;
@@ -204,7 +207,6 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        _gradientsFunctionName = gradientsFunctionName;
                        _aggregationFunctionName = aggregationFunctionName;
                        _hyperParams = hyperParams;
-                       _batchCounterVarID = batchCounterVarID;
                        _modelVarID = modelVarID;
                }
 
@@ -221,7 +223,6 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        ec.setVariable(Statement.PS_FED_GRADIENTS_FNAME, new 
StringObject(_gradientsFunctionName));
                        ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new 
StringObject(_aggregationFunctionName));
                        ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
-                       ec.setVariable(Statement.PS_FED_BATCHCOUNTER_VARID, new 
IntObject(_batchCounterVarID));
                        ec.setVariable(Statement.PS_FED_MODEL_VARID, new 
IntObject(_modelVarID));
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
@@ -272,7 +273,6 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        ec.removeVariable(Statement.PS_FED_NAMESPACE);
                        ec.removeVariable(Statement.PS_FED_GRADIENTS_FNAME);
                        ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
-                       ec.removeVariable(Statement.PS_FED_BATCHCOUNTER_VARID);
                        ec.removeVariable(Statement.PS_FED_MODEL_VARID);
                        ParamservUtils.cleanupListObject(ec, 
Statement.PS_HYPER_PARAMS);
                        
@@ -319,9 +319,10 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                return _ps.pull(_workerID);
        }
 
-       protected void scaleAndPushGradients(ListObject gradients) {
+       protected void weighAndPushGradients(ListObject gradients) {
                // scale gradients - must only include MatrixObjects
                if(_weighing && _weighingFactor != 1) {
+                       Timing tWeighing = DMLScript.STATISTICS ? new 
Timing(true) : null;
                        gradients.getData().parallelStream().forEach((matrix) 
-> {
                                MatrixObject matrixObject = (MatrixObject) 
matrix;
                                MatrixBlock input = 
matrixObject.acquireReadAndRelease().scalarOperations(
@@ -329,6 +330,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                matrixObject.acquireModify(input);
                                matrixObject.release();
                        });
+                       accFedPSGradientWeighingTime(tWeighing);
                }
 
                // Push the gradients to ps
@@ -350,12 +352,10 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                int localStartBatchNum = 
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
                                ListObject model = pullModel();
                                ListObject gradients = 
computeGradientsForNBatches(model, 1, localStartBatchNum);
-                               scaleAndPushGradients(gradients);
+                               weighAndPushGradients(gradients);
                                ParamservUtils.cleanupListObject(model);
                                ParamservUtils.cleanupListObject(gradients);
-                               LOG.info("[+] " + this.getWorkerName() + " 
completed BATCH " + localStartBatchNum);
                        }
-                       LOG.info("[+] " + this.getWorkerName() + " --- 
completed EPOCH " + epochCounter);
                }
        }
 
@@ -376,9 +376,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        // Pull the global parameters from ps
                        ListObject model = pullModel();
                        ListObject gradients = 
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, 
true);
-                       scaleAndPushGradients(gradients);
-
-                       LOG.info("[+] " + this.getWorkerName() + " --- 
completed EPOCH " + epochCounter);
+                       weighAndPushGradients(gradients);
                        ParamservUtils.cleanupListObject(model);
                        ParamservUtils.cleanupListObject(gradients);
                }
@@ -401,15 +399,13 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        protected ListObject computeGradientsForNBatches(ListObject model,
                int numBatchesToCompute, int localStartBatchNum, boolean 
localUpdate)
        {
-               // put local start batch num on federated worker
-               Future<FederatedResponse> putBatchCounterResponse = 
_featuresData.executeFederatedOperation(
-                       new FederatedRequest(RequestType.PUT_VAR, 
_localStartBatchNumVarID, new IntObject(localStartBatchNum)));
+               Timing tFedCommunication = DMLScript.STATISTICS ? new 
Timing(true) : null;
                // put current model on federated worker
                Future<FederatedResponse> putParamsResponse = 
_featuresData.executeFederatedOperation(
                        new FederatedRequest(RequestType.PUT_VAR, _modelVarID, 
model));
 
                try {
-                       if(!putParamsResponse.get().isSuccessful() || 
!putBatchCounterResponse.get().isSuccessful())
+                       if(!putParamsResponse.get().isSuccessful())
                                throw new 
DMLRuntimeException("FederatedLocalPSThread: put was not successful");
                }
                catch(Exception e) {
@@ -420,14 +416,22 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(
                        new FederatedRequest(RequestType.EXEC_UDF, 
_featuresData.getVarID(),
                                new federatedComputeGradientsForNBatches(new 
long[]{_featuresData.getVarID(), _labelsData.getVarID(),
-                               _localStartBatchNumVarID, _modelVarID}, 
numBatchesToCompute,localUpdate)
+                               _modelVarID}, numBatchesToCompute, localUpdate, 
localStartBatchNum)
                ));
 
                try {
                        Object[] responseData = udfResponse.get().getData();
+                       if(DMLScript.STATISTICS) {
+                               long total = (long) tFedCommunication.stop();
+                               long workerComputing = ((DoubleObject) 
responseData[1]).getLongValue();
+                               
Statistics.accFedPSWorkerComputing(workerComputing);
+                               Statistics.accFedPSCommunicationTime(total - 
workerComputing);
+                       }
                        return (ListObject) responseData[0];
                }
                catch(Exception e) {
+                       if(DMLScript.STATISTICS)
+                               tFedCommunication.stop();
                        throw new DMLRuntimeException("FederatedLocalPSThread: 
failed to execute UDF" + e.getMessage());
                }
        }
@@ -439,20 +443,22 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                private static final long serialVersionUID = 
-3075901536748794832L;
                int _numBatchesToCompute;
                boolean _localUpdate;
+               int _localStartBatchNum;
 
-               protected federatedComputeGradientsForNBatches(long[] inIDs, 
int numBatchesToCompute, boolean localUpdate) {
+               protected federatedComputeGradientsForNBatches(long[] inIDs, 
int numBatchesToCompute, boolean localUpdate, int localStartBatchNum) {
                        super(inIDs);
                        _numBatchesToCompute = numBatchesToCompute;
                        _localUpdate = localUpdate;
+                       _localStartBatchNum = localStartBatchNum;
                }
 
                @Override
                public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       Timing tGradients = new Timing(true);
                        // read in data by varid
                        MatrixObject features = (MatrixObject) data[0];
                        MatrixObject labels = (MatrixObject) data[1];
-                       int localStartBatchNum = (int) ((IntObject) 
data[2]).getLongValue();
-                       ListObject model = (ListObject) data[3];
+                       ListObject model = (ListObject) data[2];
 
                        // get data from execution context
                        long batchSize = ((IntObject) 
ec.getVariable(Statement.PS_FED_BATCH_SIZE)).getLongValue();
@@ -493,7 +499,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        }
 
                        ListObject accGradients = null;
-                       int currentLocalBatchNumber = localStartBatchNum;
+                       int currentLocalBatchNumber = _localStartBatchNum;
                        // prepare execution context
                        ec.setVariable(Statement.PS_MODEL, model);
                        for (int batchCounter = 0; batchCounter < 
_numBatchesToCompute; batchCounter++) {
@@ -534,14 +540,14 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                ParamservUtils.cleanupListObject(ec, 
gradientsOutput.getName());
                                ParamservUtils.cleanupData(ec, 
Statement.PS_FEATURES);
                                ParamservUtils.cleanupData(ec, 
Statement.PS_LABELS);
-                               
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
                        }
 
                        // model clean up
                        ParamservUtils.cleanupListObject(ec, 
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
                        ParamservUtils.cleanupListObject(ec, 
Statement.PS_MODEL);
-
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, accGradients);
+                       // stop timing
+                       DoubleObject gradientsTime = new 
DoubleObject(tGradients.stop());
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new 
Object[]{accGradients, gradientsTime});
                }
 
                @Override
@@ -551,6 +557,11 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        }
 
        // Statistics methods
+       protected void accFedPSGradientWeighingTime(Timing time) {
+               if (DMLScript.STATISTICS && time != null)
+                       Statistics.accFedPSGradientWeighingTime((long) 
time.stop());
+       }
+
        @Override
        public String getWorkerName() {
                return String.format("Federated worker_%d", _workerID);
@@ -558,21 +569,22 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
 
        @Override
        protected void incWorkerNumber() {
-
+               if (DMLScript.STATISTICS)
+                       Statistics.incWorkerNumber();
        }
 
        @Override
        protected void accLocalModelUpdateTime(Timing time) {
-
+               throw new NotImplementedException();
        }
 
        @Override
        protected void accBatchIndexingTime(Timing time) {
-
+               throw new NotImplementedException();
        }
 
        @Override
        protected void accGradientComputeTime(Timing time) {
-
+               throw new NotImplementedException();
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
index 29193b0..7bd96f2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.controlprogram.paramserv;
 
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 
@@ -30,12 +31,19 @@ public class LocalParamServer extends ParamServer {
                super();
        }
 
-       public static LocalParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
-               return new LocalParamServer(model, aggFunc, updateType, ec, 
workerNum);
+       public static LocalParamServer create(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
+               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
+               MatrixObject valFeatures, MatrixObject valLabels)
+       {
+               return new LocalParamServer(model, aggFunc, updateType, freq, 
ec,
+                       workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels);
        }
 
-       private LocalParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
-               super(model, aggFunc, updateType, ec, workerNum);
+       private LocalParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
+               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc, int numBatchesPerEpoch,
+               MatrixObject valFeatures, MatrixObject valLabels)
+       {
+               super(model, aggFunc, updateType, freq, ec, workerNum, valFunc, 
numBatchesPerEpoch, valFeatures, valLabels);
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 701e45c..c0389f3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -35,7 +35,6 @@ import 
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 
-// TODO use the validate features and labels to calculate the model precision 
when training
 public abstract class PSWorker implements Serializable 
 {
        private static final long serialVersionUID = -3510485051178200118L;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index e420ed8..6315ef9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -38,9 +38,11 @@ import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.utils.Statistics;
@@ -57,15 +59,30 @@ public abstract class ParamServer
        //aggregation service
        protected ExecutionContext _ec;
        private Statement.PSUpdateType _updateType;
+       private Statement.PSFrequency _freq;
 
        private FunctionCallCPInstruction _inst;
        private String _outputName;
        private boolean[] _finishedStates;  // Workers' finished states
        private ListObject _accGradients = null;
 
+       private boolean _validationPossible;
+       private FunctionCallCPInstruction _valInst;
+       private String _lossOutput;
+       private String _accuracyOutput;
+
+       private int _syncCounter = 0;
+       private int _epochCounter = 0 ;
+       private int _numBatchesPerEpoch;
+
+       private int _numWorkers;
+
        protected ParamServer() {}
 
-       protected ParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
+       protected ParamServer(ListObject model, String aggFunc, 
Statement.PSUpdateType updateType,
+               Statement.PSFrequency freq, ExecutionContext ec, int workerNum, 
String valFunc,
+               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels)
+       {
                // init worker queues and global model
                _modelMap = new HashMap<>(workerNum);
                IntStream.range(0, workerNum).forEach(i -> {
@@ -77,8 +94,15 @@ public abstract class ParamServer
                // init aggregation service
                _ec = ec;
                _updateType = updateType;
+               _freq = freq;
                _finishedStates = new boolean[workerNum];
                setupAggFunc(_ec, aggFunc);
+
+               if(valFunc != null && numBatchesPerEpoch > 0) {
+                       setupValFunc(_ec, valFunc, valFeatures, valLabels);
+               }
+               _numBatchesPerEpoch = numBatchesPerEpoch;
+               _numWorkers = workerNum;
                
                // broadcast initial model
                broadcastModel(true);
@@ -110,6 +134,39 @@ public abstract class ParamServer
                        func.getInputParamNames(), outputNames, "aggregate 
function");
        }
 
+       protected void setupValFunc(ExecutionContext ec, String valFunc, 
MatrixObject valFeatures, MatrixObject valLabels) {
+               String[] cfn = DMLProgram.splitFunctionKey(valFunc);
+               String ns = cfn[0];
+               String fname = cfn[1];
+               FunctionProgramBlock func = 
ec.getProgram().getFunctionProgramBlock(ns, fname, false);
+               ArrayList<DataIdentifier> inputs = func.getInputParams();
+               ArrayList<DataIdentifier> outputs = func.getOutputParams();
+
+               // Check the output of the validate function
+               if (outputs.size() != 2) {
+                       throw new DMLRuntimeException(String.format("The output 
of the '%s' function should provide the loss and the accuracy in that order", 
valFunc));
+               }
+               if (outputs.get(0).getDataType() != DataType.SCALAR || 
outputs.get(1).getDataType() != DataType.SCALAR) {
+                       throw new DMLRuntimeException(String.format("The 
outputs of the '%s' function should both be scalars", valFunc));
+               }
+               _lossOutput = outputs.get(0).getName();
+               _accuracyOutput = outputs.get(1).getName();
+
+               CPOperand[] boundInputs = inputs.stream()
+                       .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
+                       .toArray(CPOperand[]::new);
+               ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
+                       .collect(Collectors.toCollection(ArrayList::new));
+               _valInst = new FunctionCallCPInstruction(ns, fname, false, 
boundInputs,
+                       func.getInputParamNames(), outputNames, "validate 
function");
+
+               // write validation data to execution context. hyper params are 
already in ec
+               _ec.setVariable(Statement.PS_VAL_FEATURES, valFeatures);
+               _ec.setVariable(Statement.PS_VAL_LABELS, valLabels);
+
+               _validationPossible = true;
+       }
+
        public abstract void push(int workerID, ListObject value);
 
        public abstract ListObject pull(int workerID);
@@ -119,7 +176,7 @@ public abstract class ParamServer
                // so we could return directly the result model
                return _model;
        }
-       
+
        protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
                try {
                        if (LOG.isDebugEnabled()) {
@@ -143,6 +200,22 @@ public abstract class ParamServer
                                                        
updateGlobalModel(_accGradients);
                                                        _accGradients = null;
                                                }
+
+                                               // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
+                                               // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
+                                               // BSP epoch case every time
+                                               if ((_freq == 
Statement.PSFrequency.EPOCH ||
+                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+
+                                                       if(LOG.isInfoEnabled())
+                                                               LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
+
+                                                       if(_validationPossible)
+                                                               validate();
+
+                                                       _epochCounter++;
+                                                       _syncCounter = 0;
+                                               }
                                                
                                                // Broadcast the updated model
                                                resetFinishedStates();
@@ -154,6 +227,21 @@ public abstract class ParamServer
                                }
                                case ASP: {
                                        updateGlobalModel(gradients);
+                                       // This if works similarly to the one 
for BSP, but divides the sync couter through the number of workers,
+                                       // creating "Pseudo Epochs"
+                                       if ((_freq == 
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
+                                               (_freq == 
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) 
_numBatchesPerEpoch == 0)) {
+
+                                               if(LOG.isInfoEnabled())
+                                                       LOG.info("[+] 
PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
+
+                                               if(_validationPossible)
+                                                       validate();
+
+                                               _epochCounter++;
+                                               _syncCounter = 0;
+                                       }
+
                                        broadcastModel(workerID);
                                        break;
                                }
@@ -162,14 +250,14 @@ public abstract class ParamServer
                        }
                } 
                catch (Exception e) {
-                       throw new DMLRuntimeException("Aggregation service 
failed: ", e);
+                       throw new DMLRuntimeException("Aggregation or 
validation service failed: ", e);
                }
        }
 
        private void updateGlobalModel(ListObject gradients) {
                Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
                _model = updateLocalModel(_ec, gradients, _model);
-               if (DMLScript.STATISTICS)
+               if (DMLScript.STATISTICS && tAgg != null)
                        Statistics.accPSAggregationTime((long) tAgg.stop());
        }
 
@@ -226,14 +314,37 @@ public abstract class ParamServer
 
        private void broadcastModel(int workerID) throws InterruptedException {
                Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
-
                //broadcast copy of model to specific worker, cleaned up by 
worker
                _modelMap.get(workerID).put(ParamservUtils.copyList(_model, 
false));
-
-               if (DMLScript.STATISTICS)
+               if (DMLScript.STATISTICS && tBroad != null)
                        Statistics.accPSModelBroadcastTime((long) 
tBroad.stop());
        }
 
+       /**
+        * Checks the current model against the validation set
+        */
+       private synchronized void validate() {
+               Timing tValidate = DMLScript.STATISTICS ? new Timing(true) : 
null;
+               _ec.setVariable(Statement.PS_MODEL, _model);
+
+               // Invoke the validation function
+               _valInst.processInstruction(_ec);
+
+               // Get the validation results
+               double loss = ((DoubleObject) 
_ec.getVariable(_lossOutput)).getDoubleValue();
+               double accuracy = ((DoubleObject) 
_ec.getVariable(_accuracyOutput)).getDoubleValue();
+
+               // cleanup
+               ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
+
+               // Log validation results
+               if(LOG.isInfoEnabled())
+                       LOG.info("[+] PARAMSERV: validation-loss: " + loss + " 
validation-accuracy: " + accuracy);
+
+               if(tValidate != null)
+                       Statistics.accPSValidationTime((long) tValidate.stop());
+       }
+
        public FunctionCallCPInstruction getAggInst() {
                return _inst;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java
index aec117a..79830ce 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/Timing.java
@@ -26,27 +26,22 @@ package org.apache.sysds.runtime.controlprogram.parfor.stat;
  */
 public class Timing 
 {
-
-       
        private long _start = -1;
        
        public Timing() {
                //default constructor
        }
        
-       public Timing(boolean start)
-       {
+       public Timing(boolean start) {
                //init and start the timer
-               if( start ){
+               if( start )
                        start();
-               }
        }
        
        /**
         * Starts the time measurement.
         */
-       public void start()
-       {
+       public void start() {
                _start = System.nanoTime();
        }
        
@@ -56,16 +51,15 @@ public class Timing
         * 
         * @return duration between start and stop
         */
-       public double stop()
-       {
+       public double stop() {
                if( _start == -1 )
                        throw new RuntimeException("Stop time measurement 
without prior start is invalid.");
        
-               long end = System.nanoTime();           
+               long end = System.nanoTime();
                double duration = ((double)(end-_start))/1000000;
                
                //carry end time over
-               _start = end;           
+               _start = end;
                return duration;
        }
        
@@ -73,11 +67,8 @@ public class Timing
         * Measures and returns the time since the last start() or stop() 
invocation,
         * restarts the measurement, and prints the last measurement to STDOUT.
         */
-       public void stopAndPrint()
-       {
+       public void stopAndPrint() {
                double tmp = stop();
-               
                System.out.println("PARFOR: time = "+tmp+"ms");
        }
-       
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a66e039..785915d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -46,6 +46,9 @@ import static 
org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
 import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
 import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
 import static org.apache.sysds.parser.Statement.PS_SEED;
+import static org.apache.sysds.parser.Statement.PS_VAL_FEATURES;
+import static org.apache.sysds.parser.Statement.PS_VAL_LABELS;
+import static org.apache.sysds.parser.Statement.PS_VAL_FUN;
 
 import org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.commons.logging.Log;
@@ -123,12 +126,15 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private void runFederated(ExecutionContext ec) {
+               Timing tExecutionTime = DMLScript.STATISTICS ? new Timing(true) 
: null;
+               Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
                LOG.info("PARAMETER SERVER");
                LOG.info("[+] Running in federated mode");
 
                // get inputs
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
+               String valFunc = getValFunction();
                PSUpdateType updateType = getUpdateType();
                PSFrequency freq = getFrequency();
                FederatedPSScheme federatedPSScheme = getFederatedScheme();
@@ -144,17 +150,24 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        LOG.info("[+] Weighing: " + weighing);
                        LOG.info("[+] Seed: " + seed);
                }
-               
+               if (tSetup != null)
+                       Statistics.accPSSetupTime((long) tSetup.stop());
+
                // partition federated data
+               Timing tDataPartitioning = DMLScript.STATISTICS ? new 
Timing(true) : null;
                DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(federatedPSScheme, seed)
                        
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), 
ec.getMatrixObject(getParam(PS_LABELS)));
                int workerNum = result._workerNum;
+               if (DMLScript.STATISTICS)
+                       Statistics.accFedPSDataPartitioningTime((long) 
tDataPartitioning.stop());
 
+
+               if (DMLScript.STATISTICS)
+                       tSetup.start();
                // setup threading
                BasicThreadFactory factory = new BasicThreadFactory.Builder()
                        .namingPattern("workers-pool-thread-%d").build();
                ExecutorService es = Executors.newFixedThreadPool(workerNum, 
factory);
-
                // Get the compiled execution context
                LocalVariableMap newVarsMap = createVarsMap(ec);
                // Level of par is -1 so each federated worker can scale to its 
cpu cores
@@ -165,24 +178,25 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                ExecutionContext aggServiceEC = 
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
                // Create the parameter server
                ListObject model = ec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, workerNum, model, aggServiceEC);
+               ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, freq, workerNum, model, aggServiceEC, valFunc,
+                               getNumBatchesPerEpoch(runtimeBalancing, 
result._balanceMetrics), ec.getMatrixObject(getParam(PS_VAL_FEATURES)), 
ec.getMatrixObject(getParam(PS_VAL_LABELS)));
                // Create the local workers
                int finalNumBatchesPerEpoch = 
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
                List<FederatedPSControlThread> threads = IntStream.range(0, 
workerNum)
                        .mapToObj(i -> new FederatedPSControlThread(i, updFunc, 
freq, runtimeBalancing, weighing,
                                getEpochs(), getBatchSize(), 
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
                        .collect(Collectors.toList());
-
                if(workerNum != threads.size()) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning 
does not match threads!");
                }
-
                // Set features and lables for the control threads and write 
the program and instructions and hyperparams to the federated workers
                for (int i = 0; i < threads.size(); i++) {
                        threads.get(i).setFeatures(result._pFeatures.get(i));
                        threads.get(i).setLabels(result._pLabels.get(i));
                        threads.get(i).setup(result._weighingFactors.get(i));
                }
+               if (DMLScript.STATISTICS)
+                       Statistics.accPSSetupTime((long) tSetup.stop());
 
                try {
                        // Launch the worker threads and wait for completion
@@ -190,6 +204,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                                ret.get(); //error handling
                        // Fetch the final model from ps
                        ec.setVariable(output.getName(), ps.getResult());
+                       if (DMLScript.STATISTICS)
+                               Statistics.accPSExecutionTime((long) 
tExecutionTime.stop());
                } catch (InterruptedException | ExecutionException e) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
                } finally {
@@ -215,7 +231,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Create the parameter server
                ListObject model = sec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(mode, aggFunc, getUpdateType(), 
workerNum, model, aggServiceEC);
+               ParamServer ps = createPS(mode, aggFunc, getUpdateType(), 
getFrequency(), workerNum, model, aggServiceEC);
 
                // Get driver host
                String host = 
sec.getSparkContext().getConf().get("spark.driver.host");
@@ -299,7 +315,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Create the parameter server
                ListObject model = ec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(mode, aggFunc, updateType, workerNum, 
model, aggServiceEC);
+               ParamServer ps = createPS(mode, aggFunc, updateType, freq, 
workerNum, model, aggServiceEC);
 
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
@@ -436,14 +452,24 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         *
         * @return parameter server
         */
-       private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
+       private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
+               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec)
+       {
+               return createPS(mode, aggFunc, updateType, freq, workerNum, 
model, ec, null, -1, null, null);
+       }
+
+       // When this creation is used the parameter server is able to validate 
after each epoch
+       private static ParamServer createPS(PSModeType mode, String aggFunc, 
PSUpdateType updateType,
+               PSFrequency freq, int workerNum, ListObject model, 
ExecutionContext ec, String valFunc,
+               int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject 
valLabels)
+       {
                switch (mode) {
                        case FEDERATED:
                        case LOCAL:
                        case REMOTE_SPARK:
-                               return LocalParamServer.create(model, aggFunc, 
updateType, ec, workerNum);
+                               return LocalParamServer.create(model, aggFunc, 
updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, 
valLabels);
                        default:
-                               throw new DMLRuntimeException("Unsupported 
parameter server: "+mode.name());
+                               throw new DMLRuntimeException("Unsupported 
parameter server: " + mode.name());
                }
        }
 
@@ -518,7 +544,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         * @return numBatchesPerEpoch
         */
        private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, 
DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
-               int numBatchesPerEpoch = 0;
+               int numBatchesPerEpoch;
                if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
                        numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
                } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
@@ -526,6 +552,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
                } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
                        numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._maxRows / (float) getBatchSize());
+               } else {
+                       numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
                }
                return numBatchesPerEpoch;
        }
@@ -534,6 +562,13 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                return getParameterMap().containsKey(PS_FED_WEIGHING) && 
Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
        }
 
+       private String getValFunction() {
+               if (getParameterMap().containsKey(PS_VAL_FUN)) {
+                       return getParam(PS_VAL_FUN);
+               }
+               return null;
+       }
+
        private int getSeed() {
                return (getParameterMap().containsKey(PS_SEED)) ? 
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
        }
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index b9059d9..b40e905 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -117,6 +117,7 @@ public class Statistics
        private static final LongAdder sparkBroadcastCount = new LongAdder();
 
        // Paramserv function stats (time is in milli sec)
+       private static final LongAdder psExecutionTime = new LongAdder();
        private static final LongAdder psNumWorkers = new LongAdder();
        private static final LongAdder psSetupTime = new LongAdder();
        private static final LongAdder psGradientComputeTime = new LongAdder();
@@ -125,6 +126,12 @@ public class Statistics
        private static final LongAdder psModelBroadcastTime = new LongAdder();
        private static final LongAdder psBatchIndexTime = new LongAdder();
        private static final LongAdder psRpcRequestTime = new LongAdder();
+       private static final LongAdder psValidationTime = new LongAdder();
+       // Federated parameter server specifics (time is in milli sec)
+       private static final LongAdder fedPSDataPartitioningTime = new 
LongAdder();
+       private static final LongAdder fedPSWorkerComputingTime = new 
LongAdder();
+       private static final LongAdder fedPSGradientWeighingTime = new 
LongAdder();
+       private static final LongAdder fedPSCommunicationTime = new LongAdder();
 
        //PARFOR optimization stats (low frequency updates)
        private static long parforOptTime = 0; //in milli sec
@@ -562,6 +569,10 @@ public class Statistics
                psNumWorkers.add(n);
        }
 
+       public static void accPSExecutionTime(long n) {
+               psExecutionTime.add(n);
+       }
+
        public static void accPSSetupTime(long t) {
                psSetupTime.add(t);
        }
@@ -590,6 +601,24 @@ public class Statistics
                psRpcRequestTime.add(t);
        }
 
+       public static void accPSValidationTime(long t) {
+               psValidationTime.add(t);
+       }
+
+       public static void accFedPSDataPartitioningTime(long t) {
+               fedPSDataPartitioningTime.add(t);
+       }
+
+       public static void accFedPSWorkerComputing(long t) {
+               fedPSWorkerComputingTime.add(t);
+       }
+
+       public static void accFedPSGradientWeighingTime(long t) {
+               fedPSGradientWeighingTime.add(t);
+       }
+
+       public static void accFedPSCommunicationTime(long t) { 
fedPSCommunicationTime.add(t);}
+
        public static String getCPHeavyHitterCode( Instruction inst )
        {
                String opcode = null;
@@ -758,13 +787,13 @@ public class Statistics
                                if(wrapIter == 0) {
                                        // Display instruction count
                                        sb.append(String.format(
-                                                       " %" + maxNumLen + "d  
%-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "d",
-                                                       (i + 1), instStr, 
timeSString, count));
+                                               " %" + maxNumLen + "d  %-" + 
maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "d",
+                                               (i + 1), instStr, timeSString, 
count));
                                }
                                else {
                                        sb.append(String.format(
-                                                       " %" + maxNumLen + "s  
%-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "s",
-                                                       "", instStr, "", ""));
+                                               " %" + maxNumLen + "s  %-" + 
maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "s",
+                                               "", instStr, "", ""));
                                }
                                sb.append("\n");
                        }
@@ -795,8 +824,8 @@ public class Statistics
 
                maxNameLength = Math.max(maxNameLength, "Object".length());
                StringBuilder res = new StringBuilder();
-               res.append(String.format("  %-" + numPadLen + "s" + "  %-" + 
maxNameLength + "s" + "  %s\n",
-                               "#", "Object", "Memory"));
+               res.append(String.format("  %-" + numPadLen + "s" + "  %-" 
+                       + maxNameLength + "s" + "  %s\n", "#", "Object", 
"Memory"));
 
                for (int ix = 1; ix <= numHittersToDisplay; ix++) {
                        String objName = entries[ix-1].getKey();
@@ -831,8 +860,7 @@ public class Statistics
        public static long getJITCompileTime(){
                long ret = -1; //unsupported
                CompilationMXBean cmx = 
ManagementFactory.getCompilationMXBean();
-               if( cmx.isCompilationTimeMonitoringSupported() )
-               {
+               if( cmx.isCompilationTimeMonitoringSupported() ) {
                        ret = cmx.getTotalCompilationTime();
                        ret += jitCompileTime; //add from remote processes
                }
@@ -1011,14 +1039,26 @@ public class Statistics
                                                                
sparkCollect.longValue()*1e-9));
                        }
                        if (psNumWorkers.longValue() > 0) {
+                               sb.append(String.format("Paramserv total 
execution time:\t%.3f secs.\n", psExecutionTime.doubleValue() / 1000));
                                sb.append(String.format("Paramserv total num 
workers:\t%d.\n", psNumWorkers.longValue()));
                                sb.append(String.format("Paramserv setup 
time:\t\t%.3f secs.\n", psSetupTime.doubleValue() / 1000));
-                               sb.append(String.format("Paramserv grad compute 
time:\t%.3f secs.\n", psGradientComputeTime.doubleValue() / 1000));
-                               sb.append(String.format("Paramserv model update 
time:\t%.3f/%.3f secs.\n",
+
+                               if(fedPSDataPartitioningTime.doubleValue() > 0) 
{       //if data partitioning happens this is the federated case
+                                       sb.append(String.format("PS fed data 
partitioning time:\t%.3f secs.\n", fedPSDataPartitioningTime.doubleValue() / 
1000));
+                                       sb.append(String.format("PS fed comm 
time (cum):\t\t%.3f secs.\n", fedPSCommunicationTime.doubleValue() / 1000));
+                                       sb.append(String.format("PS fed worker 
comp time (cum):\t%.3f secs.\n", fedPSWorkerComputingTime.doubleValue() / 
1000));
+                                       sb.append(String.format("PS fed grad 
weigh time (cum):\t%.3f secs.\n", fedPSGradientWeighingTime.doubleValue() / 
1000));
+                                       sb.append(String.format("PS fed global 
model agg time:\t%.3f secs.\n", psAggregationTime.doubleValue() / 1000));
+                               }
+                               else {
+                                       sb.append(String.format("Paramserv grad 
compute time:\t%.3f secs.\n", psGradientComputeTime.doubleValue() / 1000));
+                                       sb.append(String.format("Paramserv 
model update time:\t%.3f/%.3f secs.\n",
                                                
psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() / 
1000));
-                               sb.append(String.format("Paramserv model 
broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000));
-                               sb.append(String.format("Paramserv batch slice 
time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
-                               sb.append(String.format("Paramserv RPC request 
time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000));
+                                       sb.append(String.format("Paramserv 
model broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 
1000));
+                                       sb.append(String.format("Paramserv 
batch slice time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
+                                       sb.append(String.format("Paramserv RPC 
request time:\t%.3f secs.\n", psRpcRequestTime.doubleValue() / 1000));
+                               }
+                               sb.append(String.format("Paramserv valdiation 
time:\t%.3f secs.\n", psValidationTime.doubleValue() / 1000));
                        }
                        if( parforOptCount>0 ){
                                sb.append("ParFor loops optimized:\t\t" + 
getParforOptCount() + ".\n");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index a00e8dc..5d7c7e2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -64,29 +64,32 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                return Arrays.asList(new Object[][] {
                        // Network type, number of federated workers, data set 
size, batch size, epochs, learning rate, update type, update frequency
                        // basic functionality
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   200},
+                       //{"TwoNN",     4, 60000, 32, 4, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "false","BALANCED",             
200},
+
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,             "true", "IMBALANCED",  
 200},
                        {"CNN",         2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "SHUFFLE",                              "NONE" ,                
"true", "IMBALANCED",   200},
                        {"CNN",         2, 4, 1, 4, 0.01,               "ASP", 
"BATCH", "REPLICATE_TO_MAX",     "RUN_MIN" ,     "true", "IMBALANCED",   200},
                        {"TwoNN",       2, 4, 1, 4, 0.01,               "ASP", 
"EPOCH", "BALANCE_TO_AVG",               "CYCLE_MAX" ,   "true", "IMBALANCED",  
 200},
                        {"TwoNN",       5, 1000, 100, 2, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "true", "BALANCED",             
200},
 
-                       /* // runtime balancing
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   200},
-
-                       // data partitioning
-                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "SHUFFLE",                              "CYCLE_AVG" ,   "true", 
"IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "REPLICATE_TO_MAX",             "NONE" ,                "true", 
"IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "SUBSAMPLE_TO_MIN",             "NONE" ,                "true", 
"IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "BALANCE_TO_AVG",               "NONE" ,                "true", 
"IMBALANCED",   200},
-
-                       // balanced tests
-                       {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", "EPOCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "true", "BALANCED",             
200} */
-
+                       /*
+                               // runtime balancing
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   
200},
+
+                               // data partitioning
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SHUFFLE",                              "CYCLE_AVG" ,   "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "REPLICATE_TO_MAX",             "NONE" ,                "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SUBSAMPLE_TO_MIN",             "NONE" ,                "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "BALANCE_TO_AVG",               "NONE" ,                "true", 
"IMBALANCED",   200},
+
+                               // balanced tests
+                               {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "NONE" ,                "true", "BALANCED",    
         200}
+                       */
                });
        }
 
@@ -125,6 +128,7 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
        }
 
        private void federatedParamserv(ExecMode mode) {
+               // Warning Statistics accumulate in unit test
                // config
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml 
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 0f9ae63..79628ef 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -65,13 +65,10 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
  *       - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
  *       - b4: 4th layer biases vector, of shape (1, K)
  */
-train = function(matrix[double] X, matrix[double] y,
-                 matrix[double] X_val, matrix[double] y_val,
-                 int epochs, int batch_size, double eta,
-                 int C, int Hin, int Win,
-                 int seed = -1)
-    return (list[unknown] model) {
-
+train = function(matrix[double] X, matrix[double] y, matrix[double] X_val,
+  matrix[double] y_val, int epochs, int batch_size, double eta, int C, int Hin,
+       int Win, int seed = -1) return (list[unknown] model)
+{
   N = nrow(X)
   K = ncol(y)
 
@@ -162,12 +159,11 @@ train = function(matrix[double] X, matrix[double] y,
  *       - b4: 4th layer biases vector, of shape (1, K)
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
-                 matrix[double] X_val, matrix[double] y_val,
-                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing, string weighing,
-                 double eta, int C, int Hin, int Win,
-                 int seed = -1)
-    return (list[unknown] model) {
-
+  matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
+  string utype, string freq, int batch_size, string scheme, string 
runtime_balancing,
+  string weighing, double eta, int C, int Hin, int Win, int seed = -1)
+  return (list[unknown] model)
+{
   N = nrow(X)
   K = ncol(y)
 
@@ -210,6 +206,7 @@ train_paramserv = function(matrix[double] X, matrix[double] 
y,
   model = paramserv(model=model, features=X, labels=y, val_features=X_val, 
val_labels=y_val,
     upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
     
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
+    val="./src/test/scripts/functions/federated/paramserv/CNN.dml::validate",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
     scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, 
hyperparams=hyperparams, seed=seed)
 }
@@ -267,7 +264,7 @@ predict = function(matrix[double] X, int C, int Hin, int 
Win, int batch_size, li
   # Compute predictions over mini-batches
   probs = matrix(0, rows=N, cols=K)
   iters = ceil(N / batch_size)
-  parfor(i in 1:iters, check=0) {
+  for(i in 1:iters, check=0) {
     # Get next batch
     beg = ((i-1) * batch_size) %% N + 1
     end = min(N, beg + batch_size - 1)
@@ -320,6 +317,25 @@ eval = function(matrix[double] probs, matrix[double] y)
   accuracy = mean(correct_pred)
 }
 
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for 
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels, 
+  list[unknown] model, list[unknown] hyperparams) 
+       return (double loss, double accuracy)
+{
+  [loss, accuracy] = eval(predict(val_features, 
as.integer(as.scalar(hyperparams["C"])),
+    as.integer(as.scalar(hyperparams["Hin"])), 
as.integer(as.scalar(hyperparams["Win"])), 
+    32, model), val_labels)
+}
+
 # Should always use 'features' (batch features), 'labels' (batch labels),
 # 'hyperparams', 'model' as the arguments
 # and return the gradients of type list
@@ -371,7 +387,7 @@ gradients = function(list[unknown] model,
   # Compute loss & accuracy for training data
   loss = cross_entropy_loss::forward(probs, labels)
   accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
-  print("[+] Completed forward pass on batch: train loss: " + loss + ", train 
accuracy: " + accuracy)
+  # print("[+] Completed forward pass on batch: train loss: " + loss + ", 
train accuracy: " + accuracy)
 
   # Compute data backward pass
   ## loss
@@ -452,4 +468,4 @@ aggregation = function(list[unknown] model,
    [b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
 
    model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, 
vb1, vb2, vb3, vb4)
-}
\ No newline at end of file
+}
diff --git 
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml 
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 5176cca..c7ad305 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -26,12 +26,16 @@ 
source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
 features = read($features)
 labels = read($labels)
 
-print($weighing)
-
 if($network_type == "TwoNN") {
-  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $weighing, $eta, $seed)
+  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $seed)
+  print("Test results:")
+  [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784), 
matrix(0, rows=100, cols=10), model, list())
+  print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + 
"\n")
 }
 else {
-  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $weighing, $eta, $channels, $hin, $win, $seed)
+  model = CNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $channels, $hin, 
$win, $seed)
+  print("Test results:")
+  hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+  [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784), 
matrix(0, rows=100, cols=10), model, hyperparams)
+  print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + 
"\n")
 }
-print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml 
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index a6dc6f2..e7fc6d9 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -154,6 +154,7 @@ train_paramserv = function(matrix[double] X, matrix[double] 
y,
   model = paramserv(model=model, features=X, labels=y, val_features=X_val, 
val_labels=y_val,
     
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
     
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
+    val="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::validate",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
     scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, 
hyperparams=hyperparams, seed=seed)
 }
@@ -214,6 +215,21 @@ eval = function(matrix[double] probs, matrix[double] y)
   accuracy = mean(correct_pred)
 }
 
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for 
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ *  - loss: Scalar loss, of shape (1).
+ *  - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels, 
list[unknown] model, list[unknown] hyperparams)
+    return (double loss, double accuracy) {
+  [loss, accuracy] = eval(predict(val_features, model), val_labels)
+}
+
 # Should always use 'features' (batch features), 'labels' (batch labels),
 # 'hyperparams', 'model' as the arguments
 # and return the gradients of type list
@@ -242,7 +258,7 @@ gradients = function(list[unknown] model,
   # Compute loss & accuracy for training data
   loss = cross_entropy_loss::forward(probs, labels)
   accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
-  print("[+] Completed forward pass on batch: train loss: " + loss + ", train 
accuracy: " + accuracy)
+  # print("[+] Completed forward pass on batch: train loss: " + loss + ", 
train accuracy: " + accuracy)
 
   # Compute data backward pass
   dprobs = cross_entropy_loss::backward(probs, labels)

Reply via email to