This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 912908316a [SYSTEMDS-3474] Lineage-based reuse of prefetch instruction
912908316a is described below
commit 912908316a29dbbecfd89c121117cfc32f740a2a
Author: Arnab Phani <[email protected]>
AuthorDate: Fri Dec 2 20:27:45 2022 +0100
[SYSTEMDS-3474] Lineage-based reuse of prefetch instruction
This patch enables caching and reusing prefetch instruction
outputs. This is the first step towards reusing asynchronous
operators.
Closes #1746
---
.../lops/compile/linearization/ILinearize.java | 3 +-
.../instructions/cp/PrefetchCPInstruction.java | 7 ++-
.../instructions/cp/TriggerPrefetchTask.java | 15 ++++++
.../apache/sysds/runtime/lineage/LineageCache.java | 52 ++++++++++++++++++
.../sysds/runtime/lineage/LineageCacheConfig.java | 2 +-
.../functions/async/LineageReuseSparkTest.java | 63 +++++++++++++++++-----
.../scripts/functions/async/LineageReuseSpark2.dml | 40 ++++++++++++++
7 files changed, 166 insertions(+), 16 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index 7eee970e2b..70ab1533df 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -44,6 +44,7 @@ import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
+import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
@@ -359,7 +360,7 @@ public interface ILinearize {
&& !(lop instanceof CoVariance)
// Not qualified for prefetching
&& !(lop instanceof Checkpoint) && !(lop
instanceof ReBlock)
- && !(lop instanceof CSVReBlock)
+ && !(lop instanceof CSVReBlock) && !(lop
instanceof DataGen)
// Cannot filter Transformation cases from
Actions (FIXME)
&& !(lop instanceof MMTSJ) && !(lop instanceof
UAggOuterChain)
&& !(lop instanceof ParameterizedBuiltin) &&
!(lop instanceof SpoofFused);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
index db9bbb1b84..192e165391 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java
@@ -23,6 +23,8 @@ import java.util.concurrent.Executors;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.lineage.LineageCacheConfig;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -42,8 +44,9 @@ public class PrefetchCPInstruction extends UnaryCPInstruction
{
@Override
public void processInstruction(ExecutionContext ec) {
- //TODO: handle non-matrix objects
+ // TODO: handle non-matrix objects
ec.setVariable(output.getName(), ec.getMatrixObject(input1));
+ LineageItem li = !LineageCacheConfig.ReuseCacheType.isNone() ?
this.getLineageItem(ec).getValue() : null;
// Note, a Prefetch instruction doesn't guarantee an
asynchronous execution.
// If the next instruction which takes this output as an input
comes before
@@ -51,6 +54,6 @@ public class PrefetchCPInstruction extends UnaryCPInstruction
{
// In that case this Prefetch instruction will act like a NOOP.
if (CommonThreadPool.triggerRemoteOPsPool == null)
CommonThreadPool.triggerRemoteOPsPool =
Executors.newCachedThreadPool();
- CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerPrefetchTask(ec.getMatrixObject(output)));
+ CommonThreadPool.triggerRemoteOPsPool.submit(new
TriggerPrefetchTask(ec.getMatrixObject(output), li));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
index 26a1c8e8bb..b7c69d01f5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/TriggerPrefetchTask.java
@@ -22,18 +22,28 @@ package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
+import org.apache.sysds.runtime.lineage.LineageCache;
+import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.utils.stats.SparkStatistics;
public class TriggerPrefetchTask implements Runnable {
MatrixObject _prefetchMO;
+ LineageItem _inputLi;
public TriggerPrefetchTask(MatrixObject mo) {
_prefetchMO = mo;
+ _inputLi = null;
+ }
+
+ public TriggerPrefetchTask(MatrixObject mo, LineageItem li) {
+ _prefetchMO = mo;
+ _inputLi = li;
}
@Override
public void run() {
boolean prefetched = false;
+ long t1 = System.nanoTime();
synchronized (_prefetchMO) {
// Having this check inside the critical section
// safeguards against concurrent rmVar.
@@ -44,6 +54,11 @@ public class TriggerPrefetchTask implements Runnable {
prefetched = true;
}
}
+
+ // Save the collected intermediate in the lineage cache
+ if (_inputLi != null)
+ LineageCache.putValueAsyncOp(_inputLi, _prefetchMO,
prefetched, t1);
+
if (DMLScript.STATISTICS && prefetched) {
if (_prefetchMO.isFederated())
FederatedStatistics.incAsyncPrefetchCount(1);
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
index ecd734c588..8e8d962199 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCache.java
@@ -38,6 +38,7 @@ import
org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.instructions.CPInstructionParser;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.cp.BroadcastCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
@@ -45,6 +46,7 @@ import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import
org.apache.sysds.runtime.instructions.cp.MultiReturnBuiltinCPInstruction;
import
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.PrefetchCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
@@ -579,6 +581,10 @@ public class LineageCache
continue;
}
+ if (inst instanceof PrefetchCPInstruction ||
inst instanceof BroadcastCPInstruction)
+ // For the async. instructions, caching
is handled separately by the tasks
+ continue;
+
if (data instanceof MatrixObject &&
((MatrixObject) data).hasRDDHandle()) {
// Avoid triggering pre-matured Spark
instruction chains
removePlaceholder(item);
@@ -637,6 +643,52 @@ public class LineageCache
}
}
+ public static void putValueAsyncOp(LineageItem instLI, Data data,
boolean prefetched, long starttime)
+ {
+ if (ReuseCacheType.isNone())
+ return;
+ if (!prefetched) //prefetching was not successful
+ return;
+
+ synchronized( _cache )
+ {
+ if (!probe(instLI))
+ return;
+
+ long computetime = System.nanoTime() - starttime;
+ LineageCacheEntry centry = _cache.get(instLI);
+ if(!(data instanceof MatrixObject) && !(data instanceof
ScalarObject)) {
+ // Reusable instructions can return a frame
(rightIndex). Remove placeholders.
+ removePlaceholder(instLI);
+ return;
+ }
+
+ MatrixBlock mb = (data instanceof MatrixObject) ?
+ ((MatrixObject)data).acquireReadAndRelease() :
null;
+ long size = mb != null ? mb.getInMemorySize() :
((ScalarObject)data).getSize();
+
+ // remove the placeholder if the entry is bigger than
the cache.
+ if (size > LineageCacheEviction.getCacheLimit()) {
+ removePlaceholder(instLI);
+ return;
+ }
+
+ // place the data
+ if (data instanceof MatrixObject)
+ centry.setValue(mb, computetime);
+ else if (data instanceof ScalarObject)
+ centry.setValue((ScalarObject)data,
computetime);
+
+ if (DMLScript.STATISTICS &&
LineageCacheEviction._removelist.containsKey(centry._key)) {
+ // Add to missed compute time
+
LineageCacheStatistics.incrementMissedComputeTime(centry._computeTime);
+ }
+
+ //maintain order for eviction
+ LineageCacheEviction.addEntry(centry);
+ }
+ }
+
public static void putValue(List<DataIdentifier> outputs,
LineageItem[] liInputs, String name, ExecutionContext ec, long
computetime)
{
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 5a7c46dfe7..72ea3835a2 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -54,7 +54,7 @@ public class LineageCacheConfig
"^", "uamax", "uark+", "uacmean", "eigen", "ctableexpand",
"replace",
"^2", "uack+", "tak+*", "uacsqk+", "uark+", "n+", "uarimax",
"qsort",
"qpick", "transformapply", "uarmax", "n+", "-*", "castdtm",
"lowertri",
- "mapmm", "cpmm"
+ "mapmm", "cpmm", "prefetch"
//TODO: Reuse everything.
};
private static String[] REUSE_OPCODES = new String[] {};
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
index e958093122..57d7892b3b 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/LineageReuseSparkTest.java
@@ -28,6 +28,7 @@ package org.apache.sysds.test.functions.async;
import org.apache.sysds.hops.recompile.Recompiler;
import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.lineage.Lineage;
+ import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
@@ -40,7 +41,7 @@ public class LineageReuseSparkTest extends AutomatedTestBase {
protected static final String TEST_DIR = "functions/async/";
protected static final String TEST_NAME = "LineageReuseSpark";
- protected static final int TEST_VARIANTS = 1;
+ protected static final int TEST_VARIANTS = 2;
protected static String TEST_CLASS_DIR = TEST_DIR +
LineageReuseSparkTest.class.getSimpleName() + "/";
@Override
@@ -52,16 +53,21 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
@Test
public void testlmdsHB() {
- runTest(TEST_NAME+"1", ExecMode.HYBRID);
+ runTest(TEST_NAME+"1", ExecMode.HYBRID, 1);
}
@Test
public void testlmdsSP() {
// Only reuse the actions
- runTest(TEST_NAME+"1", ExecMode.SPARK);
+ runTest(TEST_NAME+"1", ExecMode.SPARK, 1);
}
- public void runTest(String testname, ExecMode execMode) {
+ @Test
+ public void testReusePrefetch() {
+ runTest(TEST_NAME+"2", ExecMode.HYBRID, 2);
+ }
+
+ public void runTest(String testname, ExecMode execMode, int testId) {
boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
boolean old_trans_exec_type =
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
@@ -85,31 +91,52 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
programArgs = proArgs.toArray(new
String[proArgs.size()]);
Lineage.resetInternalState();
+ if (testId == 2) enablePrefetch();
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+ disablePrefetch();
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
- long numTsmm =
Statistics.getCPHeavyHitterCount("sp_tsmm");
- long numMapmm =
Statistics.getCPHeavyHitterCount("sp_mapmm");
-
+ long numTsmm = 0;
+ long numMapmm = 0;
+ if (testId == 1) {
+ numTsmm =
Statistics.getCPHeavyHitterCount("sp_tsmm");
+ numMapmm =
Statistics.getCPHeavyHitterCount("sp_mapmm");
+ }
+ long numPrefetch = 0;
+ if (testId == 2) numPrefetch =
Statistics.getCPHeavyHitterCount("prefetch");
+
+ proArgs.clear();
proArgs.add("-explain");
proArgs.add("-stats");
proArgs.add("-lineage");
- proArgs.add("reuse_hybrid");
+
proArgs.add(LineageCacheConfig.ReuseCacheType.REUSE_FULL.name().toLowerCase());
proArgs.add("-args");
proArgs.add(output("R"));
programArgs = proArgs.toArray(new
String[proArgs.size()]);
Lineage.resetInternalState();
+ if (testId == 2) enablePrefetch();
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+ disablePrefetch();
HashMap<MatrixValue.CellIndex, Double> R_reused =
readDMLScalarFromOutputDir("R");
- long numTsmm_r =
Statistics.getCPHeavyHitterCount("sp_tsmm");
- long numMapmm_r=
Statistics.getCPHeavyHitterCount("sp_mapmm");
+ long numTsmm_r = 0;
+ long numMapmm_r = 0;
+ if (testId == 1) {
+ numTsmm_r =
Statistics.getCPHeavyHitterCount("sp_tsmm");
+ numMapmm_r =
Statistics.getCPHeavyHitterCount("sp_mapmm");
+ }
+ long numPrefetch_r = 0;
+ if (testId == 2) numPrefetch_r =
Statistics.getCPHeavyHitterCount("prefetch");
//compare matrices
boolean matchVal = TestUtils.compareMatrices(R,
R_reused, 1e-6, "Origin", "withPrefetch");
if (!matchVal)
System.out.println("Value w/o reuse "+R+" w/
reuse "+R_reused);
- Assert.assertTrue("Violated sp_tsmm: reuse count:
"+numTsmm_r+" < "+numTsmm, numTsmm_r < numTsmm);
- Assert.assertTrue("Violated sp_mapmm: reuse count:
"+numMapmm_r+" < "+numMapmm, numMapmm_r < numMapmm);
+ if (testId == 1) {
+ Assert.assertTrue("Violated sp_tsmm reuse
count: " + numTsmm_r + " < " + numTsmm, numTsmm_r < numTsmm);
+ Assert.assertTrue("Violated sp_mapmm reuse
count: " + numMapmm_r + " < " + numMapmm, numMapmm_r < numMapmm);
+ }
+ if (testId == 2)
+ Assert.assertTrue("Violated prefetch reuse
count: " + numPrefetch_r + " < " + numPrefetch, numPrefetch_r<numPrefetch);
} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
@@ -120,4 +147,16 @@ public class LineageReuseSparkTest extends
AutomatedTestBase {
Recompiler.reinitRecompiler();
}
}
+
+ private void enablePrefetch() {
+ OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
+ OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+ }
+
+ private void disablePrefetch() {
+ OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
+ OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
+ OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+ }
}
diff --git a/src/test/scripts/functions/async/LineageReuseSpark2.dml
b/src/test/scripts/functions/async/LineageReuseSpark2.dml
new file mode 100644
index 0000000000..63792332b4
--- /dev/null
+++ b/src/test/scripts/functions/async/LineageReuseSpark2.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# Spark transformation operations
+for (i in 1:10) {
+ while(FALSE){}
+ sp1 = X + ceil(X);
+ sp2 = sp1 %*% v; #output fits in local
+ # Place a prefetch after mapmm and reuse
+
+ # CP instructions
+ v2 = ((v + v) * 1 - v) / (1+1);
+ v2 = ((v + v) * 2 - v) / (2+1);
+
+ # CP binary triggers the DAG of SP operations
+ cp = sp2 + sum(v2);
+ R = sum(cp);
+}
+
+write(R, $1, format="text");