This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch branch-2.2.0
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit fa7de58f28f8b54624176c133820d2112e69f636
Author: arnabp <[email protected]>
AuthorDate: Tue Nov 9 23:03:02 2021 +0100

    [SYSTEMDS-3212] Move Prefetch threadpool to CommonThreadPool
---
 src/main/java/org/apache/sysds/api/DMLScript.java             |  4 ++--
 .../sysds/runtime/instructions/cp/BroadcastCPInstruction.java |  8 ++++----
 .../sysds/runtime/instructions/cp/PrefetchCPInstruction.java  |  8 ++++----
 .../sysds/runtime/instructions/spark/utils/SparkUtils.java    | 11 -----------
 .../java/org/apache/sysds/runtime/util/CommonThreadPool.java  |  9 +++++++++
 .../apache/sysds/test/functions/async/PrefetchRDDTest.java    |  4 +++-
 6 files changed, 22 insertions(+), 22 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 4564f05..1b5b3df 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -67,13 +67,13 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedWorker;
 import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
 import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
-import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
 import org.apache.sysds.runtime.util.LocalFileUtils;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.utils.Explain;
 import org.apache.sysds.utils.NativeHelper;
@@ -519,7 +519,7 @@ public class DMLScript
                FederatedData.clearFederatedWorkers();
                
                //0) shutdown prefetch/broadcast thread pool if necessary
-               SparkUtils.shutdownPool();
+               CommonThreadPool.shutdownAsyncRDDPool();
 
                //1) cleanup scratch space (everything for current uuid)
                //(required otherwise export to hdfs would skip assumed 
unnecessary writes if same name)
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
index 51b8ba5..d29ef4c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java
@@ -23,8 +23,8 @@ import java.util.concurrent.Executors;
 
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 
 public class BroadcastCPInstruction extends UnaryCPInstruction {
        private BroadcastCPInstruction(Operator op, CPOperand in, CPOperand 
out, String opcode, String istr) {
@@ -44,8 +44,8 @@ public class BroadcastCPInstruction extends 
UnaryCPInstruction {
        public void processInstruction(ExecutionContext ec) {
                ec.setVariable(output.getName(), ec.getMatrixObject(input1));
 
-               if (SparkUtils.triggerRDDPool == null)
-                       SparkUtils.triggerRDDPool = 
Executors.newCachedThreadPool();
-               SparkUtils.triggerRDDPool.submit(new TriggerBroadcastTask(ec, 
ec.getMatrixObject(output)));
+               if (CommonThreadPool.triggerRDDPool == null)
+                       CommonThreadPool.triggerRDDPool = 
Executors.newCachedThreadPool();
+               CommonThreadPool.triggerRDDPool.submit(new 
TriggerBroadcastTask(ec, ec.getMatrixObject(output)));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index f19f151..00f8ac2 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -23,8 +23,8 @@ import java.util.concurrent.Executors;
 
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 
 public class PrefetchCPInstruction extends UnaryCPInstruction {
        private PrefetchCPInstruction(Operator op, CPOperand in, CPOperand out, 
String opcode, String istr) {
@@ -49,8 +49,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction 
{
                // If the next instruction which takes this output as an input 
comes before
                // the prefetch thread triggers, that instruction will start 
the operations.
                // In that case this Prefetch instruction will act like a NOOP. 
-               if (SparkUtils.triggerRDDPool == null)
-                       SparkUtils.triggerRDDPool = 
Executors.newCachedThreadPool();
-               SparkUtils.triggerRDDPool.submit(new 
TriggerRDDOperationsTask(ec.getMatrixObject(output)));
+               if (CommonThreadPool.triggerRDDPool == null)
+                       CommonThreadPool.triggerRDDPool = 
Executors.newCachedThreadPool();
+               CommonThreadPool.triggerRDDPool.submit(new 
TriggerRDDOperationsTask(ec.getMatrixObject(output)));
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
index 479cbd7..6c83d75 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java
@@ -57,14 +57,11 @@ import scala.Tuple2;
 
 import java.util.Iterator;
 import java.util.List;
-import java.util.concurrent.ExecutorService;
 import java.util.stream.Collectors;
 import java.util.stream.LongStream;
 
 public class SparkUtils 
 {      
-       public static ExecutorService triggerRDDPool = null;
-
        //internal configuration
        public static final StorageLevel DEFAULT_TMP = 
Checkpoint.DEFAULT_STORAGE_LEVEL;
 
@@ -296,14 +293,6 @@ public class SparkUtils
                        mo.acquireReadAndRelease();
        }
        
-       public static void shutdownPool() {
-               if (triggerRDDPool != null) {
-                       //shutdown prefetch/broadcast thread pool
-                       triggerRDDPool.shutdown();
-                       triggerRDDPool = null;
-               }
-       }
-       
        private static class CheckSparsityFunction implements 
VoidFunction<Tuple2<MatrixIndexes,MatrixBlock>>
        {
                private static final long serialVersionUID = 
4150132775681848807L;
diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java 
b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
index 019f1b8..0c552f3 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java
@@ -49,6 +49,7 @@ public class CommonThreadPool implements ExecutorService
        private static final int size = 
InfrastructureAnalyzer.getLocalParallelism();
        private static final ExecutorService shared = ForkJoinPool.commonPool();
        private final ExecutorService _pool;
+       public static ExecutorService triggerRDDPool = null;
 
        public CommonThreadPool(ExecutorService pool) {
                _pool = pool;
@@ -78,6 +79,14 @@ public class CommonThreadPool implements ExecutorService
                shared.shutdownNow();
        }
 
+       public static void shutdownAsyncRDDPool() {
+               if (triggerRDDPool != null) {
+                       //shutdown prefetch/broadcast thread pool
+                       triggerRDDPool.shutdown();
+                       triggerRDDPool = null;
+               }
+       }
+       
        @Override
        public void shutdown() {
                if( _pool != shared )
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index b625139..7da3f52 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -101,7 +101,9 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        HashMap<MatrixValue.CellIndex, Double> R_pf = 
readDMLScalarFromOutputDir("R");
 
                        //compare matrices
-                       TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", 
"withPrefetch");
+                       Boolean matchVal = TestUtils.compareMatrices(R, R_pf, 
1e-6, "Origin", "withPrefetch");
+                       if (!matchVal)
+                               System.out.println("Value w/o Prefetch "+R+" w/ 
Prefetch "+R_pf);
                        //assert Prefetch instructions and number of success.
                        long expected_numPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
                        long expected_successPF = 
!testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;

Reply via email to