This is an automated email from the ASF dual-hosted git repository.

kinnerebner pushed a commit to branch paramserv_spark_fix
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit f586eaa8b95aefc7c67eea379b69405463632447
Author: Kevin Innerebner <[email protected]>
AuthorDate: Mon Jul 11 22:47:01 2022 +0200

    [MINOR] Fix Spark ParameterServer
    
    This patch fixes the Spark execution mode for the parameter server. In 
commit 28ff18fca2a9258168db7397d56236a5e0d9564b the handling of functions was 
changed, leading to the parameter server in Spark mode, not finding or sending 
the functions to the workers properly.
    
    Closes #1662
---
 .../runtime/controlprogram/paramserv/ParamServer.java  | 18 ++++++++++--------
 .../controlprogram/paramserv/ParamservUtils.java       |  3 +++
 .../controlprogram/paramserv/SparkPSWorker.java        |  3 +++
 .../instructions/cp/ParamservBuiltinCPInstruction.java |  4 ++--
 .../test/functions/paramserv/ParamservSparkNNTest.java |  5 ++++-
 5 files changed, 22 insertions(+), 11 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
index 3957965988..e88a19d964 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamServer.java
@@ -78,7 +78,8 @@ public abstract class ParamServer
 
        private int _numWorkers;
        private int _numBackupWorkers;
-       private boolean[] _discardWorkerRes;
+       // number of updates the respective worker is straggling behind
+       private int[] _numUpdatesStraggling;
        private boolean _modelAvg;
        private ListObject _accModels = null;
 
@@ -109,7 +110,7 @@ public abstract class ParamServer
                _numBatchesPerEpoch = numBatchesPerEpoch;
                _numWorkers = workerNum;
                _numBackupWorkers = numBackupWorkers;
-               _discardWorkerRes = new boolean[workerNum];
+               _numUpdatesStraggling = new int[workerNum];
                _modelAvg = modelAvg;
 
                // broadcast initial model
@@ -118,6 +119,8 @@ public abstract class ParamServer
 
        protected void setupAggFunc(ExecutionContext ec, String aggFunc) {
                String[] cfn = DMLProgram.splitFunctionKey(aggFunc);
+               if(cfn.length == 1)
+                       cfn = new String[] {null, cfn[0]};
                String ns = cfn[0];
                String fname = cfn[1];
                boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, 
fname, false);
@@ -240,10 +243,10 @@ public abstract class ParamServer
                                        break;
                                }
                                case SBP: {
-                                       if(_discardWorkerRes[workerID]) {
+                                       if(_numUpdatesStraggling[workerID] > 0) 
{
                                                LOG.info("[+] PRAMSERV: 
discarding result of backup-worker/straggler " + workerID);
                                                broadcastModel(workerID);
-                                               _discardWorkerRes[workerID] = 
false;
+                                               
_numUpdatesStraggling[workerID]--;
                                                break;
                                        }
                                        setFinishedState(workerID);
@@ -255,7 +258,6 @@ public abstract class ParamServer
                                                updateGlobalModel(gradients);
 
                                        if(enoughFinished()) {
-                                               // set flags to throwaway 
backup worker results
                                                tagStragglers();
                                                performGlobalGradientUpdate();
                                        }
@@ -300,7 +302,7 @@ public abstract class ParamServer
        private void tagStragglers() {
                for(int i = 0; i < _finishedStates.length; ++i) {
                        if(!_finishedStates[i])
-                               _discardWorkerRes[i] = true;
+                               _numUpdatesStraggling[i]++;
                }
        }
 
@@ -371,10 +373,10 @@ public abstract class ParamServer
                                case SBP: {
                                        // first weight the models based on 
number of workers
                                        ListObject weightParams = 
weightModels(model, _numWorkers - _numBackupWorkers);
-                                       if(_discardWorkerRes[workerID]) {
+                                       if(_numUpdatesStraggling[workerID] > 0) 
{
                                                LOG.info("[+] PRAMSERV: 
discarding result of backup-worker/straggler " + workerID);
                                                broadcastModel(workerID);
-                                               _discardWorkerRes[workerID] = 
false;
+                                               
_numUpdatesStraggling[workerID]--;
                                                break;
                                        }
                                        setFinishedState(workerID);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index cfc3a200a5..2a6877d89e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -268,7 +268,10 @@ public class ParamservUtils {
                        String[] parts = 
DMLProgram.splitFunctionKey(e.getKey());
                        FunctionProgramBlock fpb = ProgramConverter
                                
.createDeepCopyFunctionProgramBlock(e.getValue(), new HashSet<>(), new 
HashSet<>());
+                       fpb._namespace = parts[0];
+                       fpb._functionName = parts[1];
                        newProg.addFunctionProgramBlock(parts[0], parts[1], 
fpb, opt);
+                       newProg.addProgramBlock(fpb);
                }
                return newProg;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
index 9e96b45a5b..7823d8811c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.java
@@ -76,6 +76,9 @@ public class SparkPSWorker extends LocalPSWorker implements 
VoidFunction<Tuple2<
                _nEpochs = aEpochs;
                _nbatches = nbatches;
                _modelAvg = modelAvg;
+               
+               // make SparkPSWorker serializable
+               _tpool = null;
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 1fa83b2a8d..ef45a9c2b3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -661,10 +661,10 @@ public class ParamservBuiltinCPInstruction extends 
ParameterizedBuiltinCPInstruc
 
        private int getNumBackupWorkers() {
                if(!getParameterMap().containsKey(PS_NUM_BACKUP_WORKERS)) {
-                       if (!getUpdateType().isSBP())
-                               LOG.warn("Specifying number of backup-workers 
without SBP mode has no effect");
                        return DEFAULT_NUM_BACKUP_WORKERS;
                }
+               if (!getUpdateType().isSBP())
+                       LOG.warn("Specifying number of backup-workers without 
SBP mode has no effect");
                return Integer.parseInt(getParam(PS_NUM_BACKUP_WORKERS));
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
index c7f0e39dff..630c3c1ebd 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/paramserv/ParamservSparkNNTest.java
@@ -29,7 +29,6 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 
 @net.jcip.annotations.NotThreadSafe
-@Ignore
 public class ParamservSparkNNTest extends AutomatedTestBase {
 
        private static final String TEST_NAME1 = "paramserv-test";
@@ -77,12 +76,16 @@ public class ParamservSparkNNTest extends AutomatedTestBase 
{
        }
 
        @Test
+       @Ignore
        public void testParamservWorkerFailed() {
+               // FIXME: `aggregation` function can't be found (optimized 
away?)
                runDMLTest(TEST_NAME2, true, DMLRuntimeException.class, 
"Invalid indexing by name in unnamed list: worker_err.");
        }
 
        @Test
+       @Ignore
        public void testParamservAggServiceFailed() {
+               // FIXME: `aggregation` function can't be found (optimized 
away?)
                runDMLTest(TEST_NAME3, true, DMLRuntimeException.class, 
"Invalid indexing by name in unnamed list: agg_service_err.");
        }
 

Reply via email to