This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 460c394 [SYSTEMDS-2550] Federated parameter server scaling and weight
handling
460c394 is described below
commit 460c3945899ce0fc7fd0c0fd92bb5f3d32a25f7a
Author: Tobias Rieger <[email protected]>
AuthorDate: Sat Jan 9 22:57:15 2021 +0100
[SYSTEMDS-2550] Federated parameter server scaling and weight handling
Closes #1141.
---
.../ParameterizedBuiltinFunctionExpression.java | 8 +-
.../java/org/apache/sysds/parser/Statement.java | 6 +-
.../runtime/compress/colgroup/ColGroupValue.java | 2 +-
.../paramserv/FederatedPSControlThread.java | 165 ++++++++++++---------
.../paramserv/dp/BalanceToAvgFederatedScheme.java | 31 ++--
.../paramserv/dp/DataPartitionFederatedScheme.java | 43 ++++--
.../paramserv/dp/FederatedDataPartitioner.java | 7 +-
.../dp/KeepDataOnWorkerFederatedScheme.java | 13 +-
.../dp/ReplicateToMaxFederatedScheme.java | 26 +++-
.../paramserv/dp/ShuffleFederatedScheme.java | 24 ++-
.../dp/SubsampleToMinFederatedScheme.java | 26 +++-
.../cp/ParamservBuiltinCPInstruction.java | 116 +++++++++------
.../paramserv/FederatedParamservTest.java | 58 ++++----
.../scripts/functions/federated/paramserv/CNN.dml | 4 +-
.../federated/paramserv/FederatedParamservTest.dml | 6 +-
.../functions/federated/paramserv/TwoNN.dml | 4 +-
16 files changed, 342 insertions(+), 197 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 5171f21..05bfc48 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -289,8 +289,8 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
Set<String> valid = CollectionUtils.asSet(Statement.PS_MODEL,
Statement.PS_FEATURES, Statement.PS_LABELS,
Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS,
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
Statement.PS_MODE, Statement.PS_UPDATE_TYPE,
Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
- Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM,
Statement.PS_SCHEME, Statement.PS_RUNTIME_BALANCING,
- Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING);
+ Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM,
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
+ Statement.PS_FED_WEIGHING, Statement.PS_HYPER_PARAMS,
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
// check existence and correctness of parameters
@@ -308,9 +308,11 @@ public class ParameterizedBuiltinFunctionExpression
extends DataIdentifier
checkDataValueType(true, fname, Statement.PS_BATCH_SIZE,
DataType.SCALAR, ValueType.INT64, conditional);
checkDataValueType(true, fname, Statement.PS_PARALLELISM,
DataType.SCALAR, ValueType.INT64, conditional);
checkStringParam(true, fname, Statement.PS_SCHEME, conditional);
- checkStringParam(true, fname, Statement.PS_RUNTIME_BALANCING,
conditional);
+ checkStringParam(true, fname,
Statement.PS_FED_RUNTIME_BALANCING, conditional);
+ checkStringParam(true, fname, Statement.PS_FED_WEIGHING,
conditional);
checkDataValueType(true, fname, Statement.PS_HYPER_PARAMS,
DataType.LIST, ValueType.UNKNOWN, conditional);
checkStringParam(true, fname, Statement.PS_CHECKPOINTING,
conditional);
+ checkDataValueType(true, fname, Statement.PS_SEED,
DataType.SCALAR, ValueType.INT64, conditional);
// set output characteristics
output.setDataType(DataType.LIST);
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java
b/src/main/java/org/apache/sysds/parser/Statement.java
index 6767d85..9104246 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -70,6 +70,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_AGGREGATION_FUN = "agg";
public static final String PS_MODE = "mode";
public static final String PS_GRADIENTS = "gradients";
+ public static final String PS_SEED = "seed";
public enum PSModeType {
FEDERATED, LOCAL, REMOTE_SPARK
}
@@ -87,9 +88,10 @@ public abstract class Statement implements ParseInfo
public enum PSFrequency {
BATCH, EPOCH
}
- public static final String PS_RUNTIME_BALANCING = "runtime_balancing";
+ public static final String PS_FED_WEIGHING = "weighing";
+ public static final String PS_FED_RUNTIME_BALANCING =
"runtime_balancing";
public enum PSRuntimeBalancing {
- NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH,
SCALE_BATCH_AND_WEIGH
+ NONE, RUN_MIN, CYCLE_AVG, CYCLE_MAX, SCALE_BATCH
}
public static final String PS_EPOCHS = "epochs";
public static final String PS_BATCH_SIZE = "batchsize";
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
index 54a45d0..f09b5c2 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupValue.java
@@ -228,7 +228,7 @@ public abstract class ColGroupValue extends ColGroup
implements Cloneable {
return val;
}
- protected final double sumValuesSparse(int valIx, SparseRow[] rows,
double[] dictVals, int rowsIndex) {
+ protected static double sumValuesSparse(int valIx, SparseRow[] rows,
double[] dictVals, int rowsIndex) {
throw new NotImplementedException("This Method was implemented
incorrectly");
// final int numCols = getNumCols();
// final int valOff = valIx * numCols;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 393b131..48249db 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -24,6 +24,8 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
+import org.apache.sysds.parser.Statement.PSFrequency;
+import org.apache.sysds.parser.Statement.PSRuntimeBalancing;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
@@ -37,13 +39,17 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.IntObject;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.util.ProgramConverter;
import java.util.ArrayList;
@@ -58,21 +64,29 @@ import static
org.apache.sysds.runtime.util.ProgramConverter.*;
public class FederatedPSControlThread extends PSWorker implements
Callable<Void> {
private static final long serialVersionUID = 6846648059569648791L;
protected static final Log LOG =
LogFactory.getLog(ParamServer.class.getName());
-
- Statement.PSRuntimeBalancing _runtimeBalancing;
- FederatedData _featuresData;
- FederatedData _labelsData;
- final long _localStartBatchNumVarID;
- final long _modelVarID;
- int _numBatchesPerGlobalEpoch;
- int _possibleBatchesPerLocalEpoch;
- boolean _cycleStartAt0 = false;
-
- public FederatedPSControlThread(int workerID, String updFunc,
Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int
epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec,
ParamServer ps) {
+
+ private FederatedData _featuresData;
+ private FederatedData _labelsData;
+ private final long _localStartBatchNumVarID;
+ private final long _modelVarID;
+
+ // runtime balancing
+ private PSRuntimeBalancing _runtimeBalancing;
+ private int _numBatchesPerEpoch;
+ private int _possibleBatchesPerLocalEpoch;
+ private boolean _weighing;
+ private double _weighingFactor = 1;
+ private boolean _cycleStartAt0 = false;
+
+ public FederatedPSControlThread(int workerID, String updFunc,
Statement.PSFrequency freq,
+ PSRuntimeBalancing runtimeBalancing, boolean weighing, int
epochs, long batchSize,
+ int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer
ps)
+ {
super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
- _numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
+ _numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
+ _weighing = weighing;
// generate the IDs for model and batch counter. These get
overwritten on the federated worker each time
_localStartBatchNumVarID = FederationUtils.getNextFedDataID();
_modelVarID = FederationUtils.getNextFedDataID();
@@ -80,65 +94,72 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
/**
* Sets up the federated worker and control thread
+ *
+ * @param weighingFactor Gradients from this worker will be multiplied
by this factor if weighing is enabled
*/
- public void setup() {
+ public void setup(double weighingFactor) {
// prepare features and labels
_featuresData = (FederatedData)
_features.getFedMapping().getMap().values().toArray()[0];
_labelsData = (FederatedData)
_labels.getFedMapping().getMap().values().toArray()[0];
- // calculate number of batches and get data size
+ // weighing factor is always set, but only used when weighing
is specified
+ _weighingFactor = weighingFactor;
+
+ // different runtime balancing calculations
long dataSize = _features.getNumRows();
- _possibleBatchesPerLocalEpoch = (int) Math.ceil((double)
dataSize / _batchSize);
- if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN
- || _runtimeBalancing ==
Statement.PSRuntimeBalancing.CYCLE_AVG
- || _runtimeBalancing ==
Statement.PSRuntimeBalancing.CYCLE_MAX)) {
- _numBatchesPerGlobalEpoch =
_possibleBatchesPerLocalEpoch;
+
+ // calculate scaled batch size if balancing via batch size.
+ // In some cases there will be some cycling
+ if(_runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
+ _batchSize = (int) Math.ceil((double) dataSize /
_numBatchesPerEpoch);
}
- if(_runtimeBalancing ==
Statement.PSRuntimeBalancing.SCALE_BATCH
- || _runtimeBalancing ==
Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
- throw new NotImplementedException();
+ // Calculate possible batches with batch size
+ _possibleBatchesPerLocalEpoch = (int) Math.ceil((double)
dataSize / _batchSize);
+
+ // If no runtime balancing is specified, just run possible
number of batches
+ // WARNING: Will get stuck on miss match
+ if(_runtimeBalancing == PSRuntimeBalancing.NONE) {
+ _numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;
}
+ LOG.info("Setup config for worker " + this.getWorkerName());
+ LOG.info("Batch size: " + _batchSize + " possible batches: " +
_possibleBatchesPerLocalEpoch
+ + " batches to run: " + _numBatchesPerEpoch + "
weighing factor: " + _weighingFactor);
+
// serialize program
// create program blocks for the instruction filtering
String programSerialized;
- ArrayList<ProgramBlock> programBlocks = new ArrayList<>();
+ ArrayList<ProgramBlock> pbs = new ArrayList<>();
BasicProgramBlock gradientProgramBlock = new
BasicProgramBlock(_ec.getProgram());
gradientProgramBlock.setInstructions(new
ArrayList<>(Arrays.asList(_inst)));
- programBlocks.add(gradientProgramBlock);
+ pbs.add(gradientProgramBlock);
- if(_freq == Statement.PSFrequency.EPOCH) {
+ if(_freq == PSFrequency.EPOCH) {
BasicProgramBlock aggProgramBlock = new
BasicProgramBlock(_ec.getProgram());
aggProgramBlock.setInstructions(new
ArrayList<>(Arrays.asList(_ps.getAggInst())));
- programBlocks.add(aggProgramBlock);
+ pbs.add(aggProgramBlock);
}
- StringBuilder sb = new StringBuilder();
- sb.append(PROG_BEGIN);
- sb.append( NEWLINE );
- sb.append(ProgramConverter.serializeProgram(_ec.getProgram(),
- programBlocks,
- new HashMap<>(),
- false
- ));
- sb.append(PROG_END);
- programSerialized = sb.toString();
+ programSerialized = InstructionUtils.concatStrings(
+ PROG_BEGIN, NEWLINE,
+ ProgramConverter.serializeProgram(_ec.getProgram(),
pbs, new HashMap<>(), false),
+ PROG_END);
// write program and meta data to worker
Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.EXEC_UDF,
_featuresData.getVarID(),
new SetupFederatedWorker(_batchSize,
- dataSize,
- _possibleBatchesPerLocalEpoch,
- programSerialized,
- _inst.getNamespace(),
- _inst.getFunctionName(),
-
_ps.getAggInst().getFunctionName(),
-
_ec.getListObject("hyperparams"),
- _localStartBatchNumVarID,
- _modelVarID
+ dataSize,
+ _possibleBatchesPerLocalEpoch,
+ programSerialized,
+ _inst.getNamespace(),
+ _inst.getFunctionName(),
+ _ps.getAggInst().getFunctionName(),
+ _ec.getListObject("hyperparams"),
+ _localStartBatchNumVarID,
+ _modelVarID
)
));
@@ -286,12 +307,23 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
return _ps.pull(_workerID);
}
- protected void pushGradients(ListObject gradients) {
+ protected void scaleAndPushGradients(ListObject gradients) {
+ // scale gradients - must only include MatrixObjects
+ if(_weighing && _weighingFactor != 1) {
+ gradients.getData().parallelStream().forEach((matrix)
-> {
+ MatrixObject matrixObject = (MatrixObject)
matrix;
+ MatrixBlock input =
matrixObject.acquireReadAndRelease().scalarOperations(
+ new
RightScalarOperator(Multiply.getMultiplyFnObject(), _weighingFactor), new
MatrixBlock());
+ matrixObject.acquireModify(input);
+ matrixObject.release();
+ });
+ }
+
// Push the gradients to ps
_ps.push(_workerID, gradients);
}
- static protected int getNextLocalBatchNum(int currentLocalBatchNumber,
int possibleBatchesPerLocalEpoch) {
+ protected static int getNextLocalBatchNum(int currentLocalBatchNumber,
int possibleBatchesPerLocalEpoch) {
return currentLocalBatchNumber % possibleBatchesPerLocalEpoch;
}
@@ -300,18 +332,18 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
*/
protected void computeWithBatchUpdates() {
for (int epochCounter = 0; epochCounter < _epochs;
epochCounter++) {
- int currentLocalBatchNumber = (_cycleStartAt0) ? 0 :
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+ int currentLocalBatchNumber = (_cycleStartAt0) ? 0 :
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
- for (int batchCounter = 0; batchCounter <
_numBatchesPerGlobalEpoch; batchCounter++) {
+ for (int batchCounter = 0; batchCounter <
_numBatchesPerEpoch; batchCounter++) {
int localStartBatchNum =
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
ListObject gradients =
computeGradientsForNBatches(model, 1, localStartBatchNum);
- pushGradients(gradients);
+ scaleAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
+ LOG.info("[+] " + this.getWorkerName() + "
completed BATCH " + localStartBatchNum);
}
- if( LOG.isInfoEnabled() )
- LOG.info("[+] " + this.getWorkerName() + "
completed epoch " + epochCounter);
+ LOG.info("[+] " + this.getWorkerName() + " ---
completed EPOCH " + epochCounter);
}
}
@@ -327,15 +359,14 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
*/
protected void computeWithEpochUpdates() {
for (int epochCounter = 0; epochCounter < _epochs;
epochCounter++) {
- int localStartBatchNum = (_cycleStartAt0) ? 0 :
_numBatchesPerGlobalEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
+ int localStartBatchNum = (_cycleStartAt0) ? 0 :
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
// Pull the global parameters from ps
ListObject model = pullModel();
- ListObject gradients =
computeGradientsForNBatches(model, _numBatchesPerGlobalEpoch,
localStartBatchNum, true);
- pushGradients(gradients);
-
- if( LOG.isInfoEnabled() )
- LOG.info("[+] " + this.getWorkerName() + "
completed epoch " + epochCounter);
+ ListObject gradients =
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum,
true);
+ scaleAndPushGradients(gradients);
+
+ LOG.info("[+] " + this.getWorkerName() + " ---
completed EPOCH " + epochCounter);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
@@ -424,12 +455,12 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ArrayList<DataIdentifier> inputs =
func.getInputParams();
ArrayList<DataIdentifier> outputs =
func.getOutputParams();
CPOperand[] boundInputs = inputs.stream()
- .map(input -> new
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
- .toArray(CPOperand[]::new);
+ .map(input -> new CPOperand(input.getName(),
input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
ArrayList<String> outputNames =
outputs.stream().map(DataIdentifier::getName)
-
.collect(Collectors.toCollection(ArrayList::new));
+
.collect(Collectors.toCollection(ArrayList::new));
Instruction gradientsInstruction = new
FunctionCallCPInstruction(namespace, gradientsFunctionName, false, boundInputs,
- func.getInputParamNames(), outputNames,
"gradient function");
+ func.getInputParamNames(), outputNames,
"gradient function");
DataIdentifier gradientsOutput = outputs.get(0);
// recreate aggregation instruction and output if needed
@@ -440,12 +471,12 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
inputs = func.getInputParams();
outputs = func.getOutputParams();
boundInputs = inputs.stream()
- .map(input -> new
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
- .toArray(CPOperand[]::new);
+ .map(input -> new
CPOperand(input.getName(), input.getValueType(), input.getDataType()))
+ .toArray(CPOperand[]::new);
outputNames =
outputs.stream().map(DataIdentifier::getName)
-
.collect(Collectors.toCollection(ArrayList::new));
+
.collect(Collectors.toCollection(ArrayList::new));
aggregationInstruction = new
FunctionCallCPInstruction(namespace, aggregationFuctionName, false, boundInputs,
- func.getInputParamNames(),
outputNames, "aggregation function");
+ func.getInputParamNames(), outputNames,
"aggregation function");
aggregationOutput = outputs.get(0);
}
@@ -492,8 +523,6 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ParamservUtils.cleanupData(ec,
Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec,
Statement.PS_LABELS);
ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
- if( LOG.isInfoEnabled() )
- LOG.info("[+]" + " completed batch " +
localBatchNum);
}
// model clean up
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
index 460faba..34e94f0 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/BalanceToAvgFederatedScheme.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.controlprogram.paramserv.dp;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
@@ -35,13 +34,25 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Balance to Avg Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the
global average number of examples is taken
+ * and the worker subsamples or replicates data to match that number of
examples. See the other federated schemes.
+ *
+ * Then all entries in the federation map of the input matrix are separated
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class BalanceToAvgFederatedScheme extends DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int
seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ BalanceMetrics balanceMetricsBefore =
getBalanceMetrics(pFeatures);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures,
balanceMetricsBefore);
- int average_num_rows = (int)
Math.round(pFeatures.stream().map(CacheableData::getNumRows).mapToInt(Long::intValue).average().orElse(Double.NaN));
+ int average_num_rows = (int) balanceMetricsBefore._avgRows;
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
@@ -49,7 +60,7 @@ public class BalanceToAvgFederatedScheme extends
DataPartitionFederatedScheme {
FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new
balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, average_num_rows)));
+ featuresData.getVarID(), new
balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, seed, average_num_rows)));
try {
FederatedResponse response = udfResponse.get();
@@ -66,7 +77,7 @@ public class BalanceToAvgFederatedScheme extends
DataPartitionFederatedScheme {
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weighingFactors);
}
/**
@@ -74,10 +85,12 @@ public class BalanceToAvgFederatedScheme extends
DataPartitionFederatedScheme {
*/
private static class balanceDataOnFederatedWorker extends FederatedUDF {
private static final long serialVersionUID =
6631958250346625546L;
+ private final int _seed;
private final int _average_num_rows;
-
- protected balanceDataOnFederatedWorker(long[] inIDs, int
average_num_rows) {
+
+ protected balanceDataOnFederatedWorker(long[] inIDs, int seed,
int average_num_rows) {
super(inIDs);
+ _seed = seed;
_average_num_rows = average_num_rows;
}
@@ -88,14 +101,14 @@ public class BalanceToAvgFederatedScheme extends
DataPartitionFederatedScheme {
if(features.getNumRows() > _average_num_rows) {
// generate subsampling matrix
- MatrixBlock subsampleMatrixBlock =
ParamservUtils.generateSubsampleMatrix(_average_num_rows,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock subsampleMatrixBlock =
ParamservUtils.generateSubsampleMatrix(_average_num_rows,
Math.toIntExact(features.getNumRows()), _seed);
subsampleTo(features, subsampleMatrixBlock);
subsampleTo(labels, subsampleMatrixBlock);
}
else if(features.getNumRows() < _average_num_rows) {
int num_rows_needed = _average_num_rows -
Math.toIntExact(features.getNumRows());
// generate replication matrix
- MatrixBlock replicateMatrixBlock =
ParamservUtils.generateReplicationMatrix(num_rows_needed,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock replicateMatrixBlock =
ParamservUtils.generateReplicationMatrix(num_rows_needed,
Math.toIntExact(features.getNumRows()), _seed);
replicateTo(features, replicateMatrixBlock);
replicateTo(labels, replicateMatrixBlock);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
index f5c9638..e00923e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/DataPartitionFederatedScheme.java
@@ -45,16 +45,31 @@ public abstract class DataPartitionFederatedScheme {
public final List<MatrixObject> _pLabels;
public final int _workerNum;
public final BalanceMetrics _balanceMetrics;
+ public final List<Double> _weighingFactors;
- public Result(List<MatrixObject> pFeatures, List<MatrixObject>
pLabels, int workerNum, BalanceMetrics balanceMetrics) {
- this._pFeatures = pFeatures;
- this._pLabels = pLabels;
- this._workerNum = workerNum;
- this._balanceMetrics = balanceMetrics;
+
+ public Result(List<MatrixObject> pFeatures, List<MatrixObject>
pLabels, int workerNum, BalanceMetrics balanceMetrics, List<Double>
weighingFactors) {
+ _pFeatures = pFeatures;
+ _pLabels = pLabels;
+ _workerNum = workerNum;
+ _balanceMetrics = balanceMetrics;
+ _weighingFactors = weighingFactors;
}
}
- public abstract Result doPartitioning(MatrixObject features,
MatrixObject labels);
+ public static final class BalanceMetrics {
+ public final long _minRows;
+ public final long _avgRows;
+ public final long _maxRows;
+
+ public BalanceMetrics(long minRows, long avgRows, long maxRows)
{
+ _minRows = minRows;
+ _avgRows = avgRows;
+ _maxRows = maxRows;
+ }
+ }
+
+ public abstract Result partition(MatrixObject features, MatrixObject
labels, int seed);
/**
* Takes a row federated Matrix and slices it into a matrix for each
worker
@@ -110,16 +125,12 @@ public abstract class DataPartitionFederatedScheme {
return new BalanceMetrics(minRows, sum / slices.size(),
maxRows);
}
- public static final class BalanceMetrics {
- public final long _minRows;
- public final long _avgRows;
- public final long _maxRows;
-
- public BalanceMetrics(long minRows, long avgRows, long maxRows)
{
- this._minRows = minRows;
- this._avgRows = avgRows;
- this._maxRows = maxRows;
- }
+ static List<Double> getWeighingFactors(List<MatrixObject> pFeatures,
BalanceMetrics balanceMetrics) {
+ List<Double> weighingFactors = new ArrayList<>();
+ pFeatures.forEach((feature) -> {
+ weighingFactors.add((double) feature.getNumRows() /
balanceMetrics._avgRows);
+ });
+ return weighingFactors;
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
index d1ebb6c..ce2f954 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/FederatedDataPartitioner.java
@@ -24,10 +24,11 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
public class FederatedDataPartitioner {
-
private final DataPartitionFederatedScheme _scheme;
+ private final int _seed;
- public FederatedDataPartitioner(Statement.FederatedPSScheme scheme) {
+ public FederatedDataPartitioner(Statement.FederatedPSScheme scheme, int
seed) {
+ _seed = seed;
switch (scheme) {
case KEEP_DATA_ON_WORKER:
_scheme = new KeepDataOnWorkerFederatedScheme();
@@ -50,6 +51,6 @@ public class FederatedDataPartitioner {
}
public DataPartitionFederatedScheme.Result doPartitioning(MatrixObject
features, MatrixObject labels) {
- return _scheme.doPartitioning(features, labels);
+ return _scheme.partition(features, labels, _seed);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
index e306f25..afbaf4d 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/KeepDataOnWorkerFederatedScheme.java
@@ -22,11 +22,20 @@ package
org.apache.sysds.runtime.controlprogram.paramserv.dp;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import java.util.List;
+/**
+ * Keep Data on Worker Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data
which is already on the workers.
+ * All entries in the federation map of the input matrix are separated into
MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class KeepDataOnWorkerFederatedScheme extends
DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int
seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures,
balanceMetrics);
+ return new Result(pFeatures, pLabels, pFeatures.size(),
balanceMetrics, weighingFactors);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
index 068cfa9..a1b8f6c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ReplicateToMaxFederatedScheme.java
@@ -34,11 +34,23 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Replicate to Max Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the
global maximum number of examples is taken
+ * and the worker replicates data to match that number of examples. The
generation is done by multiplying with a
+ * Permutation Matrix with a global seed. These selected examples are appended
to the original data.
+ *
+ * Then all entries in the federation map of the input matrix are separated
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int
seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures,
getBalanceMetrics(pFeatures));
int max_rows = 0;
for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@ public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme
FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new
replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, max_rows)));
+ featuresData.getVarID(), new
replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, seed, max_rows)));
try {
FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@ public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weighingFactors);
}
/**
@@ -76,10 +88,12 @@ public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme
*/
private static class replicateDataOnFederatedWorker extends
FederatedUDF {
private static final long serialVersionUID =
-6930898456315100587L;
+ private final int _seed;
private final int _max_rows;
-
- protected replicateDataOnFederatedWorker(long[] inIDs, int
max_rows) {
+
+ protected replicateDataOnFederatedWorker(long[] inIDs, int
seed, int max_rows) {
super(inIDs);
+ _seed = seed;
_max_rows = max_rows;
}
@@ -92,7 +106,7 @@ public class ReplicateToMaxFederatedScheme extends
DataPartitionFederatedScheme
if(features.getNumRows() < _max_rows) {
int num_rows_needed = _max_rows -
Math.toIntExact(features.getNumRows());
// generate replication matrix
- MatrixBlock replicateMatrixBlock =
ParamservUtils.generateReplicationMatrix(num_rows_needed,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock replicateMatrixBlock =
ParamservUtils.generateReplicationMatrix(num_rows_needed,
Math.toIntExact(features.getNumRows()), _seed);
replicateTo(features, replicateMatrixBlock);
replicateTo(labels, replicateMatrixBlock);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
index 65ef69d..1920593 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/ShuffleFederatedScheme.java
@@ -33,11 +33,23 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Shuffle Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case it is
shuffled by generating a permutation
+ * matrix with a global seed and doing a mat mult.
+ *
+ * Then all entries in the federation map of the input matrix are separated
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class ShuffleFederatedScheme extends DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int
seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ BalanceMetrics balanceMetrics = getBalanceMetrics(pFeatures);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures,
balanceMetrics);
for(int i = 0; i < pFeatures.size(); i++) {
// Works, because the map contains a single entry
@@ -45,7 +57,7 @@ public class ShuffleFederatedScheme extends
DataPartitionFederatedScheme {
FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new
shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()})));
+ featuresData.getVarID(), new
shuffleDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, seed)));
try {
FederatedResponse response = udfResponse.get();
@@ -57,7 +69,7 @@ public class ShuffleFederatedScheme extends
DataPartitionFederatedScheme {
}
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(),
balanceMetrics, weighingFactors);
}
/**
@@ -65,9 +77,11 @@ public class ShuffleFederatedScheme extends
DataPartitionFederatedScheme {
*/
private static class shuffleDataOnFederatedWorker extends FederatedUDF {
private static final long serialVersionUID =
3228664618781333325L;
+ private final int _seed;
- protected shuffleDataOnFederatedWorker(long[] inIDs) {
+ protected shuffleDataOnFederatedWorker(long[] inIDs, int seed) {
super(inIDs);
+ _seed = seed;
}
@Override
@@ -76,7 +90,7 @@ public class ShuffleFederatedScheme extends
DataPartitionFederatedScheme {
MatrixObject labels = (MatrixObject) data[1];
// generate permutation matrix
- MatrixBlock permutationMatrixBlock =
ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()),
System.currentTimeMillis());
+ MatrixBlock permutationMatrixBlock =
ParamservUtils.generatePermutation(Math.toIntExact(features.getNumRows()),
_seed);
shuffle(features, permutationMatrixBlock);
shuffle(labels, permutationMatrixBlock);
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
index 9b62cc8..937c37e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/dp/SubsampleToMinFederatedScheme.java
@@ -34,11 +34,23 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
import java.util.List;
import java.util.concurrent.Future;
+/**
+ * Subsample to Min Federated scheme
+ *
+ * When the parameter server runs in federated mode it cannot pull in the data
which is already on the workers.
+ * Therefore, a UDF is sent to manipulate the data locally. In this case the
global minimum number of examples is taken
+ * and the worker subsamples data to match that number of examples. The
subsampling is done by multiplying with a
+ * Permutation Matrix with a global seed.
+ *
+ * Then all entries in the federation map of the input matrix are separated
into MatrixObjects and returned as a list.
+ * Only supports row federated matrices atm.
+ */
public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme {
@Override
- public Result doPartitioning(MatrixObject features, MatrixObject
labels) {
+ public Result partition(MatrixObject features, MatrixObject labels, int
seed) {
List<MatrixObject> pFeatures = sliceFederatedMatrix(features);
List<MatrixObject> pLabels = sliceFederatedMatrix(labels);
+ List<Double> weighingFactors = getWeighingFactors(pFeatures,
getBalanceMetrics(pFeatures));
int min_rows = Integer.MAX_VALUE;
for (MatrixObject pFeature : pFeatures) {
@@ -51,7 +63,7 @@ public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme
FederatedData labelsData = (FederatedData)
pLabels.get(i).getFedMapping().getMap().values().toArray()[0];
Future<FederatedResponse> udfResponse =
featuresData.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- featuresData.getVarID(), new
subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, min_rows)));
+ featuresData.getVarID(), new
subsampleDataOnFederatedWorker(new long[]{featuresData.getVarID(),
labelsData.getVarID()}, seed, min_rows)));
try {
FederatedResponse response = udfResponse.get();
@@ -68,7 +80,7 @@ public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme
pLabels.get(i).updateDataCharacteristics(update);
}
- return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures));
+ return new Result(pFeatures, pLabels, pFeatures.size(),
getBalanceMetrics(pFeatures), weighingFactors);
}
/**
@@ -76,10 +88,12 @@ public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme
*/
private static class subsampleDataOnFederatedWorker extends
FederatedUDF {
private static final long serialVersionUID =
2213790859544004286L;
+ private final int _seed;
private final int _min_rows;
-
- protected subsampleDataOnFederatedWorker(long[] inIDs, int
min_rows) {
+
+ protected subsampleDataOnFederatedWorker(long[] inIDs, int
seed, int min_rows) {
super(inIDs);
+ _seed = seed;
_min_rows = min_rows;
}
@@ -91,7 +105,7 @@ public class SubsampleToMinFederatedScheme extends
DataPartitionFederatedScheme
// subsample down to minimum
if(features.getNumRows() > _min_rows) {
// generate subsampling matrix
- MatrixBlock subsampleMatrixBlock =
ParamservUtils.generateSubsampleMatrix(_min_rows,
Math.toIntExact(features.getNumRows()), System.currentTimeMillis());
+ MatrixBlock subsampleMatrixBlock =
ParamservUtils.generateSubsampleMatrix(_min_rows,
Math.toIntExact(features.getNumRows()), _seed);
subsampleTo(features, subsampleMatrixBlock);
subsampleTo(labels, subsampleMatrixBlock);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index a2b8d9f..a66e039 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -19,6 +19,17 @@
package org.apache.sysds.runtime.instructions.cp;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
import static org.apache.sysds.parser.Statement.PS_AGGREGATION_FUN;
import static org.apache.sysds.parser.Statement.PS_BATCH_SIZE;
import static org.apache.sysds.parser.Statement.PS_EPOCHS;
@@ -32,18 +43,9 @@ import static
org.apache.sysds.parser.Statement.PS_PARALLELISM;
import static org.apache.sysds.parser.Statement.PS_SCHEME;
import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
import static org.apache.sysds.parser.Statement.PS_UPDATE_TYPE;
-import static org.apache.sysds.parser.Statement.PS_RUNTIME_BALANCING;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
+import static org.apache.sysds.parser.Statement.PS_FED_RUNTIME_BALANCING;
+import static org.apache.sysds.parser.Statement.PS_FED_WEIGHING;
+import static org.apache.sysds.parser.Statement.PS_SEED;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.commons.logging.Log;
@@ -121,37 +123,36 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
private void runFederated(ExecutionContext ec) {
- System.out.println("PARAMETER SERVER");
- System.out.println("[+] Running in federated mode");
+ LOG.info("PARAMETER SERVER");
+ LOG.info("[+] Running in federated mode");
// get inputs
- PSFrequency freq = getFrequency();
- PSUpdateType updateType = getUpdateType();
- PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
- FederatedPSScheme federatedPSScheme = getFederatedScheme();
String updFunc = getParam(PS_UPDATE_FUN);
String aggFunc = getParam(PS_AGGREGATION_FUN);
-
+ PSUpdateType updateType = getUpdateType();
+ PSFrequency freq = getFrequency();
+ FederatedPSScheme federatedPSScheme = getFederatedScheme();
+ PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
+ boolean weighing = getWeighing();
+ int seed = getSeed();
+
+ if( LOG.isInfoEnabled() ) {
+ LOG.info("[+] Update Type: " + updateType);
+ LOG.info("[+] Frequency: " + freq);
+ LOG.info("[+] Data Partitioning: " + federatedPSScheme);
+ LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
+ LOG.info("[+] Weighing: " + weighing);
+ LOG.info("[+] Seed: " + seed);
+ }
+
// partition federated data
- DataPartitionFederatedScheme.Result result = new
FederatedDataPartitioner(federatedPSScheme)
-
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)),
ec.getMatrixObject(getParam(PS_LABELS)));
- List<MatrixObject> pFeatures = result._pFeatures;
- List<MatrixObject> pLabels = result._pLabels;
+ DataPartitionFederatedScheme.Result result = new
FederatedDataPartitioner(federatedPSScheme, seed)
+
.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)),
ec.getMatrixObject(getParam(PS_LABELS)));
int workerNum = result._workerNum;
- // calculate runtime balancing
- int numBatchesPerEpoch = 0;
- if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
- numBatchesPerEpoch = (int)
Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
- } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) {
- numBatchesPerEpoch = (int)
Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
- } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
- numBatchesPerEpoch = (int)
Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());
- }
-
// setup threading
BasicThreadFactory factory = new BasicThreadFactory.Builder()
-
.namingPattern("workers-pool-thread-%d").build();
+ .namingPattern("workers-pool-thread-%d").build();
ExecutorService es = Executors.newFixedThreadPool(workerNum,
factory);
// Get the compiled execution context
@@ -166,10 +167,11 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ListObject model = ec.getListObject(getParam(PS_MODEL));
ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc,
updateType, workerNum, model, aggServiceEC);
// Create the local workers
- int finalNumBatchesPerEpoch = numBatchesPerEpoch;
+ int finalNumBatchesPerEpoch =
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
List<FederatedPSControlThread> threads = IntStream.range(0,
workerNum)
- .mapToObj(i -> new FederatedPSControlThread(i,
updFunc, freq, runtimeBalancing, getEpochs(), getBatchSize(),
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
- .collect(Collectors.toList());
+ .mapToObj(i -> new FederatedPSControlThread(i, updFunc,
freq, runtimeBalancing, weighing,
+ getEpochs(), getBatchSize(),
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
+ .collect(Collectors.toList());
if(workerNum != threads.size()) {
throw new
DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning
does not match threads!");
@@ -177,9 +179,9 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// Set features and lables for the control threads and write
the program and instructions and hyperparams to the federated workers
for (int i = 0; i < threads.size(); i++) {
- threads.get(i).setFeatures(pFeatures.get(i));
- threads.get(i).setLabels(pLabels.get(i));
- threads.get(i).setup();
+ threads.get(i).setFeatures(result._pFeatures.get(i));
+ threads.get(i).setLabels(result._pLabels.get(i));
+ threads.get(i).setup(result._weighingFactors.get(i));
}
try {
@@ -395,14 +397,14 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
private PSRuntimeBalancing getRuntimeBalancing() {
- if (!getParameterMap().containsKey(PS_RUNTIME_BALANCING)) {
+ if (!getParameterMap().containsKey(PS_FED_RUNTIME_BALANCING)) {
return DEFAULT_RUNTIME_BALANCING;
}
try {
- return
PSRuntimeBalancing.valueOf(getParam(PS_RUNTIME_BALANCING));
+ return
PSRuntimeBalancing.valueOf(getParam(PS_FED_RUNTIME_BALANCING));
} catch (IllegalArgumentException e) {
throw new DMLRuntimeException(String.format("Paramserv
function: "
- + "not support '%s' runtime
balancing.", getParam(PS_RUNTIME_BALANCING)));
+ + "not support '%s' runtime balancing.",
getParam(PS_FED_RUNTIME_BALANCING)));
}
}
@@ -507,4 +509,32 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
return federated_scheme;
}
+
+ /**
+ * Calculates the number of batches per epoch depending on the balance
metrics and the runtime balancing
+ *
+ * @param runtimeBalancing the runtime balancing
+ * @param balanceMetrics the balance metrics calculated during data
partitioning
+ * @return numBatchesPerEpoch
+ */
+ private int getNumBatchesPerEpoch(PSRuntimeBalancing runtimeBalancing,
DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
+ int numBatchesPerEpoch = 0;
+ if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
+ numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._minRows / (float) getBatchSize());
+ } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
+ || runtimeBalancing ==
PSRuntimeBalancing.SCALE_BATCH) {
+ numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._avgRows / (float) getBatchSize());
+ } else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
+ numBatchesPerEpoch = (int)
Math.ceil(balanceMetrics._maxRows / (float) getBatchSize());
+ }
+ return numBatchesPerEpoch;
+ }
+
+ private boolean getWeighing() {
+ return getParameterMap().containsKey(PS_FED_WEIGHING) &&
Boolean.parseBoolean(getParam(PS_FED_WEIGHING));
+ }
+
+ private int getSeed() {
+ return (getParameterMap().containsKey(PS_SEED)) ?
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 6a52fc4..a00e8dc 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -54,44 +54,45 @@ public class FederatedParamservTest extends
AutomatedTestBase {
private final String _freq;
private final String _scheme;
private final String _runtime_balancing;
+ private final String _weighing;
private final String _data_distribution;
+ private final int _seed;
// parameters
@Parameterized.Parameters
public static Collection<Object[]> parameters() {
return Arrays.asList(new Object[][] {
// Network type, number of federated workers, data set
size, batch size, epochs, learning rate, update type, update frequency
-
// basic functionality
- {"TwoNN", 2, 4, 1, 4, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "CYCLE_AVG", "IMBALANCED"},
- {"CNN", 2, 4, 1, 4, 0.01, "BSP", "EPOCH",
"SHUFFLE", "NONE" , "IMBALANCED"},
- {"CNN", 2, 4, 1, 4, 0.01, "ASP", "BATCH",
"REPLICATE_TO_MAX", "RUN_MIN" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01, "ASP", "EPOCH",
"BALANCE_TO_AVG", "CYCLE_MAX", "IMBALANCED"},
- {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"},
-
- /*
- // runtime balancing
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "IMBALANCED"},
-
- // data partitioning
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SHUFFLE", "CYCLE_AVG" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "REPLICATE_TO_MAX", "NONE" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "IMBALANCED"},
- {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "BALANCE_TO_AVG", "NONE" , "IMBALANCED"},
-
- // balanced tests
- {"CNN", 5, 1000, 100, 2, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "NONE" , "BALANCED"}
- */
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"CNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "SHUFFLE", "NONE" ,
"true", "IMBALANCED", 200},
+ {"CNN", 2, 4, 1, 4, 0.01, "ASP",
"BATCH", "REPLICATE_TO_MAX", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "ASP",
"EPOCH", "BALANCE_TO_AVG", "CYCLE_MAX" , "true", "IMBALANCED",
200},
+ {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED",
200},
+
+ /* // runtime balancing
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "RUN_MIN" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX" , "true", "IMBALANCED", 200},
+
+ // data partitioning
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SHUFFLE", "CYCLE_AVG" , "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "REPLICATE_TO_MAX", "NONE" , "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "SUBSAMPLE_TO_MIN", "NONE" , "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01, "BSP",
"BATCH", "BALANCE_TO_AVG", "NONE" , "true",
"IMBALANCED", 200},
+
+ // balanced tests
+ {"CNN", 5, 1000, 100, 2, 0.01, "BSP", "EPOCH",
"KEEP_DATA_ON_WORKER", "NONE" , "true", "BALANCED",
200} */
+
});
}
public FederatedParamservTest(String networkType, int
numFederatedWorkers, int dataSetSize, int batch_size,
- int epochs, double eta, String utype, String freq, String
scheme, String runtime_balancing, String data_distribution) {
+ int epochs, double eta, String utype, String freq, String
scheme, String runtime_balancing, String weighing, String data_distribution,
int seed) {
+
_networkType = networkType;
_numFederatedWorkers = numFederatedWorkers;
_dataSetSize = dataSetSize;
@@ -102,7 +103,9 @@ public class FederatedParamservTest extends
AutomatedTestBase {
_freq = freq;
_scheme = scheme;
_runtime_balancing = runtime_balancing;
+ _weighing = weighing;
_data_distribution = data_distribution;
+ _seed = seed;
}
@Override
@@ -185,11 +188,12 @@ public class FederatedParamservTest extends
AutomatedTestBase {
"freq=" + _freq,
"scheme=" + _scheme,
"runtime_balancing=" +
_runtime_balancing,
+ "weighing=" + _weighing,
"network_type=" + _networkType,
"channels=" + C,
"hin=" + Hin,
"win=" + Win,
- "seed=" + 25));
+ "seed=" + _seed));
programArgs = programArgsList.toArray(new String[0]);
LOG.debug(runTest(null));
diff --git a/src/test/scripts/functions/federated/paramserv/CNN.dml
b/src/test/scripts/functions/federated/paramserv/CNN.dml
index 69c7e76..0f9ae63 100644
--- a/src/test/scripts/functions/federated/paramserv/CNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/CNN.dml
@@ -163,7 +163,7 @@ train = function(matrix[double] X, matrix[double] y,
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing,
+ int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing, string weighing,
double eta, int C, int Hin, int Win,
int seed = -1)
return (list[unknown] model) {
@@ -211,7 +211,7 @@ train_paramserv = function(matrix[double] X, matrix[double]
y,
upd="./src/test/scripts/functions/federated/paramserv/CNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/CNN.dml::aggregation",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
- scheme=scheme, runtime_balancing=runtime_balancing,
hyperparams=hyperparams)
+ scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing,
hyperparams=hyperparams, seed=seed)
}
/*
diff --git
a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
index 10d2cc7..5176cca 100644
--- a/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
+++ b/src/test/scripts/functions/federated/paramserv/FederatedParamservTest.dml
@@ -26,10 +26,12 @@
source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
features = read($features)
labels = read($labels)
+print($weighing)
+
if($network_type == "TwoNN") {
- model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $eta, $seed)
+ model = TwoNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $weighing, $eta, $seed)
}
else {
- model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $eta, $channels, $hin, $win, $seed)
+ model = CNN::train_paramserv(features, labels, matrix(0, rows=0, cols=0),
matrix(0, rows=0, cols=0), 0, $epochs, $utype, $freq, $batch_size, $scheme,
$runtime_balancing, $weighing, $eta, $channels, $hin, $win, $seed)
}
print(toString(model))
\ No newline at end of file
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
index 9bd49d8..a6dc6f2 100644
--- a/src/test/scripts/functions/federated/paramserv/TwoNN.dml
+++ b/src/test/scripts/functions/federated/paramserv/TwoNN.dml
@@ -125,7 +125,7 @@ train = function(matrix[double] X, matrix[double] y,
*/
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
- int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing,
+ int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing, string weighing,
double eta, int seed = -1)
return (list[unknown] model) {
@@ -155,7 +155,7 @@ train_paramserv = function(matrix[double] X, matrix[double]
y,
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
- scheme=scheme, runtime_balancing=runtime_balancing,
hyperparams=hyperparams)
+ scheme=scheme, runtime_balancing=runtime_balancing, weighing=weighing,
hyperparams=hyperparams, seed=seed)
}
/*