This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new e112345  [SYSTEMDS-3018] Fix federated paramserv setup of model update 
functions
e112345 is described below

commit e112345edaebced1f419c2e4cd7abad08dba6599
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Sep 13 13:42:05 2021 +0200

    [SYSTEMDS-3018] Fix federated paramserv setup of model update functions
    
    This patch fixes inconsistencies in federated paramserv with model
    averaging.
---
 .../controlprogram/paramserv/FederatedPSControlThread.java        | 2 +-
 .../federated/paramserv/AvgModelFederatedParamservTest.java       | 8 ++------
 2 files changed, 3 insertions(+), 7 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index 85bb745..ea8f0e8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -503,7 +503,7 @@ public class FederatedPSControlThread extends PSWorker 
implements Callable<Void>
                        // recreate aggregation instruction and output if needed
                        Instruction aggregationInstruction = null;
                        DataIdentifier aggregationOutput = null;
-                       if(_localUpdate && _numBatchesToCompute > 1) {
+                       if(_localUpdate && _numBatchesToCompute > 1 | modelAvg) 
{
                                func = 
ec.getProgram().getFunctionProgramBlock(namespace, aggFunc, opt);
                                inputs = func.getInputParams();
                                outputs = func.getOutputParams();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
index d96097e..66482f3 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/AvgModelFederatedParamservTest.java
@@ -32,7 +32,6 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.apache.sysds.utils.Statistics;
 import org.junit.Assert;
-import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -118,15 +117,11 @@ public class AvgModelFederatedParamservTest extends 
AutomatedTestBase {
        }
 
        @Test
-       @Ignore
-       // TODO FIX ME
        public void AvgmodelfederatedParamservSingleNode() {
                AvgmodelfederatedParamserv(ExecMode.SINGLE_NODE, true);
        }
 
        @Test
-       @Ignore
-       // TODO FIX ME
        public void AvgmodelfederatedParamservHybrid() {
                AvgmodelfederatedParamserv(ExecMode.HYBRID, true);
        }
@@ -149,7 +144,8 @@ public class AvgModelFederatedParamservTest extends 
AutomatedTestBase {
                        List<Thread> threads = new ArrayList<>();
                        for(int i = 0; i < _numFederatedWorkers; i++) {
                                ports.add(getRandomAvailablePort());
-                               
threads.add(startLocalFedWorkerThread(ports.get(i), FED_WORKER_WAIT_S));
+                               
threads.add(startLocalFedWorkerThread(ports.get(i),
+                                       i==(_numFederatedWorkers-1) ? 
FED_WORKER_WAIT : FED_WORKER_WAIT_S));
                        }
 
                        // generate test data

Reply via email to