This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b65212  [SYSTEMDS-3088] Add cleanup for the prefetch threads
4b65212 is described below

commit 4b65212fa7cb651d28c017f9720a3d05b460738c
Author: arnabp <[email protected]>
AuthorDate: Wed Aug 11 09:46:31 2021 +0200

    [SYSTEMDS-3088] Add cleanup for the prefetch threads
    
    This patch adds the missing shutdown of the threads created
    for asynchronous triggering of spark operations. Moreover,
    now we use a CachedThreadPool to manage the varying number
    of prefetch instructions efficently.
---
 src/main/java/org/apache/sysds/api/DMLScript.java              |  4 ++++
 src/main/java/org/apache/sysds/lops/compile/Dag.java           |  1 +
 .../sysds/runtime/instructions/cp/PrefetchCPInstruction.java   |  6 +++---
 .../sysds/runtime/instructions/spark/utils/SparkUtils.java     | 10 +++++++++-
 4 files changed, 17 insertions(+), 4 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index e2e67a5..09c0a6e 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -67,6 +67,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedWorker;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
@@ -514,6 +515,9 @@ public class DMLScript
                
                //0) cleanup federated workers if necessary
                FederatedData.clearFederatedWorkers();
+               
+               //0) shutdown prefetch/broadcast thread pool if necessary
+               SparkUtils.shutdownPool();
 
                //1) cleanup scratch space (everything for current uuid)
                //(required otherwise export to hdfs would skip assumed 
unnecessary writes if same name)
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java 
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 4cccb1e..823c14d 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -233,6 +233,7 @@ public class Dag<N extends Lop>
                for (Lop l : nodes) {
                        nodesWithPrefetch.add(l);
                        if (isPrefetchNeeded(l)) {
+                               //TODO: No prefetch if the parent is placed 
right after the spark OP
                                List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
                                //Construct a Prefetch lop that takes this 
Spark node as a input
                                UnaryCP prefetch = new UnaryCP(l, 
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 2e2d091..f19f151 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -49,8 +49,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
                // If the next instruction which takes this output as an input 
comes before
                // the prefetch thread triggers, that instruction will start 
the operations.
                // In that case this Prefetch instruction will act like a NOOP. 
-               if (SparkUtils.triggerRDDThread == null)
-                       SparkUtils.triggerRDDThread = 
Executors.newSingleThreadExecutor();
-               SparkUtils.triggerRDDThread.submit(new 
TriggerRDDOperationsTask(ec.getMatrixObject(output)));
+               if (SparkUtils.triggerRDDPool == null)
+                       SparkUtils.triggerRDDPool = 
Executors.newCachedThreadPool();
+               SparkUtils.triggerRDDPool.submit(new 
TriggerRDDOperationsTask(ec.getMatrixObject(output)));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
index 2c15b91..479cbd7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
@@ -63,7 +63,7 @@ import java.util.stream.LongStream;
 
 public class SparkUtils 
 {      
-       public static ExecutorService triggerRDDThread = null;
+       public static ExecutorService triggerRDDPool = null;
 
        //internal configuration
        public static final StorageLevel DEFAULT_TMP = 
Checkpoint.DEFAULT_STORAGE_LEVEL;
@@ -296,6 +296,14 @@ public class SparkUtils
                        mo.acquireReadAndRelease();
        }
        
+       public static void shutdownPool() {
+               if (triggerRDDPool != null) {
+                       //shutdown prefetch/broadcast thread pool
+                       triggerRDDPool.shutdown();
+                       triggerRDDPool = null;
+               }
+       }
+       
        private static class CheckSparsityFunction implements 
VoidFunction<Tuple2<MatrixIndexes,MatrixBlock>>
        {
                private static final long serialVersionUID = 
4150132775681848807L;

Reply via email to