This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 9c99ae67d0 [SYSTEMDS-3562] Add rewrite to place checkpoints inside loop
9c99ae67d0 is described below
commit 9c99ae67d079b599b1a79bfdafa66077c1eefa24
Author: Arnab Phani <[email protected]>
AuthorDate: Sat May 6 15:36:10 2023 +0200
[SYSTEMDS-3562] Add rewrite to place checkpoints inside loop
This patch adds a new LOP rewrite to place checkpoints for the
variables which are updated in each iteration of a loop.
This rewrite improves pnmf factorization by 1.4x.
Closes #1819
---
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 23 ++++
src/main/java/org/apache/sysds/lops/Data.java | 8 ++
.../apache/sysds/lops/OperatorOrderingUtils.java | 28 +++-
.../lops/compile/linearization/ILinearize.java | 4 +-
.../org/apache/sysds/lops/rewrite/LopRewriter.java | 1 +
.../lops/rewrite/RewriteAddChkpointInLoop.java | 153 +++++++++++++++++++++
.../sysds/lops/rewrite/RewriteAddChkpointLop.java | 49 ++-----
.../apache/sysds/lops/rewrite/RewriteFixIDs.java | 31 ++++-
.../functions/async/CheckpointSharedOpsTest.java | 14 +-
.../functions/async/CheckpointSharedOps2.dml | 27 ++++
10 files changed, 287 insertions(+), 51 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 210a9f152b..ed93ea7366 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1553,6 +1553,29 @@ public class HopRewriteUtils {
return sb instanceof WhileStatementBlock
|| sb instanceof ForStatementBlock; //incl parfor
}
+
+ // Return true if this loop contains only basic blocks
+ public static boolean isLastLevelLoopStatementBlock (StatementBlock sb)
{
+ if (!isLoopStatementBlock(sb))
+ return false;
+ if (sb instanceof WhileStatementBlock) {
+ WhileStatement wstmt = (WhileStatement)
sb.getStatement(0);
+ if (wstmt.getBody().isEmpty())
+ return false;
+ for(StatementBlock csb : wstmt.getBody())
+ if (!isLastLevelStatementBlock(csb))
+ return false;
+ }
+ else if (sb instanceof ForStatementBlock) {
+ ForStatement fstmt = (ForStatement) sb.getStatement(0);
+ if (fstmt.getBody().isEmpty())
+ return false;
+ for(StatementBlock csb : fstmt.getBody())
+ if(!isLastLevelStatementBlock(csb))
+ return false;
+ }
+ return true;
+ }
public static long getMaxNrowInput(Hop hop) {
return getMaxInputDim(hop, true);
diff --git a/src/main/java/org/apache/sysds/lops/Data.java
b/src/main/java/org/apache/sysds/lops/Data.java
index 1489c69842..7879379a55 100644
--- a/src/main/java/org/apache/sysds/lops/Data.java
+++ b/src/main/java/org/apache/sysds/lops/Data.java
@@ -240,6 +240,14 @@ public class Data extends Lop
&& !literal_var;
}
+ public boolean isTransientWrite() {
+ return _op == OpOpData.TRANSIENTWRITE;
+ }
+
+ public boolean isTransientRead() {
+ return _op == OpOpData.TRANSIENTREAD;
+ }
+
/**
* Method to get CP instructions for reading/writing scalars and
matrices from/to HDFS.
* This method generates CP read/write instructions.
diff --git a/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
index 35926961f2..add21f160f 100644
--- a/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
+++ b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
@@ -25,6 +25,7 @@ import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.StatementBlock;
import java.util.ArrayList;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -55,7 +56,7 @@ public class OperatorOrderingUtils
// Gather the Spark operators which return intermediates to local
(actions/single_block)
// In addition count the number of Spark OPs underneath every Operator
- public static int collectSparkRoots(Lop root, Map<Long, Integer>
sparkOpCount, List<Lop> sparkRoots) {
+ public static int collectSparkRoots(Lop root, Map<Long, Integer>
sparkOpCount, HashSet<Lop> sparkRoots) {
if (sparkOpCount.containsKey(root.getID())) //visited before
return sparkOpCount.get(root.getID());
@@ -71,7 +72,6 @@ public class OperatorOrderingUtils
// Triggering point: Spark action/operator with all CP consumers
if (isSparkTriggeringOp(root)) {
sparkRoots.add(root);
- root.setAsynchronous(true); //candidate for async.
execution
}
return total;
@@ -82,7 +82,7 @@ public class OperatorOrderingUtils
public static boolean isPersistableSparkOp(Lop lop) {
return lop.isExecSpark() && (lop instanceof MapMult
|| lop instanceof MMCJ || lop instanceof MMRJ
- || lop instanceof MMZip);
+ || lop instanceof MMZip || lop instanceof
WeightedDivMMR);
}
private static boolean isSparkTriggeringOp(Lop lop) {
@@ -109,6 +109,28 @@ public class OperatorOrderingUtils
return isSparkOp && isBc && (lop.getDataType() ==
Types.DataType.MATRIX);
}
+ // Count the number of jobs a Spark operator is part of
+ public static void markSharedSparkOps(HashSet<Lop> sparkRoots,
Map<Long, Integer> operatorJobCount) {
+ for (Lop root : sparkRoots) {
+ collectSharedSparkOps(root, operatorJobCount);
+ root.resetVisitStatus();
+ }
+ }
+
+ private static void collectSharedSparkOps(Lop root, Map<Long, Integer>
operatorJobCount) {
+ if (root.isVisited())
+ return;
+
+ for (Lop input : root.getInputs())
+ if (root.getBroadcastInput() != input)
+ collectSharedSparkOps(input, operatorJobCount);
+
+ // Increment the job counter if this node is reachable from
multiple job roots
+ operatorJobCount.merge(root.getID(), 1, Integer::sum);
+
+ root.setVisited();
+ }
+
private static boolean addNode(ArrayList<Lop> lops, Lop node) {
if (lops.contains(node))
return false;
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index 656a0262d6..e5ee982c4f 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -25,6 +25,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@@ -186,8 +187,9 @@ public interface ILinearize {
// Step 1: Collect the Spark roots and #Spark
instructions in each subDAG
Map<Long, Integer> sparkOpCount = new HashMap<>();
List<Lop> roots =
v.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
- List<Lop> sparkRoots = new ArrayList<>();
+ HashSet<Lop> sparkRoots = new HashSet<>();
roots.forEach(r ->
OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
+ sparkRoots.forEach(sr -> sr.setAsynchronous(true));
// Step 2: Depth-first linearization of Spark roots.
// Maintain the default order (by ID) to trigger
independent Spark jobs first
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 4567cf1c4e..0457558c3b 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -43,6 +43,7 @@ public class LopRewriter
_lopSBRuleSet.add(new RewriteAddPrefetchLop());
_lopSBRuleSet.add(new RewriteAddBroadcastLop());
_lopSBRuleSet.add(new RewriteAddChkpointLop());
+ _lopSBRuleSet.add(new RewriteAddChkpointInLoop());
// TODO: A rewrite pass to remove less effective chkpoints
// Last rewrite to reset Lop IDs in a depth-first manner
_lopSBRuleSet.add(new RewriteFixIDs());
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
new file mode 100644
index 0000000000..5e445a6bce
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.Data;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class RewriteAddChkpointInLoop extends LopRewriteRule
+{
+ @Override
+ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock
sb) {
+ if (!ConfigurationManager.isCheckpointEnabled())
+ return List.of(sb);
+
+ if (sb == null ||
!HopRewriteUtils.isLastLevelLoopStatementBlock(sb))
+ return List.of(sb);
+ // TODO: support If-Else block inside loop. Consumers inside
branches.
+
+ // This rewrite adds checkpoints for the Spark intermediates,
which
+ // are updated in each iteration of a loop. Without the
checkpoints,
+ // CP consumers in the loop body will trigger long Spark jobs
containing
+ // all previous iterations. Note, a checkpoint is
counterproductive if
+ // there is no consumer in the loop body, i.e. all iterations
combine
+ // to form a single Spark job triggered from outside the loop.
+
+ // Find the variables which are read and updated in each
iteration
+ Set<String> readUpdatedVars =
sb.variablesRead().getVariableNames().stream()
+ .filter(v -> sb.variablesUpdated().containsVariable(v))
+ .collect(Collectors.toSet());
+ if (readUpdatedVars.isEmpty())
+ return List.of(sb);
+
+ // Collect the Spark roots in the loop body (assuming single
block)
+ StatementBlock csb = sb instanceof WhileStatementBlock
+ ? ((WhileStatement) sb.getStatement(0)).getBody().get(0)
+ : ((ForStatement) sb.getStatement(0)).getBody().get(0);
+ ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(csb);
+ List<Lop> roots =
lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
+ HashSet<Lop> sparkRoots = new HashSet<>();
+ roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r,
new HashMap<>(), sparkRoots));
+ if (sparkRoots.isEmpty())
+ return List.of(sb);
+
+ // Mark the Spark intermediates which are read and updated in
each iteration
+ Map<Long, Integer> operatorJobCount = new HashMap<>();
+ findOverlappingJobs(sparkRoots, readUpdatedVars,
operatorJobCount);
+ if (operatorJobCount.isEmpty())
+ return List.of(sb);
+
+ // Add checkpoint Lops after the shared operators
+ addChkpointLop(lops, operatorJobCount);
+ // TODO: A rewrite pass to remove less effective checkpoints
+ return List.of(sb);
+ }
+
+ @Override
+ public List<StatementBlock>
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+ return sbs;
+ }
+
+ private void addChkpointLop(List<Lop> nodes, Map<Long, Integer>
operatorJobCount) {
+ for (Lop l : nodes) {
+ if(operatorJobCount.containsKey(l.getID()) &&
operatorJobCount.get(l.getID()) > 1) {
+ // TODO: Check if this lop leads to one of
those variables
+ // This operation is shared between Spark jobs
+ List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
+ // Construct a chkpoint lop that takes this
Spark node as an input
+ Lop checkpoint = new Checkpoint(l,
l.getDataType(), l.getValueType(),
+
Checkpoint.getDefaultStorageLevelString(), false);
+ for (Lop out : oldOuts) {
+ //Rewire l -> out to l -> checkpoint ->
out
+ checkpoint.addOutput(out);
+ out.replaceInput(l, checkpoint);
+ l.removeOutput(out);
+ }
+ }
+ }
+ }
+
+ private void findOverlappingJobs(HashSet<Lop> sparkRoots, Set<String>
ruVars, Map<Long, Integer> operatorJobCount) {
+ HashSet<Lop> sharedRoots = new HashSet<>();
+ // Find the Spark jobs which are sharing these variables
+ for (String var : ruVars) {
+ for (Lop root : sparkRoots) {
+ if(ifJobContains(root, var))
+ sharedRoots.add(root);
+ root.resetVisitStatus();
+ }
+ // Mark the operators shared by these Spark jobs
+ if (!sharedRoots.isEmpty())
+
OperatorOrderingUtils.markSharedSparkOps(sharedRoots, operatorJobCount);
+ sharedRoots.clear();
+ }
+ }
+
+ // Check if this Spark job has the passed variable as a leaf node
+ private boolean ifJobContains(Lop root, String var) {
+ if (root.isVisited())
+ return false;
+
+ for (Lop input : root.getInputs()) {
+ if (!(input instanceof Data) && (!input.isExecSpark()
|| root.getBroadcastInput() == input))
+ continue; //consider only Spark operator chains
+ if (ifJobContains(input, var)) {
+ root.setVisited();
+ return true;
+ }
+ }
+
+ if (root instanceof Data && ((Data) root).isTransientRead())
+ if
(root.getOutputParameters().getLabel().equalsIgnoreCase(var)) {
+ root.setVisited();
+ return true;
+ }
+
+ root.setVisited();
+ return false;
+ }
+
+}
diff --git
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
index 1cf5761423..6a1a1192ea 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -27,6 +27,7 @@ import org.apache.sysds.parser.StatementBlock;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -44,7 +45,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
return List.of(sb);
// Collect the Spark roots and #Spark instructions in each
subDAG
- List<Lop> sparkRoots = new ArrayList<>();
+ HashSet<Lop> sparkRoots = new HashSet<>();
Map<Long, Integer> sparkOpCount = new HashMap<>();
List<Lop> roots =
lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r,
sparkOpCount, sparkRoots));
@@ -55,10 +56,10 @@ public class RewriteAddChkpointLop extends LopRewriteRule
// shared among multiple Spark jobs. Only consider operators
with
// Spark consumers for now.
Map<Long, Integer> operatorJobCount = new HashMap<>();
- markPersistableSparkOps(sparkRoots, operatorJobCount);
- // TODO: A rewrite pass to remove less effective chkpoints
- @SuppressWarnings("unused")
- List<Lop> nodesWithChkpt = addChkpointLop(lops,
operatorJobCount);
+ //markPersistableSparkOps(sparkRoots, operatorJobCount);
+ OperatorOrderingUtils.markSharedSparkOps(sparkRoots,
operatorJobCount);
+ // TODO: A rewrite pass to remove less effective checkpoints
+ addChkpointLop(lops, operatorJobCount);
//New node is added inplace in the Lop DAG
return List.of(sb);
}
@@ -68,12 +69,13 @@ public class RewriteAddChkpointLop extends LopRewriteRule
return sbs;
}
- private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long,
Integer> operatorJobCount) {
- List<Lop> nodesWithChkpt = new ArrayList<>();
-
+ private void addChkpointLop(List<Lop> nodes, Map<Long, Integer>
operatorJobCount) {
for (Lop l : nodes) {
- nodesWithChkpt.add(l);
- if(operatorJobCount.containsKey(l.getID()) &&
operatorJobCount.get(l.getID()) > 1) {
+ // Increment the job counter if this node benefits from
persisting
+ // and reachable from multiple job roots
+ if(operatorJobCount.containsKey(l.getID())
+ && operatorJobCount.get(l.getID()) > 1
+ &&
OperatorOrderingUtils.isPersistableSparkOp(l)) {
// This operation is expensive and shared
between Spark jobs
List<Lop> oldOuts = new
ArrayList<>(l.getOutputs());
// Construct a chkpoint lop that takes this
Spark node as a input
@@ -85,34 +87,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
out.replaceInput(l, chkpoint);
l.removeOutput(out);
}
- // Place it immediately after the Spark lop in
the node list
- nodesWithChkpt.add(chkpoint);
}
}
- return nodesWithChkpt;
- }
-
- // Count the number of jobs a Spark operator is part of
- private static void markPersistableSparkOps(List<Lop> sparkRoots,
Map<Long, Integer> operatorJobCount) {
- for (Lop root : sparkRoots) {
- collectPersistableSparkOps(root, operatorJobCount);
- root.resetVisitStatus();
- }
- }
-
- private static void collectPersistableSparkOps(Lop root, Map<Long,
Integer> operatorJobCount) {
- if (root.isVisited())
- return;
-
- for (Lop input : root.getInputs())
- if (root.getBroadcastInput() != input)
- collectPersistableSparkOps(input,
operatorJobCount);
-
- // Increment the job counter if this node benefits from
persisting
- // and reachable from multiple job roots
- if (OperatorOrderingUtils.isPersistableSparkOp(root))
- operatorJobCount.merge(root.getID(), 1, Integer::sum);
-
- root.setVisited();
}
}
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
index 00d205b553..828eaef0c1 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
@@ -20,8 +20,12 @@
package org.apache.sysds.lops.rewrite;
import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
import java.util.List;
@@ -35,12 +39,16 @@ public class RewriteFixIDs extends LopRewriteRule
&& !ConfigurationManager.isCheckpointEnabled())
return List.of(sb);
- // Reset the IDs in a depth-first manner
- if (sb.getLops() != null && !sb.getLops().isEmpty()) {
- for (Lop root : sb.getLops())
- assignNewID(root);
- sb.getLops().forEach(Lop::resetVisitStatus);
+ if (HopRewriteUtils.isLastLevelLoopStatementBlock(sb)) {
+ // Some rewrites add new Lops in the last-level loop
body
+ StatementBlock csb = sb instanceof WhileStatementBlock
+ ? ((WhileStatement)
sb.getStatement(0)).getBody().get(0)
+ : ((ForStatement)
sb.getStatement(0)).getBody().get(0);
+ assignNewIDStatementBlock(csb);
}
+ else
+ assignNewIDStatementBlock(sb);
+
return List.of(sb);
}
@@ -49,7 +57,16 @@ public class RewriteFixIDs extends LopRewriteRule
return sbs;
}
- private void assignNewID(Lop lop) {
+ private void assignNewIDStatementBlock(StatementBlock sb) {
+ // Reset the IDs in a depth-first manner
+ if (sb.getLops() != null && !sb.getLops().isEmpty()) {
+ for (Lop root : sb.getLops())
+ assignNewIDLop(root);
+ sb.getLops().forEach(Lop::resetVisitStatus);
+ }
+ }
+
+ private void assignNewIDLop(Lop lop) {
if (lop.isVisited())
return;
@@ -59,7 +76,7 @@ public class RewriteFixIDs extends LopRewriteRule
return;
}
for (Lop input : lop.getInputs())
- assignNewID(input);
+ assignNewIDLop(input);
lop.setNewID();
lop.setVisited();
diff --git
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
index eda92023b9..bceb4e2090 100644
---
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
@@ -55,6 +55,12 @@ public class CheckpointSharedOpsTest extends
AutomatedTestBase {
runTest(TEST_NAME+"1");
}
+ @Test
+ public void testPnmf() {
+ // Place checkpoint at the end of a loop as the updated vars
are read in each iteration.
+ runTest(TEST_NAME+"2");
+ }
+
public void runTest(String testname) {
Types.ExecMode oldPlatform = setExecMode(Types.ExecMode.HYBRID);
@@ -86,10 +92,12 @@ public class CheckpointSharedOpsTest extends
AutomatedTestBase {
OptimizerUtils.ASYNC_CHECKPOINT_SPARK = false;
//compare matrices
- boolean matchVal = TestUtils.compareMatrices(R, R_mp,
1e-6, "Origin", "withPrefetch");
+ boolean matchVal = TestUtils.compareMatrices(R, R_mp,
1e-3, "Origin", "withChkpoint");
if (!matchVal)
- System.out.println("Value w/o Prefetch "+R+" w/
Prefetch "+R_mp);
- Assert.assertTrue("Violated checkpoint count: " + numCP
+ " < " + numCP_maxp, numCP < numCP_maxp);
+ System.out.println("Value w/o Checkpoint "+R+"
w/ Checkpoint "+R_mp);
+ //compare checkpoint instruction count
+ if (!testname.equalsIgnoreCase(TEST_NAME+"2"))
+ Assert.assertTrue("Violated checkpoint count: "
+ numCP + " < " + numCP_maxp, numCP < numCP_maxp);
} finally {
resetExecMode(oldPlatform);
InfrastructureAnalyzer.setLocalMaxMemory(oldmem);
diff --git a/src/test/scripts/functions/async/CheckpointSharedOps2.dml
b/src/test/scripts/functions/async/CheckpointSharedOps2.dml
new file mode 100644
index 0000000000..61e2fefc2f
--- /dev/null
+++ b/src/test/scripts/functions/async/CheckpointSharedOps2.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=13000, cols=150, seed=42); #sp_rand
+[W, H] = pnmf(X=X, rnk=100, verbose=FALSE);
+
+print(sum(W %*% H));
+R = sum(W %*% H);
+write(R, $1, format="text");
+