This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 53c8157746 [SYSTEMDS-3519] Extend Prefetch instruction for GPU to CP
53c8157746 is described below
commit 53c81577465ee6d42b10f5a29d03cd504f978d56
Author: Arnab Phani <[email protected]>
AuthorDate: Mon Sep 25 13:20:27 2023 +0200
[SYSTEMDS-3519] Extend Prefetch instruction for GPU to CP
This patch enables prefetch instruction to copy intermediates from
GPU to local memory asynchronously. As we reuse prefetch, this change
also allows removing synchronization barriers between GPU and CPU via
reusing prefetched matrix blocks.
---
.../java/org/apache/sysds/conf/ConfigurationManager.java | 4 ++--
src/main/java/org/apache/sysds/conf/DMLConfig.java | 6 +++---
src/main/java/org/apache/sysds/hops/OptimizerUtils.java | 10 +++++-----
.../sysds/lops/compile/linearization/ILinearize.java | 1 +
.../java/org/apache/sysds/lops/rewrite/LopRewriter.java | 2 +-
.../apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java | 10 ++++++++++
.../runtime/controlprogram/caching/CacheableData.java | 7 ++++++-
.../runtime/instructions/cp/TriggerPrefetchTask.java | 3 ++-
.../org/apache/sysds/runtime/lineage/LineageCache.java | 8 ++++++--
.../sysds/runtime/lineage/LineageCacheStatistics.java | 16 +++++++---------
src/main/java/org/apache/sysds/utils/Statistics.java | 2 +-
.../org/apache/sysds/utils/stats/SparkStatistics.java | 2 +-
.../test/functions/async/MaxParallelizeOrderTest.java | 4 ++--
.../sysds/test/functions/async/PrefetchRDDTest.java | 4 ++--
.../sysds/test/functions/async/ReuseAsyncOpTest.java | 4 ++--
.../sysds/test/functions/lineage/GPUFullReuseTest.java | 15 +++++++++------
16 files changed, 60 insertions(+), 38 deletions(-)
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index a4b5c0ffec..1ac4d13974 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -262,8 +262,8 @@ public class ConfigurationManager{
}
public static boolean isPrefetchEnabled() {
- return
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_PREFETCH)
- || OptimizerUtils.ASYNC_PREFETCH_SPARK);
+ return (getDMLConfig().getBooleanValue(DMLConfig.ASYNC_PREFETCH)
+ || OptimizerUtils.ASYNC_PREFETCH);
}
public static boolean isMaxPrallelizeEnabled() {
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 103cf0a01e..767026a161 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -131,7 +131,7 @@ public class DMLConfig
public static final int DEFAULT_FEDERATED_PORT = 4040; // borrowed
default Spark Port
public static final int DEFAULT_NUMBER_OF_FEDERATED_WORKER_THREADS = 8;
/** Asynchronous triggering of Spark OPs and operator placement **/
- public static final String ASYNC_SPARK_PREFETCH =
"sysds.async.prefetch"; // boolean: enable asynchronous prefetching spark
intermediates
+ public static final String ASYNC_PREFETCH = "sysds.async.prefetch"; //
boolean: enable asynchronous prefetching spark/gpu intermediates
public static final String ASYNC_SPARK_BROADCAST =
"sysds.async.broadcast"; // boolean: enable asynchronous broadcasting CP
intermediates
public static final String ASYNC_SPARK_CHECKPOINT =
"sysds.async.checkpoint"; // boolean: enable compile-time persisting of Spark
intermediates
//internal config
@@ -207,7 +207,7 @@ public class DMLConfig
_defaultVals.put(FEDERATED_MONITOR_FREQUENCY, "3");
_defaultVals.put(FEDERATED_COMPRESSION, "none");
_defaultVals.put(PRIVACY_CONSTRAINT_MOCK, null);
- _defaultVals.put(ASYNC_SPARK_PREFETCH, "false" );
+ _defaultVals.put(ASYNC_PREFETCH, "false" );
_defaultVals.put(ASYNC_SPARK_BROADCAST, "false" );
_defaultVals.put(ASYNC_SPARK_CHECKPOINT, "false" );
}
@@ -463,7 +463,7 @@ public class DMLConfig
FLOATING_POINT_PRECISION, GPU_EVICTION_POLICY,
LOCAL_SPARK_NUM_THREADS, EVICTION_SHADOW_BUFFERSIZE,
GPU_MEMORY_ALLOCATOR, GPU_MEMORY_UTILIZATION_FACTOR,
USE_SSL_FEDERATED_COMMUNICATION,
DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT,
FEDERATED_TIMEOUT, FEDERATED_MONITOR_FREQUENCY, FEDERATED_COMPRESSION,
- ASYNC_SPARK_PREFETCH, ASYNC_SPARK_BROADCAST,
ASYNC_SPARK_CHECKPOINT
+ ASYNC_PREFETCH, ASYNC_SPARK_BROADCAST,
ASYNC_SPARK_CHECKPOINT
};
StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 1a7cc1c02b..8da0ff110d 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -285,16 +285,16 @@ public class OptimizerUtils
public static boolean ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
/**
- * Enable prefetch and broadcast. Prefetch asynchronously calls
acquireReadAndRelease() to trigger a chain of spark
- * transformations, which would would otherwise make the next
instruction wait till completion. Broadcast allows
+ * Enable prefetch and broadcast. Prefetch asynchronously calls
acquireReadAndRelease() to trigger remote
+ * operations, which would otherwise make the next instruction wait
till completion. Broadcast allows
* asynchronously transferring the data to all the nodes.
*/
- public static boolean ASYNC_PREFETCH_SPARK = false;
+ public static boolean ASYNC_PREFETCH = false; //both Spark and GPU
public static boolean ASYNC_BROADCAST_SPARK = false;
public static boolean ASYNC_CHECKPOINT_SPARK = false;
/**
- * Heuristic-based instruction ordering to maximize inter-operator
parallelism.
+ * Heuristic-based instruction ordering to maximize inter-operator
PARALLELISM.
* Place the Spark operator chains first and trigger them to execute in
parallel.
*/
public static boolean MAX_PARALLELIZE_ORDER = false;
@@ -308,7 +308,7 @@ public class OptimizerUtils
/**
* Rule-based operator placement policy for GPU.
*/
- public static boolean RULE_BASED_GPU_EXEC = false;
+ public static boolean RULE_BASED_GPU_EXEC = true;
//////////////////////
// Optimizer levels //
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index 3c0aa61692..364dd662b8 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -210,6 +210,7 @@ public class ILinearize {
final_v = depthFirst(v);
return final_v;
+ //TODO: Support GPU operator chains
}
// Place the operators in a depth-first manner, but order
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 2b054d9b2b..8d2c0a63f8 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -40,11 +40,11 @@ public class LopRewriter
public LopRewriter() {
_lopSBRuleSet = new ArrayList<>();
// Add rewrite rules (single and multi-statement block)
+ _lopSBRuleSet.add(new RewriteUpdateGPUPlacements());
_lopSBRuleSet.add(new RewriteAddPrefetchLop());
_lopSBRuleSet.add(new RewriteAddBroadcastLop());
_lopSBRuleSet.add(new RewriteAddChkpointLop());
_lopSBRuleSet.add(new RewriteAddChkpointInLoop());
- _lopSBRuleSet.add(new RewriteUpdateGPUPlacements());
// TODO: A rewrite pass to remove less effective chkpoints
// Last rewrite to reset Lop IDs in a depth-first manner
_lopSBRuleSet.add(new RewriteFixIDs());
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
index 91b7f81e71..4567e88d52 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
@@ -90,6 +90,10 @@ public class RewriteAddPrefetchLop extends LopRewriteRule
}
private boolean isPrefetchNeeded(Lop lop) {
+ return isPrefetchFromSparkNeeded(lop) ||
isPrefetchFromGPUNeeded(lop);
+ }
+
+ private boolean isPrefetchFromSparkNeeded(Lop lop) {
// Run Prefetch for a Spark instruction if the instruction is a
Transformation
// and the output is consumed by only CP instructions.
boolean transformOP = lop.getExecType() == Types.ExecType.SPARK
&& lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK
@@ -119,4 +123,10 @@ public class RewriteAddPrefetchLop extends LopRewriteRule
&& (lop.isAllOutputsCP() ||
OperatorOrderingUtils.isCollectForBroadcast(lop))
&& lop.getDataType() == Types.DataType.MATRIX;
}
+
+ private boolean isPrefetchFromGPUNeeded(Lop lop) {
+ // Prefetch a GPU intermediate if all the outputs are CP.
+ return lop.getDataType() == Types.DataType.MATRIX
+ && lop.isExecGPU() && lop.isAllOutputsCP();
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index a10b13284f..6961173af2 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -1280,7 +1280,12 @@ public abstract class CacheableData<T extends
CacheBlock<?>> extends Data
public boolean isPendingRDDOps() {
return isEmpty(true) && _data == null && (_rddHandle != null &&
_rddHandle.hasBackReference());
}
-
+
+ public boolean isDeviceToHostCopy() {
+ boolean isGpuOP = isEmpty(true) && _data == null && _gpuObjects
!= null;
+ return isGpuOP && _gpuObjects.values().stream().anyMatch(gobj
-> (gobj != null && gobj.isDirty()));
+ }
+
protected void setEmpty() {
_cacheStatus = CacheStatus.EMPTY;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index 78857c5a17..f1d5a8d3f6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -49,7 +49,8 @@ public class TriggerPrefetchTask implements Runnable {
synchronized (_prefetchMO) {
// Having this check inside the critical section
// safeguards against concurrent rmVar.
- if (_prefetchMO.isPendingRDDOps() ||
_prefetchMO.isFederated()) {
+ if (_prefetchMO.isPendingRDDOps() ||
_prefetchMO.isDeviceToHostCopy()
+ || _prefetchMO.isFederated()) {
// TODO: Add robust runtime constraints for
federated prefetch
// Execute and bring the result to local
mb = _prefetchMO.acquireReadAndRelease();
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index d05ea234a1..7aacfb5151 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -641,8 +641,9 @@ public class LineageCache
// Scalar gpu intermediates is always copied
back to host.
// No need to cache the GPUobj for scalar
intermediates.
+ instLI = ec.getLineageItem(((GPUInstruction)
inst)._output);
if (liGPUObj == null)
- liData = Arrays.asList(Pair.of(instLI,
ec.getVariable(((GPUInstruction)inst)._output)));
+ liData = Arrays.asList(Pair.of(instLI,
ec.getVariable(((GPUInstruction) inst)._output)));
}
else if (inst instanceof ComputationSPInstruction
&& (ec.getVariable(((ComputationSPInstruction)
inst).output) instanceof MatrixObject)
@@ -1463,6 +1464,8 @@ public class LineageCache
LineageCacheStatistics.incrementSavedComputeTime(e._computeTime);
if (e.isGPUObject()) LineageCacheStatistics.incrementGpuHits();
+ if (inst.getOpcode().equals("prefetch") &&
DMLScript.USE_ACCELERATOR)
+ LineageCacheStatistics.incrementGpuPrefetch();
if (e.isRDDPersist()) {
if
(SparkExecutionContext.isRDDCached(e.getRDDObject().getRDD().id()))
LineageCacheStatistics.incrementRDDPersistHits(); //persisted in the executors
@@ -1470,7 +1473,8 @@ public class LineageCache
LineageCacheStatistics.incrementRDDHits();
//only locally cached
}
if (e.isMatrixValue() || e.isScalarValue()) {
- if (inst instanceof ComputationSPInstruction ||
inst.getOpcode().equals("prefetch"))
+ if (inst instanceof ComputationSPInstruction
+ || (inst.getOpcode().equals("prefetch") &&
!DMLScript.USE_ACCELERATOR))
// Single_block Spark instructions (sync/async)
and prefetch
LineageCacheStatistics.incrementSparkCollectHits();
else
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
index 00fd36c378..be7460f6fd 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheStatistics.java
@@ -44,7 +44,7 @@ public class LineageCacheStatistics {
private static final LongAdder _ctimeProbe = new LongAdder();
// Bellow entries are specific to gpu lineage cache
private static final LongAdder _numHitsGpu = new LongAdder();
- private static final LongAdder _numAsyncEvictGpu= new LongAdder();
+ private static final LongAdder _numPrefetchGpu= new LongAdder();
private static final LongAdder _numSyncEvictGpu = new LongAdder();
private static final LongAdder _numRecycleGpu = new LongAdder();
private static final LongAdder _numDelGpu = new LongAdder();
@@ -74,7 +74,7 @@ public class LineageCacheStatistics {
_ctimeProbe.reset();
_evtimeGpu.reset();
_numHitsGpu.reset();
- _numAsyncEvictGpu.reset();
+ _numPrefetchGpu.reset();
_numSyncEvictGpu.reset();
_numRecycleGpu.reset();
_numDelGpu.reset();
@@ -210,9 +210,9 @@ public class LineageCacheStatistics {
_numHitsGpu.increment();
}
- public static void incrementGpuAsyncEvicts() {
- // Number of gpu cache entries moved to cpu cache via the
background thread
- _numAsyncEvictGpu.increment();
+ public static void incrementGpuPrefetch() {
+ // Number of reuse of GPU to host prefetches (asynchronous)
+ _numPrefetchGpu.increment();
}
public static void incrementGpuSyncEvicts() {
@@ -318,9 +318,7 @@ public class LineageCacheStatistics {
StringBuilder sb = new StringBuilder();
sb.append(_numHitsGpu.longValue());
sb.append("/");
- sb.append(_numAsyncEvictGpu.longValue());
- sb.append("/");
- sb.append(_numSyncEvictGpu.longValue());
+ sb.append(_numPrefetchGpu.longValue());
return sb.toString();
}
@@ -339,7 +337,7 @@ public class LineageCacheStatistics {
}
public static boolean ifGpuStats() {
- return (_numHitsGpu.longValue() + _numAsyncEvictGpu.longValue()
+ return (_numHitsGpu.longValue() + _numPrefetchGpu.longValue()
+ _numSyncEvictGpu.longValue() +
_numRecycleGpu.longValue()
+ _numDelGpu.longValue() + _evtimeGpu.longValue()) != 0;
}
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 6978507179..01c9682d90 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -639,7 +639,7 @@ public class Statistics
sb.append("LinCache hits (Mem/FS/Del): \t" +
LineageCacheStatistics.displayHits() + ".\n");
sb.append("LinCache MultiLevel (Ins/SB/Fn):" +
LineageCacheStatistics.displayMultiLevelHits() + ".\n");
if (LineageCacheStatistics.ifGpuStats()) {
- sb.append("LinCache GPU
(Hit/Async/Sync): \t" + LineageCacheStatistics.displayGpuStats() + ".\n");
+ sb.append("LinCache GPU (Hit/PF): \t" +
LineageCacheStatistics.displayGpuStats() + ".\n");
sb.append("LinCache GPU (Recyc/Del):
\t" + LineageCacheStatistics.displayGpuPointerStats() + ".\n");
sb.append("LinCache GPU evict time: \t"
+ LineageCacheStatistics.displayGpuEvictTime() + " sec.\n");
}
diff --git a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
index ae21ea0672..8c110a0b92 100644
--- a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
@@ -131,7 +131,7 @@ public class SparkStatistics {
parallelizeTime.longValue()*1e-9,
broadcastTime.longValue()*1e-9,
collectTime.longValue()*1e-9));
- sb.append("Spark async. count (pf,bc,op): \t" +
+ sb.append("Async. OP count (pf,bc,op): \t" +
String.format("%d/%d/%d.\n",
getAsyncPrefetchCount(), getAsyncBroadcastCount(), getAsyncSparkOpCount()));
return sb.toString();
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
index eeb3f6d5e7..51a166c10d 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
@@ -96,13 +96,13 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+ OptimizerUtils.ASYNC_PREFETCH = true;
OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
if (testname.equalsIgnoreCase(TEST_NAME+"4"))
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE
= false;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R_mp =
readDMLScalarFromOutputDir("R");
- OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+ OptimizerUtils.ASYNC_PREFETCH = false;
OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index 886a850d22..46ab3444df 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -103,9 +103,9 @@ public class PrefetchRDDTest extends AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+ OptimizerUtils.ASYNC_PREFETCH = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
- OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+ OptimizerUtils.ASYNC_PREFETCH = false;
OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
HashMap<MatrixValue.CellIndex, Double> R_pf =
readDMLScalarFromOutputDir("R");
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
index 7666a30184..7d700cc3b6 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
@@ -136,12 +136,12 @@ public class ReuseAsyncOpTest extends AutomatedTestBase {
private void enableAsync() {
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
- OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+ OptimizerUtils.ASYNC_PREFETCH = true;
}
private void disableAsync() {
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
- OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+ OptimizerUtils.ASYNC_PREFETCH = false;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java
b/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java
index 1a0665c187..47bfd8c4e8 100644
---
a/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/lineage/GPUFullReuseTest.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
+import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
@@ -43,21 +44,21 @@ public class GPUFullReuseTest extends AutomatedTestBase{
protected static final String TEST_NAME = "LineageReuseGPU";
protected static final int TEST_VARIANTS = 4;
protected String TEST_CLASS_DIR = TEST_DIR +
GPUFullReuseTest.class.getSimpleName() + "/";
-
+
@BeforeClass
public static void checkGPU() {
// Skip all the tests if no GPU is available
// FIXME: Fails to skip if gpu available but no libraries
Assume.assumeTrue(TestUtils.isGPUAvailable() ==
cudaError.cudaSuccess);
}
-
+
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
for( int i=1; i<=TEST_VARIANTS; i++ )
addTestConfiguration(TEST_NAME+i, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
}
-
+
@Test
public void ReuseAggBin() { //reuse AggregateBinary and sum
testLineageTraceExec(TEST_NAME+"1");
@@ -90,9 +91,10 @@ public class GPUFullReuseTest extends AutomatedTestBase{
proArgs.add(output("R"));
programArgs = proArgs.toArray(new String[proArgs.size()]);
fullDMLScriptName = getScript();
-
+
Lineage.resetInternalState();
//run the test
+ OptimizerUtils.ASYNC_PREFETCH = true;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R_orig =
readDMLMatrixFromOutputDir("R");
@@ -104,14 +106,15 @@ public class GPUFullReuseTest extends AutomatedTestBase{
proArgs.add(output("R"));
programArgs = proArgs.toArray(new String[proArgs.size()]);
fullDMLScriptName = getScript();
-
+
Lineage.resetInternalState();
//run the test
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+ OptimizerUtils.ASYNC_PREFETCH = false;
AutomatedTestBase.TEST_GPU = false;
HashMap<MatrixValue.CellIndex, Double> R_reused =
readDMLMatrixFromOutputDir("R");
- //compare results
+ //compare results
TestUtils.compareMatrices(R_orig, R_reused, 1e-6, "Origin",
"Reused");
if( testname.endsWith("3") ) { //function reuse