This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 192e81c [SYSTEMDS-3067] Optional model averaging in local/federated
paramserv
192e81c is described below
commit 192e81c95ff46c6467171184370f3a2efed26848
Author: Atefeh Asayesh <[email protected]>
AuthorDate: Mon Aug 23 19:02:25 2021 +0200
[SYSTEMDS-3067] Optional model averaging in local/federated paramserv
Closes #1358.
---
.../ParameterizedBuiltinFunctionExpression.java | 4 +-
.../java/org/apache/sysds/parser/Statement.java | 7 +-
.../runtime/compress/CompressedMatrixBlock.java | 4 +-
.../paramserv/FederatedPSControlThread.java | 70 ++-
.../controlprogram/paramserv/LocalPSWorker.java | 75 ++--
.../controlprogram/paramserv/LocalParamServer.java | 8 +-
.../runtime/controlprogram/paramserv/PSWorker.java | 10 +-
.../controlprogram/paramserv/ParamServer.java | 125 +++++-
.../controlprogram/paramserv/ParamservUtils.java | 35 ++
.../controlprogram/paramserv/SparkPSWorker.java | 17 +-
.../cp/ParamservBuiltinCPInstruction.java | 49 ++-
.../paramserv/AvgModelFederatedParamservTest.java | 245 +++++++++++
.../paramserv/ParamservLocalNNAveragingTest.java | 75 ++++
.../paramserv/AvgModelFederatedParamservTest.dml | 61 +++
.../functions/federated/paramserv/CNNModelAvg.dml | 474 +++++++++++++++++++++
.../federated/paramserv/TwoNNModelAvg.dml | 307 +++++++++++++
.../paramserv/mnist_lenet_paramserv_avg.dml | 372 ++++++++++++++++
.../paramserv/paramserv-averaging-test.dml | 49 +++
18 files changed, 1861 insertions(+), 126 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 26a54ac..1f5fd16 100644
---
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -102,7 +102,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
new ParameterizedBuiltinFunctionExpression(ctx,
pbifop,varParams, fileName);
return retVal;
}
-
+
public ParameterizedBuiltinFunctionExpression(ParserRuleContext ctx,
Builtins op, LinkedHashMap<String,Expression> varParams,
String filename) {
@@ -315,7 +315,7 @@ public class ParameterizedBuiltinFunctionExpression extends
DataIdentifier
Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS,
Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN,
Statement.PS_VAL_FUN, Statement.PS_MODE,
Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS,
Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM,
Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING,
- Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS,
Statement.PS_CHECKPOINTING, Statement.PS_SEED);
+ Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS,
Statement.PS_CHECKPOINTING, Statement.PS_SEED, Statement.PS_MODELAVG);
checkInvalidParameters(getOpCode(), getVarParams(), valid);
// check existence and correctness of parameters
diff --git a/src/main/java/org/apache/sysds/parser/Statement.java
b/src/main/java/org/apache/sysds/parser/Statement.java
index d15fd44..f9f5911 100644
--- a/src/main/java/org/apache/sysds/parser/Statement.java
+++ b/src/main/java/org/apache/sysds/parser/Statement.java
@@ -33,10 +33,10 @@ public abstract class Statement implements ParseInfo
// parameter names for seq()
public static final String SEQ_FROM = "from";
public static final String SEQ_TO = "to";
- public static final String SEQ_INCR = "incr";
+ public static final String SEQ_INCR = "incr";
- public static final String SOURCE = "source";
- public static final String SETWD = "setwd";
+ public static final String SOURCE = "source";
+ public static final String SETWD = "setwd";
public static final String MATRIX_DATA_TYPE = "matrix";
public static final String FRAME_DATA_TYPE = "frame";
@@ -72,6 +72,7 @@ public abstract class Statement implements ParseInfo
public static final String PS_MODE = "mode";
public static final String PS_GRADIENTS = "gradients";
public static final String PS_SEED = "seed";
+ public static final String PS_MODELAVG = "modelAvg";
public enum PSModeType {
FEDERATED, LOCAL, REMOTE_SPARK
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index bae7bae..033368e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -1224,11 +1224,11 @@ public class CompressedMatrixBlock extends MatrixBlock {
return getUncompressed();
}
- private void printDecompressWarning(String operation) {
+ private static void printDecompressWarning(String operation) {
LOG.warn("Operation '" + operation + "' not supported yet -
decompressing for ULA operations.");
}
- private void printDecompressWarning(String operation, MatrixBlock m2) {
+ private static void printDecompressWarning(String operation,
MatrixBlock m2) {
if(isCompressed(m2))
LOG.warn("Operation '" + operation + "' not supported
yet - decompressing for ULA operations.");
else
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index c77ddf4..de99e20 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -44,6 +44,7 @@ import
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
@@ -84,15 +85,16 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
public FederatedPSControlThread(int workerID, String updFunc,
Statement.PSFrequency freq,
PSRuntimeBalancing runtimeBalancing, boolean weighting, int
epochs, long batchSize,
- int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer
ps)
+ int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer
ps, boolean modelAvg)
{
- super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+ super(workerID, updFunc, freq, epochs, batchSize, ec, ps,
modelAvg);
_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
_runtimeBalancing = runtimeBalancing;
_weighting = weighting;
// generate the ID for the model
_modelVarID = FederationUtils.getNextFedDataID();
+ _modelAvg = modelAvg;
}
/**
@@ -150,7 +152,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
aggProgramBlock.setInstructions(new
ArrayList<>(Collections.singletonList(_ps.getAggInst())));
pbs.add(aggProgramBlock);
}
-
+
programSerialized = InstructionUtils.concatStrings(
PROG_BEGIN, NEWLINE,
ProgramConverter.serializeProgram(_ec.getProgram(),
pbs, new HashMap<>()),
@@ -159,17 +161,10 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// write program and meta data to worker
Future<FederatedResponse> udfResponse =
_featuresData.executeFederatedOperation(
new FederatedRequest(RequestType.EXEC_UDF,
_featuresData.getVarID(),
- new SetupFederatedWorker(_batchSize,
- dataSize,
- _possibleBatchesPerLocalEpoch,
- programSerialized,
- _inst.getNamespace(),
- _inst.getFunctionName(),
- _ps.getAggInst().getFunctionName(),
- _ec.getListObject("hyperparams"),
- _modelVarID
- )
- ));
+ new SetupFederatedWorker(_batchSize, dataSize,
_possibleBatchesPerLocalEpoch,
+ programSerialized,
_inst.getNamespace(), _inst.getFunctionName(),
+ _ps.getAggInst().getFunctionName(),
_ec.getListObject("hyperparams"),
+ _modelVarID, _modelAvg)));
try {
FederatedResponse response = udfResponse.get();
@@ -195,10 +190,11 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
private final String _aggregationFunctionName;
private final ListObject _hyperParams;
private final long _modelVarID;
+ private final boolean _modelAvg;
protected SetupFederatedWorker(long batchSize, long dataSize,
int possibleBatchesPerLocalEpoch,
String programString, String namespace, String
gradientsFunctionName, String aggregationFunctionName,
- ListObject hyperParams, long modelVarID)
+ ListObject hyperParams, long modelVarID, boolean
modelAvg)
{
super(new long[]{});
_batchSize = batchSize;
@@ -210,6 +206,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
_aggregationFunctionName = aggregationFunctionName;
_hyperParams = hyperParams;
_modelVarID = modelVarID;
+ _modelAvg = modelAvg;
}
@Override
@@ -226,6 +223,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ec.setVariable(Statement.PS_FED_AGGREGATION_FNAME, new
StringObject(_aggregationFunctionName));
ec.setVariable(Statement.PS_HYPER_PARAMS, _hyperParams);
ec.setVariable(Statement.PS_FED_MODEL_VARID, new
IntObject(_modelVarID));
+ ec.setVariable(Statement.PS_MODELAVG, new
BooleanObject(_modelAvg));
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
@@ -277,7 +275,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ec.removeVariable(Statement.PS_FED_AGGREGATION_FNAME);
ec.removeVariable(Statement.PS_FED_MODEL_VARID);
ParamservUtils.cleanupListObject(ec,
Statement.PS_HYPER_PARAMS);
-
+
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
}
@@ -334,7 +332,6 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
});
accFedPSGradientWeightingTime(tWeighting);
}
-
// Push the gradients to ps
_ps.push(_workerID, gradients);
}
@@ -344,31 +341,26 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
/**
- * Computes all epochs and updates after each batch
+ * Computes all epochs and updates after each batch
*/
protected void computeWithBatchUpdates() {
for (int epochCounter = 0; epochCounter < _epochs;
epochCounter++) {
int currentLocalBatchNumber = (_cycleStartAt0) ? 0 :
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
-
for (int batchCounter = 0; batchCounter <
_numBatchesPerEpoch; batchCounter++) {
int localStartBatchNum =
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
ListObject gradients =
computeGradientsForNBatches(model, 1, localStartBatchNum);
- weightAndPushGradients(gradients);
- ParamservUtils.cleanupListObject(model);
+ if (_modelAvg)
+ model = _ps.updateLocalModel(_ec,
gradients, model);
+ else
+ ParamservUtils.cleanupListObject(model);
+ weightAndPushGradients(_modelAvg ? model :
gradients);
ParamservUtils.cleanupListObject(gradients);
}
}
}
/**
- * Computes all epochs and updates after N batches
- */
- protected void computeWithNBatchUpdates() {
- throw new NotImplementedException();
- }
-
- /**
* Computes all epochs and updates after each epoch
*/
protected void computeWithEpochUpdates() {
@@ -376,6 +368,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
int localStartBatchNum = (_cycleStartAt0) ? 0 :
_numBatchesPerEpoch * epochCounter % _possibleBatchesPerLocalEpoch;
// Pull the global parameters from ps
+ // TODO double check if model averaging is handled
correctly (internally?)
ListObject model = pullModel();
ListObject gradients =
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum,
true);
weightAndPushGradients(gradients);
@@ -469,6 +462,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
String namespace = ((StringObject)
ec.getVariable(Statement.PS_FED_NAMESPACE)).getStringValue();
String gradientsFunc = ((StringObject)
ec.getVariable(Statement.PS_FED_GRADIENTS_FNAME)).getStringValue();
String aggFunc = ((StringObject)
ec.getVariable(Statement.PS_FED_AGGREGATION_FNAME)).getStringValue();
+ boolean modelAvg = ((BooleanObject)
ec.getVariable(Statement.PS_MODELAVG)).getBooleanValue();
// recreate gradient instruction and output
boolean opt =
!ec.getProgram().containsFunctionProgramBlock(namespace, gradientsFunc, false);
@@ -481,7 +475,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
ArrayList<String> outputNames =
outputs.stream().map(DataIdentifier::getName)
.collect(Collectors.toCollection(ArrayList::new));
Instruction gradientsInstruction = new
FunctionCallCPInstruction(namespace, gradientsFunc,
- opt, boundInputs,func.getInputParamNames(),
outputNames, "gradient function");
+ opt, boundInputs, func.getInputParamNames(),
outputNames, "gradient function");
DataIdentifier gradientsOutput = outputs.get(0);
// recreate aggregation instruction and output if needed
@@ -505,7 +499,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
int currentLocalBatchNumber = _localStartBatchNum;
// prepare execution context
ec.setVariable(Statement.PS_MODEL, model);
- for (int batchCounter = 0; batchCounter <
_numBatchesToCompute; batchCounter++) {
+ for(int batchCounter = 0; batchCounter <
_numBatchesToCompute; batchCounter++) {
int localBatchNum =
getNextLocalBatchNum(currentLocalBatchNumber++, possibleBatchesPerLocalEpoch);
// slice batch from feature and label matrix
@@ -521,13 +515,14 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// calculate gradients for batch
gradientsInstruction.processInstruction(ec);
ListObject gradients =
ec.getListObject(gradientsOutput.getName());
-
+
// accrue the computed gradients - In the
single batch case this is just a list copy
// is this equivalent for momentum based and
AMS prob?
- accGradients =
ParamservUtils.accrueGradients(accGradients, gradients, false);
-
+ accGradients = modelAvg ? null :
+
ParamservUtils.accrueGradients(accGradients, gradients, false);
+
// update the local model with gradients if
needed
- if(_localUpdate && batchCounter <
_numBatchesToCompute - 1) {
+ if((_localUpdate && batchCounter <
_numBatchesToCompute - 1) | modelAvg) {
// Invoke the aggregate function
assert aggregationInstruction != null;
aggregationInstruction.processInstruction(ec);
@@ -540,17 +535,18 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
}
// clean up
- ParamservUtils.cleanupListObject(ec,
gradientsOutput.getName());
ParamservUtils.cleanupData(ec,
Statement.PS_FEATURES);
ParamservUtils.cleanupData(ec,
Statement.PS_LABELS);
}
// model clean up
ParamservUtils.cleanupListObject(ec,
ec.getVariable(Statement.PS_FED_MODEL_VARID).toString());
- ParamservUtils.cleanupListObject(ec,
Statement.PS_MODEL);
+ // TODO double check cleanup gradients and models
+
// stop timing
DoubleObject gradientsTime = new
DoubleObject(tGradients.stop());
- return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new
Object[]{accGradients, gradientsTime});
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+ new Object[]{modelAvg ? model : accGradients,
gradientsTime});
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
index 8ba81f6..fef617b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -43,16 +43,16 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
protected LocalPSWorker() {}
public LocalPSWorker(int workerID, String updFunc,
Statement.PSFrequency freq,
- int epochs, long batchSize, ExecutionContext ec, ParamServer ps)
+ int epochs, long batchSize, ExecutionContext ec, ParamServer
ps, boolean modelAvg)
{
- super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
+ super(workerID, updFunc, freq, epochs, batchSize, ec, ps,
modelAvg);
}
@Override
public String getWorkerName() {
return String.format("Local worker_%d", _workerID);
}
-
+
@Override
public Void call() throws Exception {
incWorkerNumber();
@@ -81,7 +81,7 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
}
private void computeEpoch(long dataSize, int batchIter) {
- for (int i = 0; i < _epochs; i++) {
+ for(int i = 0; i < _epochs; i++) {
// Pull the global parameters from ps
ListObject params = pullModel();
Future<ListObject> accGradients =
ConcurrentUtils.constantFuture(null);
@@ -89,32 +89,32 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
try {
for (int j = 0; j < batchIter; j++) {
ListObject gradients =
computeGradients(params, dataSize, batchIter, i, j);
-
+
boolean localUpdate = j < batchIter - 1;
-
- // Accumulate the intermediate
gradients (async for overlap w/ model updates
+
+ // Accumulate the intermediate
gradients (async for overlap w/ model updates
// and gradient computation, sequential
over gradient matrices to avoid deadlocks)
ListObject accGradientsPrev =
accGradients.get();
- accGradients = _tpool.submit(() ->
ParamservUtils.accrueGradients(
- accGradientsPrev, gradients,
false, !localUpdate));
-
+ accGradients = _modelAvg ?
ConcurrentUtils.constantFuture(null) :
+ _tpool.submit(() ->
ParamservUtils.accrueGradients(
+ accGradientsPrev,
gradients, false, !localUpdate));
+
// Update the local model with gradients
- if(localUpdate)
+ if(localUpdate | _modelAvg)
params = updateModel(params,
gradients, i, j, batchIter);
-
+
accNumBatches(1);
}
-
- // Push the gradients to ps
- pushGradients(accGradients.get());
- ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
+ pushGradients(_modelAvg ? params :
accGradients.get());
+ if (!_modelAvg)
+ ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
}
catch(ExecutionException | InterruptedException ex) {
throw new DMLRuntimeException(ex);
}
-
+
accNumEpochs(1);
- if (LOG.isDebugEnabled()) {
+ if(LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
}
}
@@ -126,9 +126,9 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
globalParams = _ps.updateLocalModel(_ec, gradients,
globalParams);
accLocalModelUpdateTime(tUpd);
-
- if (LOG.isDebugEnabled()) {
- LOG.debug(String.format("%s: local global parameter
[size:%d kb] updated. "
+
+ if(LOG.isDebugEnabled()) {
+ LOG.debug(String.format("%s: local global parameter
[size:%d kb] updated. "
+ "[Epoch:%d Total epoch:%d Iteration:%d
Total iteration:%d]",
getWorkerName(), globalParams.getDataSize(), i
+ 1, _epochs, j + 1, batchIter));
}
@@ -136,19 +136,24 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
}
private void computeBatch(long dataSize, int totalIter) {
- for (int i = 0; i < _epochs; i++) {
- for (int j = 0; j < totalIter; j++) {
+ for(int i = 0; i < _epochs; i++) {
+ for(int j = 0; j < totalIter; j++) {
ListObject globalParams = pullModel();
-
ListObject gradients =
computeGradients(globalParams, dataSize, totalIter, i, j);
-
- // Push the gradients to ps
- pushGradients(gradients);
- ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
+ if(_modelAvg) {
+ // Update locally & Push the local
update model to ps
+ ListObject model =
updateModel(globalParams, gradients, i, j, totalIter);
+ pushGradients(model);
+ }
+ else {
+ // Push the gradients to ps
+ pushGradients(gradients);
+ ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
+ }
accNumBatches(1);
}
-
+
accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
@@ -159,7 +164,7 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
private ListObject pullModel() {
// Pull the global parameters from ps
ListObject globalParams = _ps.pull(_workerID);
- if (LOG.isDebugEnabled()) {
+ if(LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: successfully pull the
global parameters "
+ "[size:%d kb] from ps.", getWorkerName(),
globalParams.getDataSize() / 1024));
}
@@ -169,7 +174,7 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
private void pushGradients(ListObject gradients) {
// Push the gradients to ps
_ps.push(_workerID, gradients);
- if (LOG.isDebugEnabled()) {
+ if(LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: successfully push the
gradients "
+ "[size:%d kb] to ps.", getWorkerName(),
gradients.getDataSize() / 1024));
}
@@ -189,11 +194,11 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
_ec.setVariable(Statement.PS_FEATURES, bFeatures);
_ec.setVariable(Statement.PS_LABELS, bLabels);
- if (LOG.isDebugEnabled()) {
+ if(LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: got batch data [size:%d
kb] of index from %d to %d [last index: %d]. "
+ "[Epoch:%d Total epoch:%d Iteration:%d
Total iteration:%d]", getWorkerName(),
- bFeatures.getDataSize() / 1024 +
bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
- j + 1, batchIter));
+ bFeatures.getDataSize() / 1024 +
bLabels.getDataSize() / 1024, begin, end, dataSize,
+ i + 1, _epochs, j + 1, batchIter));
}
// Invoke the update function
@@ -208,7 +213,7 @@ public class LocalPSWorker extends PSWorker implements
Callable<Void> {
ParamservUtils.cleanupData(_ec, Statement.PS_LABELS);
return gradients;
}
-
+
@Override
protected void incWorkerNumber() {
if (DMLScript.STATISTICS)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
index 7bd96f2..ebf6698 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -33,17 +33,17 @@ public class LocalParamServer extends ParamServer {
public static LocalParamServer create(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch,
- MatrixObject valFeatures, MatrixObject valLabels)
+ MatrixObject valFeatures, MatrixObject valLabels, boolean
modelAvg)
{
return new LocalParamServer(model, aggFunc, updateType, freq,
ec,
- workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels);
+ workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels, modelAvg);
}
private LocalParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch,
- MatrixObject valFeatures, MatrixObject valLabels)
+ MatrixObject valFeatures, MatrixObject valLabels, boolean
modelAvg)
{
- super(model, aggFunc, updateType, freq, ec, workerNum, valFunc,
numBatchesPerEpoch, valFeatures, valLabels);
+ super(model, aggFunc, updateType, freq, ec, workerNum, valFunc,
numBatchesPerEpoch, valFeatures, valLabels, modelAvg);
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
index 99ec9e2..a1e55f3 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.java
@@ -37,7 +37,7 @@ import
org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
-public abstract class PSWorker implements Serializable
+public abstract class PSWorker implements Serializable
{
private static final long serialVersionUID = -3510485051178200118L;
@@ -45,7 +45,7 @@ public abstract class PSWorker implements Serializable
// Note: we use a non-static variable to obtain the live maintenance
thread pool
// which is important in scenarios w/ multiple scripts in a single JVM
(e.g., tests)
protected ExecutorService _tpool = LazyWriteBuffer.getUtilThreadPool();
-
+
protected int _workerID;
protected int _epochs;
protected long _batchSize;
@@ -57,10 +57,11 @@ public abstract class PSWorker implements Serializable
protected MatrixObject _labels;
protected String _updFunc;
protected Statement.PSFrequency _freq;
+ protected boolean _modelAvg;
protected PSWorker() {}
- protected PSWorker(int workerID, String updFunc, Statement.PSFrequency
freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) {
+ protected PSWorker(int workerID, String updFunc, Statement.PSFrequency
freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps, boolean
modelAvg) {
_workerID = workerID;
_updFunc = updFunc;
_freq = freq;
@@ -68,6 +69,7 @@ public abstract class PSWorker implements Serializable
_batchSize = batchSize;
_ec = ec;
_ps = ps;
+ _modelAvg = modelAvg;
setupUpdateFunction(updFunc, ec);
}
@@ -148,7 +150,7 @@ public abstract class PSWorker implements Serializable
protected void accNumEpochs(int n) {
//do nothing
}
-
+
protected void accNumBatches(int n) {
//do nothing
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 9f5b126..dc5b85f 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -28,6 +28,7 @@ import java.util.concurrent.BlockingQueue;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -41,17 +42,20 @@ import
org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.utils.Statistics;
-public abstract class ParamServer
+public abstract class ParamServer
{
protected static final Log LOG =
LogFactory.getLog(ParamServer.class.getName());
protected static final boolean ACCRUE_BSP_GRADIENTS = true;
-
+
// worker input queues and global model
protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
private ListObject _model;
@@ -72,16 +76,17 @@ public abstract class ParamServer
private String _accuracyOutput;
private int _syncCounter = 0;
- private int _epochCounter = 0 ;
+ private int _epochCounter = 0;
private int _numBatchesPerEpoch;
private int _numWorkers;
+ private boolean _modelAvg;
+ private ListObject _accModels = null;
protected ParamServer() {}
protected ParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
- Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc,
- int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels)
+ Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels, boolean modelAvg)
{
// init worker queues and global model
_modelMap = new HashMap<>(workerNum);
@@ -90,7 +95,7 @@ public abstract class ParamServer
_modelMap.put(i, new ArrayBlockingQueue<>(1));
});
_model = model;
-
+
// init aggregation service
_ec = ec;
_updateType = updateType;
@@ -103,7 +108,8 @@ public abstract class ParamServer
}
_numBatchesPerEpoch = numBatchesPerEpoch;
_numWorkers = workerNum;
-
+ _modelAvg = modelAvg;
+
// broadcast initial model
broadcastModel(true);
}
@@ -179,9 +185,17 @@ public abstract class ParamServer
return _model;
}
- protected synchronized void updateGlobalModel(int workerID, ListObject
gradients) {
+ protected synchronized void updateGlobalModel(int workerID, ListObject
params) {
+ if(_modelAvg) {
+ updateAverageModel(workerID, params);
+ }
+ else
+ updateGlobalGradients(workerID, params);
+ }
+
+ protected synchronized void updateGlobalGradients(int workerID,
ListObject gradients) {
try {
- if (LOG.isDebugEnabled()) {
+ if(LOG.isDebugEnabled()) {
LOG.debug(String.format("Successfully pulled
the gradients [size:%d kb] of worker_%d.",
gradients.getDataSize() / 1024,
workerID));
}
@@ -221,7 +235,7 @@ public abstract class ParamServer
_epochCounter++;
_syncCounter = 0;
}
-
+
// Broadcast the updated model
resetFinishedStates();
broadcastModel(true);
@@ -256,7 +270,7 @@ public abstract class ParamServer
default:
throw new
DMLRuntimeException("Unsupported update: " + _updateType.name());
}
- }
+ }
catch (Exception e) {
throw new DMLRuntimeException("Aggregation or
validation service failed: ", e);
}
@@ -293,8 +307,93 @@ public abstract class ParamServer
ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
return newModel;
}
-
- private boolean allFinished() {
+
+ protected synchronized void updateAverageModel(int workerID, ListObject
model) {
+ try {
+ if(LOG.isDebugEnabled()) {
+ LOG.debug(String.format("Successfully pulled
the models [size:%d kb] of worker_%d.",
+ model.getDataSize() / 1024, workerID));
+ }
+ Timing tAgg = DMLScript.STATISTICS ? new Timing(true) :
null;
+
+ //first weight the models based on number of workers
+ ListObject weightParams = weightModels(model,
_numWorkers);
+ switch(_updateType) {
+ case BSP: {
+ setFinishedState(workerID);
+ // second Accumulate the given
weightModels into the accrued models
+ _accModels =
ParamservUtils.accrueGradients(_accModels, weightParams, true);
+
+ if(allFinished()) {
+ _model = setParams(_ec,
_accModels, _model);
+ if (DMLScript.STATISTICS &&
tAgg != null)
+
Statistics.accPSAggregationTime((long) tAgg.stop());
+ _accModels = null; //reset for
next accumulation
+
+ // This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
+ // In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
+ // BSP epoch case every time
+ if(_numBatchesPerEpoch != -1 &&
(_freq == Statement.PSFrequency.EPOCH || (_freq == Statement.PSFrequency.BATCH
&& ++_syncCounter % _numBatchesPerEpoch == 0))) {
+
+ if(LOG.isInfoEnabled())
+ LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
+ time_epoch();
+ if(_validationPossible)
{
+ validate();
+ }
+ _epochCounter++;
+ _syncCounter = 0;
+ }
+ // Broadcast the updated model
+ resetFinishedStates();
+ broadcastModel(true);
+ if(LOG.isDebugEnabled())
+ LOG.debug("Global
parameter is broadcasted successfully ");
+ }
+ break;
+ }
+ case ASP:
+ throw new NotImplementedException();
+
+ default:
+ throw new
DMLRuntimeException("Unsupported update: " + _updateType.name());
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException("Aggregation or
validation service failed: ", e);
+ }
+ }
+
+ protected ListObject weightModels(ListObject params, int numWorkers) {
+ double _averagingFactor = 1d / numWorkers;
+
+ if( _averagingFactor != 1) {
+ double final_averagingFactor = _averagingFactor;
+ params.getData().parallelStream().forEach((matrix) -> {
+ MatrixObject matrixObject = (MatrixObject)
matrix;
+ MatrixBlock input =
matrixObject.acquireReadAndRelease().scalarOperations(
+ new
RightScalarOperator(Multiply.getMultiplyFnObject(), final_averagingFactor), new
MatrixBlock());
+ matrixObject.acquireModify(input);
+ matrixObject.release();
+ });
+ }
+ return params;
+ }
+
+ /* A service method for averaging model with models
+ *
+ * @param ec execution context
+ * @param accModels list of models
+ * @param model old model
+ * @return new model (accModels)
+ */
+ protected ListObject setParams(ExecutionContext ec, ListObject
accModels, ListObject model) {
+ ec.setVariable(Statement.PS_MODEL, model);
+ ec.setVariable(Statement.PS_GRADIENTS, accModels);
+ return accModels;
+ }
+
+ private boolean allFinished() {
return !ArrayUtils.contains(_finishedStates, false);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index cb4271d..da1e9f7 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -479,4 +479,39 @@ public class ParamservUtils {
ParamservUtils.cleanupListObject(gradients);
return accGradients;
}
+
+ /**
+ * Accumulate the given models into the accrued accrueModels
+ *
+ * @param accModels accrued models list object
+ * @param model given models list object
+ * @param cleanup clean up the given models list object
+ * @return new accrued models list object
+ */
+ public static ListObject accrueModels(ListObject accModels, ListObject
model, boolean cleanup) {
+ return accrueModels(accModels, model, false, cleanup);
+ }
+
+ /**
+ * Accumulate the given models into the accrued models
+ *
+ * @param accModels accrued models list object
+ * @param model given models list object
+ * @param par parallel execution
+ * @param cleanup clean up the given models list object
+ * @return new accrued models list object
+ */
+ public static ListObject accrueModels(ListObject accModels, ListObject
model, boolean par, boolean cleanup) {
+ if (accModels == null)
+ return ParamservUtils.copyList(model, cleanup);
+ IntStream range = IntStream.range(0, accModels.getLength());
+ (par ? range.parallel() : range).forEach(i -> {
+ MatrixBlock mb1 = ((MatrixObject)
accModels.getData().get(i)).acquireReadAndRelease();
+ MatrixBlock mb2 = ((MatrixObject)
model.getData().get(i)).acquireReadAndRelease();
+ mb1.binaryOperationsInPlace(new
BinaryOperator(Plus.getPlusFnObject()), mb2);
+ });
+ if (cleanup)
+ ParamservUtils.cleanupListObject(model);
+ return accModels;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
index 870fd9c..1f3cd1a 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
@@ -53,8 +53,8 @@ public class SparkPSWorker extends LocalPSWorker implements
VoidFunction<Tuple2<
private final LongAccumulator _aRPC; // accumulator for rpc request
private final LongAccumulator _nBatches; //number of executed batches
private final LongAccumulator _nEpochs; //number of executed epoches
-
- public SparkPSWorker(String updFunc, String aggFunc,
Statement.PSFrequency freq, int epochs, long batchSize, String program,
HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator
aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator
aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches,
LongAccumulator aEpochs) {
+
+ public SparkPSWorker(String updFunc, String aggFunc,
Statement.PSFrequency freq, int epochs, long batchSize, String program,
HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator
aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator
aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches,
LongAccumulator aEpochs, boolean modelAvg) {
_updFunc = updFunc;
_aggFunc = aggFunc;
_freq = freq;
@@ -72,13 +72,14 @@ public class SparkPSWorker extends LocalPSWorker implements
VoidFunction<Tuple2<
_aRPC = aRPC;
_nBatches = aBatches;
_nEpochs = aEpochs;
+ _modelAvg = modelAvg;
}
@Override
public String getWorkerName() {
return String.format("Spark worker_%d", _workerID);
}
-
+
@Override
public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>
input) throws Exception {
Timing tSetup = new Timing(true);
@@ -116,13 +117,13 @@ public class SparkPSWorker extends LocalPSWorker
implements VoidFunction<Tuple2<
setFeatures(ParamservUtils.newMatrixObject(input._2._1, false));
setLabels(ParamservUtils.newMatrixObject(input._2._2, false));
}
-
+
@Override
protected void incWorkerNumber() {
_aWorker.add(1);
}
-
+
@Override
protected void accLocalModelUpdateTime(Timing time) {
if( time != null )
@@ -140,17 +141,17 @@ public class SparkPSWorker extends LocalPSWorker
implements VoidFunction<Tuple2<
if( time != null )
_aGrad.add((long) time.stop());
}
-
+
@Override
protected void accNumEpochs(int n) {
_nEpochs.add(n);
}
-
+
@Override
protected void accNumBatches(int n) {
_nBatches.add(n);
}
-
+
private void accSetupTime(Timing time) {
if( time != null )
_aSetup.add((long) time.stop());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index bc6ee67..53987a0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -39,6 +39,7 @@ import static
org.apache.sysds.parser.Statement.PS_HYPER_PARAMS;
import static org.apache.sysds.parser.Statement.PS_LABELS;
import static org.apache.sysds.parser.Statement.PS_MODE;
import static org.apache.sysds.parser.Statement.PS_MODEL;
+import static org.apache.sysds.parser.Statement.PS_MODELAVG;
import static org.apache.sysds.parser.Statement.PS_PARALLELISM;
import static org.apache.sysds.parser.Statement.PS_SCHEME;
import static org.apache.sysds.parser.Statement.PS_UPDATE_FUN;
@@ -89,7 +90,7 @@ import org.apache.sysds.utils.Statistics;
public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruction {
private static final Log LOG =
LogFactory.getLog(ParamservBuiltinCPInstruction.class.getName());
-
+
public static final int DEFAULT_BATCH_SIZE = 64;
private static final PSFrequency DEFAULT_UPDATE_FREQUENCY =
PSFrequency.EPOCH;
private static final PSScheme DEFAULT_SCHEME =
PSScheme.DISJOINT_CONTIGUOUS;
@@ -97,6 +98,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
private static final FederatedPSScheme DEFAULT_FEDERATED_SCHEME =
FederatedPSScheme.KEEP_DATA_ON_WORKER;
private static final PSModeType DEFAULT_MODE = PSModeType.LOCAL;
private static final PSUpdateType DEFAULT_TYPE = PSUpdateType.ASP;
+ private static final Boolean DEFAULT_MODELAVG = false;
public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String,
String> paramsMap, CPOperand out, String opcode, String istr) {
super(op, paramsMap, out, opcode, istr);
@@ -106,7 +108,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
public void processInstruction(ExecutionContext ec) {
// check if the input is federated
if(ec.getMatrixObject(getParam(PS_FEATURES)).isFederated() ||
-
ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
+ ec.getMatrixObject(getParam(PS_LABELS)).isFederated()) {
runFederated(ec);
}
// if not federated check mode
@@ -181,13 +183,14 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ListObject model = ec.getListObject(getParam(PS_MODEL));
MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null)
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ?
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
+ boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
ParamServer ps = createPS(PSModeType.FEDERATED, aggFunc,
updateType, freq, workerNum, model, aggServiceEC, getValFunction(),
- getNumBatchesPerEpoch(runtimeBalancing,
result._balanceMetrics), val_features, val_labels);
+ getNumBatchesPerEpoch(runtimeBalancing,
result._balanceMetrics), val_features, val_labels, modelAvg);
// Create the local workers
int finalNumBatchesPerEpoch =
getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
List<FederatedPSControlThread> threads = IntStream.range(0,
workerNum)
.mapToObj(i -> new FederatedPSControlThread(i, updFunc,
freq, runtimeBalancing, weighting,
- getEpochs(), getBatchSize(),
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps))
+ getEpochs(), getBatchSize(),
finalNumBatchesPerEpoch, federatedWorkerECs.get(i), ps, modelAvg))
.collect(Collectors.toList());
if(workerNum != threads.size()) {
throw new
DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning
does not match threads!");
@@ -223,6 +226,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
int workerNum = getWorkerNum(mode);
String updFunc = getParam(PS_UPDATE_FUN);
String aggFunc = getParam(PS_AGGREGATION_FUN);
+ boolean modelAvg = Boolean.parseBoolean(getParam(PS_MODELAVG));
// Get the compiled execution context
LocalVariableMap newVarsMap = createVarsMap(sec);
@@ -234,7 +238,7 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
// Create the parameter server
ListObject model = sec.getListObject(getParam(PS_MODEL));
- ParamServer ps = createPS(mode, aggFunc, getUpdateType(),
getFrequency(), workerNum, model, aggServiceEC);
+ ParamServer ps = createPS(mode, aggFunc, getUpdateType(),
getFrequency(), workerNum, model, aggServiceEC, modelAvg);
// Get driver host
String host =
sec.getSparkContext().getConf().get("spark.driver.host");
@@ -260,11 +264,11 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
LongAccumulator aRPC =
sec.getSparkContext().sc().longAccumulator("rpcRequest");
LongAccumulator aBatch =
sec.getSparkContext().sc().longAccumulator("numBatches");
LongAccumulator aEpoch =
sec.getSparkContext().sc().longAccumulator("numEpochs");
-
+
// Create remote workers
SparkPSWorker worker = new
SparkPSWorker(getParam(PS_UPDATE_FUN), getParam(PS_AGGREGATION_FUN),
getFrequency(), getEpochs(), getBatchSize(), program,
clsMap, sec.getSparkContext().getConf(),
- server.getPort(), aSetup, aWorker, aUpdate, aIndex,
aGrad, aRPC, aBatch, aEpoch);
+ server.getPort(), aSetup, aWorker, aUpdate, aIndex,
aGrad, aRPC, aBatch, aEpoch, modelAvg);
if (DMLScript.STATISTICS)
Statistics.accPSSetupTime((long) tSetup.stop());
@@ -326,13 +330,14 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
ListObject model = ec.getListObject(getParam(PS_MODEL));
MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null)
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ?
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
- ParamServer ps = createPS(mode, aggFunc, updateType, freq,
workerNum, model, aggServiceEC, getValFunction(),
- num_batches_per_epoch, val_features,
val_labels);
+ boolean modelAvg = getModelAvg();
+ ParamServer ps = createPS(mode, aggFunc, updateType, freq,
workerNum, model, aggServiceEC,
+ getValFunction(), num_batches_per_epoch, val_features,
val_labels, modelAvg);
// Create the local workers
List<LocalPSWorker> workers = IntStream.range(0, workerNum)
.mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
- getEpochs(), getBatchSize(), workerECs.get(i),
ps))
+ getEpochs(), getBatchSize(), workerECs.get(i),
ps, modelAvg))
.collect(Collectors.toList());
// Do data partition
@@ -468,21 +473,21 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
* @return parameter server
*/
private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
- PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec)
+ PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec, boolean modelAvg)
{
- return createPS(mode, aggFunc, updateType, freq, workerNum,
model, ec, null, -1, null, null);
+ return createPS(mode, aggFunc, updateType, freq, workerNum,
model, ec, null, -1, null, null, modelAvg);
}
// When this creation is used the parameter server is able to validate
after each epoch
private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec, String valFunc,
- int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels)
+ int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels, boolean modelAvg)
{
switch (mode) {
case FEDERATED:
case LOCAL:
case REMOTE_SPARK:
- return LocalParamServer.create(model, aggFunc,
updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels);
+ return LocalParamServer.create(model, aggFunc,
updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels, modelAvg);
default:
throw new DMLRuntimeException("Unsupported
parameter server: " + mode.name());
}
@@ -575,17 +580,25 @@ public class ParamservBuiltinCPInstruction extends
ParameterizedBuiltinCPInstruc
}
private boolean getWeighting() {
- return getParameterMap().containsKey(PS_FED_WEIGHTING) &&
Boolean.parseBoolean(getParam(PS_FED_WEIGHTING));
+ return getParameterMap().containsKey(PS_FED_WEIGHTING)
+ && Boolean.parseBoolean(getParam(PS_FED_WEIGHTING));
}
private String getValFunction() {
- if (getParameterMap().containsKey(PS_VAL_FUN)) {
+ if (getParameterMap().containsKey(PS_VAL_FUN))
return getParam(PS_VAL_FUN);
- }
return null;
}
private int getSeed() {
- return (getParameterMap().containsKey(PS_SEED)) ?
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
+ return getParameterMap().containsKey(PS_SEED) ?
+ Integer.parseInt(getParam(PS_SEED)) :
+ (int) System.currentTimeMillis();
+ }
+
+ private boolean getModelAvg() {
+ if(!getParameterMap().containsKey(PS_MODELAVG))
+ return DEFAULT_MODELAVG;
+ return Boolean.parseBoolean(getParam(PS_MODELAVG));
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
new file mode 100644
index 0000000..3583cee
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.paramserv;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class AvgModelFederatedParamservTest extends AutomatedTestBase {
+ private static final Log LOG =
LogFactory.getLog(AvgModelFederatedParamservTest.class.getName());
+ private final static String TEST_DIR = "functions/federated/paramserv/";
+ private final static String TEST_NAME =
"AvgModelFederatedParamservTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
AvgModelFederatedParamservTest.class.getSimpleName() + "/";
+
+ private final String _networkType;
+ private final int _numFederatedWorkers;
+ private final int _dataSetSize;
+ private final int _epochs;
+ private final int _batch_size;
+ private final double _eta;
+ private final String _utype;
+ private final String _freq;
+ private final String _scheme;
+ private final String _runtime_balancing;
+ private final String _weighting;
+ private final String _data_distribution;
+ private final int _seed;
+
+ // parameters
+ @Parameterized.Parameters
+ public static Collection<Object[]> parameters() {
+ return Arrays.asList(new Object[][] {
+ // Network type, number of federated workers, data set
size, batch size, epochs, learning rate, update type, update frequency
+ // basic functionality
+ //{"TwoNN", 4, 60000, 32, 4, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "false","BALANCED",
200},
+
+ // One important point is that we do the model
averaging in the case of BSP
+ {"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "BASELINE", "true", "IMBALANCED",
200},
+ {"CNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "SHUFFLE", "BASELINE",
"true", "IMBALANCED", 200},
+ {"TwoNN", 5, 1000, 100, 2, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED",
200},
+
+ /*
+ // runtime balancing
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MIN", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_AVG", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "BATCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED",
200},
+ {"TwoNN", 2, 4, 1, 4, 0.01,
"BSP", "EPOCH", "KEEP_DATA_ON_WORKER", "CYCLE_MAX", "true", "IMBALANCED",
200},
+
+ // data partitioning
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SHUFFLE", "CYCLE_AVG", "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "REPLICATE_TO_MAX", "NONE", "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "SUBSAMPLE_TO_MIN", "NONE", "true",
"IMBALANCED", 200},
+ {"TwoNN", 2, 4, 1, 1, 0.01,
"BSP", "BATCH", "BALANCE_TO_AVG", "NONE", "true",
"IMBALANCED", 200},
+
+ // balanced tests
+ {"CNN", 5, 1000, 100, 2, 0.01, "BSP",
"EPOCH", "KEEP_DATA_ON_WORKER", "NONE", "true", "BALANCED",
200}
+ */
+ });
+ }
+
+ public AvgModelFederatedParamservTest(String networkType, int
numFederatedWorkers, int dataSetSize, int batch_size,
+ int epochs, double eta, String utype, String freq, String
scheme, String runtime_balancing, String weighting, String data_distribution,
int seed) {
+
+ _networkType = networkType;
+ _numFederatedWorkers = numFederatedWorkers;
+ _dataSetSize = dataSetSize;
+ _batch_size = batch_size;
+ _epochs = epochs;
+ _eta = eta;
+ _utype = utype;
+ _freq = freq;
+ _scheme = scheme;
+ _runtime_balancing = runtime_balancing;
+ _weighting = weighting;
+ _data_distribution = data_distribution;
+ _seed = seed;
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+ }
+
+ @Test
+ public void AvgmodelfederatedParamservSingleNode() {
+ AvgmodelfederatedParamserv(ExecMode.SINGLE_NODE, true);
+ }
+
+ @Test
+ public void AvgmodelfederatedParamservHybrid() {
+ AvgmodelfederatedParamserv(ExecMode.HYBRID, true);
+ }
+
+ private void AvgmodelfederatedParamserv(ExecMode mode, boolean
modelAvg) {
+ // Warning Statistics accumulate in unit test
+ // config
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ setOutputBuffering(false);
+
+ int C = 1, Hin = 28, Win = 28;
+ int numLabels = 10;
+
+ ExecMode platformOld = setExecMode(mode);
+
+ try {
+ // start threads
+ List<Integer> ports = new ArrayList<>();
+ List<Thread> threads = new ArrayList<>();
+ for(int i = 0; i < _numFederatedWorkers; i++) {
+ ports.add(getRandomAvailablePort());
+
threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+ }
+
+ // generate test data
+ double[][] features =
generateDummyMNISTFeatures(_dataSetSize, C, Hin, Win);
+ double[][] labels =
generateDummyMNISTLabels(_dataSetSize, numLabels);
+ String featuresName = "";
+ String labelsName = "";
+
+ // federate test data balanced or imbalanced
+ if(_data_distribution.equals("IMBALANCED")) {
+ featuresName = "X_IMBALANCED_" +
_numFederatedWorkers;
+ labelsName = "y_IMBALANCED_" +
_numFederatedWorkers;
+ double[][] ranges = {{0,1}, {1,4}};
+
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features,
_numFederatedWorkers, ports, ranges);
+
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
+ }
+ else {
+ featuresName = "X_BALANCED_" +
_numFederatedWorkers;
+ labelsName = "y_BALANCED_" +
_numFederatedWorkers;
+ double[][] ranges =
generateBalancedFederatedRowRanges(_numFederatedWorkers, features.length);
+
rowFederateLocallyAndWriteInputMatrixWithMTD(featuresName, features,
_numFederatedWorkers, ports, ranges);
+
rowFederateLocallyAndWriteInputMatrixWithMTD(labelsName, labels,
_numFederatedWorkers, ports, ranges);
+ }
+
+ try {
+ //wait for all workers to be setup
+ Thread.sleep(FED_WORKER_WAIT);
+ }
+ catch(InterruptedException e) {
+ e.printStackTrace();
+ }
+
+ // dml name
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ // generate program args
+ List<String> programArgsList = new
ArrayList<>(Arrays.asList("-stats",
+ "-nvargs",
+ "features=" + input(featuresName),
+ "labels=" + input(labelsName),
+ "epochs=" + _epochs,
+ "batch_size=" + _batch_size,
+ "eta=" + _eta,
+ "utype=" + _utype,
+ "freq=" + _freq,
+ "scheme=" + _scheme,
+ "runtime_balancing=" + _runtime_balancing,
+ "weighting=" + _weighting,
+ "network_type=" + _networkType,
+ "channels=" + C,
+ "hin=" + Hin,
+ "win=" + Win,
+ "seed=" + _seed,
+ "modelAvg=" +
Boolean.toString(modelAvg).toUpperCase()));
+
+ programArgs = programArgsList.toArray(new String[0]);
+ LOG.debug(runTest(null));
+ Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
+
+ // shut down threads
+ for(int i = 0; i < _numFederatedWorkers; i++) {
+ TestUtils.shutdownThreads(threads.get(i));
+ }
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+
+ /**
+ * Generates an feature matrix that has the same format as the MNIST
dataset,
+ * but is completely random and normalized
+ *
+ * @param numExamples Number of examples to generate
+ * @param C Channels in the input data
+ * @param Hin Height in Pixels of the input data
+ * @param Win Width in Pixels of the input data
+ * @return a dummy MNIST feature matrix
+ */
+ private double[][] generateDummyMNISTFeatures(int numExamples, int C,
int Hin, int Win) {
+ // Seed -1 takes the time in milliseconds as a seed
+ // Sparsity 1 means no sparsity
+ return getRandomMatrix(numExamples, C*Hin*Win, 0, 1, 1, -1);
+ }
+
+ /**
+ * Generates an label matrix that has the same format as the MNIST
dataset, but is completely random and consists
+ * of one hot encoded vectors as rows
+ *
+ * @param numExamples Number of examples to generate
+ * @param numLabels Number of labels to generate
+ * @return a dummy MNIST lable matrix
+ */
+ private double[][] generateDummyMNISTLabels(int numExamples, int
numLabels) {
+ // Seed -1 takes the time in milliseconds as a seed
+ // Sparsity 1 means no sparsity
+ return getRandomMatrix(numExamples, numLabels, 0, 1, 1, -1);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
new file mode 100644
index 0000000..396f24f
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservLocalNNAveragingTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.paramserv;
+
+import org.apache.sysds.parser.Statement;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.junit.Test;
+
[email protected]
+public class ParamservLocalNNAveragingTest extends AutomatedTestBase {
+
+ private static final String TEST_NAME = "paramserv-averaging-test";
+
+ private static final String TEST_DIR = "functions/paramserv/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
ParamservLocalNNAveragingTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {}));
+ }
+
+ @Test
+ public void testParamservBSPBatchDisjointContiguous() {
+ runDMLTest(10, 2, Statement.PSUpdateType.BSP,
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, true);
+ }
+
+ @Test
+ public void testParamservBSPEpoch() {
+ runDMLTest(10, 2, Statement.PSUpdateType.BSP,
Statement.PSFrequency.EPOCH, 32, Statement.PSScheme.DISJOINT_CONTIGUOUS, true);
+ }
+
+ @Test
+ public void testParamservBSPBatchDisjointRoundRobin() {
+ runDMLTest(10, 2, Statement.PSUpdateType.BSP,
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_ROUND_ROBIN, true);
+ }
+
+ @Test
+ public void testParamservBSPBatchDisjointRandom() {
+ runDMLTest(10, 2, Statement.PSUpdateType.BSP,
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.DISJOINT_RANDOM, true);
+ }
+
+ @Test
+ public void testParamservBSPBatchOverlapReshuffle() {
+ runDMLTest(10, 2, Statement.PSUpdateType.BSP,
Statement.PSFrequency.BATCH, 32, Statement.PSScheme.OVERLAP_RESHUFFLE, true);
+ }
+
+ private void runDMLTest(int epochs, int workers, Statement.PSUpdateType
utype, Statement.PSFrequency freq, int batchsize, Statement.PSScheme scheme,
boolean modelAvg) {
+ TestConfiguration config =
getTestConfiguration(ParamservLocalNNAveragingTest.TEST_NAME);
+ loadTestConfiguration(config);
+ programArgs = new String[] { "-stats", "-nvargs", "mode=LOCAL",
"epochs=" + epochs,
+ "workers=" + workers, "utype=" + utype, "freq=" + freq,
"batchsize=" + batchsize,
+ "scheme=" + scheme, "modelAvg=" +modelAvg };
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME +
ParamservLocalNNAveragingTest.TEST_NAME + ".dml";
+ runTest(true, false, null, null, -1);
+ }
+}
diff --git
a/src/test/scripts/functions/federated/paramserv/AvgModelFederatedParamservTest.dml
b/src/test/scripts/functions/federated/paramserv/AvgModelFederatedParamservTest.dml
new file mode 100644
index 0000000..b802186
--- /dev/null
+++
b/src/test/scripts/functions/federated/paramserv/AvgModelFederatedParamservTest.dml
@@ -0,0 +1,61 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/federated/paramserv/TwoNN.dml") as TwoNN
+source("src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml") as
TwoNNModelAvg
+source("src/test/scripts/functions/federated/paramserv/CNN.dml") as CNN
+source("src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml") as
CNNModelAvg
+
+
+# create federated input matrices
+features = read($features)
+labels = read($labels)
+
+if($network_type == "TwoNN") {
+ if(!as.logical($modelAvg)) {
+ model = TwoNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed)
+ print("Test results:")
+ [loss_test, accuracy_test] = TwoNN::validate(matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), model, list())
+ print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test
+ "\n")
+ }
+ else if (as.logical($modelAvg)){
+ model = TwoNNModelAvg::train_paramserv(features, labels, matrix(0,
rows=100, cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $seed, $modelAvg)
+ print("Test results:")
+ [loss_test, accuracy_test] = TwoNNModelAvg::validate(matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), model, list())
+ print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test
+ "\n")
+ }
+}
+else if($network_type == "CNN") {
+ if(!as.logical($modelAvg)) {
+ model = CNN::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin,
$win, $seed)
+ print("Test results:")
+ hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+ [loss_test, accuracy_test] = CNN::validate(matrix(0, rows=100, cols=784),
matrix(0, rows=100, cols=10), model, hyperparams)
+ print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test
+ "\n")
+ }
+ else if (as.logical($modelAvg)){
+ model = CNNModelAvg::train_paramserv(features, labels, matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), 0, $epochs, $utype, $freq,
$batch_size, $scheme, $runtime_balancing, $weighting, $eta, $channels, $hin,
$win, $seed, $modelAvg)
+ print("Test results:")
+ hyperparams = list(learning_rate=$eta, C=$channels, Hin=$hin, Win=$win)
+ [loss_test, accuracy_test] = CNNModelAvg::validate(matrix(0, rows=100,
cols=784), matrix(0, rows=100, cols=10), model, hyperparams)
+ print("[+] test loss: " + loss_test + ", test accuracy: " + accuracy_test
+ "\n")
+ }
+}
diff --git a/src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml
b/src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml
new file mode 100644
index 0000000..d14cc18
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml
@@ -0,0 +1,474 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a convolutional
neural network of the "LeNet" architecture
+ * on different execution schemes and with different inputs, for example a
federated input matrix.
+ */
+
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+/*
+ * Trains a convolutional net using the "LeNet" architectur single threaded
the conventional way.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win)
+ * - y: Target matrix, of shape (N, K)
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ * - y_val: Target validation matrix, of shape (N, K)
+ * - C: Number of input channels (dimensionality of input depth)
+ * - Hin: Input height
+ * - Win: Input width
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ * - b1: 1st layer biases vector, of shape (F1, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ * - b2: 2nd layer biases vector, of shape (F2, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape
(F2*(Hin/4)*(Win/4), N3)
+ * - b3: 3rd layer biases vector, of shape (1, N3)
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ * - b4: 4th layer biases vector, of shape (1, K)
+ */
+train = function(matrix[double] X, matrix[double] y, matrix[double] X_val,
+ matrix[double] y_val, int epochs, int batch_size, double eta, int C, int Hin,
+ int Win, int seed = -1, boolean modelAvg) return (list[unknown] model)
+{
+ N = nrow(X)
+ K = ncol(y)
+
+ # Create network:
+ ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 ->
relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target
dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed) # inputs: (N, C*Hin*Win)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed) # inputs: (N,
F1*(Hin/2)*(Win/2))
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed) # inputs:
(N, F2*(Hin/2/2)*(Win/2/2))
+ lseed = ifelse(seed==-1, -1, seed + 3);
+ [W4, b4] = affine::init(N3, K, seed = lseed) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ mu = 0.9 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+ model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2,
vb3, vb4)
+
+ # Regularization
+ lambda = 5e-04
+
+ # Create the hyper parameter list
+ hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin,
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2,
N3=N3)
+ # Calculate iterations
+ iters = ceil(N / batch_size)
+
+ for (e in 1:epochs) {
+ for(i in 1:iters) {
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+ y_batch = y[beg:end,]
+
+ gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+ model = aggregation(model, hyperparams, gradients_list)
+ }
+ }
+}
+
+/*
+ * Trains a convolutional net using the "LeNet" architecture using a parameter
server with specified properties.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win)
+ * - Y: Target matrix, of shape (N, K)
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win)
+ * - Y_val: Target validation matrix, of shape (N, K)
+ * - C: Number of input channels (dimensionality of input depth)
+ * - Hin: Input height
+ * - Win: Input width
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ * - workers: Number of workers to create
+ * - utype: parameter server framework to use
+ * - scheme: update schema
+ * - mode: local or distributed
+ * - modelAvg: Optional boolean parameter to select between updating or
averaging the model in paramserver side.
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ * - b1: 1st layer biases vector, of shape (F1, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ * - b2: 2nd layer biases vector, of shape (F2, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape
(F2*(Hin/4)*(Win/4), N3)
+ * - b3: 3rd layer biases vector, of shape (1, N3)
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ * - b4: 4th layer biases vector, of shape (1, K)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+ matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
+ string utype, string freq, int batch_size, string scheme, string
runtime_balancing,
+ string weighting, double eta, int C, int Hin, int Win, int seed = -1,
boolean modelAvg)
+ return (list[unknown] model)
+{
+
+
+ N = nrow(X)
+ K = ncol(y)
+
+ # Create network:
+ ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 ->
relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target
dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf, seed = seed) # inputs: (N, C*Hin*Win)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf, seed = lseed) # inputs: (N,
F1*(Hin/2)*(Win/2))
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, seed = lseed) # inputs:
(N, F2*(Hin/2/2)*(Win/2/2))
+ lseed = ifelse(seed==-1, -1, seed + 3);
+ [W4, b4] = affine::init(N3, K, seed = lseed) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ learning_rate = eta # learning rate
+ mu = 0.9 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+ # Regularization
+ lambda = 5e-04
+ # Create the model list
+ model = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2,
vb3, vb4)
+ # Create the hyper parameter list
+ hyperparams = list(learning_rate=eta, mu=mu, decay=decay, C=C, Hin=Hin,
Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2,
N3=N3)
+
+ # Use paramserv function
+ model = paramserv(model=model, features=X, labels=y, val_features=X_val,
val_labels=y_val,
+
upd="./src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml::gradients",
+
agg="./src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml::aggregation",
+
val="./src/test/scripts/functions/federated/paramserv/CNNModelAvg.dml::validate",
+ k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+ scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting,
hyperparams=hyperparams, seed=seed, modelAvg=modelAvg)
+}
+
+/*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win)
+ * - C: Number of input channels (dimensionality of input depth)
+ * - Hin: Input height
+ * - Win: Input width
+ * - batch_size: Batch size
+ * - model: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf)
+ * - b1: 1st layer biases vector, of shape (F1, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf)
+ * - b2: 2nd layer biases vector, of shape (F2, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape
(F2*(Hin/4)*(Win/4), N3)
+ * - b3: 3rd layer biases vector, of shape (1, N3)
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K)
+ * - b4: 4th layer biases vector, of shape (1, K)
+ *
+ * Outputs:
+ * - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
list[unknown] model)
+ return (matrix[double] probs) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ N = nrow(X)
+
+ # Network:
+ ## input -> conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 ->
relu3 -> affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+ F1 = nrow(W1) # num conv filters in conv1
+ F2 = nrow(W2) # num conv filters in conv2
+ N3 = ncol(W3) # num nodes in affine3
+ K = ncol(W4) # num nodes in affine4, equal to number of target dimensions
(num classes)
+
+ # Compute predictions over mini-batches
+ probs = matrix(0, rows=N, cols=K)
+ iters = ceil(N / batch_size)
+ for(i in 1:iters, check=0) {
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win,
Hf, Wf, stride, stride,
+ pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1,
2, 2, 2, 2, 0, 0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1,
Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2,
2, 2, 2, 2, 0, 0)
+ ## layer 3: affine3 -> relu3
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outr3, W4, b4)
+ probs_batch = softmax::forward(outa4)
+
+ # Store predictions
+ probs[beg:end,] = probs_batch
+ }
+}
+
+/*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples. The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ * - probs: Class probabilities, of shape (N, K)
+ * - y: Target matrix, of shape (N, K)
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1)
+ * - accuracy: Scalar accuracy, of shape (1)
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+ return (double loss, double accuracy) {
+
+ # Compute loss & accuracy
+ loss = cross_entropy_loss::forward(probs, y)
+ correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+ accuracy = mean(correct_pred)
+}
+
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels,
+ list[unknown] model, list[unknown] hyperparams)
+ return (double loss, double accuracy)
+{
+ [loss, accuracy] = eval(predict(val_features,
as.integer(as.scalar(hyperparams["C"])),
+ as.integer(as.scalar(hyperparams["Hin"])),
as.integer(as.scalar(hyperparams["Win"])),
+ 32, model), val_labels)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+ list[unknown] hyperparams,
+ matrix[double] features,
+ matrix[double] labels)
+ return (list[unknown] gradients) {
+
+ C = as.integer(as.scalar(hyperparams["C"]))
+ Hin = as.integer(as.scalar(hyperparams["Hin"]))
+ Win = as.integer(as.scalar(hyperparams["Win"]))
+ Hf = as.integer(as.scalar(hyperparams["Hf"]))
+ Wf = as.integer(as.scalar(hyperparams["Wf"]))
+ stride = as.integer(as.scalar(hyperparams["stride"]))
+ pad = as.integer(as.scalar(hyperparams["pad"]))
+ lambda = as.double(as.scalar(hyperparams["lambda"]))
+ F1 = as.integer(as.scalar(hyperparams["F1"]))
+ F2 = as.integer(as.scalar(hyperparams["F2"]))
+ N3 = as.integer(as.scalar(hyperparams["N3"]))
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf,
Wf,
+ stride, stride, pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2,
2, 2, 2, 0, 0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1,
Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2,
2, 2, 2, 0, 0)
+ ## layer 3: affine3 -> relu3 -> dropout
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outd3, W4, b4)
+ probs = softmax::forward(outa4)
+
+ # Compute loss & accuracy for training data
+ loss = cross_entropy_loss::forward(probs, labels)
+ accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+ # print("[+] Completed forward pass on batch: train loss: " + loss + ",
train accuracy: " + accuracy)
+
+ # Compute data backward pass
+ ## loss
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ ## layer 4: affine4 -> softmax
+ douta4 = softmax::backward(dprobs, outa4)
+ [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+ ## layer 3: affine3 -> relu3 -> dropout
+ doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+ douta3 = relu::backward(doutr3, outa3)
+ [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+ ## layer 2: conv2 -> relu2 -> pool2
+ doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2,
Woutc2, 2, 2, 2, 2, 0, 0)
+ doutc2 = relu::backward(doutr2, outc2)
+ [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2,
F1,
+ Houtp1, Woutp1, Hf, Wf, stride,
stride, pad, pad)
+ ## layer 1: conv1 -> relu1 -> pool1
+ doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1,
Woutc1, 2, 2, 2, 2, 0, 0)
+ doutc1 = relu::backward(doutr1, outc1)
+ [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features,
W1, b1, C, Hin, Win,
+ Hf, Wf, stride, stride, pad, pad)
+
+ # Compute regularization backward pass
+ dW1_reg = l2_reg::backward(W1, lambda)
+ dW2_reg = l2_reg::backward(W2, lambda)
+ dW3_reg = l2_reg::backward(W3, lambda)
+ dW4_reg = l2_reg::backward(W4, lambda)
+ dW1 = dW1 + dW1_reg
+ dW2 = dW2 + dW2_reg
+ dW3 = dW3 + dW3_reg
+ dW4 = dW4 + dW4_reg
+
+ gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+ list[unknown] hyperparams,
+ list[unknown] gradients)
+ return (list[unknown] model_result) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ dW4 = as.matrix(gradients[4])
+ db1 = as.matrix(gradients[5])
+ db2 = as.matrix(gradients[6])
+ db3 = as.matrix(gradients[7])
+ db4 = as.matrix(gradients[8])
+ vW1 = as.matrix(model[9])
+ vW2 = as.matrix(model[10])
+ vW3 = as.matrix(model[11])
+ vW4 = as.matrix(model[12])
+ vb1 = as.matrix(model[13])
+ vb2 = as.matrix(model[14])
+ vb3 = as.matrix(model[15])
+ vb4 = as.matrix(model[16])
+ learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+ mu = as.double(as.scalar(hyperparams["mu"]))
+
+ # Optimize with SGD w/ Nesterov momentum
+ [W1, vW1] = sgd_nesterov::update(W1, dW1, learning_rate, mu, vW1)
+ [b1, vb1] = sgd_nesterov::update(b1, db1, learning_rate, mu, vb1)
+ [W2, vW2] = sgd_nesterov::update(W2, dW2, learning_rate, mu, vW2)
+ [b2, vb2] = sgd_nesterov::update(b2, db2, learning_rate, mu, vb2)
+ [W3, vW3] = sgd_nesterov::update(W3, dW3, learning_rate, mu, vW3)
+ [b3, vb3] = sgd_nesterov::update(b3, db3, learning_rate, mu, vb3)
+ [W4, vW4] = sgd_nesterov::update(W4, dW4, learning_rate, mu, vW4)
+ [b4, vb4] = sgd_nesterov::update(b4, db4, learning_rate, mu, vb4)
+
+ model_result = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4,
vb1, vb2, vb3, vb4)
+}
diff --git a/src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml
b/src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml
new file mode 100644
index 0000000..049bea8
--- /dev/null
+++ b/src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml
@@ -0,0 +1,307 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * This file implements all needed functions to evaluate a simple feed forward
neural network
+ * on different execution schemes and with different inputs, for example a
federated input matrix.
+ */
+
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd.dml") as sgd
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers single
threaded the conventional way.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix of shape (N, D)
+ * - y: Target matrix of shape (N, K)
+ * - X_val: Input validation data matrix of shape (N_val, D)
+ * - y_val: Targed validation matrix of shape (N_val, K)
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ * - Optional boolean parameter to select between updating or averaging the
model in paramserver side.
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ * - b1: 1st layer biases vector, of shape (200, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ * - b2: 2nd layer biases vector, of shape (200, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ * - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train = function(matrix[double] X, matrix[double] y,
+ matrix[double] X_val, matrix[double] y_val,
+ int epochs, int batch_size, double eta,
+ int seed = -1 , boolean modelAvg )
+ return (list[unknown] model) {
+
+ N = nrow(X) # num examples
+ D = ncol(X) # num features
+ K = ncol(y) # num classes
+
+ # Create the network:
+ ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+ [W1, b1] = affine::init(D, 200, seed = seed)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = affine::init(200, 200, seed = lseed)
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(200, K, seed = lseed)
+ W3 = W3 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
+ model = list(W1, W2, W3, b1, b2, b3)
+
+ # Create the hyper parameter list
+ hyperparams = list(learning_rate=eta)
+ # Calculate iterations
+ iters = ceil(N / batch_size)
+
+ for (e in 1:epochs) {
+ for(i in 1:iters) {
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+ y_batch = y[beg:end,]
+
+ gradients_list = gradients(model, hyperparams, X_batch, y_batch)
+ model = aggregation(model, hyperparams, gradients_list)
+ }
+ }
+}
+
+/*
+ * Trains a simple feed forward neural network with two hidden layers
+ * using a parameter server with specified properties.
+ *
+ * The input matrix has one example per row (N) and D features.
+ * The targets, y, have K classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix of shape (N, D)
+ * - y: Target matrix of shape (N, K)
+ * - X_val: Input validation data matrix of shape (N_val, D)
+ * - y_val: Targed validation matrix of shape (N_val, K)
+ * - epochs: Total number of full training loops over the full data set
+ * - batch_size: Batch size
+ * - learning_rate: The learning rate for the SGD
+ * - workers: Number of workers to create
+ * - utype: parameter server framework to use
+ * - scheme: update schema
+ * - mode: local or distributed
+ *
+ * Outputs:
+ * - model_trained: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ * - b1: 1st layer biases vector, of shape (200, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ * - b2: 2nd layer biases vector, of shape (200, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ * - b3: 3rd layer biases vector, of shape (K, 1)
+ */
+train_paramserv = function(matrix[double] X, matrix[double] y,
+ matrix[double] X_val, matrix[double] y_val,
+ int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing, string weighting,
+ double eta, int seed = -1, boolean modelAvg)
+ return (list[unknown] model) {
+
+ N = nrow(X) # num examples
+ D = ncol(X) # num features
+ K = ncol(y) # num classes
+
+ # Create the network:
+ ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+ [W1, b1] = affine::init(D, 200, seed = seed)
+ lseed = ifelse(seed==-1, -1, seed + 1);
+ [W2, b2] = affine::init(200, 200, seed = lseed)
+ lseed = ifelse(seed==-1, -1, seed + 2);
+ [W3, b3] = affine::init(200, K, seed = lseed)
+ # W3 = W3 / sqrt(2) # different initialization, since being fed into
softmax, instead of relu
+
+ # [W1, b1] = affine::init(D, 200)
+ # [W2, b2] = affine::init(200, 200)
+ # [W3, b3] = affine::init(200, K)
+
+ # Create the model list
+ model = list(W1, W2, W3, b1, b2, b3)
+ # Create the hyper parameter list
+ hyperparams = list(learning_rate=eta)
+
+ # Use paramserv function
+ model = paramserv(model=model, features=X, labels=y, val_features=X_val,
val_labels=y_val,
+
upd="./src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml::gradients",
+
agg="./src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml::aggregation",
+
val="./src/test/scripts/functions/federated/paramserv/TwoNNModelAvg.dml::validate",
+ k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
+ scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting,
hyperparams=hyperparams, seed=seed, modelAvg=modelAvg)
+}
+
+/*
+ * Computes the class probability predictions of a simple feed forward neural
network.
+ *
+ * Inputs:
+ * - X: The input data matrix of shape (N, D)
+ * - model: List containing
+ * - W1: 1st layer weights (parameters) matrix, of shape (D, 200)
+ * - b1: 1st layer biases vector, of shape (200, 1)
+ * - W2: 2nd layer weights (parameters) matrix, of shape (200, 200)
+ * - b2: 2nd layer biases vector, of shape (200, 1)
+ * - W3: 3rd layer weights (parameters) matrix, of shape (200, K)
+ * - b3: 3rd layer biases vector, of shape (K, 1)
+ *
+ * Outputs:
+ * - probs: Class probabilities, of shape (N, K)
+ */
+predict = function(matrix[double] X,
+ list[unknown] model)
+ return (matrix[double] probs) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ b1 = as.matrix(model[4])
+ b2 = as.matrix(model[5])
+ b3 = as.matrix(model[6])
+
+ out1relu = relu::forward(affine::forward(X, W1, b1))
+ out2relu = relu::forward(affine::forward(out1relu, W2, b2))
+ probs = softmax::forward(affine::forward(out2relu, W3, b3))
+}
+
+/*
+ * Evaluates a simple feed forward neural network.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples. The targets, y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ * - probs: Class probabilities, of shape (N, K).
+ * - y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+eval = function(matrix[double] probs, matrix[double] y)
+ return (double loss, double accuracy) {
+
+ # Compute loss & accuracy
+ loss = cross_entropy_loss::forward(probs, y)
+ correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+ accuracy = mean(correct_pred)
+}
+
+/*
+ * Gives the accuracy and loss for a model and given feature and label matrices
+ *
+ * This function is a combination of the predict and eval function used for
validation.
+ * For inputs see eval and predict.
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+validate = function(matrix[double] val_features, matrix[double] val_labels,
list[unknown] model, list[unknown] hyperparams)
+ return (double loss, double accuracy) {
+ [loss, accuracy] = eval(predict(val_features, model), val_labels)
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+ list[unknown] hyperparams,
+ matrix[double] features,
+ matrix[double] labels)
+ return (list[unknown] gradients) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ b1 = as.matrix(model[4])
+ b2 = as.matrix(model[5])
+ b3 = as.matrix(model[6])
+
+ # Compute forward pass
+ ## input -> affine1 -> relu1 -> affine2 -> relu2 -> affine3 -> softmax
+ out1 = affine::forward(features, W1, b1)
+ out1relu = relu::forward(out1)
+ out2 = affine::forward(out1relu, W2, b2)
+ out2relu = relu::forward(out2)
+ out3 = affine::forward(out2relu, W3, b3)
+ probs = softmax::forward(out3)
+
+ # Compute loss & accuracy for training data
+ loss = cross_entropy_loss::forward(probs, labels)
+ accuracy = mean(rowIndexMax(probs) == rowIndexMax(labels))
+ # print("[+] Completed forward pass on batch: train loss: " + loss + ",
train accuracy: " + accuracy)
+
+ # Compute data backward pass
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ dout3 = softmax::backward(dprobs, out3)
+ [dout2relu, dW3, db3] = affine::backward(dout3, out2relu, W3, b3)
+ dout2 = relu::backward(dout2relu, out2)
+ [dout1relu, dW2, db2] = affine::backward(dout2, out1relu, W2, b2)
+ dout1 = relu::backward(dout1relu, out1)
+ [dfeatures, dW1, db1] = affine::backward(dout1, features, W1, b1)
+
+ gradients = list(dW1, dW2, dW3, db1, db2, db3)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+ list[unknown] hyperparams,
+ list[unknown] gradients)
+ return (list[unknown] model_result) {
+
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ b1 = as.matrix(model[4])
+ b2 = as.matrix(model[5])
+ b3 = as.matrix(model[6])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ db1 = as.matrix(gradients[4])
+ db2 = as.matrix(gradients[5])
+ db3 = as.matrix(gradients[6])
+ learning_rate = as.double(as.scalar(hyperparams["learning_rate"]))
+
+ # Optimize with SGD
+ W3 = sgd::update(W3, dW3, learning_rate)
+ b3 = sgd::update(b3, db3, learning_rate)
+ W2 = sgd::update(W2, dW2, learning_rate)
+ b2 = sgd::update(b2, db2, learning_rate)
+ W1 = sgd::update(W1, dW1, learning_rate)
+ b1 = sgd::update(b1, db1, learning_rate)
+
+ model_result = list(W1, W2, W3, b1, b2, b3)
+}
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml
b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml
new file mode 100644
index 0000000..bd5fd7d
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml
@@ -0,0 +1,372 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("scripts/nn/layers/affine.dml") as affine
+source("scripts/nn/layers/conv2d_builtin.dml") as conv2d
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("scripts/nn/layers/dropout.dml") as dropout
+source("scripts/nn/layers/l2_reg.dml") as l2_reg
+source("scripts/nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("scripts/nn/layers/relu.dml") as relu
+source("scripts/nn/layers/softmax.dml") as softmax
+source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] Y,
+ matrix[double] X_val, matrix[double] Y_val,
+ int C, int Hin, int Win, int epochs, int workers,
+ string utype, string freq, int batchsize, string scheme,
string mode, boolean modelAvg)
+ return (matrix[double] W1, matrix[double] b1,
+ matrix[double] W2, matrix[double] b2,
+ matrix[double] W3, matrix[double] b3,
+ matrix[double] W4, matrix[double] b4) {
+ /*
+ * Trains a convolutional net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector. The targets, Y, have K
+ * classes, and are one-hot encoded.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win).
+ * - Y: Target matrix, of shape (N, K).
+ * - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+ * - Y_val: Target validation matrix, of shape (N, K).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - epochs: Total number of full training loops over the full data set.
+ * - modelAv: Optional boolean parameter to select between updating or
averaging the model in paramserver side.
+ *
+ * Outputs:
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+ * - b1: 1st layer biases vector, of shape (F1, 1).
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+ * - b2: 2nd layer biases vector, of shape (F2, 1).
+ * - W3: 3rd layer weights (parameters) matrix, of shape
(F2*(Hin/4)*(Win/4), N3).
+ * - b3: 3rd layer biases vector, of shape (1, N3).
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+ * - b4: 4th layer biases vector, of shape (1, K).
+ */
+ N = nrow(X)
+ K = ncol(Y)
+
+ # Create network:
+ # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 ->
affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+
+ F1 = 32 # num conv filters in conv1
+ F2 = 64 # num conv filters in conv2
+ N3 = 512 # num nodes in affine3
+ # Note: affine4 has K nodes, which is equal to the number of target
dimensions (num classes)
+
+ [W1, b1] = conv2d::init(F1, C, Hf, Wf, -1) # inputs: (N, C*Hin*Win)
+ [W2, b2] = conv2d::init(F2, F1, Hf, Wf, -1) # inputs: (N,
F1*(Hin/2)*(Win/2))
+ [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3, -1) # inputs: (N,
F2*(Hin/2/2)*(Win/2/2))
+ [W4, b4] = affine::init(N3, K, -1) # inputs: (N, N3)
+ W4 = W4 / sqrt(2) # different initialization, since being fed into softmax,
instead of relu
+
+ # Initialize SGD w/ Nesterov momentum optimizer
+ lr = 0.01 # learning rate
+ mu = 0.9 #0.5 # momentum
+ decay = 0.95 # learning rate decay constant
+ vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+ vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+ vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+ vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+ # Regularization
+ lambda = 5e-04
+
+ # Create the model list
+ modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1,
vb2, vb3, vb4)
+
+ # Create the hyper parameter list
+ params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf,
Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
+
+ # Use paramserv function
+ modelList2 = paramserv(model=modelList, features=X, labels=Y,
val_features=X_val, val_labels=Y_val,
upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml::gradients",
agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml::aggregation",
mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batchsize,
k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE",
modelAvg=modelAvg)
+
+ W1 = as.matrix(modelList2[1])
+ W2 = as.matrix(modelList2[2])
+ W3 = as.matrix(modelList2[3])
+ W4 = as.matrix(modelList2[4])
+ b1 = as.matrix(modelList2[5])
+ b2 = as.matrix(modelList2[6])
+ b3 = as.matrix(modelList2[7])
+ b4 = as.matrix(modelList2[8])
+}
+
+# Should always use 'features' (batch features), 'labels' (batch labels),
+# 'hyperparams', 'model' as the arguments
+# and return the gradients of type list
+gradients = function(list[unknown] model,
+ list[unknown] hyperparams,
+ matrix[double] features,
+ matrix[double] labels)
+ return (list[unknown] gradients) {
+
+ C = as.integer(as.scalar(hyperparams["C"]))
+ Hin = as.integer(as.scalar(hyperparams["Hin"]))
+ Win = as.integer(as.scalar(hyperparams["Win"]))
+ Hf = as.integer(as.scalar(hyperparams["Hf"]))
+ Wf = as.integer(as.scalar(hyperparams["Wf"]))
+ stride = as.integer(as.scalar(hyperparams["stride"]))
+ pad = as.integer(as.scalar(hyperparams["pad"]))
+ lambda = as.double(as.scalar(hyperparams["lambda"]))
+ F1 = as.integer(as.scalar(hyperparams["F1"]))
+ F2 = as.integer(as.scalar(hyperparams["F2"]))
+ N3 = as.integer(as.scalar(hyperparams["N3"]))
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(features, W1, b1, C, Hin, Win, Hf,
Wf,
+ stride, stride, pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 2,
2, 2, 2, 0, 0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, Woutp1,
Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 2,
2, 2, 2, 0, 0)
+ ## layer 3: affine3 -> relu3 -> dropout
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outd3, W4, b4)
+ probs = softmax::forward(outa4)
+
+ # Compute data backward pass
+ ## loss:
+ dprobs = cross_entropy_loss::backward(probs, labels)
+ ## layer 4: affine4 -> softmax
+ douta4 = softmax::backward(dprobs, outa4)
+ [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+ ## layer 3: affine3 -> relu3 -> dropout
+ doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+ douta3 = relu::backward(doutr3, outa3)
+ [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+ ## layer 2: conv2 -> relu2 -> pool2
+ doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2,
Woutc2, 2, 2, 2, 2, 0, 0)
+ doutc2 = relu::backward(doutr2, outc2)
+ [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, b2,
F1,
+ Houtp1, Woutp1, Hf, Wf, stride,
stride, pad, pad)
+ ## layer 1: conv1 -> relu1 -> pool1
+ doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1,
Woutc1, 2, 2, 2, 2, 0, 0)
+ doutc1 = relu::backward(doutr1, outc1)
+ [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, features,
W1, b1, C, Hin, Win,
+ Hf, Wf, stride, stride, pad, pad)
+
+ # Compute regularization backward pass
+ dW1_reg = l2_reg::backward(W1, lambda)
+ dW2_reg = l2_reg::backward(W2, lambda)
+ dW3_reg = l2_reg::backward(W3, lambda)
+ dW4_reg = l2_reg::backward(W4, lambda)
+ dW1 = dW1 + dW1_reg
+ dW2 = dW2 + dW2_reg
+ dW3 = dW3 + dW3_reg
+ dW4 = dW4 + dW4_reg
+
+ gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
+}
+
+# Should use the arguments named 'model', 'gradients', 'hyperparams'
+# and return always a model of type list
+aggregation = function(list[unknown] model,
+ list[unknown] hyperparams,
+ list[unknown] gradients)
+ return (list[unknown] modelResult) {
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ dW4 = as.matrix(gradients[4])
+ db1 = as.matrix(gradients[5])
+ db2 = as.matrix(gradients[6])
+ db3 = as.matrix(gradients[7])
+ db4 = as.matrix(gradients[8])
+ vW1 = as.matrix(model[9])
+ vW2 = as.matrix(model[10])
+ vW3 = as.matrix(model[11])
+ vW4 = as.matrix(model[12])
+ vb1 = as.matrix(model[13])
+ vb2 = as.matrix(model[14])
+ vb3 = as.matrix(model[15])
+ vb4 = as.matrix(model[16])
+ lr = as.double(as.scalar(hyperparams["lr"]))
+ mu = as.double(as.scalar(hyperparams["mu"]))
+
+ # Optimize with SGD w/ Nesterov momentum
+ [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+ [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+ [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+ [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+ [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+ [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+ [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+ [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+
+ modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4,
vb1, vb2, vb3, vb4)
+ }
+
+predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
+ matrix[double] W1, matrix[double] b1,
+ matrix[double] W2, matrix[double] b2,
+ matrix[double] W3, matrix[double] b3,
+ matrix[double] W4, matrix[double] b4)
+ return (matrix[double] probs) {
+ /*
+ * Computes the class probability predictions of a convolutional
+ * net using the "LeNet" architecture.
+ *
+ * The input matrix, X, has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * Inputs:
+ * - X: Input data matrix, of shape (N, C*Hin*Win).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+ * - b1: 1st layer biases vector, of shape (F1, 1).
+ * - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+ * - b2: 2nd layer biases vector, of shape (F2, 1).
+ * - W3: 3rd layer weights (parameters) matrix, of shape
(F2*(Hin/4)*(Win/4), N3).
+ * - b3: 3rd layer biases vector, of shape (1, N3).
+ * - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+ * - b4: 4th layer biases vector, of shape (1, K).
+ *
+ * Outputs:
+ * - probs: Class probabilities, of shape (N, K).
+ */
+ N = nrow(X)
+
+ # Network:
+ # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 ->
affine4 -> softmax
+ Hf = 5 # filter height
+ Wf = 5 # filter width
+ stride = 1
+ pad = 2 # For same dimensions, (Hf - stride) / 2
+
+ F1 = nrow(W1) # num conv filters in conv1
+ F2 = nrow(W2) # num conv filters in conv2
+ N3 = ncol(W3) # num nodes in affine3
+ K = ncol(W4) # num nodes in affine4, equal to number of target dimensions
(num classes)
+
+ # Compute predictions over mini-batches
+ probs = matrix(0, rows=N, cols=K)
+ iters = ceil(N / batch_size)
+ parfor(i in 1:iters, check=0) {
+ # Get next batch
+ beg = ((i-1) * batch_size) %% N + 1
+ end = min(N, beg + batch_size - 1)
+ X_batch = X[beg:end,]
+
+ # Compute forward pass
+ ## layer 1: conv1 -> relu1 -> pool1
+ [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win,
Hf, Wf, stride, stride,
+ pad, pad)
+ outr1 = relu::forward(outc1)
+ [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1,
2, 2, 2, 2, 0, 0)
+ ## layer 2: conv2 -> relu2 -> pool2
+ [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1,
Woutp1, Hf, Wf,
+ stride, stride, pad, pad)
+ outr2 = relu::forward(outc2)
+ [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2,
2, 2, 2, 2, 0, 0)
+ ## layer 3: affine3 -> relu3
+ outa3 = affine::forward(outp2, W3, b3)
+ outr3 = relu::forward(outa3)
+ ## layer 4: affine4 -> softmax
+ outa4 = affine::forward(outr3, W4, b4)
+ probs_batch = softmax::forward(outa4)
+
+ # Store predictions
+ probs[beg:end,] = probs_batch
+ }
+}
+
+eval = function(matrix[double] probs, matrix[double] Y)
+ return (double loss, double accuracy) {
+ /*
+ * Evaluates a convolutional net using the "LeNet" architecture.
+ *
+ * The probs matrix contains the class probability predictions
+ * of K classes over N examples. The targets, Y, have K classes,
+ * and are one-hot encoded.
+ *
+ * Inputs:
+ * - probs: Class probabilities, of shape (N, K).
+ * - Y: Target matrix, of shape (N, K).
+ *
+ * Outputs:
+ * - loss: Scalar loss, of shape (1).
+ * - accuracy: Scalar accuracy, of shape (1).
+ */
+ # Compute loss & accuracy
+ loss = cross_entropy_loss::forward(probs, Y)
+ correct_pred = rowIndexMax(probs) == rowIndexMax(Y)
+ accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+ return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {
+ /*
+ * Generate a dummy dataset similar to the MNIST dataset.
+ *
+ * Outputs:
+ * - X: Input data matrix, of shape (N, D).
+ * - Y: Target matrix, of shape (N, K).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ */
+ # Generate dummy input data
+ N = 1024 # num examples
+ C = 1 # num input channels
+ Hin = 28 # input height
+ Win = 28 # input width
+ K = 10 # num target classes
+ X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+ classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+ Y = table(seq(1, N), classes) # one-hot encoding
+}
diff --git a/src/test/scripts/functions/paramserv/paramserv-averaging-test.dml
b/src/test/scripts/functions/paramserv/paramserv-averaging-test.dml
new file mode 100644
index 0000000..073e216
--- /dev/null
+++ b/src/test/scripts/functions/paramserv/paramserv-averaging-test.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv_avg.dml")
as mnist_lenet_avg
+source("src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml") as
mnist_lenet
+source("scripts/nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+
+# Generate the training data
+[images, labels, C, Hin, Win] = mnist_lenet_avg::generate_dummy_data()
+n = nrow(images)
+
+# Generate the training data
+[X, Y, C, Hin, Win] = mnist_lenet_avg::generate_dummy_data()
+
+# Split into training and validation
+val_size = n * 0.1
+X = images[(val_size+1):n,]
+X_val = images[1:val_size,]
+Y = labels[(val_size+1):n,]
+Y_val = labels[1:val_size,]
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet_avg::train(X, Y, X_val, Y_val,
C, Hin, Win, $epochs, $workers, $utype, $freq, $batchsize, $scheme, $mode,
$modelAvg)
+
+# Compute validation loss & accuracy
+probs_val = mnist_lenet_avg::predict(X_val, C, Hin, Win, $batchsize, W1, b1,
W2, b2, W3, b3, W4, b4)
+loss_val = cross_entropy_loss::forward(probs_val, Y_val)
+accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))
+
+# Output results
+print("Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)