This is an automated email from the ASF dual-hosted git repository. arnabp20 pushed a commit to branch branch-2.2.0 in repository https://gitbox.apache.org/repos/asf/systemds.git
commit fa7de58f28f8b54624176c133820d2112e69f636 Author: arnabp <[email protected]> AuthorDate: Tue Nov 9 23:03:02 2021 +0100 [SYSTEMDS-3212] Move Prefetch threadpool to CommonThreadPool --- src/main/java/org/apache/sysds/api/DMLScript.java | 4 ++-- .../sysds/runtime/instructions/cp/BroadcastCPInstruction.java | 8 ++++---- .../sysds/runtime/instructions/cp/PrefetchCPInstruction.java | 8 ++++---- .../sysds/runtime/instructions/spark/utils/SparkUtils.java | 11 ----------- .../java/org/apache/sysds/runtime/util/CommonThreadPool.java | 9 +++++++++ .../apache/sysds/test/functions/async/PrefetchRDDTest.java | 4 +++- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 4564f05..1b5b3df 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -67,13 +67,13 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedWorker; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler; import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool; -import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.lineage.LineageCacheConfig; import org.apache.sysds.runtime.lineage.LineageCacheConfig.LineageCachePolicy; import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; import org.apache.sysds.runtime.privacy.CheckedConstraintsLog; import org.apache.sysds.runtime.util.LocalFileUtils; +import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.utils.Explain; import org.apache.sysds.utils.NativeHelper; @@ -519,7 +519,7 @@ public class DMLScript FederatedData.clearFederatedWorkers(); //0) shutdown prefetch/broadcast thread pool if necessary - SparkUtils.shutdownPool(); + CommonThreadPool.shutdownAsyncRDDPool(); //1) cleanup scratch space (everything for current uuid) //(required otherwise export to hdfs would skip assumed unnecessary writes if same name) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java index 51b8ba5..d29ef4c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BroadcastCPInstruction.java @@ -23,8 +23,8 @@ import java.util.concurrent.Executors; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.util.CommonThreadPool; public class BroadcastCPInstruction extends UnaryCPInstruction { private BroadcastCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) { @@ -44,8 +44,8 @@ public class BroadcastCPInstruction extends UnaryCPInstruction { public void processInstruction(ExecutionContext ec) { ec.setVariable(output.getName(), ec.getMatrixObject(input1)); - if (SparkUtils.triggerRDDPool == null) - SparkUtils.triggerRDDPool = Executors.newCachedThreadPool(); - SparkUtils.triggerRDDPool.submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output))); + if (CommonThreadPool.triggerRDDPool == null) + CommonThreadPool.triggerRDDPool = Executors.newCachedThreadPool(); + CommonThreadPool.triggerRDDPool.submit(new TriggerBroadcastTask(ec, ec.getMatrixObject(output))); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java index f19f151..00f8ac2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java @@ -23,8 +23,8 @@ import java.util.concurrent.Executors; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.InstructionUtils; -import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.util.CommonThreadPool; public class PrefetchCPInstruction extends UnaryCPInstruction { private PrefetchCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) { @@ -49,8 +49,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction { // If the next instruction which takes this output as an input comes before // the prefetch thread triggers, that instruction will start the operations. // In that case this Prefetch instruction will act like a NOOP. - if (SparkUtils.triggerRDDPool == null) - SparkUtils.triggerRDDPool = Executors.newCachedThreadPool(); - SparkUtils.triggerRDDPool.submit(new TriggerRDDOperationsTask(ec.getMatrixObject(output))); + if (CommonThreadPool.triggerRDDPool == null) + CommonThreadPool.triggerRDDPool = Executors.newCachedThreadPool(); + CommonThreadPool.triggerRDDPool.submit(new TriggerRDDOperationsTask(ec.getMatrixObject(output))); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java index 479cbd7..6c83d75 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java @@ -57,14 +57,11 @@ import scala.Tuple2; import java.util.Iterator; import java.util.List; -import java.util.concurrent.ExecutorService; import java.util.stream.Collectors; import java.util.stream.LongStream; public class SparkUtils { - public static ExecutorService triggerRDDPool = null; - //internal configuration public static final StorageLevel DEFAULT_TMP = Checkpoint.DEFAULT_STORAGE_LEVEL; @@ -296,14 +293,6 @@ public class SparkUtils mo.acquireReadAndRelease(); } - public static void shutdownPool() { - if (triggerRDDPool != null) { - //shutdown prefetch/broadcast thread pool - triggerRDDPool.shutdown(); - triggerRDDPool = null; - } - } - private static class CheckSparsityFunction implements VoidFunction<Tuple2<MatrixIndexes,MatrixBlock>> { private static final long serialVersionUID = 4150132775681848807L; diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java index 019f1b8..0c552f3 100644 --- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java +++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java @@ -49,6 +49,7 @@ public class CommonThreadPool implements ExecutorService private static final int size = InfrastructureAnalyzer.getLocalParallelism(); private static final ExecutorService shared = ForkJoinPool.commonPool(); private final ExecutorService _pool; + public static ExecutorService triggerRDDPool = null; public CommonThreadPool(ExecutorService pool) { _pool = pool; @@ -78,6 +79,14 @@ public class CommonThreadPool implements ExecutorService shared.shutdownNow(); } + public static void shutdownAsyncRDDPool() { + if (triggerRDDPool != null) { + //shutdown prefetch/broadcast thread pool + triggerRDDPool.shutdown(); + triggerRDDPool = null; + } + } + @Override public void shutdown() { if( _pool != shared ) diff --git a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java index b625139..7da3f52 100644 --- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java +++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java @@ -101,7 +101,9 @@ public class PrefetchRDDTest extends AutomatedTestBase { HashMap<MatrixValue.CellIndex, Double> R_pf = readDMLScalarFromOutputDir("R"); //compare matrices - TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", "withPrefetch"); + Boolean matchVal = TestUtils.compareMatrices(R, R_pf, 1e-6, "Origin", "withPrefetch"); + if (!matchVal) + System.out.println("Value w/o Prefetch "+R+" w/ Prefetch "+R_pf); //assert Prefetch instructions and number of success. long expected_numPF = !testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0; long expected_successPF = !testname.equalsIgnoreCase(TEST_NAME+"3") ? 1 : 0;
