This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new e112345 [SYSTEMDS-3018] Fix federated paramserv setup of model update
functions
e112345 is described below
commit e112345edaebced1f419c2e4cd7abad08dba6599
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Sep 13 13:42:05 2021 +0200
[SYSTEMDS-3018] Fix federated paramserv setup of model update functions
This patch fixes inconsistencies in federated paramserv with model
averaging.
---
.../controlprogram/paramserv/FederatedPSControlThread.java | 2 +-
.../federated/paramserv/AvgModelFederatedParamservTest.java | 8 ++------
2 files changed, 3 insertions(+), 7 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 85bb745..ea8f0e8 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -503,7 +503,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// recreate aggregation instruction and output if needed
Instruction aggregationInstruction = null;
DataIdentifier aggregationOutput = null;
- if(_localUpdate && _numBatchesToCompute > 1) {
+ if(_localUpdate && _numBatchesToCompute > 1 | modelAvg)
{
func =
ec.getProgram().getFunctionProgramBlock(namespace, aggFunc, opt);
inputs = func.getInputParams();
outputs = func.getOutputParams();
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index d96097e..66482f3 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -32,7 +32,6 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;
import org.junit.Assert;
-import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -118,15 +117,11 @@ public class AvgModelFederatedParamservTest extends
AutomatedTestBase {
}
@Test
- @Ignore
- // TODO FIX ME
public void AvgmodelfederatedParamservSingleNode() {
AvgmodelfederatedParamserv(ExecMode.SINGLE_NODE, true);
}
@Test
- @Ignore
- // TODO FIX ME
public void AvgmodelfederatedParamservHybrid() {
AvgmodelfederatedParamserv(ExecMode.HYBRID, true);
}
@@ -149,7 +144,8 @@ public class AvgModelFederatedParamservTest extends
AutomatedTestBase {
List<Thread> threads = new ArrayList<>();
for(int i = 0; i < _numFederatedWorkers; i++) {
ports.add(getRandomAvailablePort());
-
threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+
threads.add(startLocalFedWorkerThread(ports.get(i),
+ i==(_numFederatedWorkers-1) ?
FED_WORKER_WAIT : FED_WORKER_WAIT_S));
}
// generate test data