This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new eb3e384770 [SYSTEMDS-3483] Compile-time checkpoint placement for
shared Spark OPs
eb3e384770 is described below
commit eb3e38477097d967419075f4b61b2da25321dd2c
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Jan 4 09:56:12 2023 +0100
[SYSTEMDS-3483] Compile-time checkpoint placement for shared Spark OPs
This patch brings the initial implementation of placing checkpoint Lops
after expensive Spark operations, which are shared among multiple Spark
jobs. In addition, this patch fixes bugs and extends statistics.
Closes #1758
---
.../lops/compile/linearization/ILinearize.java | 73 +++++++++++++++++++--
.../controlprogram/context/MatrixObjectFuture.java | 11 +++-
.../java/org/apache/sysds/utils/Statistics.java | 2 +-
.../apache/sysds/utils/stats/SparkStatistics.java | 13 +++-
...OrderTest.java => CheckpointSharedOpsTest.java} | 76 ++++++++--------------
.../functions/async/MaxParallelizeOrderTest.java | 6 --
.../functions/async/CheckpointSharedOps1.dml | 42 ++++++++++++
7 files changed, 159 insertions(+), 64 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index 6bc520d31a..4b2df15e63 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -47,8 +47,11 @@ import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.MMCJ;
+import org.apache.sysds.lops.MMRJ;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
+import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.lops.PickByCount;
@@ -190,17 +193,22 @@ public interface ILinearize {
roots.forEach(r -> collectSparkRoots(r, sparkOpCount,
sparkRoots));
// Step 2: Depth-first linearization. Place the Spark
OPs first.
- // Sort the Spark roots based on number of Spark
operators descending
+ // Maintain the default order (by ID) to trigger
independent Spark chains first
ArrayList<Lop> operatorList = new ArrayList<>();
- Lop[] sortedSPRoots = sparkRoots.toArray(new Lop[0]);
- Arrays.sort(sortedSPRoots, (l1, l2) ->
sparkOpCount.get(l2.getID()) - sparkOpCount.get(l1.getID()));
- Arrays.stream(sortedSPRoots).forEach(r -> depthFirst(r,
operatorList, sparkOpCount, true));
+ sparkRoots.forEach(r -> depthFirst(r, operatorList,
sparkOpCount, true));
// Step 3: Place the rest of the operators (CP). Sort
the CP roots based on
// #Spark operators in ascending order, i.e. execute
the independent CP chains first
roots.forEach(r -> depthFirst(r, operatorList,
sparkOpCount, false));
roots.forEach(Lop::resetVisitStatus);
- final_v = operatorList;
+
+ // Step 4: Add Chkpoint lops after the expensive Spark
operators, which
+ // are shared among multiple Spark jobs. Only consider
operators with
+ // Spark consumers for now.
+ Map<Long, Integer> operatorJobCount = new HashMap<>();
+ markPersistableSparkOps(sparkRoots, operatorJobCount);
+ final_v = addChkpointLop(operatorList,
operatorJobCount);
+ // TODO: A rewrite pass to remove less effective
chkpoints
}
else
// Fall back to depth if none of the operators returns
results back to local
@@ -209,7 +217,6 @@ public interface ILinearize {
// Step 4: Add Prefetch and Broadcast lops if necessary
List<Lop> v_pf = ConfigurationManager.isPrefetchEnabled() ?
addPrefetchLop(final_v) : final_v;
List<Lop> v_bc = ConfigurationManager.isBroadcastEnabled() ?
addBroadcastLop(v_pf) : v_pf;
- // TODO: Merge into a single traversal
return v_bc;
}
@@ -238,6 +245,28 @@ public interface ILinearize {
return total;
}
+ // Count the number of jobs a Spark operator is part of
+ private static void markPersistableSparkOps(List<Lop> sparkRoots,
Map<Long, Integer> operatorJobCount) {
+ for (Lop root : sparkRoots) {
+ collectPersistableSparkOps(root, operatorJobCount);
+ root.resetVisitStatus();
+ }
+ }
+
+ private static void collectPersistableSparkOps(Lop root, Map<Long,
Integer> operatorJobCount) {
+ if (root.isVisited())
+ return;
+
+ for (Lop input : root.getInputs())
+ collectPersistableSparkOps(input, operatorJobCount);
+
+ // Increment the job counter if this node benefits from
persisting
+ if (isPersistableSparkOp(root))
+ operatorJobCount.merge(root.getID(), 1, Integer::sum);
+
+ root.setVisited();
+ }
+
// Place the operators in a depth-first manner, but order
// the DAGs based on number of Spark operators
private static void depthFirst(Lop root, ArrayList<Lop> opList,
Map<Long, Integer> sparkOpCount, boolean sparkFirst) {
@@ -270,6 +299,38 @@ public interface ILinearize {
|| lop instanceof CoVariance || lop instanceof MMTSJ ||
lop.isAllOutputsCP());
}
+ // Dictionary of Spark operators which are expensive enough to be
+ // benefited from persisting if shared among jobs.
+ private static boolean isPersistableSparkOp(Lop lop) {
+ return lop.isExecSpark() && (lop instanceof MapMult
+ || lop instanceof MMCJ || lop instanceof MMRJ
+ || lop instanceof MMZip);
+ }
+
+ private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long,
Integer> operatorJobCount) {
+ List<Lop> nodesWithChkpt = new ArrayList<>();
+
+ for (Lop l : nodes) {
+ nodesWithChkpt.add(l);
+ if(operatorJobCount.containsKey(l.getID()) &&
operatorJobCount.get(l.getID()) > 1) {
+ //This operation is expensive and shared
between Spark jobs
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ //Construct a chkpoint lop that takes this
Spark node as a input
+ Lop chkpoint = new Checkpoint(l,
l.getDataType(), l.getValueType(),
+
Checkpoint.getDefaultStorageLevelString(), false);
+ for (Lop out : oldOuts) {
+ //Rewire l -> out to l -> chkpoint ->
out
+ chkpoint.addOutput(out);
+ out.replaceInput(l, chkpoint);
+ l.removeOutput(out);
+ }
+ //Place it immediately after the Spark lop in
the node list
+ nodesWithChkpt.add(chkpoint);
+ }
+ }
+ return nodesWithChkpt;
+ }
+
private static List<Lop> addPrefetchLop(List<Lop> nodes) {
List<Lop> nodesWithPrefetch = new ArrayList<>();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
index 6850a05cd3..d74cea17a9 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/MatrixObjectFuture.java
@@ -26,6 +26,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.lineage.LineageCache;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.utils.stats.SparkStatistics;
import java.util.concurrent.Future;
@@ -61,6 +62,7 @@ public class MatrixObjectFuture extends MatrixObject
if( DMLScript.STATISTICS ){
long t1 = System.nanoTime();
CacheStatistics.incrementAcquireRTime(t1-t0);
+ SparkStatistics.incAsyncSparkOpCount(1);
}
return ret;
}
@@ -91,7 +93,14 @@ public class MatrixObjectFuture extends MatrixObject
}
private synchronized void releaseIntern() {
- _futureData = null;
+ try {
+ if(isCachingActive() &&
_futureData.get().getInMemorySize() > CACHING_THRESHOLD)
+ _futureData = null;
+ //TODO: write to disk and other cache
maintenance
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
}
public synchronized void clearData(long tid) {
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index fbbce8049b..89ad98e337 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -639,7 +639,7 @@ public class Statistics
sb.append("LinCache hits (Mem/FS/Del): \t" +
LineageCacheStatistics.displayHits() + ".\n");
sb.append("LinCache MultiLevel (Ins/SB/Fn):" +
LineageCacheStatistics.displayMultiLevelHits() + ".\n");
sb.append("LinCache GPU (Hit/Async/Sync): \t" +
LineageCacheStatistics.displayGpuStats() + ".\n");
- sb.append("LinCache Spark (Col/RDD): \t\t" +
LineageCacheStatistics.displaySparkStats() + ".\n");
+ sb.append("LinCache Spark (Col/RDD): \t" +
LineageCacheStatistics.displaySparkStats() + ".\n");
sb.append("LinCache writes (Mem/FS/Del): \t" +
LineageCacheStatistics.displayWtrites() + ".\n");
sb.append("LinCache FStimes (Rd/Wr): \t" +
LineageCacheStatistics.displayFSTime() + " sec.\n");
sb.append("LinCache Computetime (S/M): \t" +
LineageCacheStatistics.displayComputeTime() + " sec.\n");
diff --git a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
index 0d1b5b05b9..ae21ea0672 100644
--- a/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/SparkStatistics.java
@@ -34,6 +34,7 @@ public class SparkStatistics {
private static final LongAdder asyncPrefetchCount = new LongAdder();
private static final LongAdder asyncBroadcastCount = new LongAdder();
private static final LongAdder asyncTriggerCheckpointCount = new
LongAdder();
+ private static final LongAdder asyncSparkOpCount = new LongAdder();
public static boolean createdSparkContext() {
return ctxCreateTime > 0;
@@ -80,6 +81,10 @@ public class SparkStatistics {
asyncTriggerCheckpointCount.add(c);
}
+ public static void incAsyncSparkOpCount(long c) {
+ asyncSparkOpCount.add(c);
+ }
+
public static long getSparkCollectCount() {
return collectCount.longValue();
}
@@ -88,6 +93,10 @@ public class SparkStatistics {
return asyncPrefetchCount.longValue();
}
+ public static long getAsyncSparkOpCount() {
+ return asyncSparkOpCount.longValue();
+ }
+
public static long getAsyncBroadcastCount() {
return asyncBroadcastCount.longValue();
}
@@ -122,8 +131,8 @@ public class SparkStatistics {
parallelizeTime.longValue()*1e-9,
broadcastTime.longValue()*1e-9,
collectTime.longValue()*1e-9));
- sb.append("Spark async. count (pf,bc,cp): \t" +
- String.format("%d/%d/%d.\n",
getAsyncPrefetchCount(), getAsyncBroadcastCount(),
getasyncTriggerCheckpointCount()));
+ sb.append("Spark async. count (pf,bc,op): \t" +
+ String.format("%d/%d/%d.\n",
getAsyncPrefetchCount(), getAsyncBroadcastCount(), getAsyncSparkOpCount()));
return sb.toString();
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
similarity index 56%
copy from
src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
copy to
src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
index ee89824c64..1a899d3d66 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
@@ -17,28 +17,30 @@
* under the License.
*/
-package org.apache.sysds.test.functions.async;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-
-import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.recompile.Recompiler;
-import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-import org.apache.sysds.runtime.matrix.data.MatrixValue;
-import org.apache.sysds.test.AutomatedTestBase;
-import org.apache.sysds.test.TestConfiguration;
-import org.apache.sysds.test.TestUtils;
-import org.junit.Test;
-
-public class MaxParallelizeOrderTest extends AutomatedTestBase {
+ package org.apache.sysds.test.functions.async;
+
+ import java.util.ArrayList;
+ import java.util.HashMap;
+ import java.util.List;
+
+ import org.apache.sysds.common.Types;
+ import org.apache.sysds.hops.OptimizerUtils;
+ import org.apache.sysds.hops.recompile.Recompiler;
+ import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+ import org.apache.sysds.runtime.matrix.data.MatrixValue;
+ import org.apache.sysds.test.AutomatedTestBase;
+ import org.apache.sysds.test.TestConfiguration;
+ import org.apache.sysds.test.TestUtils;
+ import org.apache.sysds.utils.Statistics;
+ import org.junit.Assert;
+ import org.junit.Test;
+
+public class CheckpointSharedOpsTest extends AutomatedTestBase {
protected static final String TEST_DIR = "functions/async/";
- protected static final String TEST_NAME = "MaxParallelizeOrder";
- protected static final int TEST_VARIANTS = 4;
- protected static String TEST_CLASS_DIR = TEST_DIR +
MaxParallelizeOrderTest.class.getSimpleName() + "/";
+ protected static final String TEST_NAME = "CheckpointSharedOps";
+ protected static final int TEST_VARIANTS = 2;
+ protected static String TEST_CLASS_DIR = TEST_DIR +
CheckpointSharedOpsTest.class.getSimpleName() + "/";
@Override
public void setUp() {
@@ -48,30 +50,13 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
}
@Test
- public void testlmds() {
+ public void test1() {
+ // Shared cpmm/rmm between two jobs
runTest(TEST_NAME+"1");
}
- @Test
- public void testl2svm() {
- runTest(TEST_NAME+"2");
- }
-
- @Test
- public void testSparkAction() {
- runTest(TEST_NAME+"3");
- }
-
- @Test
- public void testSparkTransformations() {
- runTest(TEST_NAME+"4");
- }
-
public void runTest(String testname) {
- boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
- boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
- boolean old_trans_exec_type =
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
- ExecMode oldPlatform = setExecMode(ExecMode.HYBRID);
+ Types.ExecMode oldPlatform = setExecMode(Types.ExecMode.HYBRID);
long oldmem = InfrastructureAnalyzer.getLocalMaxMemory();
long mem = 1024*1024*8;
@@ -92,25 +77,20 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R =
readDMLScalarFromOutputDir("R");
+ long numCP =
Statistics.getCPHeavyHitterCount("sp_chkpoint");
- OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
- if (testname.equalsIgnoreCase(TEST_NAME+"4"))
- OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE
= false;
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<MatrixValue.CellIndex, Double> R_mp =
readDMLScalarFromOutputDir("R");
- OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+ long numCP_maxp =
Statistics.getCPHeavyHitterCount("sp_chkpoint");
OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
- OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
//compare matrices
boolean matchVal = TestUtils.compareMatrices(R, R_mp,
1e-6, "Origin", "withPrefetch");
if (!matchVal)
System.out.println("Value w/o Prefetch "+R+" w/
Prefetch "+R_mp);
+ Assert.assertTrue("Violated checkpoint count: " + numCP
+ " < " + numCP_maxp, numCP < numCP_maxp);
} finally {
- OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES =
old_sum_product;
- OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE =
old_trans_exec_type;
resetExecMode(oldPlatform);
InfrastructureAnalyzer.setLocalMaxMemory(oldmem);
Recompiler.reinitRecompiler();
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
index ee89824c64..54b6626c0f 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
@@ -68,9 +68,6 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
}
public void runTest(String testname) {
- boolean old_simplification =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
- boolean old_sum_product =
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
- boolean old_trans_exec_type =
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
ExecMode oldPlatform = setExecMode(ExecMode.HYBRID);
long oldmem = InfrastructureAnalyzer.getLocalMaxMemory();
@@ -108,9 +105,6 @@ public class MaxParallelizeOrderTest extends
AutomatedTestBase {
if (!matchVal)
System.out.println("Value w/o Prefetch "+R+" w/
Prefetch "+R_mp);
} finally {
- OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
old_simplification;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES =
old_sum_product;
- OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE =
old_trans_exec_type;
resetExecMode(oldPlatform);
InfrastructureAnalyzer.setLocalMaxMemory(oldmem);
Recompiler.reinitRecompiler();
diff --git a/src/test/scripts/functions/async/CheckpointSharedOps1.dml
b/src/test/scripts/functions/async/CheckpointSharedOps1.dml
new file mode 100644
index 0000000000..aa04c68025
--- /dev/null
+++ b/src/test/scripts/functions/async/CheckpointSharedOps1.dml
@@ -0,0 +1,42 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=1500, cols=1500, seed=42); #sp_rand
+v = rand(rows=1500, cols=1, seed=42); #cp_rand
+v2 = rand(rows=1500, cols=1, seed=43); #cp_rand
+
+# CP instructions
+v = ((v + v) * 1 - v) / (1+1);
+v = ((v + v) * 2 - v) / (2+1);
+
+# Spark operations
+sp1 = X + ceil(X);
+sp2 = t(sp1) %*% sp1; #shared among Job 1 and 2
+
+# Job1: SP unary triggers the DAG of SP operations
+sp3 = sp2 + sum(v);
+R1 = sum(sp3);
+
+# Job2: SP unary triggers the DAG of SP operations
+sp4 = sp2 + sum(v2);
+R2 = sum(sp4);
+
+R = R1 + R2;
+write(R, $1, format="text");