This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 5b96745 [SYSTEMDS-2550] Minor fixes (typo, tests) federated parameter
server
5b96745 is described below
commit 5b96745aa42136e9f3ac4b62fe40dcad8c8ab9e2
Author: Tobias Rieger <[email protected]>
AuthorDate: Thu Jul 29 22:22:47 2021 +0200
[SYSTEMDS-2550] Minor fixes (typo, tests) federated parameter server
Closes #1352.
---
.../runtime/controlprogram/paramserv/FederatedPSControlThread.java | 6 +++---
.../test/functions/federated/paramserv/FederatedParamservTest.java | 1 -
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 536b529..c77ddf4 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -321,7 +321,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
return _ps.pull(_workerID);
}
- protected void weighAndPushGradients(ListObject gradients) {
+ protected void weightAndPushGradients(ListObject gradients) {
// scale gradients - must only include MatrixObjects
if(_weighting && _weightingFactor != 1) {
Timing tWeighting = DMLScript.STATISTICS ? new
Timing(true) : null;
@@ -354,7 +354,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
int localStartBatchNum =
getNextLocalBatchNum(currentLocalBatchNumber++, _possibleBatchesPerLocalEpoch);
ListObject model = pullModel();
ListObject gradients =
computeGradientsForNBatches(model, 1, localStartBatchNum);
- weighAndPushGradients(gradients);
+ weightAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
@@ -378,7 +378,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// Pull the global parameters from ps
ListObject model = pullModel();
ListObject gradients =
computeGradientsForNBatches(model, _numBatchesPerEpoch, localStartBatchNum,
true);
- weighAndPushGradients(gradients);
+ weightAndPushGradients(gradients);
ParamservUtils.cleanupListObject(model);
ParamservUtils.cleanupListObject(gradients);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index 9221a53..c316214 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -64,7 +64,6 @@ public class FederatedParamservTest extends AutomatedTestBase
{
return Arrays.asList(new Object[][] {
// Network type, number of federated workers, data set
size, batch size, epochs, learning rate, update type, update frequency
// basic functionality
- //{"TwoNN", 4, 60000, 32, 4, 0.01, "BSP", "BATCH",
"KEEP_DATA_ON_WORKER", "NONE" , "false","BALANCED",
200},
{"TwoNN", 2, 4, 1, 4, 0.01, "BSP",
"BATCH", "KEEP_DATA_ON_WORKER", "BASELINE", "true", "IMBALANCED",
200},
{"CNN", 2, 4, 1, 4, 0.01, "BSP",
"EPOCH", "SHUFFLE", "NONE",
"true", "IMBALANCED", 200},