This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 460c394  [SYSTEMDS-2550] Federated parameter server scaling and weight 
handling
460c394 is described below

commit 460c3945899ce0fc7fd0c0fd92bb5f3d32a25f7a
Author: Tobias Rieger <[email protected]>
AuthorDate: Sat Jan 9 22:57:15 2021 +0100

    [SYSTEMDS-2550] Federated parameter server scaling and weight handling
    
    Closes #1141.
---
 .../ParameterizedBuiltinFunctionExpression.java    |   8 +-
 .../java/org/apache/sysds/parser/Statement.java    |   6 +-
 .../runtime/compress/colgroup/ColGroupValue.java   |   2 +-
 .../paramserv/FederatedPSControlThread.java        | 165 ++++++++++++---------
 .../paramserv/dp/BalanceToAvgFederatedScheme.java  |  31 ++--
 .../paramserv/dp/DataPartitionFederatedScheme.java |  43 ++++--
 .../paramserv/dp/FederatedDataPartitioner.java     |   7 +-
 .../dp/KeepDataOnWorkerFederatedScheme.java        |  13 +-
 .../dp/ReplicateToMaxFederatedScheme.java          |  26 +++-
 .../paramserv/dp/ShuffleFederatedScheme.java       |  24 ++-
 .../dp/SubsampleToMinFederatedScheme.java          |  26 +++-
 .../cp/ParamservBuiltinCPInstruction.java          | 116 +++++++++------
 .../paramserv/FederatedParamservTest.java          |  58 ++++----
 .../scripts/functions/federated/paramserv/CNN.dml  |   4 +-
 .../federated/paramserv/FederatedParamservTest.dml |   6 +-
 .../functions/federated/paramserv/TwoNN.dml        |   4 +-
 16 files changed, 342 insertions(+), 197 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 5171f21..05bfc48 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -289,8 +289,8 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL, 
Statement.PS_FEATURES, Statement.PS_LABELS,
                        Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, 
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
                        Statement.PS_MODE, Statement.PS_UPDATE_TYPE, 
Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
-                       Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, 
Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING,
-                       Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
+                       Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, 
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
+                       Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS, 
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
                checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
                // check existence and correctness of parameters
@@ -308,9 +308,11 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                checkDataValueType(true, fname, Statement.PS_BATCH_SIZE, 
DataType.SCALAR, ValueType.INT64, conditional);
                checkDataValueType(true, fname, Statement.PS_PARALLELISM, 
DataType.SCALAR, ValueType.INT64, conditional);
                checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
-               checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING, 
conditional);
+               checkStringParam(true, fname, 
Statement.PS_FED_RUNTIME_BALANCING, conditional);
+               checkStringParam(true, fname, Statement.PS_FED_WEIGHING, 
conditional);
                checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS, 
DataType.LIST, ValueType.UNKNOWN, conditional);
                checkStringParam(true, fname, Statement.PS_CHECKPOINTING, 
conditional);
+               checkDataValueType(true, fname, Statement.PS_SEED, 
DataType.SCALAR, ValueType.INT64, conditional);
 
                // set output characteristics
                output.setDataType(DataType.LIST);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java 
b/src/main/java/org/apache/sysds/parser/Statement.java
index 6767d85..9104246 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -70,6 +70,7 @@ public abstract class Statement implements ParseInfo
        public static final String PS_AGGREGATION_FUN = "agg";
        public static final String PS_MODE = "mode";
        public static final String PS_GRADIENTS = "gradients";
+       public static final String PS_SEED = "seed";
        public enum PSModeType {
                FEDERATED, LOCAL, REMOTE_SPARK
        }
@@ -87,9 +88,10 @@ public abstract class Statement implements ParseInfo
        public enum PSFrequency {
                BATCH, EPOCH
        }
-       public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
+       public static final String PS_FED_WEIGHING = "weighing";
+       public static final String PS_FED_RUNTIME_BALANCING = 
"runtime_balancing";
        public enum PSRuntimeBalancing {
-               NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH, 
SCALE_BATCH_AND_WEIGH
+               NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
        }
        public static final String PS_EPOCHS = "epochs";
        public static final String PS_BATCH_SIZE = "batchsize";
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 54a45d0..f09b5c2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -228,7 +228,7 @@ public abstract class ColGroupValue extends ColGroup 
implements Cloneable {
                return val;
        }
 
-       protected final double sumValuesSparse(int valIx, SparseRow[] rows, 
double[] dictVals, int rowsIndex) {
+       protected static double sumValuesSparse(int valIx, SparseRow[] rows, 
double[] dictVals, int rowsIndex) {
                throw new NotImplementedException("This Method was implemented 
incorrectly");
                // final int numCols = getNumCols();
                // final int valOff = valIx * numCols;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 393b131..48249db 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -24,6 +24,8 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.parser.DataIdentifier;
 import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.Statement.PSFrequency;
+import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
 import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
@@ -37,13 +39,17 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.IntObject;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.runtime.util.ProgramConverter;
 
 import java.util.ArrayList;
@@ -58,21 +64,29 @@ import static 
org.apache.sysds.runtime.util.ProgramConverter.*;
 public class FederatedPSControlThread extends PSWorker implements 
Callable<Void> {
        private static final long serialVersionUID = 6846648059569648791L;
        protected static final Log LOG = 
LogFactory.getLog(ParamServer.class.getName());
-       
-       Statement.PSRuntimeBalancing _runtimeBalancing;
-       FederatedData _featuresData;
-       FederatedData _labelsData;
-       final long _localStartBatchNumVarID;
-       final long _modelVarID;
-       int _numBatchesPerGlobalEpoch;
-       int _possibleBatchesPerLocalEpoch;
-       boolean _cycleStartAt0 = false;
-
-       public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int 
epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, 
ParamServer ps) {
+
+       private FederatedData _featuresData;
+       private FederatedData _labelsData;
+       private final long _localStartBatchNumVarID;
+       private final long _modelVarID;
+
+       // runtime balancing
+       private PSRuntimeBalancing _runtimeBalancing;
+       private int _numBatchesPerEpoch;
+       private int _possibleBatchesPerLocalEpoch;
+       private boolean _weighing;
+       private double _weighingFactor = 1;
+       private boolean _cycleStartAt0 = false;
+
+       public FederatedPSControlThread(int workerID, String updFunc, 
Statement.PSFrequency freq,
+               PSRuntimeBalancing runtimeBalancing, boolean weighing, int 
epochs, long batchSize,
+               int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer 
ps)
+       {
                super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
 
-               _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
+               _numBatchesPerEpoch = numBatchesPerGlobalEpoch;
                _runtimeBalancing = runtimeBalancing;
+               _weighing = weighing;
                // generate the IDs for model and batch counter. These get 
overwritten on the federated worker each time
                _localStartBatchNumVarID = FederationUtils.getNextFedDataID();
                _modelVarID = FederationUtils.getNextFedDataID();
@@ -80,65 +94,72 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
 
        /**
         * Sets up the federated worker and control thread
+        *
+        * @param weighingFactor Gradients from this worker will be multiplied 
by this factor if weighing is enabled
         */
-       public void setup() {
+       public void setup(double weighingFactor) {
                // prepare features and labels
                _featuresData = (FederatedData) 
_features.getFedMapping().getMap().values().toArray()[0];
                _labelsData = (FederatedData) 
_labels.getFedMapping().getMap().values().toArray()[0];
 
-               // calculate number of batches and get data size
+               // weighing factor is always set, but only used when weighing 
is specified
+               _weighingFactor = weighingFactor;
+
+               // different runtime balancing calculations
                long dataSize = _features.getNumRows();
-               _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) 
dataSize / _batchSize);
-               if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN 
-                       || _runtimeBalancing == 
Statement.PSRuntimeBalancing.CYCLE_AVG 
-                       || _runtimeBalancing == 
Statement.PSRuntimeBalancing.CYCLE_MAX)) {
-                       _numBatchesPerGlobalEpoch = 
_possibleBatchesPerLocalEpoch;
+
+               // calculate scaled batch size if balancing via batch size.
+               // In some cases there will be some cycling
+               if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+                       _batchSize = (int) Math.ceil((double) dataSize / 
_numBatchesPerEpoch);
                }
 
-               if(_runtimeBalancing == 
Statement.PSRuntimeBalancing.SCALE_BATCH 
-                       || _runtimeBalancing == 
Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
-                       throw new NotImplementedException();
+               // Calculate possible batches with batch size
+               _possibleBatchesPerLocalEpoch = (int) Math.ceil((double) 
dataSize / _batchSize);
+
+               // If no runtime balancing is specified, just run possible 
number of batches
+               // WARNING: Will get stuck on miss match
+               if(_runtimeBalancing == PSRuntimeBalancing.NONE) {
+                       _numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
                }
 
+               LOG.info("Setup config for worker " + this.getWorkerName());
+               LOG.info("Batch size: " + _batchSize + " possible batches: " + 
_possibleBatchesPerLocalEpoch
+                               + " batches to run: " + _numBatchesPerEpoch + " 
weighing factor: " + _weighingFactor);
+
                // serialize program
                // create program blocks for the instruction filtering
                String programSerialized;
-               ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
+               ArrayList<ProgramBlock> pbs = new ArrayList<>();
 
                BasicProgramBlock gradientProgramBlock = new 
BasicProgramBlock(_ec.getProgram());
                gradientProgramBlock.setInstructions(new 
ArrayList<>(Arrays.asList(_inst)));
-               programBlocks.add(gradientProgramBlock);
+               pbs.add(gradientProgramBlock);
 
-               if(_freq == Statement.PSFrequency.EPOCH) {
+               if(_freq == PSFrequency.EPOCH) {
                        BasicProgramBlock aggProgramBlock = new 
BasicProgramBlock(_ec.getProgram());
                        aggProgramBlock.setInstructions(new 
ArrayList<>(Arrays.asList(_ps.getAggInst())));
-                       programBlocks.add(aggProgramBlock);
+                       pbs.add(aggProgramBlock);
                }
 
-               StringBuilder sb = new StringBuilder();
-               sb.append(PROG_BEGIN);
-               sb.append( NEWLINE );
-               sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
-                               programBlocks,
-                               new HashMap<>(),
-                               false
-               ));
-               sb.append(PROG_END);
-               programSerialized = sb.toString();
+               programSerialized = InstructionUtils.concatStrings(
+                       PROG_BEGIN, NEWLINE,
+                       ProgramConverter.serializeProgram(_ec.getProgram(), 
pbs, new HashMap<>(), false),
+                       PROG_END);
 
                // write program and meta data to worker
                Future<FederatedResponse> udfResponse = 
_featuresData.executeFederatedOperation(
                        new FederatedRequest(RequestType.EXEC_UDF, 
_featuresData.getVarID(),
                                new SetupFederatedWorker(_batchSize,
-                                               dataSize,
-                                               _possibleBatchesPerLocalEpoch,
-                                               programSerialized,
-                                               _inst.getNamespace(),
-                                               _inst.getFunctionName(),
-                                               
_ps.getAggInst().getFunctionName(),
-                                               
_ec.getListObject("hyperparams"),
-                                               _localStartBatchNumVarID,
-                                               _modelVarID
+                                       dataSize,
+                                       _possibleBatchesPerLocalEpoch,
+                                       programSerialized,
+                                       _inst.getNamespace(),
+                                       _inst.getFunctionName(),
+                                       _ps.getAggInst().getFunctionName(),
+                                       _ec.getListObject("hyperparams"),
+                                       _localStartBatchNumVarID,
+                                       _modelVarID
                                )
                ));
 
@@ -286,12 +307,23 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                return _ps.pull(_workerID);
        }
 
-       protected void pushGradients(ListObject gradients) {
+       protected void scaleAndPushGradients(ListObject gradients) {
+               // scale gradients - must only include MatrixObjects
+               if(_weighing && _weighingFactor != 1) {
+                       gradients.getData().parallelStream().forEach((matrix) 
-> {
+                               MatrixObject matrixObject = (MatrixObject) 
matrix;
+                               MatrixBlock input = 
matrixObject.acquireReadAndRelease().scalarOperations(
+                                       new 
RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new 
MatrixBlock());
+                               matrixObject.acquireModify(input);
+                               matrixObject.release();
+                       });
+               }
+
                // Push the gradients to ps
                _ps.push(_workerID, gradients);
        }
 
-       static protected int getNextLocalBatchNum(int currentLocalBatchNumber, 
int possibleBatchesPerLocalEpoch) {
+       protected static int getNextLocalBatchNum(int currentLocalBatchNumber, 
int possibleBatchesPerLocalEpoch) {
                return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
        }
 
@@ -300,18 +332,18 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
         */
        protected void computeWithBatchUpdates() {
                for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
-                       int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : 
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+                       int currentLocalBatchNumber = (_cycleStartAt0) ? 0 : 
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
 
-                       for (int batchCounter = 0; batchCounter < 
_numBatchesPerGlobalEpoch; batchCounter++) {
+                       for (int batchCounter = 0; batchCounter < 
_numBatchesPerEpoch; batchCounter++) {
                                int localStartBatchNum = 
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
                                ListObject model = pullModel();
                                ListObject gradients = 
computeGradientsForNBatches(model, 1, localStartBatchNum);
-                               pushGradients(gradients);
+                               scaleAndPushGradients(gradients);
                                ParamservUtils.cleanupListObject(model);
                                ParamservUtils.cleanupListObject(gradients);
+                               LOG.info("[+] " + this.getWorkerName() + " 
completed BATCH " + localStartBatchNum);
                        }
-                       if( LOG.isInfoEnabled() )
-                               LOG.info("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
+                       LOG.info("[+] " + this.getWorkerName() + " --- 
completed EPOCH " + epochCounter);
                }
        }
 
@@ -327,15 +359,14 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
         */
        protected void computeWithEpochUpdates() {
                for (int epochCounter = 0; epochCounter < _epochs; 
epochCounter++) {
-                       int localStartBatchNum = (_cycleStartAt0) ? 0 : 
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+                       int localStartBatchNum = (_cycleStartAt0) ? 0 : 
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
 
                        // Pull the global parameters from ps
                        ListObject model = pullModel();
-                       ListObject gradients = 
computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch, 
localStartBatchNum, true);
-                       pushGradients(gradients);
-                       
-                       if( LOG.isInfoEnabled() )
-                               LOG.info("[+] " + this.getWorkerName() + " 
completed epoch " + epochCounter);
+                       ListObject gradients = 
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, 
true);
+                       scaleAndPushGradients(gradients);
+
+                       LOG.info("[+] " + this.getWorkerName() + " --- 
completed EPOCH " + epochCounter);
                        ParamservUtils.cleanupListObject(model);
                        ParamservUtils.cleanupListObject(gradients);
                }
@@ -424,12 +455,12 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        ArrayList<DataIdentifier> inputs = 
func.getInputParams();
                        ArrayList<DataIdentifier> outputs = 
func.getOutputParams();
                        CPOperand[] boundInputs = inputs.stream()
-                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-                                       .toArray(CPOperand[]::new);
+                               .map(input -> new CPOperand(input.getName(), 
input.getValueType(), input.getDataType()))
+                               .toArray(CPOperand[]::new);
                        ArrayList<String> outputNames = 
outputs.stream().map(DataIdentifier::getName)
-                                       
.collect(Collectors.toCollection(ArrayList::new));
+                               
.collect(Collectors.toCollection(ArrayList::new));
                        Instruction gradientsInstruction = new 
FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
-                                       func.getInputParamNames(), outputNames, 
"gradient function");
+                               func.getInputParamNames(), outputNames, 
"gradient function");
                        DataIdentifier gradientsOutput = outputs.get(0);
 
                        // recreate aggregation instruction and output if needed
@@ -440,12 +471,12 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                inputs = func.getInputParams();
                                outputs = func.getOutputParams();
                                boundInputs = inputs.stream()
-                                               .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
-                                               .toArray(CPOperand[]::new);
+                                       .map(input -> new 
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+                                       .toArray(CPOperand[]::new);
                                outputNames = 
outputs.stream().map(DataIdentifier::getName)
-                                               
.collect(Collectors.toCollection(ArrayList::new));
+                                       
.collect(Collectors.toCollection(ArrayList::new));
                                aggregationInstruction = new 
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
-                                               func.getInputParamNames(), 
outputNames, "aggregation function");
+                                       func.getInputParamNames(), outputNames, 
"aggregation function");
                                aggregationOutput = outputs.get(0);
                        }
 
@@ -492,8 +523,6 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                ParamservUtils.cleanupData(ec, 
Statement.PS_FEATURES);
                                ParamservUtils.cleanupData(ec, 
Statement.PS_LABELS);
                                
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
-                               if( LOG.isInfoEnabled() )
-                                       LOG.info("[+]" + " completed batch " + 
localBatchNum);
                        }
 
                        // model clean up
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index 460faba..34e94f0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.runtime.controlprogram.paramserv.dp;
 
 import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
@@ -35,13 +34,25 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Balance to Avg Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data 
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the 
global average number of examples is taken
+ * and the worker subsamples or replicates data to match that number of 
examples. See the other federated schemes.
+ *
+ * Then all entries in the federation map of the input matrix are separated 
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
        @Override
-       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+       public Result partition(MatrixObject features, MatrixObject labels, int 
seed) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+               BalanceMetrics balanceMetricsBefore = 
getBalanceMetrics(pFeatures);
+               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
balanceMetricsBefore);
 
-               int average_num_rows = (int) 
Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN));
+               int average_num_rows = (int) balanceMetricsBefore._avgRows;
 
                for(int i = 0; i < pFeatures.size(); i++) {
                        // Works, because the map contains a single entry
@@ -49,7 +60,7 @@ public class BalanceToAvgFederatedScheme extends 
DataPartitionFederatedScheme {
                        FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
                        Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                                       featuresData.getVarID(), new 
balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, average_num_rows)));
+                                       featuresData.getVarID(), new 
balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, seed, average_num_rows)));
 
                        try {
                                FederatedResponse response = udfResponse.get();
@@ -66,7 +77,7 @@ public class BalanceToAvgFederatedScheme extends 
DataPartitionFederatedScheme {
                        pLabels.get(i).updateDataCharacteristics(update);
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weighingFactors);
        }
 
        /**
@@ -74,10 +85,12 @@ public class BalanceToAvgFederatedScheme extends 
DataPartitionFederatedScheme {
         */
        private static class balanceDataOnFederatedWorker extends FederatedUDF {
                private static final long serialVersionUID = 
6631958250346625546L;
+               private final int _seed;
                private final int _average_num_rows;
-               
-               protected balanceDataOnFederatedWorker(long[] inIDs, int 
average_num_rows) {
+
+               protected balanceDataOnFederatedWorker(long[] inIDs, int seed, 
int average_num_rows) {
                        super(inIDs);
+                       _seed = seed;
                        _average_num_rows = average_num_rows;
                }
 
@@ -88,14 +101,14 @@ public class BalanceToAvgFederatedScheme extends 
DataPartitionFederatedScheme {
 
                        if(features.getNumRows() > _average_num_rows) {
                                // generate subsampling matrix
-                               MatrixBlock subsampleMatrixBlock = 
ParamservUtils.generateSubsampleMatrix(_average_num_rows, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               MatrixBlock subsampleMatrixBlock = 
ParamservUtils.generateSubsampleMatrix(_average_num_rows, 
Math.toIntExact(features.getNumRows()), _seed);
                                subsampleTo(features, subsampleMatrixBlock);
                                subsampleTo(labels, subsampleMatrixBlock);
                        }
                        else if(features.getNumRows() < _average_num_rows) {
                                int num_rows_needed = _average_num_rows - 
Math.toIntExact(features.getNumRows());
                                // generate replication matrix
-                               MatrixBlock replicateMatrixBlock = 
ParamservUtils.generateReplicationMatrix(num_rows_needed, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               MatrixBlock replicateMatrixBlock = 
ParamservUtils.generateReplicationMatrix(num_rows_needed, 
Math.toIntExact(features.getNumRows()), _seed);
                                replicateTo(features, replicateMatrixBlock);
                                replicateTo(labels, replicateMatrixBlock);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index f5c9638..e00923e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -45,16 +45,31 @@ public abstract class DataPartitionFederatedScheme {
                public final List<MatrixObject> _pLabels;
                public final int _workerNum;
                public final BalanceMetrics _balanceMetrics;
+               public final List<Double> _weighingFactors;
 
-               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels, int workerNum, BalanceMetrics balanceMetrics) {
-                       this._pFeatures = pFeatures;
-                       this._pLabels = pLabels;
-                       this._workerNum = workerNum;
-                       this._balanceMetrics = balanceMetrics;
+
+               public Result(List<MatrixObject> pFeatures, List<MatrixObject> 
pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double> 
weighingFactors) {
+                       _pFeatures = pFeatures;
+                       _pLabels = pLabels;
+                       _workerNum = workerNum;
+                       _balanceMetrics = balanceMetrics;
+                       _weighingFactors = weighingFactors;
                }
        }
 
-       public abstract Result doPartitioning(MatrixObject features, 
MatrixObject labels);
+       public static final class BalanceMetrics {
+               public final long _minRows;
+               public final long _avgRows;
+               public final long _maxRows;
+
+               public BalanceMetrics(long minRows, long avgRows, long maxRows) 
{
+                       _minRows = minRows;
+                       _avgRows = avgRows;
+                       _maxRows = maxRows;
+               }
+       }
+
+       public abstract Result partition(MatrixObject features, MatrixObject 
labels, int seed);
 
        /**
         * Takes a row federated Matrix and slices it into a matrix for each 
worker
@@ -110,16 +125,12 @@ public abstract class DataPartitionFederatedScheme {
                return new BalanceMetrics(minRows, sum / slices.size(), 
maxRows);
        }
 
-       public static final class BalanceMetrics {
-               public final long _minRows;
-               public final long _avgRows;
-               public final long _maxRows;
-
-               public BalanceMetrics(long minRows, long avgRows, long maxRows) 
{
-                       this._minRows = minRows;
-                       this._avgRows = avgRows;
-                       this._maxRows = maxRows;
-               }
+       static List<Double> getWeighingFactors(List<MatrixObject> pFeatures, 
BalanceMetrics balanceMetrics) {
+               List<Double> weighingFactors = new ArrayList<>();
+               pFeatures.forEach((feature) -> {
+                       weighingFactors.add((double) feature.getNumRows() / 
balanceMetrics._avgRows);
+               });
+               return weighingFactors;
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
index d1ebb6c..ce2f954 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -24,10 +24,11 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 
 public class FederatedDataPartitioner {
-
        private final DataPartitionFederatedScheme _scheme;
+       private final int _seed;
 
-       public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
+       public FederatedDataPartitioner(Statement.FederatedPSScheme scheme, int 
seed) {
+               _seed = seed;
                switch (scheme) {
                        case KEEP_DATA_ON_WORKER:
                                _scheme = new KeepDataOnWorkerFederatedScheme();
@@ -50,6 +51,6 @@ public class FederatedDataPartitioner {
        }
 
        public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject 
features, MatrixObject labels) {
-               return _scheme.doPartitioning(features, labels);
+               return _scheme.partition(features, labels, _seed);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index e306f25..afbaf4d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -22,11 +22,20 @@ package 
org.apache.sysds.runtime.controlprogram.paramserv.dp;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import java.util.List;
 
+/**
+ * Keep Data on Worker Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data 
which is already on the workers.
+ * All entries in the federation map of the input matrix are separated into 
MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class KeepDataOnWorkerFederatedScheme extends 
DataPartitionFederatedScheme {
        @Override
-       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+       public Result partition(MatrixObject features, MatrixObject labels, int 
seed) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+               BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
balanceMetrics);
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
balanceMetrics, weighingFactors);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index 068cfa9..a1b8f6c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -34,11 +34,23 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Replicate to Max Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data 
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the 
global maximum number of examples is taken
+ * and the worker replicates data to match that number of examples. The 
generation is done by multiplying with a
+ * Permutation Matrix with a global seed. These selected examples are appended 
to the original data.
+ *
+ * Then all entries in the federation map of the input matrix are separated 
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme {
        @Override
-       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+       public Result partition(MatrixObject features, MatrixObject labels, int 
seed) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
getBalanceMetrics(pFeatures));
 
                int max_rows = 0;
                for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@ public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme
                        FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
                        Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                                       featuresData.getVarID(), new 
replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, max_rows)));
+                                       featuresData.getVarID(), new 
replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, seed, max_rows)));
 
                        try {
                                FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@ public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme
                        pLabels.get(i).updateDataCharacteristics(update);
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weighingFactors);
        }
 
        /**
@@ -76,10 +88,12 @@ public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme
         */
        private static class replicateDataOnFederatedWorker extends 
FederatedUDF {
                private static final long serialVersionUID = 
-6930898456315100587L;
+               private final int _seed;
                private final int _max_rows;
-               
-               protected replicateDataOnFederatedWorker(long[] inIDs, int 
max_rows) {
+
+               protected replicateDataOnFederatedWorker(long[] inIDs, int 
seed, int max_rows) {
                        super(inIDs);
+                       _seed = seed;
                        _max_rows = max_rows;
                }
 
@@ -92,7 +106,7 @@ public class ReplicateToMaxFederatedScheme extends 
DataPartitionFederatedScheme
                        if(features.getNumRows() < _max_rows) {
                                int num_rows_needed = _max_rows - 
Math.toIntExact(features.getNumRows());
                                // generate replication matrix
-                               MatrixBlock replicateMatrixBlock = 
ParamservUtils.generateReplicationMatrix(num_rows_needed, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               MatrixBlock replicateMatrixBlock = 
ParamservUtils.generateReplicationMatrix(num_rows_needed, 
Math.toIntExact(features.getNumRows()), _seed);
                                replicateTo(features, replicateMatrixBlock);
                                replicateTo(labels, replicateMatrixBlock);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 65ef69d..1920593 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -33,11 +33,23 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Shuffle Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data 
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case it is 
shuffled by generating a permutation
+ * matrix with a global seed and doing a mat mult.
+ *
+ * Then all entries in the federation map of the input matrix are separated 
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
        @Override
-       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+       public Result partition(MatrixObject features, MatrixObject labels, int 
seed) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+               BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
balanceMetrics);
 
                for(int i = 0; i < pFeatures.size(); i++) {
                        // Works, because the map contains a single entry
@@ -45,7 +57,7 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                        FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
                        Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                                       featuresData.getVarID(), new 
shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()})));
+                                       featuresData.getVarID(), new 
shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, seed)));
 
                        try {
                                FederatedResponse response = udfResponse.get();
@@ -57,7 +69,7 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                        }
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
balanceMetrics, weighingFactors);
        }
 
        /**
@@ -65,9 +77,11 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
         */
        private static class shuffleDataOnFederatedWorker extends FederatedUDF {
                private static final long serialVersionUID = 
3228664618781333325L;
+               private final int _seed;
 
-               protected shuffleDataOnFederatedWorker(long[] inIDs) {
+               protected shuffleDataOnFederatedWorker(long[] inIDs, int seed) {
                        super(inIDs);
+                       _seed = seed;
                }
 
                @Override
@@ -76,7 +90,7 @@ public class ShuffleFederatedScheme extends 
DataPartitionFederatedScheme {
                        MatrixObject labels = (MatrixObject) data[1];
 
                        // generate permutation matrix
-                       MatrixBlock permutationMatrixBlock = 
ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), 
System.currentTimeMillis());
+                       MatrixBlock permutationMatrixBlock = 
ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()), 
_seed);
                        shuffle(features, permutationMatrixBlock);
                        shuffle(labels, permutationMatrixBlock);
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index 9b62cc8..937c37e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -34,11 +34,23 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
 import java.util.List;
 import java.util.concurrent.Future;
 
+/**
+ * Subsample to Min Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data 
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the 
global minimum number of examples is taken
+ * and the worker subsamples data to match that number of examples. The 
subsampling is done by multiplying with a
+ * Permutation Matrix with a global seed.
+ *
+ * Then all entries in the federation map of the input matrix are separated 
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
 public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme {
        @Override
-       public Result doPartitioning(MatrixObject features, MatrixObject 
labels) {
+       public Result partition(MatrixObject features, MatrixObject labels, int 
seed) {
                List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
                List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+               List<Double> weighingFactors = getWeighingFactors(pFeatures, 
getBalanceMetrics(pFeatures));
 
                int min_rows = Integer.MAX_VALUE;
                for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@ public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme
                        FederatedData labelsData = (FederatedData) 
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
 
                        Future<FederatedResponse> udfResponse = 
featuresData.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                                       featuresData.getVarID(), new 
subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, min_rows)));
+                                       featuresData.getVarID(), new 
subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(), 
labelsData.getVarID()}, seed, min_rows)));
 
                        try {
                                FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@ public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme
                        pLabels.get(i).updateDataCharacteristics(update);
                }
 
-               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures));
+               return new Result(pFeatures, pLabels, pFeatures.size(), 
getBalanceMetrics(pFeatures), weighingFactors);
        }
 
        /**
@@ -76,10 +88,12 @@ public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme
         */
        private static class subsampleDataOnFederatedWorker extends 
FederatedUDF {
                private static final long serialVersionUID = 
2213790859544004286L;
+               private final int _seed;
                private final int _min_rows;
-               
-               protected subsampleDataOnFederatedWorker(long[] inIDs, int 
min_rows) {
+
+               protected subsampleDataOnFederatedWorker(long[] inIDs, int 
seed, int min_rows) {
                        super(inIDs);
+                       _seed = seed;
                        _min_rows = min_rows;
                }
 
@@ -91,7 +105,7 @@ public class SubsampleToMinFederatedScheme extends 
DataPartitionFederatedScheme
                        // subsample down to minimum
                        if(features.getNumRows() > _min_rows) {
                                // generate subsampling matrix
-                               MatrixBlock subsampleMatrixBlock = 
ParamservUtils.generateSubsampleMatrix(_min_rows, 
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+                               MatrixBlock subsampleMatrixBlock = 
ParamservUtils.generateSubsampleMatrix(_min_rows, 
Math.toIntExact(features.getNumRows()), _seed);
                                subsampleTo(features, subsampleMatrixBlock);
                                subsampleTo(labels, subsampleMatrixBlock);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a2b8d9f..a66e039 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,6 +19,17 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
 import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN;
 import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE;
 import static org.apache.sysds.parser.Statement.PS_EPOCHS;
@@ -32,18 +43,9 @@ import static 
org.apache.sysds.parser.Statement.PS_PARALLELISM;
 import static org.apache.sysds.parser.Statement.PS_SCHEME;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
 import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
-import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
+import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
+import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
+import static org.apache.sysds.parser.Statement.PS_SEED;
 
 import org.apache.commons.lang3.concurrent.BasicThreadFactory;
 import org.apache.commons.logging.Log;
@@ -121,37 +123,36 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private void runFederated(ExecutionContext ec) {
-               System.out.println("PARAMETER SERVER");
-               System.out.println("[+] Running in federated mode");
+               LOG.info("PARAMETER SERVER");
+               LOG.info("[+] Running in federated mode");
 
                // get inputs
-               PSFrequency freq = getFrequency();
-               PSUpdateType updateType = getUpdateType();
-               PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
-               FederatedPSScheme federatedPSScheme = getFederatedScheme();
                String updFunc = getParam(PS_UPDATE_FUN);
                String aggFunc = getParam(PS_AGGREGATION_FUN);
-
+               PSUpdateType updateType = getUpdateType();
+               PSFrequency freq = getFrequency();
+               FederatedPSScheme federatedPSScheme = getFederatedScheme();
+               PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
+               boolean weighing = getWeighing();
+               int seed = getSeed();
+
+               if( LOG.isInfoEnabled() ) {
+                       LOG.info("[+] Update Type: " + updateType);
+                       LOG.info("[+] Frequency: " + freq);
+                       LOG.info("[+] Data Partitioning: " + federatedPSScheme);
+                       LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
+                       LOG.info("[+] Weighing: " + weighing);
+                       LOG.info("[+] Seed: " + seed);
+               }
+               
                // partition federated data
-               DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(federatedPSScheme)
-                               
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), 
ec.getMatrixObject(getParam(PS_LABELS)));
-               List<MatrixObject> pFeatures = result._pFeatures;
-               List<MatrixObject> pLabels = result._pLabels;
+               DataPartitionFederatedScheme.Result result = new 
FederatedDataPartitioner(federatedPSScheme, seed)
+                       
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), 
ec.getMatrixObject(getParam(PS_LABELS)));
                int workerNum = result._workerNum;
 
-               // calculate runtime balancing
-               int numBatchesPerEpoch = 0;
-               if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
-                       numBatchesPerEpoch = (int) 
Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
-               } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) {
-                       numBatchesPerEpoch = (int) 
Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
-               } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
-                       numBatchesPerEpoch = (int) 
Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());
-               }
-
                // setup threading
                BasicThreadFactory factory = new BasicThreadFactory.Builder()
-                               
.namingPattern("workers-pool-thread-%d").build();
+                       .namingPattern("workers-pool-thread-%d").build();
                ExecutorService es = Executors.newFixedThreadPool(workerNum, 
factory);
 
                // Get the compiled execution context
@@ -166,10 +167,11 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                ListObject model = ec.getListObject(getParam(PS_MODEL));
                ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc, 
updateType, workerNum, model, aggServiceEC);
                // Create the local workers
-               int finalNumBatchesPerEpoch = numBatchesPerEpoch;
+               int finalNumBatchesPerEpoch = 
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
                List<FederatedPSControlThread> threads = IntStream.range(0, 
workerNum)
-                               .mapToObj(i -> new FederatedPSControlThread(i, 
updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(), 
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
-                               .collect(Collectors.toList());
+                       .mapToObj(i -> new FederatedPSControlThread(i, updFunc, 
freq, runtimeBalancing, weighing,
+                               getEpochs(), getBatchSize(), 
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
+                       .collect(Collectors.toList());
 
                if(workerNum != threads.size()) {
                        throw new 
DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning 
does not match threads!");
@@ -177,9 +179,9 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
                // Set features and lables for the control threads and write 
the program and instructions and hyperparams to the federated workers
                for (int i = 0; i < threads.size(); i++) {
-                       threads.get(i).setFeatures(pFeatures.get(i));
-                       threads.get(i).setLabels(pLabels.get(i));
-                       threads.get(i).setup();
+                       threads.get(i).setFeatures(result._pFeatures.get(i));
+                       threads.get(i).setLabels(result._pLabels.get(i));
+                       threads.get(i).setup(result._weighingFactors.get(i));
                }
 
                try {
@@ -395,14 +397,14 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
        }
 
        private PSRuntimeBalancing getRuntimeBalancing() {
-               if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) {
+               if (!getParameterMap().containsKey(PS_FED_RUNTIME_BALANCING)) {
                        return DEFAULT_RUNTIME_BALANCING;
                }
                try {
-                       return 
PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING));
+                       return 
PSRuntimeBalancing.valueOf(getParam(PS_FED_RUNTIME_BALANCING));
                } catch (IllegalArgumentException e) {
                        throw new DMLRuntimeException(String.format("Paramserv 
function: "
-                                       + "not support '%s' runtime 
balancing.", getParam(PS_RUNTIME_BALANCING)));
+                               + "not support '%s' runtime balancing.", 
getParam(PS_FED_RUNTIME_BALANCING)));
                }
        }
 
@@ -507,4 +509,32 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
                }
                return federated_scheme;
        }
+
+       /**
+        * Calculates the number of batches per epoch depending on the balance 
metrics and the runtime balancing
+        *
+        * @param runtimeBalancing the runtime balancing
+        * @param balanceMetrics the balance metrics calculated during data 
partitioning
+        * @return numBatchesPerEpoch
+        */
+       private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing, 
DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
+               int numBatchesPerEpoch = 0;
+               if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+                       numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
+               } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
+                               || runtimeBalancing == 
PSRuntimeBalancing.SCALE_BATCH) {
+                       numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
+               } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
+                       numBatchesPerEpoch = (int) 
Math.ceil(balanceMetrics._maxRows / (float) getBatchSize());
+               }
+               return numBatchesPerEpoch;
+       }
+
+       private boolean getWeighing() {
+               return getParameterMap().containsKey(PS_FED_WEIGHING) && 
Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
+       }
+
+       private int getSeed() {
+               return (getParameterMap().containsKey(PS_SEED)) ? 
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 6a52fc4..a00e8dc 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -54,44 +54,45 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
        private final String _freq;
        private final String _scheme;
        private final String _runtime_balancing;
+       private final String _weighing;
        private final String _data_distribution;
+       private final int _seed;
 
        // parameters
        @Parameterized.Parameters
        public static Collection<Object[]> parameters() {
                return Arrays.asList(new Object[][] {
                        // Network type, number of federated workers, data set 
size, batch size, epochs, learning rate, update type, update frequency
-
                        // basic functionality
-                       {"TwoNN", 2, 4, 1, 4, 0.01,       "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER", "CYCLE_AVG", "IMBALANCED"},
-                       {"CNN",   2, 4, 1, 4, 0.01,       "BSP", "EPOCH", 
"SHUFFLE",             "NONE" ,     "IMBALANCED"},
-                       {"CNN",   2, 4, 1, 4, 0.01,       "ASP", "BATCH", 
"REPLICATE_TO_MAX",    "RUN_MIN" ,  "IMBALANCED"},
-                       {"TwoNN", 2, 4, 1, 4, 0.01,       "ASP", "EPOCH", 
"BALANCE_TO_AVG",      "CYCLE_MAX", "IMBALANCED"},
-                       {"TwoNN", 5, 1000, 100, 2, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER", "NONE" ,     "BALANCED"},
-
-                       /*
-                               // runtime balancing
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "IMBALANCED"},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "IMBALANCED"},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "IMBALANCED"},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "IMBALANCED"},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "IMBALANCED"},
-                               {"TwoNN",       2, 4, 1, 4, 0.01,               
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "IMBALANCED"},
-
-                               // data partitioning
-                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "SHUFFLE",                              "CYCLE_AVG" ,   "IMBALANCED"},
-                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "REPLICATE_TO_MAX",             "NONE" ,                "IMBALANCED"},
-                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "SUBSAMPLE_TO_MIN",             "NONE" ,                "IMBALANCED"},
-                               {"TwoNN", 2, 4, 1, 1, 0.01,             "BSP", 
"BATCH", "BALANCE_TO_AVG",               "NONE" ,                "IMBALANCED"},
-
-                               // balanced tests
-                               {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "NONE" ,                "BALANCED"}
-                        */
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   200},
+                       {"CNN",         2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "SHUFFLE",                              "NONE" ,                
"true", "IMBALANCED",   200},
+                       {"CNN",         2, 4, 1, 4, 0.01,               "ASP", 
"BATCH", "REPLICATE_TO_MAX",     "RUN_MIN" ,     "true", "IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "ASP", 
"EPOCH", "BALANCE_TO_AVG",               "CYCLE_MAX" ,   "true", "IMBALANCED",  
 200},
+                       {"TwoNN",       5, 1000, 100, 2, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "true", "BALANCED",             
200},
+
+                       /* // runtime balancing
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "RUN_MIN" ,     "true", "IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_AVG" ,   "true", "IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "KEEP_DATA_ON_WORKER",  "CYCLE_MAX" ,   "true", "IMBALANCED",   200},
+
+                       // data partitioning
+                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "SHUFFLE",                              "CYCLE_AVG" ,   "true", 
"IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "REPLICATE_TO_MAX",             "NONE" ,                "true", 
"IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "SUBSAMPLE_TO_MIN",             "NONE" ,                "true", 
"IMBALANCED",   200},
+                       {"TwoNN",       2, 4, 1, 1, 0.01,               "BSP", 
"BATCH", "BALANCE_TO_AVG",               "NONE" ,                "true", 
"IMBALANCED",   200},
+
+                       // balanced tests
+                       {"CNN",         5, 1000, 100, 2, 0.01,  "BSP", "EPOCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "true", "BALANCED",             
200} */
+
                });
        }
 
        public FederatedParamservTest(String networkType, int 
numFederatedWorkers, int dataSetSize, int batch_size,
-               int epochs, double eta, String utype, String freq, String 
scheme, String runtime_balancing, String data_distribution) {
+               int epochs, double eta, String utype, String freq, String 
scheme, String runtime_balancing, String weighing, String data_distribution, 
int seed) {
+
                _networkType = networkType;
                _numFederatedWorkers = numFederatedWorkers;
                _dataSetSize = dataSetSize;
@@ -102,7 +103,9 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                _freq = freq;
                _scheme = scheme;
                _runtime_balancing = runtime_balancing;
+               _weighing = weighing;
                _data_distribution = data_distribution;
+               _seed = seed;
        }
 
        @Override
@@ -185,11 +188,12 @@ public class FederatedParamservTest extends 
AutomatedTestBase {
                                        "freq=" + _freq,
                                        "scheme=" + _scheme,
                                        "runtime_balancing=" + 
_runtime_balancing,
+                                       "weighing=" + _weighing,
                                        "network_type=" + _networkType,
                                        "channels=" + C,
                                        "hin=" + Hin,
                                        "win=" + Win,
-                                       "seed=" + 25));
+                                       "seed=" + _seed));
 
                        programArgs = programArgsList.toArray(new String[0]);
                        LOG.debug(runTest(null));
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml 
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 69c7e76..0f9ae63 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -163,7 +163,7 @@ train = function(matrix[double] X, matrix[double] y,
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing,
+                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing, string weighing,
                  double eta, int C, int Hin, int Win,
                  int seed = -1)
     return (list[unknown] model) {
@@ -211,7 +211,7 @@ train_paramserv = function(matrix[double] X, matrix[double] 
y,
     upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
     
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
-    scheme=scheme, runtime_balancing=runtime_balancing, 
hyperparams=hyperparams)
+    scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, 
hyperparams=hyperparams, seed=seed)
 }
 
 /*
diff --git 
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml 
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 10d2cc7..5176cca 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -26,10 +26,12 @@ 
source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
 features = read($features)
 labels = read($labels)
 
+print($weighing)
+
 if($network_type == "TwoNN") {
-  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $eta, $seed)
+  model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $weighing, $eta, $seed)
 }
 else {
-  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $eta, $channels, $hin, $win, $seed)
+  model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0), 
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme, 
$runtime_balancing, $weighing, $eta, $channels, $hin, $win, $seed)
 }
 print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml 
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index 9bd49d8..a6dc6f2 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -125,7 +125,7 @@ train = function(matrix[double] X, matrix[double] y,
  */
 train_paramserv = function(matrix[double] X, matrix[double] y,
                  matrix[double] X_val, matrix[double] y_val,
-                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing,
+                 int num_workers, int epochs, string utype, string freq, int 
batch_size, string scheme, string runtime_balancing, string weighing,
                  double eta, int seed = -1)
     return (list[unknown] model) {
 
@@ -155,7 +155,7 @@ train_paramserv = function(matrix[double] X, matrix[double] 
y,
     
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
     
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
     k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
-    scheme=scheme, runtime_balancing=runtime_balancing, 
hyperparams=hyperparams)
+    scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing, 
hyperparams=hyperparams, seed=seed)
 }
 
 /*

Reply via email to