This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2931f6e  [SYSTEMDS-2550] Improved parameter server epoch timing/logging
2931f6e is described below

commit 2931f6ec82798e4e71281ac8ca9cc47a55381266
Author: Tobias Rieger <[email protected]>
AuthorDate: Sat Feb 20 18:13:41 2021 +0100

    [SYSTEMDS-2550] Improved parameter server epoch timing/logging
    
    Closes #1176.
---
 .../ParameterizedBuiltinFunctionExpression.java    |  4 +-
 .../java/org/apache/sysds/parser/Statement.java    |  5 +--
 .../paramserv/FederatedPSControlThread.java        | 44 +++++++++----------
 .../controlprogram/paramserv/ParamServer.java      | 37 +++++++++++++---
 .../paramserv/dp/BalanceToAvgFederatedScheme.java  |  4 +-
 .../paramserv/dp/DataPartitionFederatedScheme.java | 14 +++----
 .../dp/KeepDataOnWorkerFederatedScheme.java        |  4 +-
 .../dp/ReplicateToMaxFederatedScheme.java          |  4 +-
 .../paramserv/dp/ShuffleFederatedScheme.java       |  4 +-
 .../dp/SubsampleToMinFederatedScheme.java          |  4 +-
 .../cp/ParamservBuiltinCPInstruction.java          | 49 ++++++++++++++--------
 .../java/org/apache/sysds/utils/Statistics.java    | 22 ++++++++--
 .../paramserv/FederatedParamservTest.java          | 40 +++++++++---------
 .../scripts/functions/federated/paramserv/CNN.dml  |  4 +-
 .../federated/paramserv/FederatedParamservTest.dml |  4 +-
 .../functions/federated/paramserv/TwoNN.dml        |  4 +-
 16 files changed, 150 insertions(+), 97 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 583c643..4d111b0 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -290,7 +290,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                        Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, 
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
                        Statement.PS_VAL_FUN, Statement.PS_MODE, 
Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
                        Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, 
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
-                       Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, 
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
+                       Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, 
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
                checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
                // check existence and correctness of parameters
@@ -310,7 +310,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                checkDataValueType(true, fname, Statement.PS_PARALLELISM, 
DataType.SCALAR, ValueType.INT64, conditional);
                checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
                checkStringParam(true, fname, 
Statement.PS_FED_RUNTIME_BALANCING, conditional);
-               checkStringParam(true, fname, Statement.PS_FED_WEIGHING, 
conditional);
+               checkStringParam(true, fname, Statement.PS_FED_WEIGHTING, 
conditional);
                checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, 
DataType.LIST, ValueType.UNKNOWN, conditional);
                checkStringParam(true, fname, Statement.PS_CHECKPOINTING, 
conditional);
                checkDataValueType(true, fname, Statement.PS_SEED, 
DataType.SCALAR, ValueType.INT64, conditional);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java 
b/src/main/java/org/apache/sysds/parser/Statement.java
index eb32865..d15fd44 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -89,10 +89,10 @@ public abstract class Statement implements ParseInfo
        public enum PSFrequency {
                BATCH, EPOCH
        }
-       public static final String PS_FED_WEIGHING = "weighing";
+       public static final String PS_FED_WEIGHTING = "weighting";
        public static final String PS_FED_RUNTIME_BALANCING = 
"runtime_balancing";
        public enum PSRuntimeBalancing {
-               NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
+               NONE, BASELINE, CYCLE_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
        }
        public static final String PS_EPOCHS = "epochs";
        public static final String PS_BATCH_SIZE = "batchsize";
@@ -101,7 +101,6 @@ public abstract class Statement implements ParseInfo
        public enum PSScheme {
                DISJOINT_CONTIGUOUS, DISJOINT_ROUND_ROBIN, DISJOINT_RANDOM, 
OVERLAP_RESHUFFLE
        }
-       public static final String PS_FED_SCHEME = "fed_scheme";
        public enum FederatedPSScheme {
                KEEP_DATA_ON_WORKER, SHUFFLE, REPLICATE_TO_MAX, 
SUBSAMPLE_TO_MIN, BALANCE_TO_AVG
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 13e029c..10fee56 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -78,19 +78,19 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        private final PSRuntimeBalancing _runtimeBalancing;
        private int _numBatchesPerEpoch;
        private int _possibleBatchesPerLocalEpoch;
-       private final boolean _weighing;
-       private double _weighingFactor = 1;
-       private final boolean _cycleStartAt0 = false;
+       private final boolean _weighting;
+       private double _weightingFactor = 1;
+       private boolean _cycleStartAt0 = false;
 
        public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq,
-               PSRuntimeBalancing runtimeBalancing, boolean weighing, int 
epochs, long batchSize,
+               PSRuntimeBalancing runtimeBalancing, boolean weighting, int 
epochs, long batchSize,
                int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer 
ps)
        {
                super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
 
                _numBatchesPerEpoch = numBatchesPerGlobalEpoch;
                _runtimeBalancing = runtimeBalancing;
-               _weighing = weighing;
+               _weighting = weighting;
                // generate the ID for the model
                _modelVarID = FederationUtils.getNextFedDataID();
        }
@@ -98,40 +98,42 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        /**
         * Sets up the federated worker and control thread
         *
-        * @param weighingFactor Gradients from this worker will be multiplied 
by this factor if weighing is enabled
+        * @param weightingFactor Gradients from this worker will be multiplied 
by this factor if weighting is enabled
         */
-       public void setup(double weighingFactor) {
+       public void setup(double weightingFactor) {
                incWorkerNumber();
 
                // prepare features and labels
                _featuresData = (FederatedData) 
_features.getFedMapping().getMap().values().toArray()[0];
                _labelsData = (FederatedData) 
_labels.getFedMapping().getMap().values().toArray()[0];
 
-               // weighing factor is always set, but only used when weighing 
is specified
-               _weighingFactor = weighingFactor;
+               // weighting factor is always set, but only used when weighting 
is specified
+               _weightingFactor = weightingFactor;
 
                // different runtime balancing calculations
                long dataSize = _features.getNumRows();
 
                // calculate scaled batch size if balancing via batch size.
                // In some cases there will be some cycling
-               if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+               if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH)
                        _batchSize = (int) Math.ceil((double) dataSize / 
_numBatchesPerEpoch);
-               }
 
                // Calculate possible batches with batch size
                _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) 
dataSize / _batchSize);
 
                // If no runtime balancing is specified, just run possible 
number of batches
                // WARNING: Will get stuck on miss match
-               if(_runtimeBalancing == PSRuntimeBalancing.NONE) {
+               if(_runtimeBalancing == PSRuntimeBalancing.NONE)
                        _numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
-               }
+
+               // If running in baseline mode set cycle to false
+               if(_runtimeBalancing == PSRuntimeBalancing.BASELINE)
+                       _cycleStartAt0 = true;
 
                if( LOG.isInfoEnabled() ) {
                        LOG.info("Setup config for worker " + 
this.getWorkerName());
                        LOG.info("Batch size: " + _batchSize + " possible 
batches: " + _possibleBatchesPerLocalEpoch
-                                       + " batches to run: " + 
_numBatchesPerEpoch + " weighing factor: " + _weighingFactor);
+                                       + " batches to run: " + 
_numBatchesPerEpoch + " weighting factor: " + _weightingFactor);
                }
 
                // serialize program
@@ -321,16 +323,16 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
 
        protected void weighAndPushGradients(ListObject gradients) {
                // scale gradients - must only include MatrixObjects
-               if(_weighing && _weighingFactor != 1) {
-                       Timing tWeighing = DMLScript.STATISTICS ? new 
Timing(true) : null;
+               if(_weighting && _weightingFactor != 1) {
+                       Timing tWeighting = DMLScript.STATISTICS ? new 
Timing(true) : null;
                        gradients.getData().parallelStream().forEach((matrix) 
-> {
                                MatrixObject matrixObject = (MatrixObject) 
matrix;
                                MatrixBlock input = 
matrixObject.acquireReadAndRelease().scalarOperations(
-                                       new 
RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new 
MatrixBlock());
+                                       new 
RightScalarOperator(Multiply.getMultiplyFnObject(), _weightingFactor), new 
MatrixBlock());
                                matrixObject.acquireModify(input);
                                matrixObject.release();
                        });
-                       accFedPSGradientWeighingTime(tWeighing);
+                       accFedPSGradientWeightingTime(tWeighting);
                }
 
                // Push the gradients to ps
@@ -342,7 +344,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        }
 
        /**
-        * Computes all epochs and updates after each batch
+        * Computes all epochs and updates after each batch 
         */
        protected void computeWithBatchUpdates() {
                for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
@@ -557,9 +559,9 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
        }
 
        // Statistics methods
-       protected void accFedPSGradientWeighingTime(Timing time) {
+       protected void accFedPSGradientWeightingTime(Timing time) {
                if (DMLScript.STATISTICS && time != null)
-                       Statistics.accFedPSGradientWeighingTime((long) 
time.stop());
+                       Statistics.accFedPSGradientWeightingTime((long) 
time.stop());
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 6315ef9..4fe072c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -98,7 +98,7 @@ public abstract class ParamServer
                _finishedStates = new boolean[workerNum];
                setupAggFunc(_ec, aggFunc);
 
-               if(valFunc != null && numBatchesPerEpoch > 0) {
+               if(valFunc != null && numBatchesPerEpoch > 0 && valFeatures != 
null && valLabels != null) {
                        setupValFunc(_ec, valFunc, valFeatures, valLabels);
                }
                _numBatchesPerEpoch = numBatchesPerEpoch;
@@ -204,12 +204,15 @@ public abstract class ParamServer
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
-                                               if ((_freq == 
Statement.PSFrequency.EPOCH ||
+                                               if (_numBatchesPerEpoch != -1 &&
+                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
                                                        (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
+                                                       time_epoch();
+
                                                        if(_validationPossible)
                                                                validate();
 
@@ -229,12 +232,15 @@ public abstract class ParamServer
                                        updateGlobalModel(gradients);
                                        // This if works similarly to the one 
for BSP, but divides the sync couter through the number of workers,
                                        // creating "Pseudo Epochs"
-                                       if ((_freq == 
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
-                                               (_freq == 
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) 
_numBatchesPerEpoch == 0)) {
+                                       if (_numBatchesPerEpoch != -1 &&
+                                               ((_freq == 
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
+                                               (_freq == 
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) 
_numBatchesPerEpoch == 0))) {
 
                                                if(LOG.isInfoEnabled())
                                                        LOG.info("[+] 
PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
 
+                                               time_epoch();
+
                                                if(_validationPossible)
                                                        validate();
 
@@ -321,9 +327,28 @@ public abstract class ParamServer
        }
 
        /**
+        * Prints the time the epoch took to complete
+        */
+       private void time_epoch() {
+               if (DMLScript.STATISTICS) {
+                       //TODO double check correctness with multiple, 
potentially concurrent paramserv invocation
+                       Statistics.accPSExecutionTime((long) 
Statistics.getPSExecutionTimer().stop());
+                       double current_total_execution_time = 
Statistics.getPSExecutionTime();
+                       double current_total_validation_time = 
Statistics.getPSValidationTime();
+                       double time_to_epoch = current_total_execution_time - 
current_total_validation_time;
+
+                       if (LOG.isInfoEnabled())
+                               if(_validationPossible)
+                                       LOG.info("[+] PARAMSERV: epoch timer 
(excl. validation): " + time_to_epoch / 1000 + " secs.");
+                               else
+                                       LOG.info("[+] PARAMSERV: epoch timer: " 
+ time_to_epoch / 1000 + " secs.");
+               }
+       }
+
+       /**
         * Checks the current model against the validation set
         */
-       private synchronized void validate() {
+       private void validate() {
                Timing tValidate = DMLScript.STATISTICS ? new Timing(true) : 
null;
                _ec.setVariable(Statement.PS_MODEL, _model);
 
@@ -338,7 +363,7 @@ public abstract class ParamServer
                ParamservUtils.cleanupListObject(_ec, Statement.PS_MODEL);
 
                // Log validation results
-               if(LOG.isInfoEnabled())
+               if (LOG.isInfoEnabled())
                        LOG.info("[+] PARAMSERV: validation-loss: " + loss + " 
validation-accuracy: " + accuracy);
 
                if(tValidate != null)
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index e3daf60..9c90767 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -52,7 +52,7 @@ public class BalanceToAvgFederatedScheme extends 
DataPartitionFederatedScheme {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
                BalanceMetrics balanceMetricsBefore = 
getBalanceMetrics(pFeatures);
-               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
balanceMetricsBefore);
+               List<Double> weightingFactors = getWeightingFactors(pFeatures, 
balanceMetricsBefore);
 
                int average_num_rows = (int) balanceMetricsBefore._avgRows;
 
@@ -79,7 +79,7 @@ public class BalanceToAvgFederatedScheme extends 
DataPartitionFederatedScheme {
                        pLabels.get(i).updateDataCharacteristics(update);
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weighingFactors);
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weightingFactors);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index e00923e..c6429b4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -45,15 +45,15 @@ public abstract class DataPartitionFederatedScheme {
                public final List<MatrixObject> _pLabels;
                public final int _workerNum;
                public final BalanceMetrics _balanceMetrics;
-               public final List<Double> _weighingFactors;
+               public final List<Double> _weightingFactors;
 
 
-               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double> 
weighingFactors) {
+               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double> 
weightingFactors) {
                        _pFeatures = pFeatures;
                        _pLabels = pLabels;
                        _workerNum = workerNum;
                        _balanceMetrics = balanceMetrics;
-                       _weighingFactors = weighingFactors;
+                       _weightingFactors = weightingFactors;
                }
        }
 
@@ -125,12 +125,12 @@ public abstract class DataPartitionFederatedScheme {
                return new BalanceMetrics(minRows, sum / slices.size(), 
maxRows);
        }
 
-       static List<Double> getWeighingFactors(List<MatrixObject> pFeatures, 
BalanceMetrics balanceMetrics) {
-               List<Double> weighingFactors = new ArrayList<>();
+       static List<Double> getWeightingFactors(List<MatrixObject> pFeatures, 
BalanceMetrics balanceMetrics) {
+               List<Double> weightingFactors = new ArrayList<>();
                pFeatures.forEach((feature) -> {
-                       weighingFactors.add((double) feature.getNumRows() / 
balanceMetrics._avgRows);
+                       weightingFactors.add((double) feature.getNumRows() / 
balanceMetrics._avgRows);
                });
-               return weighingFactors;
+               return weightingFactors;
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index afbaf4d..ae8d874 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -35,7 +35,7 @@ public class KeepDataOnWorkerFederatedScheme extends 
DataPartitionFederatedSchem
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
                BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
-               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
balanceMetrics);
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
balanceMetrics, weighingFactors);
+               List<Double> weightingFactors = getWeightingFactors(pFeatures, 
balanceMetrics);
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
balanceMetrics, weightingFactors);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index e9c1b50..77b2287 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -52,7 +52,7 @@ public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme
        public Result partition(MatrixObject features, MatrixObject labels, int 
seed) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
-               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
getBalanceMetrics(pFeatures));
+               List<Double> weightingFactors = getWeightingFactors(pFeatures, 
getBalanceMetrics(pFeatures));
 
                int max_rows = 0;
                for (MatrixObject pFeature : pFeatures) {
@@ -82,7 +82,7 @@ public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme
                        pLabels.get(i).updateDataCharacteristics(update);
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weighingFactors);
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weightingFactors);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 365554d..af95270 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -51,7 +51,7 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
                BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
-               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
balanceMetrics);
+               List<Double> weightingFactors = getWeightingFactors(pFeatures, 
balanceMetrics);
 
                for(int i = 0; i < pFeatures.size(); i++) {
                        // Works, because the map contains a single entry
@@ -71,7 +71,7 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                        }
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
balanceMetrics, weighingFactors);
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
balanceMetrics, weightingFactors);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index e55b92e..369b3dd 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -52,7 +52,7 @@ public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme
        public Result partition(MatrixObject features, MatrixObject labels, int 
seed) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
-               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
getBalanceMetrics(pFeatures));
+               List<Double> weightingFactors = getWeightingFactors(pFeatures, 
getBalanceMetrics(pFeatures));
 
                int min_rows = Integer.MAX_VALUE;
                for (MatrixObject pFeature : pFeatures) {
@@ -82,7 +82,7 @@ public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme
                        pLabels.get(i).updateDataCharacteristics(update);
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weighingFactors);
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weightingFactors);
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a99e8ee..e64fdf8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -41,11 +41,10 @@ import static org.apache.sysds.parser.Statement.PS_MODE;
 import static org.apache.sysds.parser.Statement.PS_MODEL;
 import static org.apache.sysds.parser.Statement.PS_PARALLELISM;
 import static org.apache.sysds.parser.Statement.PS_SCHEME;
-import static org.apache.sysds.parser.Statement.PS_FED_SCHEME;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
 import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
-import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
+import static org.apache.sysds.parser.Statement.PS_FED_WEIGHTING;
 import static org.apache.sysds.parser.Statement.PS_SEED;
 import static org.apache.sysds.parser.Statement.PS_VAL_FEATURES;
 import static org.apache.sysds.parser.Statement.PS_VAL_LABELS;
@@ -127,7 +126,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private void runFederated(ExecutionContext ec) {
-               Timing tExecutionTime = DMLScript.STATISTICS ? new Timing(true) 
: null;
+               if(DMLScript.STATISTICS)
+                       Statistics.getPSExecutionTimer().start();
+
                Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
                LOG.info("PARAMETER SERVER");
                LOG.info("[+] Running in federated mode");
@@ -135,12 +136,11 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                // get inputs
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
-               String valFunc = getValFunction();
                PSUpdateType updateType = getUpdateType();
                PSFrequency freq = getFrequency();
                FederatedPSScheme federatedPSScheme = getFederatedScheme();
                PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
-               boolean weighing = getWeighing();
+               boolean weighting = getWeighting();
                int seed = getSeed();
 
                if( LOG.isInfoEnabled() ) {
@@ -148,7 +148,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        LOG.info("[+] Frequency: " + freq);
                        LOG.info("[+] Data Partitioning: " + federatedPSScheme);
                        LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
-                       LOG.info("[+] Weighing: " + weighing);
+                       LOG.info("[+] Weighting: " + weighting);
                        LOG.info("[+] Seed: " + seed);
                }
                if (tSetup != null)
@@ -179,12 +179,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                ExecutionContext aggServiceEC = 
ParamservUtils.copyExecutionContext(newEC, 1).get(0);
                // Create the parameter server
                ListObject model = ec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, freq, workerNum, model, aggServiceEC, valFunc,
-                               getNumBatchesPerEpoch(runtimeBalancing, 
result._balanceMetrics), ec.getMatrixObject(getParam(PS_VAL_FEATURES)), 
ec.getMatrixObject(getParam(PS_VAL_LABELS)));
+               MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) 
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
+               MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? 
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
+               ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
+                               getNumBatchesPerEpoch(runtimeBalancing, 
result._balanceMetrics), val_features, val_labels);
                // Create the local workers
                int finalNumBatchesPerEpoch = 
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
                List<FederatedPSControlThread> threads = IntStream.range(0, 
workerNum)
-                       .mapToObj(i -> new FederatedPSControlThread(i, updFunc, 
freq, runtimeBalancing, weighing,
+                       .mapToObj(i -> new FederatedPSControlThread(i, updFunc, 
freq, runtimeBalancing, weighting,
                                getEpochs(), getBatchSize(), 
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
                        .collect(Collectors.toList());
                if(workerNum != threads.size()) {
@@ -194,7 +196,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                for (int i = 0; i < threads.size(); i++) {
                        threads.get(i).setFeatures(result._pFeatures.get(i));
                        threads.get(i).setLabels(result._pLabels.get(i));
-                       threads.get(i).setup(result._weighingFactors.get(i));
+                       threads.get(i).setup(result._weightingFactors.get(i));
                }
                if (DMLScript.STATISTICS)
                        Statistics.accPSSetupTime((long) tSetup.stop());
@@ -206,7 +208,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                        // Fetch the final model from ps
                        ec.setVariable(output.getName(), ps.getResult());
                        if (DMLScript.STATISTICS)
-                               Statistics.accPSExecutionTime((long) 
tExecutionTime.stop());
+                               Statistics.accPSExecutionTime((long) 
Statistics.getPSExecutionTimer().stop());
                } catch (InterruptedException | ExecutionException e) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
                } finally {
@@ -293,6 +295,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private void runLocally(ExecutionContext ec, PSModeType mode) {
+               if(DMLScript.STATISTICS)
+                       Statistics.getPSExecutionTimer().start();
+
                Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
                int workerNum = getWorkerNum(mode);
                BasicThreadFactory factory = new BasicThreadFactory.Builder()
@@ -314,9 +319,15 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                PSFrequency freq = getFrequency();
                PSUpdateType updateType = getUpdateType();
 
+               double rows_per_worker = Math.ceil((float) 
ec.getMatrixObject(getParam(PS_FEATURES)).getNumRows() / workerNum);
+               int num_batches_per_epoch = (int) Math.ceil(rows_per_worker / 
getBatchSize());
+
                // Create the parameter server
                ListObject model = ec.getListObject(getParam(PS_MODEL));
-               ParamServer ps = createPS(mode, aggFunc, updateType, freq, 
workerNum, model, aggServiceEC);
+               MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null) 
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
+               MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ? 
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
+               ParamServer ps = createPS(mode, aggFunc, updateType, freq, 
workerNum, model, aggServiceEC, getValFunction(),
+                               num_batches_per_epoch, val_features, 
val_labels);
 
                // Create the local workers
                List<LocalPSWorker> workers = IntStream.range(0, workerNum)
@@ -344,6 +355,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                                ret.get(); //error handling
                        // Fetch the final model from ps
                        ec.setVariable(output.getName(), ps.getResult());
+                       if (DMLScript.STATISTICS)
+                               Statistics.accPSExecutionTime((long) 
Statistics.getPSExecutionTimer().stop());
                } catch (InterruptedException | ExecutionException e) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
                } finally {
@@ -529,11 +542,11 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
        private FederatedPSScheme getFederatedScheme() {
                FederatedPSScheme federated_scheme = DEFAULT_FEDERATED_SCHEME;
-               if (getParameterMap().containsKey(PS_FED_SCHEME)) {
+               if (getParameterMap().containsKey(PS_SCHEME)) {
                        try {
-                               federated_scheme = 
FederatedPSScheme.valueOf(getParam(PS_FED_SCHEME));
+                               federated_scheme = 
FederatedPSScheme.valueOf(getParam(PS_SCHEME));
                        } catch (IllegalArgumentException e) {
-                               throw new 
DMLRuntimeException(String.format("Paramserv function in federated mode: not 
support data partition scheme '%s'", getParam(PS_FED_SCHEME)));
+                               throw new 
DMLRuntimeException(String.format("Paramserv function in federated mode: not 
support data partition scheme '%s'", getParam(PS_SCHEME)));
                        }
                }
                return federated_scheme;
@@ -548,7 +561,7 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
         */
        private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, 
DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
                int numBatchesPerEpoch;
-               if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+               if(runtimeBalancing == PSRuntimeBalancing.CYCLE_MIN || 
runtimeBalancing == PSRuntimeBalancing.BASELINE) {
                        numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
                } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
                                || runtimeBalancing == 
PSRuntimeBalancing.SCALE_BATCH) {
@@ -561,8 +574,8 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                return numBatchesPerEpoch;
        }
 
-       private boolean getWeighing() {
-               return getParameterMap().containsKey(PS_FED_WEIGHING) && 
Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
+       private boolean getWeighting() {
+               return getParameterMap().containsKey(PS_FED_WEIGHTING) && 
Boolean.parseBoolean(getParam(PS_FED_WEIGHTING));
        }
 
        private String getValFunction() {
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 320f610..8fcdf02 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -38,6 +38,7 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
@@ -117,6 +118,7 @@ public class Statistics
        private static final LongAdder sparkBroadcastCount = new LongAdder();
 
        // Paramserv function stats (time is in milli sec)
+       private static final Timing psExecutionTimer = new Timing(false);
        private static final LongAdder psExecutionTime = new LongAdder();
        private static final LongAdder psNumWorkers = new LongAdder();
        private static final LongAdder psSetupTime = new LongAdder();
@@ -130,7 +132,7 @@ public class Statistics
        // Federated parameter server specifics (time is in milli sec)
        private static final LongAdder fedPSDataPartitioningTime = new 
LongAdder();
        private static final LongAdder fedPSWorkerComputingTime = new 
LongAdder();
-       private static final LongAdder fedPSGradientWeighingTime = new 
LongAdder();
+       private static final LongAdder fedPSGradientWeightingTime = new 
LongAdder();
        private static final LongAdder fedPSCommunicationTime = new LongAdder();
 
        //PARFOR optimization stats (low frequency updates)
@@ -571,6 +573,14 @@ public class Statistics
                psNumWorkers.add(n);
        }
 
+       public static Timing getPSExecutionTimer() {
+               return psExecutionTimer;
+       }
+
+       public static double getPSExecutionTime() {
+               return psExecutionTime.doubleValue();
+       }
+
        public static void accPSExecutionTime(long n) {
                psExecutionTime.add(n);
        }
@@ -603,6 +613,10 @@ public class Statistics
                psRpcRequestTime.add(t);
        }
 
+       public static double getPSValidationTime() {
+               return psValidationTime.doubleValue();
+       }
+
        public static void accPSValidationTime(long t) {
                psValidationTime.add(t);
        }
@@ -615,8 +629,8 @@ public class Statistics
                fedPSWorkerComputingTime.add(t);
        }
 
-       public static void accFedPSGradientWeighingTime(long t) {
-               fedPSGradientWeighingTime.add(t);
+       public static void accFedPSGradientWeightingTime(long t) {
+               fedPSGradientWeightingTime.add(t);
        }
 
        public static void accFedPSCommunicationTime(long t) { 
fedPSCommunicationTime.add(t);}
@@ -1049,7 +1063,7 @@ public class Statistics
                                        sb.append(String.format("PS fed data 
partitioning time:\t%.3f secs.\n", fedPSDataPartitioningTime.doubleValue() / 
1000));
                                        sb.append(String.format("PS fed comm 
time (cum):\t\t%.3f secs.\n", fedPSCommunicationTime.doubleValue() / 1000));
                                        sb.append(String.format("PS fed worker 
comp time (cum):\t%.3f secs.\n", fedPSWorkerComputingTime.doubleValue() / 
1000));
-                                       sb.append(String.format("PS fed grad 
weigh time (cum):\t%.3f secs.\n", fedPSGradientWeighingTime.doubleValue() / 
1000));
+                                       sb.append(String.format("PS fed grad. 
weigh. time (cum):\t%.3f secs.\n", fedPSGradientWeightingTime.doubleValue() / 
1000));
                                        sb.append(String.format("PS fed global 
model agg time:\t%.3f secs.\n", psAggregationTime.doubleValue() / 1000));
                                }
                                else {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 5d7c7e2..9221a53 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -54,7 +54,7 @@ public class FederatedParamservTest extends AutomatedTestBase 
{
        private final String _freq;
        private final String _scheme;
        private final String _runtime_balancing;
-       private final String _weighing;
+       private final String _weighting;
        private final String _data_distribution;
        private final int _seed;
 
@@ -66,35 +66,35 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                        // basic functionality
                        //{"TwoNN",     4, 60000, 32, 4, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "false","BALANCED",             
200},
 
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,             "true", "IMBALANCED",  
 200},
-                       {"CNN",         2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "SHUFFLE",                              "NONE" ,                
"true", "IMBALANCED",   200},
-                       {"CNN",         2, 4, 1, 4, 0.01,               "ASP", 
"BATCH", "REPLICATE_TO_MAX",     "RUN_MIN" ,     "true", "IMBALANCED",   200},
-                       {"TwoNN",       2, 4, 1, 4, 0.01,               "ASP", 
"EPOCH", "BALANCE_TO_AVG",               "CYCLE_MAX" ,   "true", "IMBALANCED",  
 200},
-                       {"TwoNN",       5, 1000, 100, 2, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "true", "BALANCED",             
200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "true", "IMBALANCED",  
 200},
+                       {"CNN",         2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "SHUFFLE",                              "NONE",                 
"true", "IMBALANCED",   200},
+                       {"CNN",         2, 4, 1, 4, 0.01,               "ASP", 
"BATCH", "REPLICATE_TO_MAX",     "CYCLE_MIN",    "true", "IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "ASP", 
"EPOCH", "BALANCE_TO_AVG",               "CYCLE_MAX",    "true", "IMBALANCED",  
 200},
+                       {"TwoNN",       5, 1000, 100, 2, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE",                 "true", "BALANCED",             
200},
 
                        /*
                                // runtime balancing
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   
200},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   
200},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   
200},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   
200},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   
200},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MIN",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MIN",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX",    "true", "IMBALANCED",   
200},
+                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX",    "true", "IMBALANCED",   
200},
 
                                // data partitioning
-                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SHUFFLE",                              "CYCLE_AVG" ,   "true", 
"IMBALANCED",   200},
-                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "REPLICATE_TO_MAX",             "NONE" ,                "true", 
"IMBALANCED",   200},
-                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SUBSAMPLE_TO_MIN",             "NONE" ,                "true", 
"IMBALANCED",   200},
-                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "BALANCE_TO_AVG",               "NONE" ,                "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SHUFFLE",                              "CYCLE_AVG",    "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "REPLICATE_TO_MAX",             "NONE",                 "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "SUBSAMPLE_TO_MIN",             "NONE",                 "true", 
"IMBALANCED",   200},
+                               {"TwoNN",       2, 4, 1, 1, 0.01,               
"BSP", "BATCH", "BALANCE_TO_AVG",               "NONE",                 "true", 
"IMBALANCED",   200},
 
                                // balanced tests
-                               {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "NONE" ,                "true", "BALANCED",    
         200}
+                               {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "NONE",                 "true", "BALANCED",    
         200}
                        */
                });
        }
 
        public FederatedParamservTest(String networkType, int 
numFederatedWorkers, int dataSetSize, int batch_size,
-               int epochs, double eta, String utype, String freq, String 
scheme, String runtime_balancing, String weighing, String data_distribution, 
int seed) {
+               int epochs, double eta, String utype, String freq, String 
scheme, String runtime_balancing, String weighting, String data_distribution, 
int seed) {
 
                _networkType = networkType;
                _numFederatedWorkers = numFederatedWorkers;
@@ -106,7 +106,7 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                _freq = freq;
                _scheme = scheme;
                _runtime_balancing = runtime_balancing;
-               _weighing = weighing;
+               _weighting = weighting;
                _data_distribution = data_distribution;
                _seed = seed;
        }
@@ -192,7 +192,7 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                                        "freq=" + _freq,
                                        "scheme=" + _scheme,
                                        "runtime_balancing=" + 
_runtime_balancing,
-                                       "weighing=" + _weighing,
+                                       "weighting=" + _weighting,
                                        "network_type=" + _networkType,
                                        "channels=" + C,
                                        "hin=" + Hin,
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml 
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 79628ef..6663ca6 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -161,7 +161,7 @@ train = function(matrix[double] X, matrix[double] y, 
matrix[double] X_val,
 train_paramserv = function(matrix[double] X, matrix[double] y,
   matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
   string utype, string freq, int batch_size, string scheme, string 
runtime_balancing,
-  string weighing, double eta, int C, int Hin, int Win, int seed = -1)
+  string weighting, double eta, int C, int Hin, int Win, int seed = -1)
   return (list[unknown] model)
 {
   N = nrow(X)
@@ -208,7 +208,7 @@ train_paramserv = function(matrix[double] X, matrix[double] 
y,
     
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
     val="./src/test/scripts/functions/federated/paramserv/CNN.dml::validate",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
-    scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, 
hyperparams=hyperparams, seed=seed)
+    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, 
hyperparams=hyperparams, seed=seed)
 }
 
 /*
diff --git 
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml 
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index c7ad305..7efd588 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -27,13 +27,13 @@ features = read($features)
 labels = read($labels)
 
 if($network_type == "TwoNN") {
-  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $seed)
+  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed)
   print("Test results:")
   [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100, cols=784), 
matrix(0, rows=100, cols=10), model, list())
   print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test + 
"\n")
 }
 else {
-  model = CNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighing, $eta, $channels, $hin, 
$win, $seed)
+  model = CNN::train_paramserv(features, labels, matrix(0, rows=100, 
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq, 
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin, 
$win, $seed)
   print("Test results:")
   hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
   [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784), 
matrix(0, rows=100, cols=10), model, hyperparams)
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml 
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index e7fc6d9..f42cdca 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -125,7 +125,7 @@ train = function(matrix[double] X, matrix[double] y,
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing, string weighing,
+                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing, string weighting,
                  double eta, int seed = -1)
     return (list[unknown] model) {
 
@@ -156,7 +156,7 @@ train_paramserv = function(matrix[double] X, matrix[double] 
y,
     
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
     val="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::validate",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
-    scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, 
hyperparams=hyperparams, seed=seed)
+    scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting, 
hyperparams=hyperparams, seed=seed)
 }
 
 /*

Reply via email to