This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 5182796632 [SYSTEMDS-3474] Lineage-based reuse of future-based
instructions
5182796632 is described below
commit 5182796632bfb9173f5d2e7b2e7d20e434270bda
Author: Arnab Phani <[email protected]>
AuthorDate: Tue Dec 6 12:36:09 2022 +0100
[SYSTEMDS-3474] Lineage-based reuse of future-based instructions
This patch enables caching and reuse of future-based Spark
actions.
Closes #1747
---
.../controlprogram/context/ExecutionContext.java | 7 +-
.../controlprogram/context/MatrixObjectFuture.java | 9 ++-
.../instructions/cp/PrefetchCPInstruction.java | 4 +-
.../instructions/cp/TriggerPrefetchTask.java | 8 +-
.../spark/AggregateUnarySPInstruction.java | 5 +-
.../instructions/spark/CpmmSPInstruction.java | 5 +-
.../instructions/spark/MapmmSPInstruction.java | 5 +-
.../instructions/spark/TsmmSPInstruction.java | 5 +-
.../apache/sysds/runtime/lineage/LineageCache.java | 53 ++++--------
.../sysds/runtime/lineage/LineageCacheConfig.java | 4 +
.../functions/async/LineageReuseSparkTest.java | 44 +---------
...geReuseSparkTest.java => ReuseAsyncOpTest.java} | 93 +++++++++-------------
.../{LineageReuseSpark2.dml => ReuseAsyncOp1.dml} | 0
.../{LineageReuseSpark2.dml => ReuseAsyncOp2.dml} | 41 ++++++----
14 files changed, 127 insertions(+), 156 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 0fa1340569..5e3e90f469 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -602,16 +602,21 @@ public class ExecutionContext {
mo.release();
}
- public void setMatrixOutput(String varName, Future<MatrixBlock> fmb) {
+ public void setMatrixOutputAndLineage(String varName,
Future<MatrixBlock> fmb, LineageItem li) {
if (isAutoCreateVars() && !containsVariable(varName)) {
MatrixObject fmo = new
MatrixObjectFuture(Types.ValueType.FP64,
OptimizerUtils.getUniqueTempFileName(), fmb);
}
MatrixObject mo = getMatrixObject(varName);
MatrixObjectFuture fmo = new MatrixObjectFuture(mo, fmb);
+ fmo.setCacheLineage(li);
setVariable(varName, fmo);
}
+ public void setMatrixOutput(String varName, Future<MatrixBlock> fmb) {
+ setMatrixOutputAndLineage(varName, fmb, null);
+ }
+
public void setMatrixOutput(String varName, MatrixBlock outputData,
UpdateType flag) {
if( isAutoCreateVars() && !containsVariable(varName) )
setVariable(varName, createMatrixObject(outputData));
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
index 3cbc7eff09..3c5581937b 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.controlprogram.context;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import java.util.concurrent.Future;
@@ -59,8 +60,14 @@ public class MatrixObjectFuture extends MatrixObject
throw new DMLRuntimeException("MatrixObject not
available to read.");
if(_data != null)
throw new DMLRuntimeException("_data must be
null for future matrix object/block.");
+ MatrixBlock out = null;
acquire(false, false);
- return _futureData.get();
+ long t1 = System.nanoTime();
+ out = _futureData.get();
+ if (hasValidLineage())
+ LineageCache.putValueAsyncOp(getCacheLineage(),
this, out, t1);
+ // FIXME: start time should indicate the actual
start of the execution
+ return out;
}
catch(Exception e) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index 192e165391..fa0d1c0e83 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -46,7 +46,7 @@ public class PrefetchCPInstruction extends UnaryCPInstruction
{
public void processInstruction(ExecutionContext ec) {
// TODO: handle non-matrix objects
ec.setVariable(output.getName(), ec.getMatrixObject(input1));
- LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ?
this.getLineageItem(ec).getValue() : null;
+ LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ?
getLineageItem(ec).getValue() : null;
// Note, a Prefetch instruction doesn't guarantee an
asynchronous execution.
// If the next instruction which takes this output as an input
comes before
@@ -54,6 +54,8 @@ public class PrefetchCPInstruction extends UnaryCPInstruction
{
// In that case this Prefetch instruction will act like a NOOP.
if (CommonThreadPool.triggerRemoteOPsPool == null)
CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
+ // Saving the lineage item inside the matrix object will
replace the pre-attached
+ // lineage item (e.g. mapmm). Hence, passing separately.
CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerPrefetchTask(ec.getMatrixObject(output), li));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index b7c69d01f5..78857c5a17 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -24,6 +24,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.stats.SparkStatistics;
public class TriggerPrefetchTask implements Runnable {
@@ -43,6 +44,7 @@ public class TriggerPrefetchTask implements Runnable {
@Override
public void run() {
boolean prefetched = false;
+ MatrixBlock mb = null;
long t1 = System.nanoTime();
synchronized (_prefetchMO) {
// Having this check inside the critical section
@@ -50,14 +52,14 @@ public class TriggerPrefetchTask implements Runnable {
if (_prefetchMO.isPendingRDDOps() ||
_prefetchMO.isFederated()) {
// TODO: Add robust runtime constraints for
federated prefetch
// Execute and bring the result to local
- _prefetchMO.acquireReadAndRelease();
+ mb = _prefetchMO.acquireReadAndRelease();
prefetched = true;
}
}
// Save the collected intermediate in the lineage cache
- if (_inputLi != null)
- LineageCache.putValueAsyncOp(_inputLi, _prefetchMO,
prefetched, t1);
+ if (_inputLi != null && mb != null)
+ LineageCache.putValueAsyncOp(_inputLi, _prefetchMO, mb,
t1);
if (DMLScript.STATISTICS && prefetched) {
if (_prefetchMO.isFederated())
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 48d41fd602..50816aefe4 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -39,6 +39,8 @@ import
org.apache.sysds.runtime.instructions.spark.functions.AggregateDropCorrec
import
org.apache.sysds.runtime.instructions.spark.functions.FilterDiagMatrixBlocksFunction;
import
org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
@@ -117,7 +119,8 @@ public class AggregateUnarySPInstruction extends
UnarySPInstruction {
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
RDDAggregateTask task = new
RDDAggregateTask(_optr, _aop, in, mc);
Future<MatrixBlock> future_out =
CommonThreadPool.triggerRemoteOPsPool.submit(task);
- sec.setMatrixOutput(output.getName(),
future_out);
+ LineageItem li =
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() :
null;
+
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 653596806d..79832eabe2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -37,6 +37,8 @@ import
org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlock
import org.apache.sysds.runtime.instructions.spark.functions.ReorgMapFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
@@ -113,7 +115,8 @@ public class CpmmSPInstruction extends
AggregateBinarySPInstruction {
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
CpmmMatrixVectorTask task = new
CpmmMatrixVectorTask(in1, in2);
Future<MatrixBlock> future_out =
CommonThreadPool.triggerRemoteOPsPool.submit(task);
- sec.setMatrixOutput(output.getName(),
future_out);
+ LineageItem li =
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() :
null;
+
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index 29f28b604e..b0285b1bba 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -50,6 +50,8 @@ import
org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import
org.apache.sysds.runtime.instructions.spark.functions.FilterNonEmptyBlocksFunction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
@@ -146,7 +148,8 @@ public class MapmmSPInstruction extends
AggregateBinarySPInstruction {
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
RDDMapmmTask task = new
RDDMapmmTask(in1, in2, type);
Future<MatrixBlock> future_out =
CommonThreadPool.triggerRemoteOPsPool.submit(task);
- sec.setMatrixOutput(output.getName(),
future_out);
+ LineageItem li =
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() :
null;
+
sec.setMatrixOutputAndLineage(output.getName(), future_out, li);
}
catch(Exception ex) { throw new
DMLRuntimeException(ex); }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index 17cef61158..acba784bf2 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -31,6 +31,8 @@ import
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -74,7 +76,8 @@ public class TsmmSPInstruction extends UnarySPInstruction {
CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
TsmmTask task = new TsmmTask(in, _type);
Future<MatrixBlock> future_out =
CommonThreadPool.triggerRemoteOPsPool.submit(task);
- sec.setMatrixOutput(output.getName(),
future_out);
+ LineageItem li =
!LineageCacheConfig.ReuseCacheType.isNone() ? getLineageItem(ec).getValue() :
null;
+ sec.setMatrixOutputAndLineage(output.getName(),
future_out, li);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index 8e8d962199..87ddd7b8da 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.lineage;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.MutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
@@ -575,16 +576,13 @@ public class LineageCache
continue;
}
- if (data instanceof MatrixObjectFuture) {
+ if (data instanceof MatrixObjectFuture || inst
instanceof PrefetchCPInstruction) {
// We don't want to call get() on the
future immediately after the execution
+ // For the async. instructions, caching
is handled separately by the tasks
removePlaceholder(item);
continue;
}
- if (inst instanceof PrefetchCPInstruction ||
inst instanceof BroadcastCPInstruction)
- // For the async. instructions, caching
is handled separately by the tasks
- continue;
-
if (data instanceof MatrixObject &&
((MatrixObject) data).hasRDDHandle()) {
// Avoid triggering pre-matured Spark
instruction chains
removePlaceholder(item);
@@ -643,49 +641,28 @@ public class LineageCache
}
}
- public static void putValueAsyncOp(LineageItem instLI, Data data,
boolean prefetched, long starttime)
+ // This method is called from inside the asynchronous operators and
directly put the output of
+ // an asynchronous instruction into the lineage cache. As the
consumers, a different operator,
+ // materializes the intermediate, we skip the placeholder placing logic.
+ public static void putValueAsyncOp(LineageItem instLI, Data data,
MatrixBlock mb, long starttime)
{
if (ReuseCacheType.isNone())
return;
- if (!prefetched) //prefetching was not successful
+ if
(!ArrayUtils.contains(LineageCacheConfig.getReusableOpcodes(),
instLI.getOpcode()))
+ return;
+ if(!(data instanceof MatrixObject) && !(data instanceof
ScalarObject)) {
return;
+ }
synchronized( _cache )
{
- if (!probe(instLI))
- return;
-
long computetime = System.nanoTime() - starttime;
- LineageCacheEntry centry = _cache.get(instLI);
- if(!(data instanceof MatrixObject) && !(data instanceof
ScalarObject)) {
- // Reusable instructions can return a frame
(rightIndex). Remove placeholders.
- removePlaceholder(instLI);
- return;
- }
+ // Make space, place data and manage queue
+ putIntern(instLI, DataType.MATRIX, mb, null,
computetime);
- MatrixBlock mb = (data instanceof MatrixObject) ?
- ((MatrixObject)data).acquireReadAndRelease() :
null;
- long size = mb != null ? mb.getInMemorySize() :
((ScalarObject)data).getSize();
-
- // remove the placeholder if the entry is bigger than
the cache.
- if (size > LineageCacheEviction.getCacheLimit()) {
- removePlaceholder(instLI);
- return;
- }
-
- // place the data
- if (data instanceof MatrixObject)
- centry.setValue(mb, computetime);
- else if (data instanceof ScalarObject)
- centry.setValue((ScalarObject)data,
computetime);
-
- if (DMLScript.STATISTICS &&
LineageCacheEviction._removelist.containsKey(centry._key)) {
+ if (DMLScript.STATISTICS &&
LineageCacheEviction._removelist.containsKey(instLI))
// Add to missed compute time
-
LineageCacheStatistics.incrementMissedComputeTime(centry._computeTime);
- }
-
- //maintain order for eviction
- LineageCacheEviction.addEntry(centry);
+
LineageCacheStatistics.incrementMissedComputeTime(computetime);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 72ea3835a2..fe32f364e5 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -197,6 +197,10 @@ public class LineageCacheConfig
public static void setReusableOpcodes(String... ops) {
REUSE_OPCODES = ops;
}
+
+ public static String[] getReusableOpcodes() {
+ return REUSE_OPCODES;
+ }
public static void resetReusableOpcodes() {
REUSE_OPCODES = OPCODES;
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index 57d7892b3b..5b49bb82fa 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -62,11 +62,6 @@ public class LineageReuseSparkTest extends AutomatedTestBase
{
runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
}
- @Test
- public void testReusePrefetch() {
- runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
- }
-
public void runTest(String testname, ExecMode execMode, int testId) {
boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -91,18 +86,10 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
programArgs = proArgs.toArray(new
String[proArgs.size()]);
Lineage.resetInternalState();
- if (testId == 2) enablePrefetch();
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
- disablePrefetch();
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- long numTsmm = 0;
- long numMapmm = 0;
- if (testId == 1) {
- numTsmm =
Statistics.getCPHeavyHitterCount("sp_tsmm");
- numMapmm =
Statistics.getCPHeavyHitterCount("sp_mapmm");
- }
- long numPrefetch = 0;
- if (testId == 2) numPrefetch =
Statistics.getCPHeavyHitterCount("prefetch");
+ long numTsmm =
Statistics.getCPHeavyHitterCount("sp_tsmm");
+ long numMapmm =
Statistics.getCPHeavyHitterCount("sp_mapmm");
proArgs.clear();
proArgs.add("-explain");
@@ -114,18 +101,10 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
programArgs = proArgs.toArray(new
String[proArgs.size()]);
Lineage.resetInternalState();
- if (testId == 2) enablePrefetch();
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
- disablePrefetch();
HashMap<MatrixValue.CellIndex, Double> R_reused =
readDMLScalarFromOutputDir("R");
- long numTsmm_r = 0;
- long numMapmm_r = 0;
- if (testId == 1) {
- numTsmm_r =
Statistics.getCPHeavyHitterCount("sp_tsmm");
- numMapmm_r =
Statistics.getCPHeavyHitterCount("sp_mapmm");
- }
- long numPrefetch_r = 0;
- if (testId == 2) numPrefetch_r =
Statistics.getCPHeavyHitterCount("prefetch");
+ long numTsmm_r =
Statistics.getCPHeavyHitterCount("sp_tsmm");
+ long numMapmm_r =
Statistics.getCPHeavyHitterCount("sp_mapmm");
//compare matrices
boolean matchVal = TestUtils.compareMatrices(R,
R_reused, 1e-6, "Origin", "withPrefetch");
@@ -135,9 +114,6 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
Assert.assertTrue("Violated sp_tsmm reuse
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
Assert.assertTrue("Violated sp_mapmm reuse
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
}
- if (testId == 2)
- Assert.assertTrue("Violated prefetch reuse
count: " + numPrefetch_r + " < " + numPrefetch, numPrefetch_r<numPrefetch);
-
} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES =
old_sum_product;
@@ -147,16 +123,4 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
Recompiler.reinitRecompiler();
}
}
-
- private void enablePrefetch() {
- OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
- OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
- OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
- }
-
- private void disablePrefetch() {
- OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
- OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
- OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
- }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
similarity index 67%
copy from
src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
copy to
src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
index 57d7892b3b..7666a30184 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/ReuseAsyncOpTest.java
@@ -19,30 +19,29 @@
package org.apache.sysds.test.functions.async;
- import java.util.ArrayList;
- import java.util.HashMap;
- import java.util.List;
-
- import org.apache.sysds.common.Types.ExecMode;
- import org.apache.sysds.hops.OptimizerUtils;
- import org.apache.sysds.hops.recompile.Recompiler;
- import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
- import org.apache.sysds.runtime.lineage.Lineage;
- import org.apache.sysds.runtime.lineage.LineageCacheConfig;
- import org.apache.sysds.runtime.matrix.data.MatrixValue;
- import org.apache.sysds.test.AutomatedTestBase;
- import org.apache.sysds.test.TestConfiguration;
- import org.apache.sysds.test.TestUtils;
- import org.apache.sysds.utils.Statistics;
- import org.junit.Assert;
- import org.junit.Test;
-
-public class LineageReuseSparkTest extends AutomatedTestBase {
-
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.recompile.Recompiler;
+import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class ReuseAsyncOpTest extends AutomatedTestBase {
protected static final String TEST_DIR = "functions/async/";
- protected static final String TEST_NAME = "LineageReuseSpark";
+ protected static final String TEST_NAME = "ReuseAsyncOp";
protected static final int TEST_VARIANTS = 2;
- protected static String TEST_CLASS_DIR = TEST_DIR +
LineageReuseSparkTest.class.getSimpleName() + "/";
+ protected static String TEST_CLASS_DIR = TEST_DIR +
ReuseAsyncOpTest.class.getSimpleName() + "/";
@Override
public void setUp() {
@@ -52,18 +51,14 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
}
@Test
- public void testlmdsHB() {
+ public void testReusePrefetch() {
+ // Reuse prefetch results
runTest(TEST_NAME+"1", ExecMode.HYBRID, 1);
}
@Test
- public void testlmdsSP() {
- // Only reuse the actions
- runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
- }
-
- @Test
- public void testReusePrefetch() {
+ public void testlmds() {
+ // Reuse future-based tsmm and mapmm
runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
}
@@ -91,18 +86,13 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
programArgs = proArgs.toArray(new
String[proArgs.size()]);
Lineage.resetInternalState();
- if (testId == 2) enablePrefetch();
+ enableAsync(); //enable max_reuse and prefetch
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
- disablePrefetch();
+ disableAsync();
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- long numTsmm = 0;
- long numMapmm = 0;
- if (testId == 1) {
- numTsmm =
Statistics.getCPHeavyHitterCount("sp_tsmm");
- numMapmm =
Statistics.getCPHeavyHitterCount("sp_mapmm");
- }
- long numPrefetch = 0;
- if (testId == 2) numPrefetch =
Statistics.getCPHeavyHitterCount("prefetch");
+ long numTsmm =
Statistics.getCPHeavyHitterCount("sp_tsmm");
+ long numMapmm =
Statistics.getCPHeavyHitterCount("sp_mapmm");
+ long numPrefetch =
Statistics.getCPHeavyHitterCount("prefetch");
proArgs.clear();
proArgs.add("-explain");
@@ -114,28 +104,23 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
programArgs = proArgs.toArray(new
String[proArgs.size()]);
Lineage.resetInternalState();
- if (testId == 2) enablePrefetch();
+ enableAsync(); //enable max_reuse and prefetch
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
- disablePrefetch();
+ disableAsync();
HashMap<MatrixValue.CellIndex, Double> R_reused =
readDMLScalarFromOutputDir("R");
- long numTsmm_r = 0;
- long numMapmm_r = 0;
- if (testId == 1) {
- numTsmm_r =
Statistics.getCPHeavyHitterCount("sp_tsmm");
- numMapmm_r =
Statistics.getCPHeavyHitterCount("sp_mapmm");
- }
- long numPrefetch_r = 0;
- if (testId == 2) numPrefetch_r =
Statistics.getCPHeavyHitterCount("prefetch");
+ long numTsmm_r =
Statistics.getCPHeavyHitterCount("sp_tsmm");
+ long numMapmm_r =
Statistics.getCPHeavyHitterCount("sp_mapmm");
+ long numPrefetch_r =
Statistics.getCPHeavyHitterCount("prefetch");
//compare matrices
boolean matchVal = TestUtils.compareMatrices(R,
R_reused, 1e-6, "Origin", "withPrefetch");
if (!matchVal)
System.out.println("Value w/o reuse "+R+" w/
reuse "+R_reused);
- if (testId == 1) {
+ if (testId == 2) {
Assert.assertTrue("Violated sp_tsmm reuse
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
Assert.assertTrue("Violated sp_mapmm reuse
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
}
- if (testId == 2)
+ if (testId == 1)
Assert.assertTrue("Violated prefetch reuse
count: " + numPrefetch_r + " < " + numPrefetch, numPrefetch_r<numPrefetch);
} finally {
@@ -148,13 +133,13 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
}
}
- private void enablePrefetch() {
+ private void enableAsync() {
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
}
- private void disablePrefetch() {
+ private void disableAsync() {
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml
b/src/test/scripts/functions/async/ReuseAsyncOp1.dml
similarity index 100%
copy from src/test/scripts/functions/async/LineageReuseSpark2.dml
copy to src/test/scripts/functions/async/ReuseAsyncOp1.dml
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml
b/src/test/scripts/functions/async/ReuseAsyncOp2.dml
similarity index 59%
rename from src/test/scripts/functions/async/LineageReuseSpark2.dml
rename to src/test/scripts/functions/async/ReuseAsyncOp2.dml
index 63792332b4..f4675f69b9 100644
--- a/src/test/scripts/functions/async/LineageReuseSpark2.dml
+++ b/src/test/scripts/functions/async/ReuseAsyncOp2.dml
@@ -18,23 +18,36 @@
# under the License.
#
#-------------------------------------------------------------
-X = rand(rows=10000, cols=200, seed=42); #sp_rand
-v = rand(rows=200, cols=1, seed=42); #cp_rand
-# Spark transformation operations
-for (i in 1:10) {
- while(FALSE){}
- sp1 = X + ceil(X);
- sp2 = sp1 %*% v; #output fits in local
- # Place a prefetch after mapmm and reuse
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda,
Integer N) return (Matrix[double] beta)
+{
+ # Reuse sp_tsmm and sp_mapmm if not future-based
+ A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+ b = t(X) %*% y;
+ beta = solve(A, b);
+}
+
+no_lamda = 10;
- # CP instructions
- v2 = ((v + v) * 1 - v) / (1+1);
- v2 = ((v + v) * 2 - v) / (2+1);
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
- # CP binary triggers the DAG of SP operations
- cp = sp2 + sum(v2);
- R = sum(cp);
+X = rand(rows=10000, cols=200, seed=42);
+y = rand(rows=10000, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+ beta = SimlinRegDS(X, y, lamda, N);
+ #beta = lmDS(X=X, y=y, reg=lamda);
+ R[,i] = beta;
+ lamda = lamda + stp;
+ i = i + 1;
}
+R = sum(R);
write(R, $1, format="text");
+