This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 5b96745  [SYSTEMDS-2550] Minor fixes (typo, tests) federated parameter 
server
5b96745 is described below

commit 5b96745aa42136e9f3ac4b62fe40dcad8c8ab9e2
Author: Tobias Rieger <[email protected]>
AuthorDate: Thu Jul 29 22:22:47 2021 +0200

    [SYSTEMDS-2550] Minor fixes (typo, tests) federated parameter server
    
    Closes #1352.
---
 .../runtime/controlprogram/paramserv/FederatedPSControlThread.java  | 6 +++---
 .../test/functions/federated/paramserv/FederatedParamservTest.java  | 1 -
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 536b529..c77ddf4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -321,7 +321,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                return _ps.pull(_workerID);
        }
 
-       protected void weighAndPushGradients(ListObject gradients) {
+       protected void weightAndPushGradients(ListObject gradients) {
                // scale gradients - must only include MatrixObjects
                if(_weighting && _weightingFactor != 1) {
                        Timing tWeighting = DMLScript.STATISTICS ? new 
Timing(true) : null;
@@ -354,7 +354,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                                int localStartBatchNum = 
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
                                ListObject model = pullModel();
                                ListObject gradients = 
computeGradientsForNBatches(model, 1, localStartBatchNum);
-                               weighAndPushGradients(gradients);
+                               weightAndPushGradients(gradients);
                                ParamservUtils.cleanupListObject(model);
                                ParamservUtils.cleanupListObject(gradients);
                        }
@@ -378,7 +378,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        // Pull the global parameters from ps
                        ListObject model = pullModel();
                        ListObject gradients = 
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum, 
true);
-                       weighAndPushGradients(gradients);
+                       weightAndPushGradients(gradients);
                        ParamservUtils.cleanupListObject(model);
                        ParamservUtils.cleanupListObject(gradients);
                }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 9221a53..c316214 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -64,7 +64,6 @@ public class FederatedParamservTest extends AutomatedTestBase 
{
                return Arrays.asList(new Object[][] {
                        // Network type, number of federated workers, data set 
size, batch size, epochs, learning rate, update type, update frequency
                        // basic functionality
-                       //{"TwoNN",     4, 60000, 32, 4, 0.01,  "BSP", "BATCH", 
"KEEP_DATA_ON_WORKER",  "NONE" ,                "false","BALANCED",             
200},
 
                        {"TwoNN",       2, 4, 1, 4, 0.01,               "BSP", 
"BATCH", "KEEP_DATA_ON_WORKER",  "BASELINE",             "true", "IMBALANCED",  
 200},
                        {"CNN",         2, 4, 1, 4, 0.01,               "BSP", 
"EPOCH", "SHUFFLE",                              "NONE",                 
"true", "IMBALANCED",   200},

Reply via email to