This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9c99ae67d0 [SYSTEMDS-3562] Add rewrite to place checkpoints inside loop
9c99ae67d0 is described below

commit 9c99ae67d079b599b1a79bfdafa66077c1eefa24
Author: Arnab Phani <[email protected]>
AuthorDate: Sat May 6 15:36:10 2023 +0200

    [SYSTEMDS-3562] Add rewrite to place checkpoints inside loop
    
    This patch adds a new LOP rewrite to place checkpoints for the
    variables which are updated in each iteration of a loop.
    This rewrite improves pnmf factorization by 1.4x.
    
    Closes #1819
---
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |  23 ++++
 src/main/java/org/apache/sysds/lops/Data.java      |   8 ++
 .../apache/sysds/lops/OperatorOrderingUtils.java   |  28 +++-
 .../lops/compile/linearization/ILinearize.java     |   4 +-
 .../org/apache/sysds/lops/rewrite/LopRewriter.java |   1 +
 .../lops/rewrite/RewriteAddChkpointInLoop.java     | 153 +++++++++++++++++++++
 .../sysds/lops/rewrite/RewriteAddChkpointLop.java  |  49 ++-----
 .../apache/sysds/lops/rewrite/RewriteFixIDs.java   |  31 ++++-
 .../functions/async/CheckpointSharedOpsTest.java   |  14 +-
 .../functions/async/CheckpointSharedOps2.dml       |  27 ++++
 10 files changed, 287 insertions(+), 51 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 210a9f152b..ed93ea7366 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1553,6 +1553,29 @@ public class HopRewriteUtils {
                return sb instanceof WhileStatementBlock
                        || sb instanceof ForStatementBlock; //incl parfor
        }
+
+       // Return true if this loop contains only basic blocks
+       public static boolean isLastLevelLoopStatementBlock (StatementBlock sb) 
{
+               if (!isLoopStatementBlock(sb))
+                       return false;
+               if (sb instanceof WhileStatementBlock) {
+                       WhileStatement wstmt = (WhileStatement) 
sb.getStatement(0);
+                       if (wstmt.getBody().isEmpty())
+                               return false;
+                       for(StatementBlock csb : wstmt.getBody())
+                               if (!isLastLevelStatementBlock(csb))
+                                       return false;
+               }
+               else if (sb instanceof ForStatementBlock) {
+                       ForStatement fstmt = (ForStatement) sb.getStatement(0);
+                       if (fstmt.getBody().isEmpty())
+                               return false;
+                       for(StatementBlock csb : fstmt.getBody())
+                               if(!isLastLevelStatementBlock(csb))
+                                       return false;
+               }
+               return true;
+       }
        
        public static long getMaxNrowInput(Hop hop) {
                return getMaxInputDim(hop, true);
diff --git a/src/main/java/org/apache/sysds/lops/Data.java 
b/src/main/java/org/apache/sysds/lops/Data.java
index 1489c69842..7879379a55 100644
--- a/src/main/java/org/apache/sysds/lops/Data.java
+++ b/src/main/java/org/apache/sysds/lops/Data.java
@@ -240,6 +240,14 @@ public class Data extends Lop
                        && !literal_var;
        }
 
+       public boolean isTransientWrite() {
+               return _op == OpOpData.TRANSIENTWRITE;
+       }
+
+       public boolean isTransientRead() {
+               return _op == OpOpData.TRANSIENTREAD;
+       }
+
        /**
         * Method to get CP instructions for reading/writing scalars and 
matrices from/to HDFS.
         * This method generates CP read/write instructions.
diff --git a/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java 
b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
index 35926961f2..add21f160f 100644
--- a/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
+++ b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
@@ -25,6 +25,7 @@ import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.StatementBlock;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 
@@ -55,7 +56,7 @@ public class OperatorOrderingUtils
 
        // Gather the Spark operators which return intermediates to local 
(actions/single_block)
        // In addition count the number of Spark OPs underneath every Operator
-       public static int collectSparkRoots(Lop root, Map<Long, Integer> 
sparkOpCount, List<Lop> sparkRoots) {
+       public static int collectSparkRoots(Lop root, Map<Long, Integer> 
sparkOpCount, HashSet<Lop> sparkRoots) {
                if (sparkOpCount.containsKey(root.getID())) //visited before
                        return sparkOpCount.get(root.getID());
 
@@ -71,7 +72,6 @@ public class OperatorOrderingUtils
                // Triggering point: Spark action/operator with all CP consumers
                if (isSparkTriggeringOp(root)) {
                        sparkRoots.add(root);
-                       root.setAsynchronous(true); //candidate for async. 
execution
                }
 
                return total;
@@ -82,7 +82,7 @@ public class OperatorOrderingUtils
        public static boolean isPersistableSparkOp(Lop lop) {
                return lop.isExecSpark() && (lop instanceof MapMult
                        || lop instanceof MMCJ || lop instanceof MMRJ
-                       || lop instanceof MMZip);
+                       || lop instanceof MMZip || lop instanceof 
WeightedDivMMR);
        }
 
        private static boolean isSparkTriggeringOp(Lop lop) {
@@ -109,6 +109,28 @@ public class OperatorOrderingUtils
                return isSparkOp && isBc && (lop.getDataType() == 
Types.DataType.MATRIX);
        }
 
+       // Count the number of jobs a Spark operator is part of
+       public static void markSharedSparkOps(HashSet<Lop> sparkRoots, 
Map<Long, Integer> operatorJobCount) {
+               for (Lop root : sparkRoots) {
+                       collectSharedSparkOps(root, operatorJobCount);
+                       root.resetVisitStatus();
+               }
+       }
+
+       private static void collectSharedSparkOps(Lop root, Map<Long, Integer> 
operatorJobCount) {
+               if (root.isVisited())
+                       return;
+
+               for (Lop input : root.getInputs())
+                       if (root.getBroadcastInput() != input)
+                               collectSharedSparkOps(input, operatorJobCount);
+
+               // Increment the job counter if this node is reachable from 
multiple job roots
+               operatorJobCount.merge(root.getID(), 1, Integer::sum);
+
+               root.setVisited();
+       }
+
        private static boolean addNode(ArrayList<Lop> lops, Lop node) {
                if (lops.contains(node))
                        return false;
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index 656a0262d6..e5ee982c4f 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -25,6 +25,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -186,8 +187,9 @@ public interface ILinearize {
                        // Step 1: Collect the Spark roots and #Spark 
instructions in each subDAG
                        Map<Long, Integer> sparkOpCount = new HashMap<>();
                        List<Lop> roots = 
v.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
-                       List<Lop> sparkRoots = new ArrayList<>();
+                       HashSet<Lop> sparkRoots = new HashSet<>();
                        roots.forEach(r -> 
OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
+                       sparkRoots.forEach(sr -> sr.setAsynchronous(true));
 
                        // Step 2: Depth-first linearization of Spark roots.
                        // Maintain the default order (by ID) to trigger 
independent Spark jobs first
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java 
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 4567cf1c4e..0457558c3b 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -43,6 +43,7 @@ public class LopRewriter
                _lopSBRuleSet.add(new RewriteAddPrefetchLop());
                _lopSBRuleSet.add(new RewriteAddBroadcastLop());
                _lopSBRuleSet.add(new RewriteAddChkpointLop());
+               _lopSBRuleSet.add(new RewriteAddChkpointInLoop());
                // TODO: A rewrite pass to remove less effective chkpoints
                // Last rewrite to reset Lop IDs in a depth-first manner
                _lopSBRuleSet.add(new RewriteFixIDs());
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
new file mode 100644
index 0000000000..5e445a6bce
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.Data;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class RewriteAddChkpointInLoop extends LopRewriteRule
+{
+       @Override
+       public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock 
sb) {
+               if (!ConfigurationManager.isCheckpointEnabled())
+                       return List.of(sb);
+
+               if (sb == null || 
!HopRewriteUtils.isLastLevelLoopStatementBlock(sb))
+                       return List.of(sb);
+               // TODO: support If-Else block inside loop. Consumers inside 
branches.
+
+               // This rewrite adds checkpoints for the Spark intermediates, 
which
+               // are updated in each iteration of a loop. Without the 
checkpoints,
+               // CP consumers in the loop body will trigger long Spark jobs 
containing
+               // all previous iterations. Note, a checkpoint is 
counterproductive if
+               // there is no consumer in the loop body, i.e. all iterations 
combine
+               // to form a single Spark job triggered from outside the loop.
+
+               // Find the variables which are read and updated in each 
iteration
+               Set<String> readUpdatedVars = 
sb.variablesRead().getVariableNames().stream()
+                       .filter(v -> sb.variablesUpdated().containsVariable(v))
+                       .collect(Collectors.toSet());
+               if (readUpdatedVars.isEmpty())
+                       return List.of(sb);
+
+               // Collect the Spark roots in the loop body (assuming single 
block)
+               StatementBlock csb = sb instanceof WhileStatementBlock
+                       ? ((WhileStatement) sb.getStatement(0)).getBody().get(0)
+                       : ((ForStatement) sb.getStatement(0)).getBody().get(0);
+               ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(csb);
+               List<Lop> roots = 
lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
+               HashSet<Lop> sparkRoots = new HashSet<>();
+               roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, 
new HashMap<>(), sparkRoots));
+               if (sparkRoots.isEmpty())
+                       return List.of(sb);
+
+               // Mark the Spark intermediates which are read and updated in 
each iteration
+               Map<Long, Integer> operatorJobCount = new HashMap<>();
+               findOverlappingJobs(sparkRoots, readUpdatedVars, 
operatorJobCount);
+               if (operatorJobCount.isEmpty())
+                       return List.of(sb);
+
+               // Add checkpoint Lops after the shared operators
+               addChkpointLop(lops, operatorJobCount);
+               // TODO: A rewrite pass to remove less effective checkpoints
+               return List.of(sb);
+       }
+
+       @Override
+       public List<StatementBlock> 
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+               return sbs;
+       }
+
+       private void addChkpointLop(List<Lop> nodes, Map<Long, Integer> 
operatorJobCount) {
+               for (Lop l : nodes) {
+                       if(operatorJobCount.containsKey(l.getID()) && 
operatorJobCount.get(l.getID()) > 1) {
+                               // TODO: Check if this lop leads to one of 
those variables
+                               // This operation is shared between Spark jobs
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               // Construct a chkpoint lop that takes this 
Spark node as an input
+                               Lop checkpoint = new Checkpoint(l, 
l.getDataType(), l.getValueType(),
+                                       
Checkpoint.getDefaultStorageLevelString(), false);
+                               for (Lop out : oldOuts) {
+                                       //Rewire l -> out to l -> checkpoint -> 
out
+                                       checkpoint.addOutput(out);
+                                       out.replaceInput(l, checkpoint);
+                                       l.removeOutput(out);
+                               }
+                       }
+               }
+       }
+
+       private void findOverlappingJobs(HashSet<Lop> sparkRoots, Set<String> 
ruVars, Map<Long, Integer> operatorJobCount) {
+               HashSet<Lop> sharedRoots = new HashSet<>();
+               // Find the Spark jobs which are sharing these variables
+               for (String var : ruVars) {
+                       for (Lop root : sparkRoots) {
+                               if(ifJobContains(root, var))
+                                       sharedRoots.add(root);
+                               root.resetVisitStatus();
+                       }
+                       // Mark the operators shared by these Spark jobs
+                       if (!sharedRoots.isEmpty())
+                               
OperatorOrderingUtils.markSharedSparkOps(sharedRoots, operatorJobCount);
+                       sharedRoots.clear();
+               }
+       }
+
+       // Check if this Spark job has the passed variable as a leaf node
+       private boolean ifJobContains(Lop root, String var) {
+               if (root.isVisited())
+                       return false;
+
+               for (Lop input : root.getInputs()) {
+                       if (!(input instanceof Data) && (!input.isExecSpark() 
|| root.getBroadcastInput() == input))
+                               continue; //consider only Spark operator chains
+                       if (ifJobContains(input, var)) {
+                               root.setVisited();
+                               return true;
+                       }
+               }
+
+               if (root instanceof Data && ((Data) root).isTransientRead())
+                       if 
(root.getOutputParameters().getLabel().equalsIgnoreCase(var)) {
+                               root.setVisited();
+                               return true;
+                       }
+
+               root.setVisited();
+               return false;
+       }
+
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
index 1cf5761423..6a1a1192ea 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -27,6 +27,7 @@ import org.apache.sysds.parser.StatementBlock;
 
 import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -44,7 +45,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                        return List.of(sb);
 
                // Collect the Spark roots and #Spark instructions in each 
subDAG
-               List<Lop> sparkRoots = new ArrayList<>();
+               HashSet<Lop> sparkRoots = new HashSet<>();
                Map<Long, Integer> sparkOpCount = new HashMap<>();
                List<Lop> roots = 
lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
                roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, 
sparkOpCount, sparkRoots));
@@ -55,10 +56,10 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                // shared among multiple Spark jobs. Only consider operators 
with
                // Spark consumers for now.
                Map<Long, Integer> operatorJobCount = new HashMap<>();
-               markPersistableSparkOps(sparkRoots, operatorJobCount);
-               // TODO: A rewrite pass to remove less effective chkpoints
-               @SuppressWarnings("unused")
-               List<Lop> nodesWithChkpt = addChkpointLop(lops, 
operatorJobCount);
+               //markPersistableSparkOps(sparkRoots, operatorJobCount);
+               OperatorOrderingUtils.markSharedSparkOps(sparkRoots, 
operatorJobCount);
+               // TODO: A rewrite pass to remove less effective checkpoints
+               addChkpointLop(lops, operatorJobCount);
                //New node is added inplace in the Lop DAG
                return List.of(sb);
        }
@@ -68,12 +69,13 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                return sbs;
        }
 
-       private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long, 
Integer> operatorJobCount) {
-               List<Lop> nodesWithChkpt = new ArrayList<>();
-
+       private void addChkpointLop(List<Lop> nodes, Map<Long, Integer> 
operatorJobCount) {
                for (Lop l : nodes) {
-                       nodesWithChkpt.add(l);
-                       if(operatorJobCount.containsKey(l.getID()) && 
operatorJobCount.get(l.getID()) > 1) {
+                       // Increment the job counter if this node benefits from 
persisting
+                       // and reachable from multiple job roots
+                       if(operatorJobCount.containsKey(l.getID())
+                               && operatorJobCount.get(l.getID()) > 1
+                               && 
OperatorOrderingUtils.isPersistableSparkOp(l)) {
                                // This operation is expensive and shared 
between Spark jobs
                                List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
                                // Construct a chkpoint lop that takes this 
Spark node as a input
@@ -85,34 +87,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                                        out.replaceInput(l, chkpoint);
                                        l.removeOutput(out);
                                }
-                               // Place it immediately after the Spark lop in 
the node list
-                               nodesWithChkpt.add(chkpoint);
                        }
                }
-               return nodesWithChkpt;
-       }
-
-       // Count the number of jobs a Spark operator is part of
-       private static void markPersistableSparkOps(List<Lop> sparkRoots, 
Map<Long, Integer> operatorJobCount) {
-               for (Lop root : sparkRoots) {
-                       collectPersistableSparkOps(root, operatorJobCount);
-                       root.resetVisitStatus();
-               }
-       }
-
-       private static void collectPersistableSparkOps(Lop root, Map<Long, 
Integer> operatorJobCount) {
-               if (root.isVisited())
-                       return;
-
-               for (Lop input : root.getInputs())
-                       if (root.getBroadcastInput() != input)
-                               collectPersistableSparkOps(input, 
operatorJobCount);
-
-               // Increment the job counter if this node benefits from 
persisting
-               // and reachable from multiple job roots
-               if (OperatorOrderingUtils.isPersistableSparkOp(root))
-                       operatorJobCount.merge(root.getID(), 1, Integer::sum);
-
-               root.setVisited();
        }
 }
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
index 00d205b553..828eaef0c1 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
@@ -20,8 +20,12 @@
 package org.apache.sysds.lops.rewrite;
 
 import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.ForStatement;
 import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
 
 import java.util.List;
 
@@ -35,12 +39,16 @@ public class RewriteFixIDs extends LopRewriteRule
                        && !ConfigurationManager.isCheckpointEnabled())
                        return List.of(sb);
 
-               // Reset the IDs in a depth-first manner
-               if (sb.getLops() != null && !sb.getLops().isEmpty()) {
-                       for (Lop root : sb.getLops())
-                               assignNewID(root);
-                       sb.getLops().forEach(Lop::resetVisitStatus);
+               if (HopRewriteUtils.isLastLevelLoopStatementBlock(sb)) {
+                       // Some rewrites add new Lops in the last-level loop 
body
+                       StatementBlock csb = sb instanceof WhileStatementBlock
+                               ? ((WhileStatement) 
sb.getStatement(0)).getBody().get(0)
+                               : ((ForStatement) 
sb.getStatement(0)).getBody().get(0);
+                       assignNewIDStatementBlock(csb);
                }
+               else
+                       assignNewIDStatementBlock(sb);
+
                return List.of(sb);
        }
 
@@ -49,7 +57,16 @@ public class RewriteFixIDs extends LopRewriteRule
                return sbs;
        }
 
-       private void assignNewID(Lop lop) {
+       private void assignNewIDStatementBlock(StatementBlock sb) {
+               // Reset the IDs in a depth-first manner
+               if (sb.getLops() != null && !sb.getLops().isEmpty()) {
+                       for (Lop root : sb.getLops())
+                               assignNewIDLop(root);
+                       sb.getLops().forEach(Lop::resetVisitStatus);
+               }
+       }
+
+       private void assignNewIDLop(Lop lop) {
                if (lop.isVisited())
                        return;
 
@@ -59,7 +76,7 @@ public class RewriteFixIDs extends LopRewriteRule
                        return;
                }
                for (Lop input : lop.getInputs())
-                       assignNewID(input);
+                       assignNewIDLop(input);
 
                lop.setNewID();
                lop.setVisited();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
index eda92023b9..bceb4e2090 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
@@ -55,6 +55,12 @@ public class CheckpointSharedOpsTest extends 
AutomatedTestBase {
                runTest(TEST_NAME+"1");
        }
 
+       @Test
+       public void testPnmf() {
+               // Place checkpoint at the end of a loop as the updated vars 
are read in each iteration.
+               runTest(TEST_NAME+"2");
+       }
+
        public void runTest(String testname) {
                Types.ExecMode oldPlatform = setExecMode(Types.ExecMode.HYBRID);
 
@@ -86,10 +92,12 @@ public class CheckpointSharedOpsTest extends 
AutomatedTestBase {
                        OptimizerUtils.ASYNC_CHECKPOINT_SPARK = false;
 
                        //compare matrices
-                       boolean matchVal = TestUtils.compareMatrices(R, R_mp, 
1e-6, "Origin", "withPrefetch");
+                       boolean matchVal = TestUtils.compareMatrices(R, R_mp, 
1e-3, "Origin", "withChkpoint");
                        if (!matchVal)
-                               System.out.println("Value w/o Prefetch "+R+" w/ 
Prefetch "+R_mp);
-                       Assert.assertTrue("Violated checkpoint count: " + numCP 
+ " < " + numCP_maxp, numCP < numCP_maxp);
+                               System.out.println("Value w/o Checkpoint "+R+" 
w/ Checkpoint "+R_mp);
+                       //compare checkpoint instruction count
+                       if (!testname.equalsIgnoreCase(TEST_NAME+"2"))
+                               Assert.assertTrue("Violated checkpoint count: " 
+ numCP + " < " + numCP_maxp, numCP < numCP_maxp);
                } finally {
                        resetExecMode(oldPlatform);
                        InfrastructureAnalyzer.setLocalMaxMemory(oldmem);
diff --git a/src/test/scripts/functions/async/CheckpointSharedOps2.dml 
b/src/test/scripts/functions/async/CheckpointSharedOps2.dml
new file mode 100644
index 0000000000..61e2fefc2f
--- /dev/null
+++ b/src/test/scripts/functions/async/CheckpointSharedOps2.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=13000, cols=150, seed=42); #sp_rand
+[W, H] = pnmf(X=X, rnk=100, verbose=FALSE);
+
+print(sum(W %*% H));
+R = sum(W %*% H);
+write(R, $1, format="text");
+

Reply via email to