mboehm7 commented on a change in pull request #1336:
URL: https://github.com/apache/systemds/pull/1336#discussion_r672207342
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
##########
@@ -337,6 +338,7 @@ protected void weighAndPushGradients(ListObject gradients) {
// Push the gradients to ps
_ps.push(_workerID, gradients);
+ //_ps.push(_workerID, modell)
Review comment:
what's this? please conditionally on the configuration either push the
gradients or model
##########
File path: src/main/java/org/apache/sysds/parser/Statement.java
##########
@@ -72,6 +72,9 @@
public static final String PS_MODE = "mode";
public static final String PS_GRADIENTS = "gradients";
public static final String PS_SEED = "seed";
+ public static final String PS_MODELAVG = "modelAvg";
+ public static final String PS_MODELS = "models";
Review comment:
why do we need this besides modelAvg - remove if unnecessary.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -112,69 +120,89 @@ private void computeEpoch(long dataSize, int batchIter) {
catch(ExecutionException | InterruptedException ex) {
throw new DMLRuntimeException(ex);
}
-
+
accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
}
}
}
+ protected ListObject createLocalModel(ExecutionContext ec, ListObject
gradients, ListObject model) {
+ // Populate the variables table with the gradients and model
+ ec.setVariable(Statement.PS_GRADIENTS, gradients);
+ ec.setVariable(Statement.PS_MODEL, model);
+
+ // Invoke the aggregate function
+ _inst.processInstruction(ec);
+
+ // Get the new model
+ ListObject newModel = ec.getListObject(_outputName);
+
+ // Clean up the list according to the data referencing status
+ ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL,
newModel.getStatus());
+ ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+ return newModel;
+ }
+
private ListObject updateModel(ListObject globalParams, ListObject
gradients, int i, int j, int batchIter) {
Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
globalParams = _ps.updateLocalModel(_ec, gradients,
globalParams);
accLocalModelUpdateTime(tUpd);
-
+
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: local global parameter
[size:%d kb] updated. "
- + "[Epoch:%d Total epoch:%d Iteration:%d
Total iteration:%d]",
- getWorkerName(), globalParams.getDataSize(), i
+ 1, _epochs, j + 1, batchIter));
+ + "[Epoch:%d Total
epoch:%d Iteration:%d Total iteration:%d]",
+ getWorkerName(),
globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));
Review comment:
fix the corrupted formatting.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -112,69 +120,89 @@ private void computeEpoch(long dataSize, int batchIter) {
catch(ExecutionException | InterruptedException ex) {
throw new DMLRuntimeException(ex);
}
-
+
accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
}
}
}
+ protected ListObject createLocalModel(ExecutionContext ec, ListObject
gradients, ListObject model) {
+ // Populate the variables table with the gradients and model
+ ec.setVariable(Statement.PS_GRADIENTS, gradients);
+ ec.setVariable(Statement.PS_MODEL, model);
+
+ // Invoke the aggregate function
+ _inst.processInstruction(ec);
+
+ // Get the new model
+ ListObject newModel = ec.getListObject(_outputName);
+
+ // Clean up the list according to the data referencing status
+ ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL,
newModel.getStatus());
+ ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+ return newModel;
+ }
+
private ListObject updateModel(ListObject globalParams, ListObject
gradients, int i, int j, int batchIter) {
Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
globalParams = _ps.updateLocalModel(_ec, gradients,
globalParams);
accLocalModelUpdateTime(tUpd);
-
+
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: local global parameter
[size:%d kb] updated. "
- + "[Epoch:%d Total epoch:%d Iteration:%d
Total iteration:%d]",
- getWorkerName(), globalParams.getDataSize(), i
+ 1, _epochs, j + 1, batchIter));
+ + "[Epoch:%d Total
epoch:%d Iteration:%d Total iteration:%d]",
+ getWorkerName(),
globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));
}
return globalParams;
}
-
private void computeBatch(long dataSize, int totalIter) {
for (int i = 0; i < _epochs; i++) {
for (int j = 0; j < totalIter; j++) {
ListObject globalParams = pullModel();
ListObject gradients =
computeGradients(globalParams, dataSize, totalIter, i, j);
-
// Push the gradients to ps
pushGradients(gradients);
ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
-
+
accNumBatches(1);
}
-
+
accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
}
}
}
-
private ListObject pullModel() {
// Pull the global parameters from ps
ListObject globalParams = _ps.pull(_workerID);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: successfully pull the
global parameters "
- + "[size:%d kb] from ps.", getWorkerName(),
globalParams.getDataSize() / 1024));
+ + "[size:%d kb] from ps.",
getWorkerName(), globalParams.getDataSize() / 1024));
}
return globalParams;
}
-
private void pushGradients(ListObject gradients) {
// Push the gradients to ps
_ps.push(_workerID, gradients);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: successfully push the
gradients "
- + "[size:%d kb] to ps.", getWorkerName(),
gradients.getDataSize() / 1024));
+ + "[size:%d kb] to ps.",
getWorkerName(), gradients.getDataSize() / 1024));
+ }
+ }
+ private void pushModelToServer(ListObject modell) {
+ // Push the Model to ps
+ _ps.push(_workerID, modell);
Review comment:
please, globally correct the spelling of variables: it's model not
modell.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
return _model;
}
- protected synchronized void updateGlobalModel(int workerID, ListObject
gradients) {
+ protected synchronized void updModel_avgModel(int workerID, ListObject
params){
+ if (_modelAvg == true){
+ updateAverageModel(workerID, params);}
+ else if (_modelAvg == false)
+ updateGlobalModel(workerID, params);
+
+}
+ protected void updateAverageModel(int workerID, ListObject models) {
try {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Successfully pulled
the gradients [size:%d kb] of worker_%d.",
- gradients.getDataSize() / 1024,
workerID));
+ models.getDataSize() / 1024,
workerID));
}
switch(_updateType) {
case BSP: {
setFinishedState(workerID);
-
- // Accumulate the intermediate gradients
- if( ACCRUE_BSP_GRADIENTS )
- _accGradients =
ParamservUtils.accrueGradients(_accGradients, gradients, true);
- else
- updateGlobalModel(gradients);
+ _accModel =
ParamservUtils.accrueModels(_accModel, models, true);
if (allFinished()) {
- // Update the global model with
accrued gradients
- if( ACCRUE_BSP_GRADIENTS ) {
-
updateGlobalModel(_accGradients);
- _accGradients = null;
- }
+ averageGlobalModel(_accModel);
+ _accModel = null;
// This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
// In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
// BSP epoch case every time
if (_numBatchesPerEpoch != -1 &&
- (_freq ==
Statement.PSFrequency.EPOCH ||
- (_freq ==
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+ (_freq ==
Statement.PSFrequency.EPOCH ||
+
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch
== 0))) {
if(LOG.isInfoEnabled())
LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
time_epoch();
-
if(_validationPossible)
validate();
-
_epochCounter++;
_syncCounter = 0;
+
}
-
// Broadcast the updated model
resetFinishedStates();
+
broadcastModel(true);
if (LOG.isDebugEnabled())
- LOG.debug("Global
parameter is broadcasted successfully.");
+ LOG.debug("Global
Averaging parameter is broadcasted successfully ");
Review comment:
what is an Averaging parameter - the model broadcast should be
unaffected by the introduction of model averaging.
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -327,12 +317,12 @@ private void runLocally(ExecutionContext ec, PSModeType
mode) {
MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null)
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ?
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
ParamServer ps = createPS(mode, aggFunc, updateType, freq,
workerNum, model, aggServiceEC, getValFunction(),
- num_batches_per_epoch, val_features,
val_labels);
+ num_batches_per_epoch, val_features,
val_labels,parseBoolean(modelAvg));
// Create the local workers
List<LocalPSWorker> workers = IntStream.range(0, workerNum)
.mapToObj(i -> new LocalPSWorker(i, updFunc, freq,
- getEpochs(), getBatchSize(), workerECs.get(i),
ps))
+ getEpochs(), getBatchSize(), workerECs.get(i),
ps,parseBoolean(modelAvg)))
Review comment:
missing spaces before additional arg
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -468,21 +458,21 @@ private int getWorkerNum(PSModeType mode) {
* @return parameter server
*/
private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
- PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec)
+ PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec,boolean modelAvg)
{
- return createPS(mode, aggFunc, updateType, freq, workerNum,
model, ec, null, -1, null, null);
+ return createPS(mode, aggFunc, updateType, freq, workerNum,
model, ec, null, -1, null, null,modelAvg );
}
// When this creation is used the parameter server is able to validate
after each epoch
private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec, String valFunc,
- int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels)
+ int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels,boolean modelAvg)
{
- switch (mode) {
+ switch (mode) {
Review comment:
wrong formatting.
##########
File path: src/test/scripts/functions/federated/paramserv/TwoNN.dml
##########
@@ -126,7 +126,7 @@ train = function(matrix[double] X, matrix[double] y,
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
int num_workers, int epochs, string utype, string freq, int
batch_size, string scheme, string runtime_balancing, string weighting,
- double eta, int seed = -1)
+ double eta, int seed = -1,boolean modelAvg)
Review comment:
formatting.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -191,9 +219,9 @@ private ListObject computeGradients(ListObject params, long
dataSize, int batchI
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: got batch data [size:%d
kb] of index from %d to %d [last index: %d]. "
- + "[Epoch:%d Total epoch:%d Iteration:%d
Total iteration:%d]", getWorkerName(),
- bFeatures.getDataSize() / 1024 +
bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
- j + 1, batchIter));
+ + "[Epoch:%d Total
epoch:%d Iteration:%d Total iteration:%d]", getWorkerName(),
Review comment:
see formatting above.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
##########
@@ -360,7 +362,26 @@ protected void computeWithBatchUpdates() {
}
}
}
+ //**************************************** ATEFEH
*********************************************************************
Review comment:
we do not use author tags - so please remove such comments with your
name.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -89,19 +97,19 @@ private void computeEpoch(long dataSize, int batchIter) {
try {
for (int j = 0; j < batchIter; j++) {
ListObject gradients =
computeGradients(params, dataSize, batchIter, i, j);
-
+
boolean localUpdate = j < batchIter - 1;
-
- // Accumulate the intermediate
gradients (async for overlap w/ model updates
+
+ // Accumulate the intermediate
gradients (async for overlap w/ model updates
// and gradient computation, sequential
over gradient matrices to avoid deadlocks)
ListObject accGradientsPrev =
accGradients.get();
accGradients = _tpool.submit(() ->
ParamservUtils.accrueGradients(
Review comment:
we only need to accrue gradients if we aim to exchange them - if model
averaging is enabled this can be avoided.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
##########
@@ -33,24 +33,32 @@ public LocalParamServer() {
public static LocalParamServer create(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch,
- MatrixObject valFeatures, MatrixObject valLabels)
+ MatrixObject valFeatures, MatrixObject valLabels,boolean
modelAvg)
{
return new LocalParamServer(model, aggFunc, updateType, freq,
ec,
- workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels);
+ workerNum, valFunc, numBatchesPerEpoch, valFeatures,
valLabels,modelAvg);
}
private LocalParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch,
- MatrixObject valFeatures, MatrixObject valLabels)
+ MatrixObject valFeatures, MatrixObject valLabels,boolean
modelAvg)
{
- super(model, aggFunc, updateType, freq, ec, workerNum, valFunc,
numBatchesPerEpoch, valFeatures, valLabels);
+ super(model, aggFunc, updateType, freq, ec, workerNum, valFunc,
numBatchesPerEpoch, valFeatures, valLabels,modelAvg);
}
@Override
public void push(int workerID, ListObject gradients) {
- updateGlobalModel(workerID, gradients);
+ updModel_avgModel(workerID, gradients);
}
+ /*
+ public void push(int workerID, ListObject values) {
Review comment:
remove such commented code.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -89,19 +97,19 @@ private void computeEpoch(long dataSize, int batchIter) {
try {
for (int j = 0; j < batchIter; j++) {
ListObject gradients =
computeGradients(params, dataSize, batchIter, i, j);
-
+
boolean localUpdate = j < batchIter - 1;
-
- // Accumulate the intermediate
gradients (async for overlap w/ model updates
+
+ // Accumulate the intermediate
gradients (async for overlap w/ model updates
// and gradient computation, sequential
over gradient matrices to avoid deadlocks)
ListObject accGradientsPrev =
accGradients.get();
accGradients = _tpool.submit(() ->
ParamservUtils.accrueGradients(
- accGradientsPrev, gradients,
false, !localUpdate));
-
+ accGradientsPrev,
gradients, false, !localUpdate));
Review comment:
some as above
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -112,69 +120,89 @@ private void computeEpoch(long dataSize, int batchIter) {
catch(ExecutionException | InterruptedException ex) {
throw new DMLRuntimeException(ex);
}
-
+
accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
}
}
}
+ protected ListObject createLocalModel(ExecutionContext ec, ListObject
gradients, ListObject model) {
+ // Populate the variables table with the gradients and model
+ ec.setVariable(Statement.PS_GRADIENTS, gradients);
+ ec.setVariable(Statement.PS_MODEL, model);
+
+ // Invoke the aggregate function
+ _inst.processInstruction(ec);
+
+ // Get the new model
+ ListObject newModel = ec.getListObject(_outputName);
+
+ // Clean up the list according to the data referencing status
+ ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL,
newModel.getStatus());
+ ParamservUtils.cleanupListObject(ec, Statement.PS_GRADIENTS);
+ return newModel;
+ }
+
private ListObject updateModel(ListObject globalParams, ListObject
gradients, int i, int j, int batchIter) {
Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
globalParams = _ps.updateLocalModel(_ec, gradients,
globalParams);
accLocalModelUpdateTime(tUpd);
-
+
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: local global parameter
[size:%d kb] updated. "
- + "[Epoch:%d Total epoch:%d Iteration:%d
Total iteration:%d]",
- getWorkerName(), globalParams.getDataSize(), i
+ 1, _epochs, j + 1, batchIter));
+ + "[Epoch:%d Total
epoch:%d Iteration:%d Total iteration:%d]",
+ getWorkerName(),
globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));
}
return globalParams;
}
-
private void computeBatch(long dataSize, int totalIter) {
for (int i = 0; i < _epochs; i++) {
for (int j = 0; j < totalIter; j++) {
ListObject globalParams = pullModel();
ListObject gradients =
computeGradients(globalParams, dataSize, totalIter, i, j);
-
// Push the gradients to ps
pushGradients(gradients);
ParamservUtils.cleanupListObject(_ec,
Statement.PS_MODEL);
-
+
accNumBatches(1);
}
-
+
accNumEpochs(1);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: finished %d
epoch.", getWorkerName(), i + 1));
}
}
}
-
private ListObject pullModel() {
// Pull the global parameters from ps
ListObject globalParams = _ps.pull(_workerID);
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: successfully pull the
global parameters "
- + "[size:%d kb] from ps.", getWorkerName(),
globalParams.getDataSize() / 1024));
+ + "[size:%d kb] from ps.",
getWorkerName(), globalParams.getDataSize() / 1024));
Review comment:
see above.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -35,24 +31,30 @@
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.utils.Statistics;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+
public class LocalPSWorker extends PSWorker implements Callable<Void> {
protected static final Log LOG =
LogFactory.getLog(LocalPSWorker.class.getName());
private static final long serialVersionUID = 5195390748495357295L;
+ private ListObject modell;
+ private String _outputName;
protected LocalPSWorker() {}
public LocalPSWorker(int workerID, String updFunc,
Statement.PSFrequency freq,
- int epochs, long batchSize, ExecutionContext ec, ParamServer ps)
+ int epochs, long batchSize,
ExecutionContext ec, ParamServer ps,boolean modelavg)
Review comment:
please, avoid corrupting the existing formatting.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
return _model;
}
- protected synchronized void updateGlobalModel(int workerID, ListObject
gradients) {
+ protected synchronized void updModel_avgModel(int workerID, ListObject
params){
+ if (_modelAvg == true){
+ updateAverageModel(workerID, params);}
+ else if (_modelAvg == false)
+ updateGlobalModel(workerID, params);
+
+}
+ protected void updateAverageModel(int workerID, ListObject models) {
try {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Successfully pulled
the gradients [size:%d kb] of worker_%d.",
- gradients.getDataSize() / 1024,
workerID));
+ models.getDataSize() / 1024,
workerID));
}
switch(_updateType) {
case BSP: {
setFinishedState(workerID);
-
- // Accumulate the intermediate gradients
- if( ACCRUE_BSP_GRADIENTS )
- _accGradients =
ParamservUtils.accrueGradients(_accGradients, gradients, true);
- else
- updateGlobalModel(gradients);
+ _accModel =
ParamservUtils.accrueModels(_accModel, models, true);
if (allFinished()) {
- // Update the global model with
accrued gradients
- if( ACCRUE_BSP_GRADIENTS ) {
-
updateGlobalModel(_accGradients);
- _accGradients = null;
- }
+ averageGlobalModel(_accModel);
+ _accModel = null;
// This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
// In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
// BSP epoch case every time
if (_numBatchesPerEpoch != -1 &&
- (_freq ==
Statement.PSFrequency.EPOCH ||
- (_freq ==
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+ (_freq ==
Statement.PSFrequency.EPOCH ||
+
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch
== 0))) {
if(LOG.isInfoEnabled())
LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
time_epoch();
-
if(_validationPossible)
validate();
-
_epochCounter++;
_syncCounter = 0;
+
Review comment:
do not introduce new lines before closing curly braces.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalPSWorker.java
##########
@@ -62,7 +64,13 @@ public Void call() throws Exception {
switch (_freq) {
case BATCH:
- computeBatch(dataSize, batchIter);
+ if (_modelAvg){
+
computeBatch_Avg(dataSize,batchIter);
+ }
+
+ else
+ computeBatch(dataSize,
batchIter);
Review comment:
the formatting seems off again.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/LocalParamServer.java
##########
@@ -33,24 +33,32 @@ public LocalParamServer() {
public static LocalParamServer create(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc, int numBatchesPerEpoch,
- MatrixObject valFeatures, MatrixObject valLabels)
+ MatrixObject valFeatures, MatrixObject valLabels,boolean
modelAvg)
Review comment:
all the additional parameters in these constructors are missing a space
before the new parameter
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -41,17 +35,43 @@
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.cp.DoubleObject;
-import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.ListObject;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.*;
Review comment:
avoid wild-card imports.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
return _model;
}
- protected synchronized void updateGlobalModel(int workerID, ListObject
gradients) {
+ protected synchronized void updModel_avgModel(int workerID, ListObject
params){
+ if (_modelAvg == true){
+ updateAverageModel(workerID, params);}
+ else if (_modelAvg == false)
+ updateGlobalModel(workerID, params);
+
+}
+ protected void updateAverageModel(int workerID, ListObject models) {
try {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Successfully pulled
the gradients [size:%d kb] of worker_%d.",
- gradients.getDataSize() / 1024,
workerID));
+ models.getDataSize() / 1024,
workerID));
}
switch(_updateType) {
case BSP: {
setFinishedState(workerID);
-
- // Accumulate the intermediate gradients
- if( ACCRUE_BSP_GRADIENTS )
- _accGradients =
ParamservUtils.accrueGradients(_accGradients, gradients, true);
- else
- updateGlobalModel(gradients);
+ _accModel =
ParamservUtils.accrueModels(_accModel, models, true);
if (allFinished()) {
- // Update the global model with
accrued gradients
- if( ACCRUE_BSP_GRADIENTS ) {
-
updateGlobalModel(_accGradients);
- _accGradients = null;
- }
+ averageGlobalModel(_accModel);
+ _accModel = null;
// This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
// In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
// BSP epoch case every time
if (_numBatchesPerEpoch != -1 &&
- (_freq ==
Statement.PSFrequency.EPOCH ||
- (_freq ==
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+ (_freq ==
Statement.PSFrequency.EPOCH ||
+
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch
== 0))) {
if(LOG.isInfoEnabled())
LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
time_epoch();
-
if(_validationPossible)
validate();
-
_epochCounter++;
_syncCounter = 0;
+
}
-
// Broadcast the updated model
resetFinishedStates();
+
broadcastModel(true);
if (LOG.isDebugEnabled())
- LOG.debug("Global
parameter is broadcasted successfully.");
+ LOG.debug("Global
Averaging parameter is broadcasted successfully ");
}
break;
}
- case ASP: {
Review comment:
you can't just delete the support for ASP here!!!
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
##########
@@ -479,4 +479,51 @@ public static ListObject accrueGradients(ListObject
accGradients, ListObject gra
ParamservUtils.cleanupListObject(gradients);
return accGradients;
}
+
+
+ /**
+ * Accumulate the given models into the accrued accrueModels
+ *
+ * @param accModels accrued models list object
+ * @param models given models list object
+ * @param cleanup clean up the given models list object
+ * @return new accrued models list object
+ */
+ public static ListObject accrueModels(ListObject accModels, ListObject
models, boolean cleanup) {
+ return accrueModels(accModels, models, false, cleanup);
+ }
+
+ /**
+ * Accumulate the given models into the accrued models
+ *
+ * @param accModels accrued models list object
+ * @param models given models list object
+ * @param par parallel execution
+ * @param cleanup clean up the given models list object
+ * @return new accrued models list object
+ */
+ public static ListObject accrueModels(ListObject accModels, ListObject
models, boolean par, boolean cleanup) {
+ if (accModels == null)
+ return ParamservUtils.copyList(models, cleanup);
+ IntStream range = IntStream.range(0, accModels.getLength());
+ (par ? range.parallel() : range).forEach(i -> {
+ MatrixBlock mb1 = ((MatrixObject)
accModels.getData().get(i)).acquireReadAndRelease();
+ MatrixBlock mb2 = ((MatrixObject)
models.getData().get(i)).acquireReadAndRelease();
+ mb1.binaryOperationsInPlace(new
BinaryOperator(Plus.getPlusFnObject()), mb2);
+ });
+ if (cleanup)
+ ParamservUtils.cleanupListObject(models);
+ return accModels;
+ }
+
+
+ //******************************************* ATEFEH
********************************************************
Review comment:
delete
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
return _model;
}
- protected synchronized void updateGlobalModel(int workerID, ListObject
gradients) {
+ protected synchronized void updModel_avgModel(int workerID, ListObject
params){
+ if (_modelAvg == true){
+ updateAverageModel(workerID, params);}
+ else if (_modelAvg == false)
+ updateGlobalModel(workerID, params);
+
+}
+ protected void updateAverageModel(int workerID, ListObject models) {
try {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Successfully pulled
the gradients [size:%d kb] of worker_%d.",
- gradients.getDataSize() / 1024,
workerID));
+ models.getDataSize() / 1024,
workerID));
}
switch(_updateType) {
case BSP: {
setFinishedState(workerID);
-
- // Accumulate the intermediate gradients
- if( ACCRUE_BSP_GRADIENTS )
- _accGradients =
ParamservUtils.accrueGradients(_accGradients, gradients, true);
- else
- updateGlobalModel(gradients);
+ _accModel =
ParamservUtils.accrueModels(_accModel, models, true);
if (allFinished()) {
- // Update the global model with
accrued gradients
- if( ACCRUE_BSP_GRADIENTS ) {
-
updateGlobalModel(_accGradients);
- _accGradients = null;
- }
+ averageGlobalModel(_accModel);
+ _accModel = null;
// This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
// In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
// BSP epoch case every time
if (_numBatchesPerEpoch != -1 &&
- (_freq ==
Statement.PSFrequency.EPOCH ||
- (_freq ==
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+ (_freq ==
Statement.PSFrequency.EPOCH ||
+
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch
== 0))) {
if(LOG.isInfoEnabled())
LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
time_epoch();
-
if(_validationPossible)
validate();
-
_epochCounter++;
_syncCounter = 0;
+
}
-
// Broadcast the updated model
resetFinishedStates();
+
broadcastModel(true);
if (LOG.isDebugEnabled())
- LOG.debug("Global
parameter is broadcasted successfully.");
+ LOG.debug("Global
Averaging parameter is broadcasted successfully ");
}
break;
}
- case ASP: {
- updateGlobalModel(gradients);
- // This works similarly to the one for
BSP, but divides the sync counter by
- // the number of workers, creating
"Pseudo Epochs"
- if (_numBatchesPerEpoch != -1 &&
- ((_freq ==
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
- (_freq ==
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float)
_numBatchesPerEpoch == 0))) {
-
- if(LOG.isInfoEnabled())
- LOG.info("[+]
PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
-
- time_epoch();
-
- if(_validationPossible)
- validate();
-
- _epochCounter++;
- _syncCounter = 0;
- }
-
- broadcastModel(workerID);
- break;
- }
+ case ASP:
+ throw new
DMLRuntimeException("Unsupported update: " + _updateType.name()+"in the case of
averaging model");
default:
throw new
DMLRuntimeException("Unsupported update: " + _updateType.name());
}
- }
+ }
catch (Exception e) {
throw new DMLRuntimeException("Aggregation or
validation service failed: ", e);
}
}
+ private void averageGlobalModel(ListObject accModel) {
+ Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+ _model = averageModel(_ec,accModel, _model);
+
+ if (DMLScript.STATISTICS && tAgg != null)
+ Statistics.accPSAggregationTime((long) tAgg.stop());
+ }
+
/*********************************************************************************************************************
+ * A service method for averaging model with models
+ *
+ * @param ec execution context
+ * @param accModels list of models
+ * @param model old model
+ * @return new model (accModel)
+ */
+
+ public static ListObject averageModel(ExecutionContext ec, ListObject
accModels,ListObject model) {
Review comment:
the formatting of the entire method is completely off; also we still
need the model update via gradients (please make sure both gradient update and
model averaging are still working).
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
##########
@@ -118,28 +118,31 @@ public Instruction preprocessInstruction(ExecutionContext
ec) {
}
@Override
- public void processInstruction(ExecutionContext ec) {
+ public void
Review comment:
???
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -327,12 +317,12 @@ private void runLocally(ExecutionContext ec, PSModeType
mode) {
MatrixObject val_features = (getParam(PS_VAL_FEATURES) != null)
? ec.getMatrixObject(getParam(PS_VAL_FEATURES)) : null;
MatrixObject val_labels = (getParam(PS_VAL_LABELS) != null) ?
ec.getMatrixObject(getParam(PS_VAL_LABELS)) : null;
ParamServer ps = createPS(mode, aggFunc, updateType, freq,
workerNum, model, aggServiceEC, getValFunction(),
- num_batches_per_epoch, val_features,
val_labels);
+ num_batches_per_epoch, val_features,
val_labels,parseBoolean(modelAvg));
Review comment:
missing spaces before additional arg
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -74,23 +94,31 @@
private int _syncCounter = 0;
private int _epochCounter = 0 ;
private int _numBatchesPerEpoch;
+ private boolean _modelAvg;
private int _numWorkers;
+ private ListObject _accModel = null;
+ private Object sum;
+ private BinaryOperator _op2;
+ private MatrixObject AvgModel;
+
protected ParamServer() {}
protected ParamServer(ListObject model, String aggFunc,
Statement.PSUpdateType updateType,
- Statement.PSFrequency freq, ExecutionContext ec, int workerNum,
String valFunc,
- int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels)
+ Statement.PSFrequency freq,
ExecutionContext ec, int workerNum, String valFunc,
Review comment:
fix the corrupted formatting.
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -127,12 +155,12 @@ protected void setupAggFunc(ExecutionContext ec, String
aggFunc) {
_outputName = outputs.get(0).getName();
CPOperand[] boundInputs = inputs.stream()
- .map(input -> new CPOperand(input.getName(),
input.getValueType(), input.getDataType()))
- .toArray(CPOperand[]::new);
+ .map(input -> new CPOperand(input.getName(),
input.getValueType(), input.getDataType()))
Review comment:
some as above.
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -87,6 +67,9 @@
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;
+import static java.lang.Boolean.parseBoolean;
+import static org.apache.sysds.parser.Statement.*;
Review comment:
again, no wild-card imports.
##########
File path:
src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
##########
@@ -19,23 +19,25 @@
package org.apache.sysds.test.functions.federated.paramserv;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.List;
-
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.codegen.SpoofFusedOp;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
+import org.dmg.pmml.True;
Review comment:
Why are you introducing this dependency?
##########
File path: src/test/scripts/functions/federated/paramserv/CNN.dml
##########
@@ -67,7 +67,7 @@ source("scripts/nn/optim/sgd_nesterov.dml") as sgd_nesterov
*/
train = function(matrix[double] X, matrix[double] y, matrix[double] X_val,
matrix[double] y_val, int epochs, int batch_size, double eta, int C, int Hin,
- int Win, int seed = -1) return (list[unknown] model)
+ int Win, int seed = -1,boolean modelAvg) return (list[unknown] model)
Review comment:
formatting
##########
File path:
src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
##########
@@ -197,12 +200,15 @@ private void federatedParamserv(ExecMode mode) {
"channels=" + C,
"hin=" + Hin,
"win=" + Win,
- "seed=" + _seed));
+ "seed=" + _seed,
+ "modelAvg="+ modelAvg));
programArgs = programArgsList.toArray(new String[0]);
- LOG.debug(runTest(null));
- Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
-
+
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+ // Assert.assertEquals(0,
Statistics.getNoOfExecutedSPInst());
Review comment:
please, do not disable assertions of existing tests.
##########
File path: src/test/scripts/functions/federated/paramserv/CNN.dml
##########
@@ -161,9 +161,11 @@ train = function(matrix[double] X, matrix[double] y,
matrix[double] X_val,
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
string utype, string freq, int batch_size, string scheme, string
runtime_balancing,
- string weighting, double eta, int C, int Hin, int Win, int seed = -1)
+ string weighting, double eta, int C, int Hin, int Win, int seed = -1,boolean
modelAvg)
Review comment:
see above.
##########
File path: src/test/scripts/functions/federated/paramserv/CNN.dml
##########
@@ -161,9 +161,11 @@ train = function(matrix[double] X, matrix[double] y,
matrix[double] X_val,
train_paramserv = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val, int num_workers, int epochs,
string utype, string freq, int batch_size, string scheme, string
runtime_balancing,
- string weighting, double eta, int C, int Hin, int Win, int seed = -1)
+ string weighting, double eta, int C, int Hin, int Win, int seed = -1,boolean
modelAvg)
return (list[unknown] model)
{
+
Review comment:
see above.
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -588,4 +578,8 @@ private String getValFunction() {
private int getSeed() {
return (getParameterMap().containsKey(PS_SEED)) ?
Integer.parseInt(getParam(PS_SEED)) : (int) System.currentTimeMillis();
}
+ private boolean getModelAvg() {
+ return getParameterMap().containsKey(PS_MODELAVG) &&
parseBoolean(getParam(PS_MODELAVG));
+ }
+
Review comment:
no such free lines.
##########
File path:
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -468,21 +458,21 @@ private int getWorkerNum(PSModeType mode) {
* @return parameter server
*/
private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
- PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec)
+ PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec,boolean modelAvg)
{
- return createPS(mode, aggFunc, updateType, freq, workerNum,
model, ec, null, -1, null, null);
+ return createPS(mode, aggFunc, updateType, freq, workerNum,
model, ec, null, -1, null, null,modelAvg );
}
// When this creation is used the parameter server is able to validate
after each epoch
private static ParamServer createPS(PSModeType mode, String aggFunc,
PSUpdateType updateType,
PSFrequency freq, int workerNum, ListObject model,
ExecutionContext ec, String valFunc,
- int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels)
+ int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject
valLabels,boolean modelAvg)
Review comment:
same as above.
##########
File path: src/test/scripts/functions/federated/paramserv/TwoNN.dml
##########
@@ -150,13 +150,15 @@ train_paramserv = function(matrix[double] X,
matrix[double] y,
model = list(W1, W2, W3, b1, b2, b3)
# Create the hyper parameter list
hyperparams = list(learning_rate=eta)
+
+while (FALSE) {}
# Use paramserv function
model = paramserv(model=model, features=X, labels=y, val_features=X_val,
val_labels=y_val,
upd="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::gradients",
agg="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::aggregation",
val="./src/test/scripts/functions/federated/paramserv/TwoNN.dml::validate",
k=num_workers, utype=utype, freq=freq, epochs=epochs, batchsize=batch_size,
- scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting,
hyperparams=hyperparams, seed=seed)
+ scheme=scheme, runtime_balancing=runtime_balancing, weighting=weighting,
hyperparams=hyperparams, seed=seed,modelAvg=modelAvg)
Review comment:
formatting.
##########
File path: src/test/scripts/functions/federated/paramserv/TwoNN.dml
##########
@@ -58,7 +58,7 @@ source("nn/optim/sgd.dml") as sgd
train = function(matrix[double] X, matrix[double] y,
matrix[double] X_val, matrix[double] y_val,
int epochs, int batch_size, double eta,
- int seed = -1)
+ int seed = -1 , boolean modelAvg )
Review comment:
formatting
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]