This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new d3480c1ee0 [SYSTEMDS-3562] Save multi-statementblock checkpoints for 
recompiler
d3480c1ee0 is described below

commit d3480c1ee032694aaf86839efb5ad656c041f16f
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Jun 14 10:53:34 2023 +0200

    [SYSTEMDS-3562] Save multi-statementblock checkpoints for recompiler
    
    This patch adds a temporary fix to save the position of the checkpoint
    instructions placed in a loop body during compilation and again place
    those in there during recompilation. A better fix would be to enable
    recompilation for that loop or the function.
    
    Closes #1844
---
 .../apache/sysds/hops/recompile/Recompiler.java    |  4 +-
 .../org/apache/sysds/lops/rewrite/LopRewriter.java |  4 +-
 .../lops/rewrite/RewriteAddChkpointInLoop.java     |  6 ++-
 .../sysds/lops/rewrite/RewriteAddChkpointLop.java  | 47 +++++++++++++++++++++-
 .../org/apache/sysds/parser/StatementBlock.java    | 18 +++++++++
 .../functions/async/CheckpointSharedOpsTest.java   |  3 +-
 6 files changed, 73 insertions(+), 9 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index d5c8169d68..01945f90d9 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -404,8 +404,8 @@ public class Recompiler {
                }
 
                // dynamic lop rewrites for the updated hop DAGs
-               if (rewrittenHops)
-                       _lopRewriter.get().rewriteLopDAG(lops);
+               if (rewrittenHops && sb != null)
+                       _lopRewriter.get().rewriteLopDAG(sb, lops);
 
                Dag<Lop> dag = new Dag<>();
                for (Lop l : lops)
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java 
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
index 55b590543a..2b054d9b2b 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -69,8 +69,8 @@ public class LopRewriter
                        rRewriteLop(fsb);
        }
 
-       public ArrayList<Lop> rewriteLopDAG(ArrayList<Lop> lops) {
-               StatementBlock sb = new StatementBlock();
+       public ArrayList<Lop> rewriteLopDAG(StatementBlock sb, ArrayList<Lop> 
lops) {
+               //StatementBlock sb = new StatementBlock();
                sb.setLops(lops);
                return rRewriteLop(sb).get(0).getLops();
        }
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
index 5e445a6bce..27a1a552ab 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.java
@@ -81,7 +81,7 @@ public class RewriteAddChkpointInLoop extends LopRewriteRule
                        return List.of(sb);
 
                // Add checkpoint Lops after the shared operators
-               addChkpointLop(lops, operatorJobCount);
+               addChkpointLop(lops, operatorJobCount, csb);
                // TODO: A rewrite pass to remove less effective checkpoints
                return List.of(sb);
        }
@@ -91,7 +91,7 @@ public class RewriteAddChkpointInLoop extends LopRewriteRule
                return sbs;
        }
 
-       private void addChkpointLop(List<Lop> nodes, Map<Long, Integer> 
operatorJobCount) {
+       private void addChkpointLop(List<Lop> nodes, Map<Long, Integer> 
operatorJobCount, StatementBlock sb) {
                for (Lop l : nodes) {
                        if(operatorJobCount.containsKey(l.getID()) && 
operatorJobCount.get(l.getID()) > 1) {
                                // TODO: Check if this lop leads to one of 
those variables
@@ -106,6 +106,8 @@ public class RewriteAddChkpointInLoop extends LopRewriteRule
                                        out.replaceInput(l, checkpoint);
                                        l.removeOutput(out);
                                }
+                               // Save the checkpoint position for the 
recompiler
+                               sb.setCheckpointPosition(l, oldOuts);
                        }
                }
        }
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
index 6a1a1192ea..701c604bc8 100644
--- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -60,6 +60,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                OperatorOrderingUtils.markSharedSparkOps(sparkRoots, 
operatorJobCount);
                // TODO: A rewrite pass to remove less effective checkpoints
                addChkpointLop(lops, operatorJobCount);
+               placeCompiledCheckpoints(lops, sb);
                //New node is added inplace in the Lop DAG
                return List.of(sb);
        }
@@ -78,7 +79,7 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                                && 
OperatorOrderingUtils.isPersistableSparkOp(l)) {
                                // This operation is expensive and shared 
between Spark jobs
                                List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
-                               // Construct a chkpoint lop that takes this 
Spark node as a input
+                               // Construct a chkpoint lop that takes this 
Spark node as an input
                                Lop chkpoint = new Checkpoint(l, 
l.getDataType(), l.getValueType(),
                                        
Checkpoint.getDefaultStorageLevelString(), false);
                                for (Lop out : oldOuts) {
@@ -90,4 +91,48 @@ public class RewriteAddChkpointLop extends LopRewriteRule
                        }
                }
        }
+
+       private void placeCompiledCheckpoints(List<Lop> nodes, StatementBlock 
sb) {
+               if (sb.getCheckpointPositions() == null)
+                       return;
+
+               for (Lop l : nodes) {
+                       // Check if the compiler placed and saved a checkpoint
+                       // TODO: Call recompiler on the loops
+                       if (isCheckpointed(l, sb)) {
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               // Construct a chkpoint lop that takes this 
Spark node as an input
+                               Lop chkpoint = new Checkpoint(l, 
l.getDataType(), l.getValueType(),
+                                       
Checkpoint.getDefaultStorageLevelString(), false);
+                               for (Lop out : oldOuts) {
+                                       //Rewire l -> out to l -> chkpoint -> 
out
+                                       chkpoint.addOutput(out);
+                                       out.replaceInput(l, chkpoint);
+                                       l.removeOutput(out);
+                               }
+                       }
+               }
+       }
+
+       private boolean isCheckpointed(Lop lop, StatementBlock sb) {
+               var cpPositions = sb.getCheckpointPositions();
+               if (cpPositions == null)
+                       return false;
+
+               if (cpPositions.containsKey(lop.getType())) {
+                       List<Lop.Type> outputsT = 
cpPositions.get(lop.getType());
+                       List<Lop> outputs = new ArrayList<>(lop.getOutputs());
+                       if (outputs.size() != outputsT.size())
+                               return false;
+                       for (int i=0; i< outputs.size(); i++) {
+                               if (outputs.get(i).getType() != outputsT.get(i)
+                                       || !outputs.get(i).isExecSpark())
+                                       return false;
+                       }
+               }
+               else
+                       return false;
+
+               return true;
+       }
 }
diff --git a/src/main/java/org/apache/sysds/parser/StatementBlock.java 
b/src/main/java/org/apache/sysds/parser/StatementBlock.java
index b4ee82405b..3deb6a8001 100644
--- a/src/main/java/org/apache/sysds/parser/StatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/StatementBlock.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.stream.Collectors;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -64,6 +65,7 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        private boolean _requiresRecompile = false;
        private boolean _splitDag = false;
        private boolean _nondeterministic = false;
+       private HashMap<Lop.Type, List<Lop.Type>> _checkpointPositions = null;
 
        protected double repetitions = 1;
        public final static double DEFAULT_LOOP_REPETITIONS = 10;
@@ -1393,4 +1395,20 @@ public class StatementBlock extends LiveVariableAnalysis 
implements ParseInfo
        public boolean isNondeterministic() {
                return _nondeterministic;
        }
+
+       public void setCheckpointPosition(Lop input, List<Lop> outputs) {
+               // FIXME: Type is not the best key as many Lops may have the 
same types
+               Lop.Type inputT = input.getType();
+               List<Lop.Type> outputsT = 
outputs.stream().map(Lop::getType).collect(Collectors.toList());
+
+               if (_checkpointPositions == null)
+                       _checkpointPositions = new HashMap<>();
+               if (!_checkpointPositions.containsKey(inputT)) {
+                       _checkpointPositions.put(inputT, outputsT);
+               }
+       }
+
+       public HashMap<Lop.Type, List<Lop.Type>> getCheckpointPositions() {
+               return _checkpointPositions;
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
index bceb4e2090..6898b9ba88 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
@@ -96,8 +96,7 @@ public class CheckpointSharedOpsTest extends 
AutomatedTestBase {
                        if (!matchVal)
                                System.out.println("Value w/o Checkpoint "+R+" 
w/ Checkpoint "+R_mp);
                        //compare checkpoint instruction count
-                       if (!testname.equalsIgnoreCase(TEST_NAME+"2"))
-                               Assert.assertTrue("Violated checkpoint count: " 
+ numCP + " < " + numCP_maxp, numCP < numCP_maxp);
+                       Assert.assertTrue("Violated checkpoint count: " + numCP 
+ " < " + numCP_maxp, numCP < numCP_maxp);
                } finally {
                        resetExecMode(oldPlatform);
                        InfrastructureAnalyzer.setLocalMaxMemory(oldmem);

Reply via email to