This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5142a7e739 [SYSTEMDS-3497] Refactor to add LOP rewrite step in 
compilation
5142a7e739 is described below

commit 5142a7e7390ff816af99c593cf31a0402ae931ee
Author: Arnab Phani <[email protected]>
AuthorDate: Thu Feb 9 14:33:11 2023 +0100

    [SYSTEMDS-3497] Refactor to add LOP rewrite step in compilation
    
    This patch adds a new step in the compilation (and recompilation)
    steps to rewrite Lop DAGs (single and multi-statement block).
    Current rewrite passes include adding prefetch, broadcast and
    checkpoint nodes. This refactoring allows us easily add new
    rewrite rules and separate the Lop rewrites from operator
    ordering.
    
    Closes #1783
---
 src/main/java/org/apache/sysds/api/DMLScript.java  |   5 +-
 .../apache/sysds/hops/recompile/Recompiler.java    |  21 +-
 src/main/java/org/apache/sysds/lops/Lop.java       |   4 +
 .../java/org/apache/sysds/lops/LopProperties.java  |   1 +
 .../apache/sysds/lops/OperatorOrderingUtils.java   | 125 +++++++++++
 src/main/java/org/apache/sysds/lops/UnaryCP.java   |   2 +-
 .../lops/compile/linearization/ILinearize.java     | 231 ++-------------------
 .../apache/sysds/lops/rewrite/LopRewriteRule.java  |  30 +++
 .../org/apache/sysds/lops/rewrite/LopRewriter.java | 134 ++++++++++++
 .../sysds/lops/rewrite/RewriteAddBroadcastLop.java |  83 ++++++++
 .../sysds/lops/rewrite/RewriteAddChkpointLop.java  | 117 +++++++++++
 .../sysds/lops/rewrite/RewriteAddPrefetchLop.java  | 118 +++++++++++
 .../apache/sysds/lops/rewrite/RewriteFixIDs.java   |  67 ++++++
 .../org/apache/sysds/parser/DMLTranslator.java     |   8 +-
 .../test/functions/async/AsyncBroadcastTest.java   |   4 -
 .../functions/async/CheckpointSharedOpsTest.java   |   4 +-
 .../test/functions/async/PrefetchRDDTest.java      |   1 -
 17 files changed, 725 insertions(+), 230 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 0c3716dd35..ad386e29e8 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -455,8 +455,11 @@ public class DMLScript
                
                //Step 6: construct lops (incl exec type and op selection)
                dmlt.constructLops(prog);
+
+               //Step 7: rewrite LOP DAGs (incl adding new LOPs s.a. prefetch, 
broadcast)
+               dmlt.rewriteLopDAG(prog);
                
-               //Step 7: generate runtime program, incl codegen
+               //Step 8: generate runtime program, incl codegen
                Program rtprog = dmlt.getRuntimeProgram(prog, 
ConfigurationManager.getDMLConfig());
                
                //Step 9: prepare statistics [and optional explain output]
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java 
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index f3fb0fcbc6..392d303ca4 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -59,6 +59,7 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.hops.rewrite.ProgramRewriter;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.lops.rewrite.LopRewriter;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.parser.ForStatementBlock;
@@ -130,6 +131,10 @@ public class Recompiler {
        private static ThreadLocal<ProgramRewriter> _rewriter = new 
ThreadLocal<ProgramRewriter>() {
                @Override protected ProgramRewriter initialValue() { return new 
ProgramRewriter(false, true); }
        };
+
+       private static ThreadLocal<LopRewriter> _lopRewriter = new 
ThreadLocal<LopRewriter>() {
+               @Override protected LopRewriter initialValue() {return new 
LopRewriter();}
+       };
        
        public enum ResetType {
                RESET,
@@ -145,6 +150,7 @@ public class Recompiler {
         */
        public static void reinitRecompiler() {
                _rewriter.set(new ProgramRewriter(false, true));
+               _lopRewriter.set(new LopRewriter());
        }
        
        public static ArrayList<Instruction> recompileHopsDag( StatementBlock 
sb, ArrayList<Hop> hops, 
@@ -305,6 +311,7 @@ public class Recompiler {
                boolean codegen = ConfigurationManager.isCodegenEnabled()
                        && !(forceEt && et == null ) //not on reset
                        && SpoofCompiler.RECOMPILE_CODEGEN;
+               boolean rewrittenHops = false;
                
                // prepare hops dag for recompile
                if( !inplace ){ 
@@ -352,6 +359,7 @@ public class Recompiler {
                                Hop.resetVisitStatus(hops);
                                for( Hop hopRoot : hops )
                                        rUpdateStatistics( hopRoot, 
ec.getVariables() );
+                               rewrittenHops = true;
                        }
                        
                        // refresh memory estimates (based on updated stats,
@@ -382,11 +390,18 @@ public class Recompiler {
                rSetMaxParallelism(hops, maxK);
                
                // construct lops
-               Dag<Lop> dag = new Dag<>();
+               ArrayList<Lop> lops = new ArrayList<>();
                for( Hop hopRoot : hops ){
-                       Lop lops = hopRoot.constructLops();
-                       lops.addToDag(dag);
+                       lops.add(hopRoot.constructLops());
                }
+
+               // dynamic lop rewrites for the updated hop DAGs
+               if (rewrittenHops)
+                       _lopRewriter.get().rewriteLopDAG(lops);
+
+               Dag<Lop> dag = new Dag<>();
+               for (Lop l : lops)
+                       l.addToDag(dag);
                
                // generate runtime instructions (incl piggybacking)
                ArrayList<Instruction> newInst = dag
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java 
b/src/main/java/org/apache/sysds/lops/Lop.java
index ecc9e7f893..b768ded9ad 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -401,6 +401,10 @@ public abstract class Lop
        public long getID() {
                return lps.getID();
        }
+
+       public void setNewID() {
+               lps.setNewID();
+       }
        
        public int getLevel() {
                return lps.getLevel();
diff --git a/src/main/java/org/apache/sysds/lops/LopProperties.java 
b/src/main/java/org/apache/sysds/lops/LopProperties.java
index ed788c79fa..9fce0b6fb0 100644
--- a/src/main/java/org/apache/sysds/lops/LopProperties.java
+++ b/src/main/java/org/apache/sysds/lops/LopProperties.java
@@ -54,6 +54,7 @@ public class LopProperties
        }
        
        public long getID() { return ID; }
+       public void setNewID() { ID = UniqueLopID.getNextID(); }
        public int getLevel() { return level; }
        public void setLevel( int l ) { level = l; }
        
diff --git a/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java 
b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
new file mode 100644
index 0000000000..35926961f2
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/OperatorOrderingUtils.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+public class OperatorOrderingUtils
+{
+       // Return a list representation of all the lops in a SB
+       public static ArrayList<Lop> getLopList(StatementBlock sb) {
+               ArrayList<Lop> lops = null;
+               if (sb.getLops() != null && !sb.getLops().isEmpty()) {
+                       lops = new ArrayList<>();
+                       for (Lop root : sb.getLops())
+                               addToLopList(lops, root);
+               }
+               return lops;
+       }
+
+       // Determine if a lop is root of a DAG
+       public static boolean isLopRoot(Lop lop) {
+               if (lop.getOutputs().isEmpty())
+                       return true;
+               //TODO: Handle internal builtins (e.g. eigen)
+               if (lop instanceof FunctionCallCP &&
+                       ((FunctionCallCP) 
lop).getFnamespace().equalsIgnoreCase(DMLProgram.INTERNAL_NAMESPACE)) {
+                       return true;
+               }
+               return false;
+       }
+
+       // Gather the Spark operators which return intermediates to local 
(actions/single_block)
+       // In addition count the number of Spark OPs underneath every Operator
+       public static int collectSparkRoots(Lop root, Map<Long, Integer> 
sparkOpCount, List<Lop> sparkRoots) {
+               if (sparkOpCount.containsKey(root.getID())) //visited before
+                       return sparkOpCount.get(root.getID());
+
+               // Aggregate #Spark operators in the child DAGs
+               int total = 0;
+               for (Lop input : root.getInputs())
+                       total += collectSparkRoots(input, sparkOpCount, 
sparkRoots);
+
+               // Check if this node is Spark
+               total = root.isExecSpark() ? total + 1 : total;
+               sparkOpCount.put(root.getID(), total);
+
+               // Triggering point: Spark action/operator with all CP consumers
+               if (isSparkTriggeringOp(root)) {
+                       sparkRoots.add(root);
+                       root.setAsynchronous(true); //candidate for async. 
execution
+               }
+
+               return total;
+       }
+
+       // Dictionary of Spark operators which are expensive enough to be
+       // benefited from persisting if shared among jobs.
+       public static boolean isPersistableSparkOp(Lop lop) {
+               return lop.isExecSpark() && (lop instanceof MapMult
+                       || lop instanceof MMCJ || lop instanceof MMRJ
+                       || lop instanceof MMZip);
+       }
+
+       private static boolean isSparkTriggeringOp(Lop lop) {
+               boolean rightSpLop = lop.isExecSpark() && (lop.getAggType() == 
AggBinaryOp.SparkAggType.SINGLE_BLOCK
+                       || lop.getDataType() == Types.DataType.SCALAR || lop 
instanceof MapMultChain
+                       || lop instanceof PickByCount || lop instanceof MMZip 
|| lop instanceof CentralMoment
+                       || lop instanceof CoVariance || lop instanceof MMTSJ || 
lop.isAllOutputsCP());
+               boolean isPrefetched = lop.getOutputs().size() == 1
+                       && lop.getOutputs().get(0) instanceof UnaryCP
+                       && ((UnaryCP) 
lop.getOutputs().get(0)).getOpCode().equalsIgnoreCase("prefetch");
+               boolean col2Bc = isCollectForBroadcast(lop);
+               boolean prefetch = (lop instanceof UnaryCP) &&
+                       ((UnaryCP) 
lop).getOpCode().equalsIgnoreCase("prefetch");
+               return (rightSpLop || col2Bc || prefetch) && !isPrefetched;
+       }
+
+       // Determine if the result of this operator is collected to
+       // broadcast for the next operator (e.g. mapmm --> map+)
+       public static boolean isCollectForBroadcast(Lop lop) {
+               boolean isSparkOp = lop.isExecSpark();
+               boolean isBc = lop.getOutputs().stream()
+                       .allMatch(out -> (out.getBroadcastInput() == lop));
+               //TODO: Handle Lops with mixed Spark (broadcast) CP consumers
+               return isSparkOp && isBc && (lop.getDataType() == 
Types.DataType.MATRIX);
+       }
+
+       private static boolean addNode(ArrayList<Lop> lops, Lop node) {
+               if (lops.contains(node))
+                       return false;
+               lops.add(node);
+               return true;
+       }
+
+       private static void addToLopList(ArrayList<Lop> lops, Lop lop) {
+               if (addNode(lops, lop))
+                       for (Lop in : lop.getInputs())
+                               addToLopList(lops, in);
+       }
+
+}
diff --git a/src/main/java/org/apache/sysds/lops/UnaryCP.java 
b/src/main/java/org/apache/sysds/lops/UnaryCP.java
index 4b95e8c1b0..7dd6a30e58 100644
--- a/src/main/java/org/apache/sysds/lops/UnaryCP.java
+++ b/src/main/java/org/apache/sysds/lops/UnaryCP.java
@@ -66,7 +66,7 @@ public class UnaryCP extends Lop {
                return "Operation: " + getInstructions("", "");
        }
        
-       private String getOpCode() {
+       public String getOpCode() {
                return operation.toString();
        }
 
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index eef3085917..656a0262d6 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -34,7 +34,6 @@ import java.util.stream.Stream;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
@@ -43,24 +42,19 @@ import org.apache.sysds.lops.CSVReBlock;
 import org.apache.sysds.lops.CentralMoment;
 import org.apache.sysds.lops.Checkpoint;
 import org.apache.sysds.lops.CoVariance;
-import org.apache.sysds.lops.DataGen;
-import org.apache.sysds.lops.FunctionCallCP;
 import org.apache.sysds.lops.GroupedAggregate;
 import org.apache.sysds.lops.GroupedAggregateM;
 import org.apache.sysds.lops.Lop;
-import org.apache.sysds.lops.MMCJ;
-import org.apache.sysds.lops.MMRJ;
 import org.apache.sysds.lops.MMTSJ;
 import org.apache.sysds.lops.MMZip;
-import org.apache.sysds.lops.MapMult;
 import org.apache.sysds.lops.MapMultChain;
+import org.apache.sysds.lops.OperatorOrderingUtils;
 import org.apache.sysds.lops.ParameterizedBuiltin;
 import org.apache.sysds.lops.PickByCount;
 import org.apache.sysds.lops.ReBlock;
 import org.apache.sysds.lops.SpoofFused;
 import org.apache.sysds.lops.UAggOuterChain;
 import org.apache.sysds.lops.UnaryCP;
-import org.apache.sysds.parser.DMLProgram;
 
 /**
  * A interface for the linearization algorithms that order the DAG nodes into 
a sequence of instructions to execute.
@@ -187,15 +181,17 @@ public interface ILinearize {
        private static List<Lop> doMaxParallelizeSort(List<Lop> v)
        {
                List<Lop> final_v = null;
-               if (v.stream().anyMatch(ILinearize::isSparkTriggeringOp)) {
+               // Fallback to default depth-first if all operators are CP
+               if (v.stream().anyMatch(ILinearize::isDistributedOp)) {
                        // Step 1: Collect the Spark roots and #Spark 
instructions in each subDAG
                        Map<Long, Integer> sparkOpCount = new HashMap<>();
-                       List<Lop> roots = 
v.stream().filter(ILinearize::isRoot).collect(Collectors.toList());
+                       List<Lop> roots = 
v.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
                        List<Lop> sparkRoots = new ArrayList<>();
-                       roots.forEach(r -> collectSparkRoots(r, sparkOpCount, 
sparkRoots));
+                       roots.forEach(r -> 
OperatorOrderingUtils.collectSparkRoots(r, sparkOpCount, sparkRoots));
 
-                       // Step 2: Depth-first linearization. Place the CP OPs 
first to increase broadcast potentials.
+                       // Step 2: Depth-first linearization of Spark roots.
                        // Maintain the default order (by ID) to trigger 
independent Spark jobs first
+                       // This allows parallel execution of the jobs in the 
cluster
                        ArrayList<Lop> operatorList = new ArrayList<>();
                        sparkRoots.forEach(r -> depthFirst(r, operatorList, 
sparkOpCount, false));
 
@@ -204,81 +200,12 @@ public interface ILinearize {
                        roots.forEach(r -> depthFirst(r, operatorList, 
sparkOpCount, false));
                        roots.forEach(Lop::resetVisitStatus);
 
-                       // Step 4: Add Chkpoint lops after the expensive Spark 
operators, which
-                       // are shared among multiple Spark jobs. Only consider 
operators with
-                       // Spark consumers for now.
-                       Map<Long, Integer> operatorJobCount = new HashMap<>();
-                       markPersistableSparkOps(sparkRoots, operatorJobCount);
-                       final_v = addChkpointLop(operatorList, 
operatorJobCount);
-                       // TODO: A rewrite pass to remove less effective 
chkpoints
+                       final_v = operatorList;
                }
                else
-                       // Fall back to depth if none of the operators returns 
results back to local
                        final_v = depthFirst(v);
 
-               // Step 4: Add Prefetch and Broadcast lops if necessary
-               List<Lop> v_pf = ConfigurationManager.isPrefetchEnabled() ? 
addPrefetchLop(final_v) : final_v;
-               List<Lop> v_bc = ConfigurationManager.isBroadcastEnabled() ? 
addBroadcastLop(v_pf) : v_pf;
-
-               return v_bc;
-       }
-
-       private static boolean isRoot(Lop lop) {
-               if (lop.getOutputs().isEmpty())
-                       return true;
-               if (lop instanceof FunctionCallCP &&
-                       ((FunctionCallCP) 
lop).getFnamespace().equalsIgnoreCase(DMLProgram.INTERNAL_NAMESPACE)) {
-                       return true;
-               }
-               return false;
-       }
-
-       // Gather the Spark operators which return intermediates to local 
(actions/single_block)
-       // In addition count the number of Spark OPs underneath every Operator
-       private static int collectSparkRoots(Lop root, Map<Long, Integer> 
sparkOpCount, List<Lop> sparkRoots) {
-               if (sparkOpCount.containsKey(root.getID())) //visited before
-                       return sparkOpCount.get(root.getID());
-
-               // Aggregate #Spark operators in the child DAGs
-               int total = 0;
-               for (Lop input : root.getInputs())
-                       total += collectSparkRoots(input, sparkOpCount, 
sparkRoots);
-
-               // Check if this node is Spark
-               total = root.isExecSpark() ? total + 1 : total;
-               sparkOpCount.put(root.getID(), total);
-
-               // Triggering point: Spark action/operator with all CP consumers
-               if (isSparkTriggeringOp(root)) {
-                       sparkRoots.add(root);
-                       root.setAsynchronous(true); //candidate for async. 
execution
-               }
-
-               return total;
-       }
-
-       // Count the number of jobs a Spark operator is part of
-       private static void markPersistableSparkOps(List<Lop> sparkRoots, 
Map<Long, Integer> operatorJobCount) {
-               for (Lop root : sparkRoots) {
-                       collectPersistableSparkOps(root, operatorJobCount);
-                       root.resetVisitStatus();
-               }
-       }
-
-       private static void collectPersistableSparkOps(Lop root, Map<Long, 
Integer> operatorJobCount) {
-               if (root.isVisited())
-                       return;
-
-               for (Lop input : root.getInputs())
-                       if (root.getBroadcastInput() != input)
-                               collectPersistableSparkOps(input, 
operatorJobCount);
-
-               // Increment the job counter if this node benefits from 
persisting
-               // and reachable from multiple job roots
-               if (isPersistableSparkOp(root))
-                       operatorJobCount.merge(root.getID(), 1, Integer::sum);
-
-               root.setVisited();
+               return final_v;
        }
 
        // Place the operators in a depth-first manner, but order
@@ -306,104 +233,11 @@ public interface ILinearize {
                root.setVisited();
        }
 
-       private static boolean isSparkTriggeringOp(Lop lop) {
-               return lop.isExecSpark() && (lop.getAggType() == 
SparkAggType.SINGLE_BLOCK
-                       || lop.getDataType() == DataType.SCALAR || lop 
instanceof MapMultChain
-                       || lop instanceof PickByCount || lop instanceof MMZip 
|| lop instanceof CentralMoment
-                       || lop instanceof CoVariance || lop instanceof MMTSJ || 
lop.isAllOutputsCP())
-                       || isCollectForBroadcast(lop);
-       }
-
-       private static boolean isCollectForBroadcast(Lop lop) {
-               boolean isSparkOp = lop.isExecSpark();
-               boolean isBc = lop.getOutputs().stream()
-                       .allMatch(out -> (out.getBroadcastInput() == lop));
-               //TODO: Handle Lops with mixed Spark (broadcast) CP consumers
-               return isSparkOp && isBc && (lop.getDataType() == 
DataType.MATRIX);
-       }
-
-       // Dictionary of Spark operators which are expensive enough to be
-       // benefited from persisting if shared among jobs.
-       private static boolean isPersistableSparkOp(Lop lop) {
-               return lop.isExecSpark() && (lop instanceof MapMult
-                       || lop instanceof MMCJ || lop instanceof MMRJ
-                       || lop instanceof MMZip);
-       }
-
-       private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long, 
Integer> operatorJobCount) {
-               List<Lop> nodesWithChkpt = new ArrayList<>();
-
-               for (Lop l : nodes) {
-                       nodesWithChkpt.add(l);
-                       if(operatorJobCount.containsKey(l.getID()) && 
operatorJobCount.get(l.getID()) > 1) {
-                               //This operation is expensive and shared 
between Spark jobs
-                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
-                               //Construct a chkpoint lop that takes this 
Spark node as a input
-                               Lop chkpoint = new Checkpoint(l, 
l.getDataType(), l.getValueType(),
-                                       
Checkpoint.getDefaultStorageLevelString(), false);
-                               for (Lop out : oldOuts) {
-                                       //Rewire l -> out to l -> chkpoint -> 
out
-                                       chkpoint.addOutput(out);
-                                       out.replaceInput(l, chkpoint);
-                                       l.removeOutput(out);
-                               }
-                               //Place it immediately after the Spark lop in 
the node list
-                               nodesWithChkpt.add(chkpoint);
-                       }
-               }
-               return nodesWithChkpt;
-       }
-
-       private static List<Lop> addPrefetchLop(List<Lop> nodes) {
-               List<Lop> nodesWithPrefetch = new ArrayList<>();
-
-               //Find the Spark nodes with all CP outputs
-               for (Lop l : nodes) {
-                       nodesWithPrefetch.add(l);
-                       if (isPrefetchNeeded(l)) {
-                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
-                               //Construct a Prefetch lop that takes this 
Spark node as a input
-                               UnaryCP prefetch = new UnaryCP(l, 
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
-                               prefetch.setAsynchronous(true);
-                               //Reset asynchronous flag for the input if 
already set (e.g. mapmm -> prefetch)
-                               l.setAsynchronous(false);
-                               for (Lop outCP : oldOuts) {
-                                       //Rewire l -> outCP to l -> Prefetch -> 
outCP
-                                       prefetch.addOutput(outCP);
-                                       outCP.replaceInput(l, prefetch);
-                                       l.removeOutput(outCP);
-                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
-                               }
-                               //Place it immediately after the Spark lop in 
the node list
-                               nodesWithPrefetch.add(prefetch);
-                       }
-               }
-               return nodesWithPrefetch;
-       }
-
-       private static List<Lop> addBroadcastLop(List<Lop> nodes) {
-               List<Lop> nodesWithBroadcast = new ArrayList<>();
-
-               for (Lop l : nodes) {
-                       nodesWithBroadcast.add(l);
-                       if (isBroadcastNeeded(l)) {
-                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
-                               //Construct a Broadcast lop that takes this 
Spark node as an input
-                               UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST, 
l.getDataType(), l.getValueType(), ExecType.CP);
-                               bc.setAsynchronous(true);
-                               //FIXME: Wire Broadcast only with the necessary 
outputs
-                               for (Lop outCP : oldOuts) {
-                                       //Rewire l -> outCP to l -> Broadcast 
-> outCP
-                                       bc.addOutput(outCP);
-                                       outCP.replaceInput(l, bc);
-                                       l.removeOutput(outCP);
-                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
-                               }
-                               //Place it immediately after the Spark lop in 
the node list
-                               nodesWithBroadcast.add(bc);
-                       }
-               }
-               return nodesWithBroadcast;
+       private static boolean isDistributedOp(Lop lop) {
+               return lop.isExecSpark()
+                       || (lop instanceof UnaryCP
+                       && (((UnaryCP) 
lop).getOpCode().equalsIgnoreCase("prefetch")
+                       || ((UnaryCP) 
lop).getOpCode().equalsIgnoreCase("broadcast")));
        }
 
        @SuppressWarnings("unused")
@@ -432,43 +266,6 @@ public interface ILinearize {
                return nodesWithCheckpoint;
        }
 
-       private static boolean isPrefetchNeeded(Lop lop) {
-               // Run Prefetch for a Spark instruction if the instruction is a 
Transformation
-               // and the output is consumed by only CP instructions.
-               boolean transformOP = lop.getExecType() == ExecType.SPARK && 
lop.getAggType() != SparkAggType.SINGLE_BLOCK
-                               // Always Action operations
-                               && !(lop.getDataType() == DataType.SCALAR)
-                               && !(lop instanceof MapMultChain) && !(lop 
instanceof PickByCount)
-                               && !(lop instanceof MMZip) && !(lop instanceof 
CentralMoment)
-                               && !(lop instanceof CoVariance)
-                               // Not qualified for prefetching
-                               && !(lop instanceof Checkpoint) && !(lop 
instanceof ReBlock)
-                               && !(lop instanceof CSVReBlock) && !(lop 
instanceof DataGen)
-                               // Cannot filter Transformation cases from 
Actions (FIXME)
-                               && !(lop instanceof MMTSJ) && !(lop instanceof 
UAggOuterChain)
-                               && !(lop instanceof ParameterizedBuiltin) && 
!(lop instanceof SpoofFused);
-
-               //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
-               boolean hasParameterizedOut = lop.getOutputs().stream()
-                               .anyMatch(out -> ((out instanceof 
ParameterizedBuiltin)
-                                       || (out instanceof GroupedAggregate)
-                                       || (out instanceof GroupedAggregateM)));
-               //TODO: support non-matrix outputs
-               return transformOP && !hasParameterizedOut
-                               && (lop.isAllOutputsCP() || 
isCollectForBroadcast(lop))
-                               && lop.getDataType() == DataType.MATRIX;
-       }
-
-       private static boolean isBroadcastNeeded(Lop lop) {
-               // Asynchronously broadcast a matrix if that is produced by a 
CP instruction,
-               // and at least one Spark parent needs to broadcast this 
intermediate (eg. mapmm)
-               boolean isBc = lop.getOutputs().stream()
-                               .anyMatch(out -> (out.getBroadcastInput() == 
lop));
-               //TODO: Early broadcast objects that are bigger than a single 
block
-               //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
-               return isBc && lop.getDataType() == DataType.MATRIX;
-       }
-
        private static boolean isCheckpointNeeded(Lop lop) {
                // Place checkpoint_e just before a Spark action (FIXME)
                boolean actionOP = lop.getExecType() == ExecType.SPARK
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriteRule.java 
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriteRule.java
new file mode 100644
index 0000000000..5af5d65244
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriteRule.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.List;
+
+public abstract class LopRewriteRule
+{
+       public abstract List<StatementBlock> 
rewriteLOPinStatementBlock(StatementBlock sb);
+       public abstract List<StatementBlock> 
rewriteLOPinStatementBlocks(List<StatementBlock> sb);
+}
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java 
b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
new file mode 100644
index 0000000000..4567cf1c4e
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/LopRewriter.java
@@ -0,0 +1,134 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+       * or more contributor license agreements.  See the NOTICE file
+       * distributed with this work for additional information
+       * regarding copyright ownership.  The ASF licenses this file
+       * to you under the Apache License, Version 2.0 (the
+       * "License"); you may not use this file except in compliance
+       * with the License.  You may obtain a copy of the License at
+       *
+       *   http://www.apache.org/licenses/LICENSE-2.0
+       *
+       * Unless required by applicable law or agreed to in writing,
+       * software distributed under the License is distributed on an
+       * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+       * KIND, either express or implied.  See the License for the
+       * specific language governing permissions and limitations
+       * under the License.
+       */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.ForStatement;
+import org.apache.sysds.parser.ForStatementBlock;
+import org.apache.sysds.parser.FunctionStatement;
+import org.apache.sysds.parser.FunctionStatementBlock;
+import org.apache.sysds.parser.IfStatement;
+import org.apache.sysds.parser.IfStatementBlock;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.parser.WhileStatement;
+import org.apache.sysds.parser.WhileStatementBlock;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class LopRewriter
+{
+       private ArrayList<LopRewriteRule> _lopSBRuleSet = null;
+
+       public LopRewriter() {
+               _lopSBRuleSet = new ArrayList<>();
+               // Add rewrite rules (single and multi-statement block)
+               _lopSBRuleSet.add(new RewriteAddPrefetchLop());
+               _lopSBRuleSet.add(new RewriteAddBroadcastLop());
+               _lopSBRuleSet.add(new RewriteAddChkpointLop());
+               // TODO: A rewrite pass to remove less effective chkpoints
+               // Last rewrite to reset Lop IDs in a depth-first manner
+               _lopSBRuleSet.add(new RewriteFixIDs());
+       }
+
+       public void rewriteProgramLopDAGs(DMLProgram dmlp) {
+               for (String namespaceKey : dmlp.getNamespaces().keySet())
+                       // for each namespace, handle function statement blocks
+                       for (String fname : 
dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
+                               FunctionStatementBlock fsblock = 
dmlp.getFunctionStatementBlock(namespaceKey,fname);
+                               rewriteLopDAGsFunction(fsblock);
+                       }
+
+               if (!_lopSBRuleSet.isEmpty()) {
+                       ArrayList<StatementBlock> sbs = 
rRewriteLops(dmlp.getStatementBlocks());
+                       dmlp.setStatementBlocks(sbs);
+               }
+       }
+
+       public void rewriteLopDAGsFunction(FunctionStatementBlock fsb) {
+               if( !_lopSBRuleSet.isEmpty() )
+                       rRewriteLop(fsb);
+       }
+
+       public ArrayList<Lop> rewriteLopDAG(ArrayList<Lop> lops) {
+               StatementBlock sb = new StatementBlock();
+               sb.setLops(lops);
+               return rRewriteLop(sb).get(0).getLops();
+       }
+
+       public ArrayList<StatementBlock> rRewriteLops(ArrayList<StatementBlock> 
sbs) {
+               // Apply rewrite rules to the lops of the list of statement 
blocks
+               List<StatementBlock> tmp = sbs;
+               for(LopRewriteRule r : _lopSBRuleSet)
+                       tmp = r.rewriteLOPinStatementBlocks(tmp);
+
+               // Recursively rewrite lops in statement blocks
+               List<StatementBlock> tmp2 = new ArrayList<>();
+               for( StatementBlock sb : tmp )
+                       tmp2.addAll(rRewriteLop(sb));
+
+               // Prepare output list
+               sbs.clear();
+               sbs.addAll(tmp2);
+               return sbs;
+       }
+
+       public ArrayList<StatementBlock> rRewriteLop(StatementBlock sb) {
+               ArrayList<StatementBlock> ret = new ArrayList<>();
+               ret.add(sb);
+
+               // Recursive invocation
+               if (sb instanceof FunctionStatementBlock) {
+                       FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
+                       FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
+                       fstmt.setBody(rRewriteLops(fstmt.getBody()));
+               }
+               else if (sb instanceof WhileStatementBlock) {
+                       WhileStatementBlock wsb = (WhileStatementBlock) sb;
+                       WhileStatement wstmt = 
(WhileStatement)wsb.getStatement(0);
+                       wstmt.setBody(rRewriteLops(wstmt.getBody()));
+               }
+               else if (sb instanceof IfStatementBlock) {
+                       IfStatementBlock isb = (IfStatementBlock) sb;
+                       IfStatement istmt = (IfStatement)isb.getStatement(0);
+                       istmt.setIfBody(rRewriteLops(istmt.getIfBody()));
+                       istmt.setElseBody(rRewriteLops(istmt.getElseBody()));
+               }
+               else if (sb instanceof ForStatementBlock) { //incl parfor
+                       //TODO: parfor statement blocks
+                       ForStatementBlock fsb = (ForStatementBlock) sb;
+                       ForStatement fstmt = (ForStatement)fsb.getStatement(0);
+                       fstmt.setBody(rRewriteLops(fstmt.getBody()));
+               }
+
+               // Apply rewrite rules to individual statement blocks
+               for(LopRewriteRule r : _lopSBRuleSet) {
+                       ArrayList<StatementBlock> tmp = new ArrayList<>();
+                       for( StatementBlock sbc : ret )
+                               tmp.addAll( r.rewriteLOPinStatementBlock(sbc) );
+
+                       // Take over set of rewritten sbs
+                       ret.clear();
+                       ret.addAll(tmp);
+               }
+
+               return ret;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java
new file mode 100644
index 0000000000..da22c51186
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteAddBroadcastLop extends LopRewriteRule
+{
+       @Override
+       public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock 
sb)
+       {
+               if (!ConfigurationManager.isBroadcastEnabled())
+                       return List.of(sb);
+
+               ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
+               if (lops == null)
+                       return List.of(sb);
+
+               ArrayList<Lop> nodesWithBroadcast = new ArrayList<>();
+               for (Lop l : lops) {
+                       nodesWithBroadcast.add(l);
+                       if (isBroadcastNeeded(l)) {
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               // Construct a Broadcast lop that takes this 
Spark node as an input
+                               UnaryCP bc = new UnaryCP(l, 
Types.OpOp1.BROADCAST, l.getDataType(), l.getValueType(), Types.ExecType.CP);
+                               bc.setAsynchronous(true);
+                               //FIXME: Wire Broadcast only with the necessary 
outputs
+                               for (Lop outCP : oldOuts) {
+                                       // Rewire l -> outCP to l -> Broadcast 
-> outCP
+                                       bc.addOutput(outCP);
+                                       outCP.replaceInput(l, bc);
+                                       l.removeOutput(outCP);
+                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
+                               }
+                               //Place it immediately after the Spark lop in 
the node list
+                               nodesWithBroadcast.add(bc);
+                       }
+               }
+               // New node is added inplace in the Lop DAG
+               return Arrays.asList(sb);
+       }
+
+       @Override
+       public List<StatementBlock> 
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+               return sbs;
+       }
+
+       private static boolean isBroadcastNeeded(Lop lop) {
+               // Asynchronously broadcast a matrix if that is produced by a 
CP instruction,
+               // and at least one Spark parent needs to broadcast this 
intermediate (eg. mapmm)
+               boolean isBc = lop.getOutputs().stream()
+                       .anyMatch(out -> (out.getBroadcastInput() == lop));
+               //TODO: Early broadcast objects that are bigger than a single 
block
+               boolean isCP = lop.getExecType() == Types.ExecType.CP;
+               return isCP && isBc && lop.getDataType() == 
Types.DataType.MATRIX;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
new file mode 100644
index 0000000000..d4f976a795
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddChkpointLop.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+public class RewriteAddChkpointLop extends LopRewriteRule
+{
+       @Override
+       public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock 
sb)
+       {
+               if (!ConfigurationManager.isCheckpointEnabled())
+                       return List.of(sb);
+
+               ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
+               if (lops == null)
+                       return List.of(sb);
+
+               // Collect the Spark roots and #Spark instructions in each 
subDAG
+               List<Lop> sparkRoots = new ArrayList<>();
+               Map<Long, Integer> sparkOpCount = new HashMap<>();
+               List<Lop> roots = 
lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
+               roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, 
sparkOpCount, sparkRoots));
+               if (sparkRoots.isEmpty())
+                       return List.of(sb);
+
+               // Add Chkpoint lops after the expensive Spark operators, which 
are
+               // shared among multiple Spark jobs. Only consider operators 
with
+               // Spark consumers for now.
+               Map<Long, Integer> operatorJobCount = new HashMap<>();
+               markPersistableSparkOps(sparkRoots, operatorJobCount);
+               // TODO: A rewrite pass to remove less effective chkpoints
+               List<Lop> nodesWithChkpt = addChkpointLop(lops, 
operatorJobCount);
+               //New node is added inplace in the Lop DAG
+               return List.of(sb);
+       }
+
+       @Override
+       public List<StatementBlock> 
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+               return sbs;
+       }
+
+       private static List<Lop> addChkpointLop(List<Lop> nodes, Map<Long, 
Integer> operatorJobCount) {
+               List<Lop> nodesWithChkpt = new ArrayList<>();
+
+               for (Lop l : nodes) {
+                       nodesWithChkpt.add(l);
+                       if(operatorJobCount.containsKey(l.getID()) && 
operatorJobCount.get(l.getID()) > 1) {
+                               // This operation is expensive and shared 
between Spark jobs
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               // Construct a chkpoint lop that takes this 
Spark node as a input
+                               Lop chkpoint = new Checkpoint(l, 
l.getDataType(), l.getValueType(),
+                                       
Checkpoint.getDefaultStorageLevelString(), false);
+                               for (Lop out : oldOuts) {
+                                       //Rewire l -> out to l -> chkpoint -> 
out
+                                       chkpoint.addOutput(out);
+                                       out.replaceInput(l, chkpoint);
+                                       l.removeOutput(out);
+                               }
+                               // Place it immediately after the Spark lop in 
the node list
+                               nodesWithChkpt.add(chkpoint);
+                       }
+               }
+               return nodesWithChkpt;
+       }
+
+       // Count the number of jobs a Spark operator is part of
+       private static void markPersistableSparkOps(List<Lop> sparkRoots, 
Map<Long, Integer> operatorJobCount) {
+               for (Lop root : sparkRoots) {
+                       collectPersistableSparkOps(root, operatorJobCount);
+                       root.resetVisitStatus();
+               }
+       }
+
+       private static void collectPersistableSparkOps(Lop root, Map<Long, 
Integer> operatorJobCount) {
+               if (root.isVisited())
+                       return;
+
+               for (Lop input : root.getInputs())
+                       if (root.getBroadcastInput() != input)
+                               collectPersistableSparkOps(input, 
operatorJobCount);
+
+               // Increment the job counter if this node benefits from 
persisting
+               // and reachable from multiple job roots
+               if (OperatorOrderingUtils.isPersistableSparkOp(root))
+                       operatorJobCount.merge(root.getID(), 1, Integer::sum);
+
+               root.setVisited();
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
new file mode 100644
index 0000000000..6eb52e0d9f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.lops.CSVReBlock;
+import org.apache.sysds.lops.CentralMoment;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.CoVariance;
+import org.apache.sysds.lops.DataGen;
+import org.apache.sysds.lops.GroupedAggregate;
+import org.apache.sysds.lops.GroupedAggregateM;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.lops.MMZip;
+import org.apache.sysds.lops.MapMultChain;
+import org.apache.sysds.lops.OperatorOrderingUtils;
+import org.apache.sysds.lops.ParameterizedBuiltin;
+import org.apache.sysds.lops.PickByCount;
+import org.apache.sysds.lops.ReBlock;
+import org.apache.sysds.lops.SpoofFused;
+import org.apache.sysds.lops.UAggOuterChain;
+import org.apache.sysds.lops.UnaryCP;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class RewriteAddPrefetchLop extends LopRewriteRule
+{
+       @Override
+       public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock 
sb)
+       {
+               if (!ConfigurationManager.isPrefetchEnabled())
+                       return List.of(sb);
+
+               ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
+               if (lops == null)
+                       return List.of(sb);
+
+               ArrayList<Lop> nodesWithPrefetch = new ArrayList<>();
+               //Find the Spark nodes with all CP outputs
+               for (Lop l : lops) {
+                       nodesWithPrefetch.add(l);
+                       if (isPrefetchNeeded(l)) {
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               //Construct a Prefetch lop that takes this 
Spark node as a input
+                               UnaryCP prefetch = new UnaryCP(l, 
Types.OpOp1.PREFETCH, l.getDataType(), l.getValueType(), Types.ExecType.CP);
+                               prefetch.setAsynchronous(true);
+                               //Reset asynchronous flag for the input if 
already set (e.g. mapmm -> prefetch)
+                               l.setAsynchronous(false);
+                               for (Lop outCP : oldOuts) {
+                                       //Rewire l -> outCP to l -> Prefetch -> 
outCP
+                                       prefetch.addOutput(outCP);
+                                       outCP.replaceInput(l, prefetch);
+                                       l.removeOutput(outCP);
+                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
+                               }
+                               //Place it immediately after the Spark lop in 
the node list
+                               nodesWithPrefetch.add(prefetch);
+                       }
+               }
+               //New node is added inplace in the Lop DAG
+               return Arrays.asList(sb);
+       }
+
+       @Override
+       public List<StatementBlock> 
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+               return sbs;
+       }
+
+       private boolean isPrefetchNeeded(Lop lop) {
+               // Run Prefetch for a Spark instruction if the instruction is a 
Transformation
+               // and the output is consumed by only CP instructions.
+               boolean transformOP = lop.getExecType() == Types.ExecType.SPARK 
&& lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK
+                       // Always Action operations
+                       && !(lop.getDataType() == Types.DataType.SCALAR)
+                       && !(lop instanceof MapMultChain) && !(lop instanceof 
PickByCount)
+                       && !(lop instanceof MMZip) && !(lop instanceof 
CentralMoment)
+                       && !(lop instanceof CoVariance)
+                       // Not qualified for prefetching
+                       && !(lop instanceof Checkpoint) && !(lop instanceof 
ReBlock)
+                       && !(lop instanceof CSVReBlock) && !(lop instanceof 
DataGen)
+                       // Cannot filter Transformation cases from Actions 
(FIXME)
+                       && !(lop instanceof MMTSJ) && !(lop instanceof 
UAggOuterChain)
+                       && !(lop instanceof ParameterizedBuiltin) && !(lop 
instanceof SpoofFused);
+
+               //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
+               boolean hasParameterizedOut = lop.getOutputs().stream()
+                       .anyMatch(out -> ((out instanceof ParameterizedBuiltin)
+                               || (out instanceof GroupedAggregate)
+                               || (out instanceof GroupedAggregateM)));
+               //TODO: support non-matrix outputs
+               return transformOP && !hasParameterizedOut
+                       && (lop.isAllOutputsCP() || 
OperatorOrderingUtils.isCollectForBroadcast(lop))
+                       && lop.getDataType() == Types.DataType.MATRIX;
+       }
+}
diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java 
b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
new file mode 100644
index 0000000000..00d205b553
--- /dev/null
+++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteFixIDs.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.rewrite;
+
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.parser.StatementBlock;
+
+import java.util.List;
+
+public class RewriteFixIDs extends LopRewriteRule
+{
+       @Override
+       public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock 
sb)
+       {
+               // Skip if no new Lop nodes are added
+               if (!ConfigurationManager.isPrefetchEnabled() && 
!ConfigurationManager.isBroadcastEnabled()
+                       && !ConfigurationManager.isCheckpointEnabled())
+                       return List.of(sb);
+
+               // Reset the IDs in a depth-first manner
+               if (sb.getLops() != null && !sb.getLops().isEmpty()) {
+                       for (Lop root : sb.getLops())
+                               assignNewID(root);
+                       sb.getLops().forEach(Lop::resetVisitStatus);
+               }
+               return List.of(sb);
+       }
+
+       @Override
+       public List<StatementBlock> 
rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
+               return sbs;
+       }
+
+       private void assignNewID(Lop lop) {
+               if (lop.isVisited())
+                       return;
+
+               if (lop.getInputs().isEmpty()) {  //leaf node
+                       lop.setNewID();
+                       lop.setVisited();
+                       return;
+               }
+               for (Lop input : lop.getInputs())
+                       assignNewID(input);
+
+               lop.setNewID();
+               lop.setVisited();
+       }
+}
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 03e2856e6c..c6ed2c5b84 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -78,6 +78,7 @@ import org.apache.sysds.hops.rewrite.ProgramRewriter;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.LopsException;
 import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.lops.rewrite.LopRewriter;
 import org.apache.sysds.parser.PrintStatement.PRINTTYPE;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
@@ -312,6 +313,11 @@ public class DMLTranslator
                                codgenHopsDAG(dmlp);
                }
        }
+
+       public void rewriteLopDAG(DMLProgram dmlp) {
+               LopRewriter rewriter = new LopRewriter();
+               rewriter.rewriteProgramLopDAGs(dmlp);
+       }
        
        public void codgenHopsDAG(DMLProgram dmlp) {
                SpoofCompiler.generateCode(dmlp);
@@ -482,7 +488,7 @@ public class DMLTranslator
        }
        
        public ProgramBlock createRuntimeProgramBlock(Program prog, 
StatementBlock sb, DMLConfig config) {
-               Dag<Lop> dag = null; 
+               Dag<Lop> dag = null;
                Dag<Lop> pred_dag = null;
 
                ArrayList<Instruction> instruct;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
index c8b7fdd94f..800655a8fc 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
@@ -71,8 +71,6 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
                InfrastructureAnalyzer.setLocalMaxMemory(mem);
                
                try {
-                       //OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
-                       //OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false;
                        OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
                        getAndLoadTestConfiguration(testname);
                        fullDMLScriptName = getScript();
@@ -88,11 +86,9 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
 
-                       OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
                        OptimizerUtils.ASYNC_BROADCAST_SPARK = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        OptimizerUtils.ASYNC_BROADCAST_SPARK = false;
-                       OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
                        HashMap<MatrixValue.CellIndex, Double> R_bc = 
readDMLScalarFromOutputDir("R");
 
                        //compare matrices
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
index 1a899d3d66..eda92023b9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/CheckpointSharedOpsTest.java
@@ -79,11 +79,11 @@ public class CheckpointSharedOpsTest extends 
AutomatedTestBase {
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
                        long numCP = 
Statistics.getCPHeavyHitterCount("sp_chkpoint");
 
-                       OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
+                       OptimizerUtils.ASYNC_CHECKPOINT_SPARK = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R_mp = 
readDMLScalarFromOutputDir("R");
                        long numCP_maxp = 
Statistics.getCPHeavyHitterCount("sp_chkpoint");
-                       OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
+                       OptimizerUtils.ASYNC_CHECKPOINT_SPARK = false;
 
                        //compare matrices
                        boolean matchVal = TestUtils.compareMatrices(R, R_mp, 
1e-6, "Origin", "withPrefetch");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index b863c81a29..f821af5eb0 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -97,7 +97,6 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
 
-                       OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
                        OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        OptimizerUtils.ASYNC_PREFETCH_SPARK = false;

Reply via email to