atefeh-asayesh commented on a change in pull request #1336:
URL: https://github.com/apache/systemds/pull/1336#discussion_r676041206
##########
File path:
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
return _model;
}
- protected synchronized void updateGlobalModel(int workerID, ListObject
gradients) {
+ protected synchronized void updModel_avgModel(int workerID, ListObject
params){
+ if (_modelAvg == true){
+ updateAverageModel(workerID, params);}
+ else if (_modelAvg == false)
+ updateGlobalModel(workerID, params);
+
+}
+ protected void updateAverageModel(int workerID, ListObject models) {
try {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("Successfully pulled
the gradients [size:%d kb] of worker_%d.",
- gradients.getDataSize() / 1024,
workerID));
+ models.getDataSize() / 1024,
workerID));
}
switch(_updateType) {
case BSP: {
setFinishedState(workerID);
-
- // Accumulate the intermediate gradients
- if( ACCRUE_BSP_GRADIENTS )
- _accGradients =
ParamservUtils.accrueGradients(_accGradients, gradients, true);
- else
- updateGlobalModel(gradients);
+ _accModel =
ParamservUtils.accrueModels(_accModel, models, true);
if (allFinished()) {
- // Update the global model with
accrued gradients
- if( ACCRUE_BSP_GRADIENTS ) {
-
updateGlobalModel(_accGradients);
- _accGradients = null;
- }
+ averageGlobalModel(_accModel);
+ _accModel = null;
// This if has grown to be
quite complex its function is rather simple. Validate at the end of each epoch
// In the BSP batch case that
occurs after the sync counter reaches the number of batches and in the
// BSP epoch case every time
if (_numBatchesPerEpoch != -1 &&
- (_freq ==
Statement.PSFrequency.EPOCH ||
- (_freq ==
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+ (_freq ==
Statement.PSFrequency.EPOCH ||
+
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch
== 0))) {
if(LOG.isInfoEnabled())
LOG.info("[+]
PARAMSERV: completed EPOCH " + _epochCounter);
time_epoch();
-
if(_validationPossible)
validate();
-
_epochCounter++;
_syncCounter = 0;
+
}
-
// Broadcast the updated model
resetFinishedStates();
+
broadcastModel(true);
if (LOG.isDebugEnabled())
- LOG.debug("Global
parameter is broadcasted successfully.");
+ LOG.debug("Global
Averaging parameter is broadcasted successfully ");
}
break;
}
- case ASP: {
Review comment:
ASP has not been deleted. we have a condition parameter "modelAvg" for
model averaging and when it is true ,we call tha "updateAverageModel" function
that in this function ASP function is not implemented. but ASP is considered
in "updateGlobalModel" function as the previous one and both gradient update
and model averaging are still working.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]