atefeh-asayesh commented on a change in pull request #1336:
URL: https://github.com/apache/systemds/pull/1336#discussion_r676041282



##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
##########
@@ -179,88 +207,118 @@ public ListObject getResult() {
                return _model;
        }
 
-       protected synchronized void updateGlobalModel(int workerID, ListObject 
gradients) {
+       protected synchronized void updModel_avgModel(int workerID, ListObject 
params){
+               if (_modelAvg == true){
+                       updateAverageModel(workerID, params);}
+               else if (_modelAvg == false)
+                       updateGlobalModel(workerID, params);
+
+}
+       protected  void updateAverageModel(int workerID, ListObject models) {
                try {
                        if (LOG.isDebugEnabled()) {
                                LOG.debug(String.format("Successfully pulled 
the gradients [size:%d kb] of worker_%d.",
-                                       gradients.getDataSize() / 1024, 
workerID));
+                                               models.getDataSize() / 1024, 
workerID));
                        }
 
                        switch(_updateType) {
                                case BSP: {
                                        setFinishedState(workerID);
-
-                                       // Accumulate the intermediate gradients
-                                       if( ACCRUE_BSP_GRADIENTS )
-                                               _accGradients = 
ParamservUtils.accrueGradients(_accGradients, gradients, true);
-                                       else
-                                               updateGlobalModel(gradients);
+                                       _accModel = 
ParamservUtils.accrueModels(_accModel, models, true);
 
                                        if (allFinished()) {
-                                               // Update the global model with 
accrued gradients
-                                               if( ACCRUE_BSP_GRADIENTS ) {
-                                                       
updateGlobalModel(_accGradients);
-                                                       _accGradients = null;
-                                               }
+                                               averageGlobalModel(_accModel);
+                                               _accModel = null;
 
                                                // This if has grown to be 
quite complex its function is rather simple. Validate at the end of each epoch
                                                // In the BSP batch case that 
occurs after the sync counter reaches the number of batches and in the
                                                // BSP epoch case every time
                                                if (_numBatchesPerEpoch != -1 &&
-                                                       (_freq == 
Statement.PSFrequency.EPOCH ||
-                                                       (_freq == 
Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch == 0))) {
+                                                               (_freq == 
Statement.PSFrequency.EPOCH ||
+                                                                               
(_freq == Statement.PSFrequency.BATCH && ++_syncCounter % _numBatchesPerEpoch 
== 0))) {
 
                                                        if(LOG.isInfoEnabled())
                                                                LOG.info("[+] 
PARAMSERV: completed EPOCH " + _epochCounter);
 
                                                        time_epoch();
-
                                                        if(_validationPossible)
                                                                validate();
-
                                                        _epochCounter++;
                                                        _syncCounter = 0;
+
                                                }
-                                               
                                                // Broadcast the updated model
                                                resetFinishedStates();
+
                                                broadcastModel(true);
                                                if (LOG.isDebugEnabled())
-                                                       LOG.debug("Global 
parameter is broadcasted successfully.");
+                                                       LOG.debug("Global 
Averaging parameter is broadcasted successfully ");
                                        }
                                        break;
                                }
-                               case ASP: {
-                                       updateGlobalModel(gradients);
-                                       // This works similarly to the one for 
BSP, but divides the sync counter by
-                                       // the number of workers, creating 
"Pseudo Epochs"
-                                       if (_numBatchesPerEpoch != -1 &&
-                                               ((_freq == 
Statement.PSFrequency.EPOCH && ((float) ++_syncCounter % _numWorkers) == 0) ||
-                                               (_freq == 
Statement.PSFrequency.BATCH && ((float) ++_syncCounter / _numWorkers) % (float) 
_numBatchesPerEpoch == 0))) {
-
-                                               if(LOG.isInfoEnabled())
-                                                       LOG.info("[+] 
PARAMSERV: completed PSEUDO EPOCH (ASP) " + _epochCounter);
-
-                                               time_epoch();
-
-                                               if(_validationPossible)
-                                                       validate();
-
-                                               _epochCounter++;
-                                               _syncCounter = 0;
-                                       }
-
-                                       broadcastModel(workerID);
-                                       break;
-                               }
+                               case ASP:
+                                       throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name()+"in the case of 
averaging model");
                                default:
                                        throw new 
DMLRuntimeException("Unsupported update: " + _updateType.name());
                        }
-               } 
+               }
                catch (Exception e) {
                        throw new DMLRuntimeException("Aggregation or 
validation service failed: ", e);
                }
        }
+       private void averageGlobalModel(ListObject accModel) {
+               Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+               _model = averageModel(_ec,accModel, _model);
+
+               if (DMLScript.STATISTICS && tAgg != null)
+                       Statistics.accPSAggregationTime((long) tAgg.stop());
+       }
+       
/*********************************************************************************************************************
+        * A service method for averaging model with models
+        *
+        * @param ec execution context
+        * @param accModels list of models
+        * @param model old model
+        * @return new model (accModel)
+        */
+
+       public static  ListObject averageModel(ExecutionContext ec, ListObject 
accModels,ListObject model) {

Review comment:
       Thanks for the comment. I explained it in the previous comment.

##########
File path: 
src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
##########
@@ -479,4 +479,51 @@ public static ListObject accrueGradients(ListObject 
accGradients, ListObject gra
                        ParamservUtils.cleanupListObject(gradients);
                return accGradients;
        }
+
+
+       /**
+                * Accumulate the given models into the accrued accrueModels
+        *
+        * @param accModels accrued models list object
+        * @param models given models list object
+        * @param cleanup clean up the given models list object
+        * @return new accrued models list object
+        */
+       public static ListObject accrueModels(ListObject accModels, ListObject 
models, boolean cleanup) {
+               return accrueModels(accModels, models, false, cleanup);
+       }
+
+       /**
+        * Accumulate the given models into the accrued models
+        *
+        * @param accModels accrued models list object
+        * @param models given models list object
+        * @param par parallel execution
+        * @param cleanup clean up the given models list object
+        * @return new accrued models list object
+        */
+       public static ListObject accrueModels(ListObject accModels, ListObject 
models, boolean par, boolean cleanup) {
+               if (accModels == null)
+                       return ParamservUtils.copyList(models, cleanup);
+               IntStream range = IntStream.range(0, accModels.getLength());
+               (par ? range.parallel() : range).forEach(i -> {
+                       MatrixBlock mb1 = ((MatrixObject) 
accModels.getData().get(i)).acquireReadAndRelease();
+                       MatrixBlock mb2 = ((MatrixObject) 
models.getData().get(i)).acquireReadAndRelease();
+                       mb1.binaryOperationsInPlace(new 
BinaryOperator(Plus.getPlusFnObject()), mb2);
+               });
+               if (cleanup)
+                       ParamservUtils.cleanupListObject(models);
+               return accModels;
+       }
+
+
+       //*******************************************   ATEFEH 
********************************************************

Review comment:
       missing to remove it. well done and removed.

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java
##########
@@ -118,28 +118,31 @@ public Instruction preprocessInstruction(ExecutionContext 
ec) {
        }
 
        @Override
-       public void processInstruction(ExecutionContext ec) {
+       public void

Review comment:
       I made a mistake. Fixed it.

##########
File path: 
src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -87,6 +67,9 @@
 import org.apache.sysds.runtime.util.ProgramConverter;
 import org.apache.sysds.utils.Statistics;
 
+import static java.lang.Boolean.parseBoolean;
+import static org.apache.sysds.parser.Statement.*;

Review comment:
       Fixed it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to