This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new bc0c19dcc8 [SYSTEMDS-3469] New operator ordering to maximize inter-op 
parallelism
bc0c19dcc8 is described below

commit bc0c19dcc8776e4fc3ac36bfb2dc6c5394541a6f
Author: Arnab Phani <[email protected]>
AuthorDate: Fri Nov 25 17:46:31 2022 +0100

    [SYSTEMDS-3469] New operator ordering to maximize inter-op parallelism
    
    This patch introduces a new heuristic-based operator linearization order,
    which aims to maximize inter-operator parallelism among Spark and local
    operators. We first traverse the LOP DAGs to collect the roots of the Spark
    operator chains and the number of Spark instructions in all subDAGs. We
    then first place the Spark operator chains followed by the CP lanes.
    Finally, we place the appropriate asynchronous operators to trigger the
    Spark operator chains in parallel.
    This change along with the future-based execution of Spark actions and
    a manual reuse of partitioned broadcast variables improve lmDS by 2x.
    
    Closes #1736
---
 .../apache/sysds/conf/ConfigurationManager.java    |  15 ++
 .../java/org/apache/sysds/hops/OptimizerUtils.java |   6 +
 src/main/java/org/apache/sysds/lops/Lop.java       |  18 +-
 .../java/org/apache/sysds/lops/compile/Dag.java    |  13 +-
 .../lops/compile/linearization/ILinearize.java     | 209 ++++++++++++++++++++-
 .../context/SparkExecutionContext.java             |  10 +-
 .../spark/AggregateUnarySPInstruction.java         |   2 +-
 .../instructions/spark/TsmmSPInstruction.java      |  53 +++++-
 .../test/functions/async/AsyncBroadcastTest.java   |   2 +
 ...dcastTest.java => MaxParallelizeOrderTest.java} |  55 +++---
 .../test/functions/async/PrefetchRDDTest.java      |   5 +-
 .../linearization/DagLinearizationTest.java        |   2 +-
 .../functions/async/MaxParallelizeOrder1.dml       |  52 +++++
 .../functions/async/MaxParallelizeOrder2.dml       |  69 +++++++
 ...ect.xml => SystemDS-config-max-parallelize.xml} |   2 +-
 15 files changed, 447 insertions(+), 66 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java 
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 930d26a6d0..bb6172993a 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -23,6 +23,7 @@ import org.apache.hadoop.mapred.JobConf;
 import org.apache.sysds.conf.CompilerConfig.ConfigType;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Compression.CompressConfig;
+import org.apache.sysds.lops.compile.linearization.ILinearize;
 
 /**
  * Singleton for accessing the parsed and merged system configuration.
@@ -237,11 +238,25 @@ public class ConfigurationManager
                        || OptimizerUtils.ASYNC_PREFETCH_SPARK);
        }
 
+       public static boolean isMaxPrallelizeEnabled() {
+               return (getLinearizationOrder() == 
ILinearize.DagLinearization.MAX_PARALLELIZE
+                       || OptimizerUtils.MAX_PARALLELIZE_ORDER);
+       }
+
        public static boolean isBroadcastEnabled() {
                return 
(getDMLConfig().getBooleanValue(DMLConfig.ASYNC_SPARK_BROADCAST)
                        || OptimizerUtils.ASYNC_BROADCAST_SPARK);
        }
 
+       public static ILinearize.DagLinearization getLinearizationOrder() {
+               if (OptimizerUtils.MAX_PARALLELIZE_ORDER)
+                       return ILinearize.DagLinearization.MAX_PARALLELIZE;
+               else
+                       return ILinearize.DagLinearization
+                       
.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
+
+       }
+
        ///////////////////////////////////////
        // Thread-local classes
        
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index ccee9c96df..d2e9670362 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -285,6 +285,12 @@ public class OptimizerUtils
        public static boolean ASYNC_PREFETCH_SPARK = false;
        public static boolean ASYNC_BROADCAST_SPARK = false;
 
+       /**
+        * Heuristic-based instruction ordering to maximize inter-operator 
parallelism.
+        * Place the Spark operator chains first and trigger them to execute in 
parallel.
+        */
+       public static boolean MAX_PARALLELIZE_ORDER = false;
+
        //////////////////////
        // Optimizer levels //
        //////////////////////
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java 
b/src/main/java/org/apache/sysds/lops/Lop.java
index 440669d13a..3f1cdfe8f6 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -188,6 +188,14 @@ public abstract class Lop
                _visited = visited;
        }
 
+       public void setVisited() {
+               setVisited(VisitStatus.DONE);
+       }
+
+       public boolean isVisited() {
+               return _visited == VisitStatus.DONE;
+       }
+
        
        public boolean[] getReachable() {
                return reachable;
@@ -297,6 +305,10 @@ public abstract class Lop
                }
        }
 
+       public void removeInput(Lop op) {
+               inputs.remove(op);
+       }
+
        /**
         * Method to add output to Lop
         * 
@@ -414,7 +426,11 @@ public abstract class Lop
        public void setExecType(ExecType newExecType){
                lps.setExecType(newExecType);
        }
-       
+
+       public boolean isExecSpark () {
+               return (lps.getExecType() == ExecType.SPARK);
+       }
+
        public boolean getProducesIntermediateOutput() {
                return lps.getProducesIntermediateOutput();
        }
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java 
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index f87163eee3..2efbea8221 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -74,7 +74,6 @@ import 
org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
-
 /**
  * 
  * Class to maintain a DAG of lops and compile it into 
@@ -193,17 +192,11 @@ public class Dag<N extends Lop>
                }
 
                List<Lop> node_v = ILinearize.linearize(nodes);
-               
-               // add Prefetch and broadcast lops, if necessary
-               List<Lop> node_pf = ConfigurationManager.isPrefetchEnabled() ? 
addPrefetchLop(node_v) : node_v;
-               List<Lop> node_bc = ConfigurationManager.isBroadcastEnabled() ? 
addBroadcastLop(node_pf) : node_pf;
-               // TODO: Merge via a single traversal of the nodes
-
-               prefetchFederated(node_bc);
+               prefetchFederated(node_v);
 
                // do greedy grouping of operations
-               ArrayList<Instruction> inst = doPlainInstructionGen(sb, 
node_bc);
-               
+               ArrayList<Instruction> inst = doPlainInstructionGen(sb, node_v);
+
                // cleanup instruction (e.g., create packed rmvar instructions)
                return cleanupInstructions(inst);
        }
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index f55271d530..d867a91f4a 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -21,8 +21,10 @@ package org.apache.sysds.lops.compile.linearization;
 
 import java.util.AbstractMap;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -31,30 +33,50 @@ import java.util.stream.Stream;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
 import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.CSVReBlock;
+import org.apache.sysds.lops.CentralMoment;
+import org.apache.sysds.lops.Checkpoint;
+import org.apache.sysds.lops.CoVariance;
+import org.apache.sysds.lops.GroupedAggregate;
+import org.apache.sysds.lops.GroupedAggregateM;
 import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.MMTSJ;
+import org.apache.sysds.lops.MMZip;
+import org.apache.sysds.lops.MapMultChain;
+import org.apache.sysds.lops.ParameterizedBuiltin;
+import org.apache.sysds.lops.PickByCount;
+import org.apache.sysds.lops.ReBlock;
+import org.apache.sysds.lops.SpoofFused;
+import org.apache.sysds.lops.UAggOuterChain;
+import org.apache.sysds.lops.UnaryCP;
 
 /**
  * A interface for the linearization algorithms that order the DAG nodes into 
a sequence of instructions to execute.
- * 
+ *
  * https://en.wikipedia.org/wiki/Linearizability#Linearization_points
  */
 public interface ILinearize {
        public static Log LOG = LogFactory.getLog(ILinearize.class.getName());
 
        public enum DagLinearization {
-               DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE
+               DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE
        }
 
        public static List<Lop> linearize(List<Lop> v) {
                try {
                        DMLConfig dmlConfig = 
ConfigurationManager.getDMLConfig();
-                       DagLinearization linearization = DagLinearization
-                               
.valueOf(dmlConfig.getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
+                       DagLinearization linearization = 
ConfigurationManager.getLinearizationOrder();
 
                        switch(linearization) {
+                               case MAX_PARALLELIZE:
+                                       return doMaxParallelizeSort(v);
                                case MIN_INTERMEDIATE:
                                        return doMinIntermediateSort(v);
                                case BREADTH_FIRST:
@@ -65,7 +87,7 @@ public interface ILinearize {
                        }
                }
                catch(Exception e) {
-                       LOG.warn("Invalid or failed DAG_LINEARIZATION, fallback 
to DEPTH_FIRST ordering");
+                       LOG.warn("Invalid DAG_LINEARIZATION 
"+ConfigurationManager.getLinearizationOrder()+", fallback to DEPTH_FIRST 
ordering");
                        return depthFirst(v);
                }
        }
@@ -155,4 +177,181 @@ public interface ILinearize {
                        sortRecursive(result, e.getKey().getInputs(), 
remaining);
                }
        }
+
+       // Place the Spark operation chains first (more expensive to less 
expensive),
+       // followed by asynchronously triggering operators and CP chains.
+       private static List<Lop> doMaxParallelizeSort(List<Lop> v)
+       {
+               List<Lop> final_v = null;
+               if (v.stream().anyMatch(ILinearize::isSparkAction)) {
+                       // Step 1: Collect the Spark roots and #Spark 
instructions in each subDAG
+                       Map<Long, Integer> sparkOpCount = new HashMap<>();
+                       List<Lop> roots = v.stream().filter(l -> 
l.getOutputs().isEmpty()).collect(Collectors.toList());
+                       List<Lop> sparkRoots = new ArrayList<>();
+                       roots.forEach(r -> collectSparkRoots(r, sparkOpCount, 
sparkRoots));
+
+                       // Step 2: Depth-first linearization. Place the Spark 
OPs first.
+                       // Sort the Spark roots based on number of Spark 
operators descending
+                       ArrayList<Lop> operatorList = new ArrayList<>();
+                       Lop[] sortedSPRoots = sparkRoots.toArray(new Lop[0]);
+                       Arrays.sort(sortedSPRoots, (l1, l2) -> 
sparkOpCount.get(l2.getID()) - sparkOpCount.get(l1.getID()));
+                       Arrays.stream(sortedSPRoots).forEach(r -> depthFirst(r, 
operatorList, sparkOpCount, true));
+
+                       // Step 3: Place the rest of the operators (CP). Sort 
the CP roots based on
+                       // #Spark operators in ascending order, i.e. execute 
the independent CP chains first
+                       roots.forEach(r -> depthFirst(r, operatorList, 
sparkOpCount, false));
+                       roots.forEach(Lop::resetVisitStatus);
+                       final_v = operatorList;
+               }
+               else
+                       // Fall back to depth if none of the operators returns 
results back to local
+                       final_v = depthFirst(v);
+
+               // Step 4: Add Prefetch and Broadcast lops if necessary
+               List<Lop> v_pf = ConfigurationManager.isPrefetchEnabled() ? 
addPrefetchLop(final_v) : final_v;
+               List<Lop> v_bc = ConfigurationManager.isBroadcastEnabled() ? 
addBroadcastLop(v_pf) : v_pf;
+               // TODO: Merge into a single traversal
+
+               return v_bc;
+       }
+
+       // Gather the Spark operators which return intermediates to local 
(actions/single_block)
+       // In addition count the number of Spark OPs underneath every Operator
+       private static int collectSparkRoots(Lop root, Map<Long, Integer> 
sparkOpCount, List<Lop> sparkRoots) {
+               if (sparkOpCount.containsKey(root.getID())) //visited before
+                       return sparkOpCount.get(root.getID());
+
+               // Sum Spark operators in the child DAGs
+               int total = 0;
+               for (Lop input : root.getInputs())
+                       total += collectSparkRoots(input, sparkOpCount, 
sparkRoots);
+
+               // Check if this node is Spark
+               total = root.isExecSpark() ? total + 1 : total;
+               sparkOpCount.put(root.getID(), total);
+
+               // Triggering point: Spark operator with all CP consumers
+               if (isSparkAction(root) && root.isAllOutputsCP())
+                       sparkRoots.add(root);
+
+               return total;
+       }
+
+       // Place the operators in a depth-first manner, but order
+       // the DAGs based on number of Spark operators
+       private static void depthFirst(Lop root, ArrayList<Lop> opList, 
Map<Long, Integer> sparkOpCount, boolean sparkFirst) {
+               if (root.isVisited())
+                       return;
+
+               if (root.getInputs().isEmpty()) {  //leaf node
+                       opList.add(root);
+                       root.setVisited();
+                       return;
+               }
+               // Sort the inputs based on number of Spark operators
+               Lop[] sortedInputs = root.getInputs().toArray(new Lop[0]);
+               if (sparkFirst) //to place the child DAG with more Spark OPs 
first
+                       Arrays.sort(sortedInputs, (l1, l2) -> 
sparkOpCount.get(l2.getID()) - sparkOpCount.get(l1.getID()));
+               else //to place the child DAG with more CP OPs first
+                       Arrays.sort(sortedInputs, Comparator.comparingInt(l -> 
sparkOpCount.get(l.getID())));
+
+               for (Lop input : sortedInputs)
+                       depthFirst(input, opList, sparkOpCount, sparkFirst);
+
+               opList.add(root);
+               root.setVisited();
+       }
+
+       private static boolean isSparkAction(Lop lop) {
+               return lop.isExecSpark() && (lop.getAggType() == 
SparkAggType.SINGLE_BLOCK
+                       || lop.getDataType() == DataType.SCALAR || lop 
instanceof MapMultChain
+                       || lop instanceof PickByCount || lop instanceof MMZip 
|| lop instanceof CentralMoment
+                       || lop instanceof CoVariance || lop instanceof MMTSJ);
+       }
+
+       private static List<Lop> addPrefetchLop(List<Lop> nodes) {
+               List<Lop> nodesWithPrefetch = new ArrayList<>();
+
+               //Find the Spark nodes with all CP outputs
+               for (Lop l : nodes) {
+                       nodesWithPrefetch.add(l);
+                       if (isPrefetchNeeded(l)) {
+                               //TODO: No prefetch if the parent is placed 
right after the spark OP
+                               //or push the parent further to increase 
parallelism
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               //Construct a Prefetch lop that takes this 
Spark node as a input
+                               UnaryCP prefetch = new UnaryCP(l, 
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
+                               for (Lop outCP : oldOuts) {
+                                       //Rewire l -> outCP to l -> Prefetch -> 
outCP
+                                       prefetch.addOutput(outCP);
+                                       outCP.replaceInput(l, prefetch);
+                                       l.removeOutput(outCP);
+                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
+                               }
+                               //Place it immediately after the Spark lop in 
the node list
+                               nodesWithPrefetch.add(prefetch);
+                       }
+               }
+               return nodesWithPrefetch;
+       }
+
+       private static List<Lop> addBroadcastLop(List<Lop> nodes) {
+               List<Lop> nodesWithBroadcast = new ArrayList<>();
+
+               for (Lop l : nodes) {
+                       nodesWithBroadcast.add(l);
+                       if (isBroadcastNeeded(l)) {
+                               List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
+                               //Construct a Broadcast lop that takes this 
Spark node as an input
+                               UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST, 
l.getDataType(), l.getValueType(), ExecType.CP);
+                               //FIXME: Wire Broadcast only with the necessary 
outputs
+                               for (Lop outCP : oldOuts) {
+                                       //Rewire l -> outCP to l -> Broadcast 
-> outCP
+                                       bc.addOutput(outCP);
+                                       outCP.replaceInput(l, bc);
+                                       l.removeOutput(outCP);
+                                       //FIXME: Rewire _inputParams when 
needed (e.g. GroupedAggregate)
+                               }
+                               //Place it immediately after the Spark lop in 
the node list
+                               nodesWithBroadcast.add(bc);
+                       }
+               }
+               return nodesWithBroadcast;
+       }
+
+       private static boolean isPrefetchNeeded(Lop lop) {
+               // Run Prefetch for a Spark instruction if the instruction is a 
Transformation
+               // and the output is consumed by only CP instructions.
+               boolean transformOP = lop.getExecType() == ExecType.SPARK && 
lop.getAggType() != SparkAggType.SINGLE_BLOCK
+                               // Always Action operations
+                               && !(lop.getDataType() == DataType.SCALAR)
+                               && !(lop instanceof MapMultChain) && !(lop 
instanceof PickByCount)
+                               && !(lop instanceof MMZip) && !(lop instanceof 
CentralMoment)
+                               && !(lop instanceof CoVariance)
+                               // Not qualified for prefetching
+                               && !(lop instanceof Checkpoint) && !(lop 
instanceof ReBlock)
+                               && !(lop instanceof CSVReBlock)
+                               // Cannot filter Transformation cases from 
Actions (FIXME)
+                               && !(lop instanceof MMTSJ) && !(lop instanceof 
UAggOuterChain)
+                               && !(lop instanceof ParameterizedBuiltin) && 
!(lop instanceof SpoofFused);
+
+               //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
+               boolean hasParameterizedOut = lop.getOutputs().stream()
+                               .anyMatch(out -> ((out instanceof 
ParameterizedBuiltin)
+                                       || (out instanceof GroupedAggregate)
+                                       || (out instanceof GroupedAggregateM)));
+               //TODO: support non-matrix outputs
+               return transformOP && !hasParameterizedOut
+                               && lop.isAllOutputsCP() && lop.getDataType() == 
DataType.MATRIX;
+       }
+
+       private static boolean isBroadcastNeeded(Lop lop) {
+               // Asynchronously broadcast a matrix if that is produced by a 
CP instruction,
+               // and at least one Spark parent needs to broadcast this 
intermediate (eg. mapmm)
+               boolean isBc = lop.getOutputs().stream()
+                               .anyMatch(out -> (out.getBroadcastInput() == 
lop));
+               //TODO: Early broadcast objects that are bigger than a single 
block
+               //return isCP && isBc && lop.getDataTypes() == DataType.Matrix;
+               return isBc && lop.getDataType() == DataType.MATRIX;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 388eb462f9..df8f84c6b9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -672,11 +672,11 @@ public class SparkExecutionContext extends 
ExecutionContext
                //the broadcasts are created (other than in local mode) in 
order to avoid 
                //unnecessary memory requirements during the lifetime of this 
broadcast handle.
                
-               long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
 
                PartitionedBroadcast<MatrixBlock> bret = null;
 
                synchronized (mo) {  //synchronize with the async. broadcast 
thread
+                       long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                        //reuse existing broadcast handle
                        if (mo.getBroadcastHandle() != null && 
mo.getBroadcastHandle().isPartitionedBroadcastValid()) {
                                bret = 
mo.getBroadcastHandle().getPartitionedBroadcast();
@@ -719,10 +719,10 @@ public class SparkExecutionContext extends 
ExecutionContext
                                        
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getDataCharacteristics()));
                                
CacheableData.addBroadcastSize(mo.getBroadcastHandle().getSize());
 
-                               if (DMLScript.STATISTICS) {
-                                       
SparkStatistics.accBroadCastTime(System.nanoTime() - t0);
-                                       SparkStatistics.incBroadcastCount(1);
-                               }
+                       }
+                       if (DMLScript.STATISTICS) {
+                               
SparkStatistics.accBroadCastTime(System.nanoTime() - t0);
+                               SparkStatistics.incBroadcastCount(1);
                        }
                }
                return bret;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 52bab3958f..89f385c54e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -109,7 +109,7 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction {
                //perform aggregation if necessary and put output into symbol 
table
                if( _aggtype == SparkAggType.SINGLE_BLOCK )
                {
-                       if (ConfigurationManager.isPrefetchEnabled()) {
+                       if (ConfigurationManager.isMaxPrallelizeEnabled()) {
                                //Trigger the chain of Spark operations and 
maintain a future to the result
                                //TODO: Make memory for the future matrix block
                                try {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
index 69db3787e5..17cef61158 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/TsmmSPInstruction.java
@@ -23,6 +23,7 @@ package org.apache.sysds.runtime.instructions.spark;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
+import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.lops.MMTSJ.MMTSJType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -33,8 +34,13 @@ import 
org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import scala.Tuple2;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
 public class TsmmSPInstruction extends UnarySPInstruction {
        private MMTSJType _type = null;
 
@@ -61,15 +67,29 @@ public class TsmmSPInstruction extends UnarySPInstruction {
                
                //get input
                JavaPairRDD<MatrixIndexes,MatrixBlock> in = 
sec.getBinaryMatrixBlockRDDHandleForVariable( input1.getName() );
-               
-               //execute tsmm instruction (always produce exactly one output 
block)
-               //(this formulation with values() requires --conf 
spark.driver.maxResultSize=0)
-               JavaRDD<MatrixBlock> tmp = in.map(new RDDTSMMFunction(_type));
-               MatrixBlock out = RDDAggregateUtils.sumStable(tmp);
 
-               //put output block into symbol table (no lineage because single 
block)
-               //this also includes implicit maintenance of matrix 
characteristics
-               sec.setMatrixOutput(output.getName(), out);
+               if (ConfigurationManager.isMaxPrallelizeEnabled()) {
+                       try {
+                               if (CommonThreadPool.triggerRemoteOPsPool == 
null)
+                                       CommonThreadPool.triggerRemoteOPsPool = 
Executors.newCachedThreadPool();
+                               TsmmTask task = new TsmmTask(in, _type);
+                               Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                               sec.setMatrixOutput(output.getName(), 
future_out);
+                       }
+                       catch(Exception ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+               }
+               else {
+                       //execute tsmm instruction (always produce exactly one 
output block)
+                       //(this formulation with values() requires --conf 
spark.driver.maxResultSize=0)
+                       JavaRDD<MatrixBlock> tmp = in.map(new 
RDDTSMMFunction(_type));
+                       MatrixBlock out = RDDAggregateUtils.sumStable(tmp);
+
+                       //put output block into symbol table (no lineage 
because single block)
+                       //this also includes implicit maintenance of matrix 
characteristics
+                       sec.setMatrixOutput(output.getName(), out);
+               }
        }
 
        private static class RDDTSMMFunction implements 
Function<Tuple2<MatrixIndexes,MatrixBlock>, MatrixBlock> 
@@ -90,5 +110,22 @@ public class TsmmSPInstruction extends UnarySPInstruction {
                        return arg0._2().transposeSelfMatrixMultOperations(new 
MatrixBlock(), _type);
                }
        }
+
+       private static class TsmmTask implements Callable<MatrixBlock> {
+               JavaPairRDD<MatrixIndexes, MatrixBlock> _in;
+               MMTSJType _type;
+
+               TsmmTask(JavaPairRDD<MatrixIndexes, MatrixBlock> in, MMTSJType 
type) {
+                       _in = in;
+                       _type = type;
+               }
+               @Override
+               public MatrixBlock call() {
+                       //execute tsmm instruction (always produce exactly one 
output block)
+                       //(this formulation with values() requires --conf 
spark.driver.maxResultSize=0)
+                       JavaRDD<MatrixBlock> tmp = _in.map(new 
RDDTSMMFunction(_type));
+                       return RDDAggregateUtils.sumStable(tmp);
+               }
+       }
        
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
index edfef8998d..c8b7fdd94f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
@@ -88,9 +88,11 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
 
+                       OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
                        OptimizerUtils.ASYNC_BROADCAST_SPARK = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        OptimizerUtils.ASYNC_BROADCAST_SPARK = false;
+                       OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
                        HashMap<MatrixValue.CellIndex, Double> R_bc = 
readDMLScalarFromOutputDir("R");
 
                        //compare matrices
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
similarity index 69%
copy from 
src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
copy to 
src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
index edfef8998d..be011925d2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/AsyncBroadcastTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
@@ -31,32 +31,29 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.apache.sysds.utils.Statistics;
-import org.apache.sysds.utils.stats.SparkStatistics;
-import org.junit.Assert;
 import org.junit.Test;
 
-public class AsyncBroadcastTest extends AutomatedTestBase {
-       
+public class MaxParallelizeOrderTest extends AutomatedTestBase {
+
        protected static final String TEST_DIR = "functions/async/";
-       protected static final String TEST_NAME = "BroadcastVar";
+       protected static final String TEST_NAME = "MaxParallelizeOrder";
        protected static final int TEST_VARIANTS = 2;
-       protected static String TEST_CLASS_DIR = TEST_DIR + 
AsyncBroadcastTest.class.getSimpleName() + "/";
-       
+       protected static String TEST_CLASS_DIR = TEST_DIR + 
MaxParallelizeOrderTest.class.getSimpleName() + "/";
+
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for( int i=1; i<=TEST_VARIANTS; i++ )
+               for(int i=1; i<=TEST_VARIANTS; i++)
                        addTestConfiguration(TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i));
        }
-       
+
        @Test
-       public void testAsyncBroadcast1() {
+       public void testlmds() {
                runTest(TEST_NAME+"1");
        }
 
        @Test
-       public void testAsyncBroadcast2() {
+       public void testl2svm() {
                runTest(TEST_NAME+"2");
        }
 
@@ -65,21 +62,19 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
                boolean old_trans_exec_type = 
OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE;
                ExecMode oldPlatform = setExecMode(ExecMode.HYBRID);
-               
+
                long oldmem = InfrastructureAnalyzer.getLocalMaxMemory();
                long mem = 1024*1024*8;
                InfrastructureAnalyzer.setLocalMaxMemory(mem);
-               
+
                try {
-                       //OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
-                       //OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = false;
-                       OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = false;
                        getAndLoadTestConfiguration(testname);
                        fullDMLScriptName = getScript();
-                       
+
                        List<String> proArgs = new ArrayList<>();
-                       
-                       //proArgs.add("-explain");
+
+                       proArgs.add("-explain");
+                       //proArgs.add("recompile_runtime");
                        proArgs.add("-stats");
                        proArgs.add("-args");
                        proArgs.add(output("R"));
@@ -88,21 +83,17 @@ public class AsyncBroadcastTest extends AutomatedTestBase {
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
 
-                       OptimizerUtils.ASYNC_BROADCAST_SPARK = true;
+                       OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
+                       OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-                       OptimizerUtils.ASYNC_BROADCAST_SPARK = false;
-                       HashMap<MatrixValue.CellIndex, Double> R_bc = 
readDMLScalarFromOutputDir("R");
+                       HashMap<MatrixValue.CellIndex, Double> R_mp = 
readDMLScalarFromOutputDir("R");
+                       OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+                       OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
 
                        //compare matrices
-                       TestUtils.compareMatrices(R, R_bc, 1e-6, "Origin", 
"withBroadcast");
-
-                       //assert called and successful early broadcast counts
-                       long expected_numBC = 1;
-                       long expected_successBC = 1;
-                       long numBC = 
Statistics.getCPHeavyHitterCount("broadcast");
-                       Assert.assertTrue("Violated Broadcast instruction 
count: "+numBC, numBC == expected_numBC);
-                       long successBC = 
SparkStatistics.getAsyncBroadcastCount();
-                       Assert.assertTrue("Violated successful Broadcast count: 
"+successBC, successBC == expected_successBC);
+                       boolean matchVal = TestUtils.compareMatrices(R, R_mp, 
1e-6, "Origin", "withPrefetch");
+                       if (!matchVal)
+                               System.out.println("Value w/o Prefetch "+R+" w/ 
Prefetch "+R_mp);
                } finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
old_simplification;
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = 
old_sum_product;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java 
b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
index a2fa45c2f4..11af05f19d 100644
--- a/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/async/PrefetchRDDTest.java
@@ -32,7 +32,6 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.apache.sysds.utils.Statistics;
-import org.apache.sysds.utils.stats.SparkStatistics;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -96,13 +95,15 @@ public class PrefetchRDDTest extends AutomatedTestBase {
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R = 
readDMLScalarFromOutputDir("R");
 
+                       OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
                        OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
+                       OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
                        HashMap<MatrixValue.CellIndex, Double> R_pf = 
readDMLScalarFromOutputDir("R");
 
                        //compare matrices
-                       Boolean matchVal = TestUtils.compareMatrices(R, R_pf, 
1e-6, "Origin", "withPrefetch");
+                       boolean matchVal = TestUtils.compareMatrices(R, R_pf, 
1e-6, "Origin", "withPrefetch");
                        if (!matchVal)
                                System.out.println("Value w/o Prefetch "+R+" w/ 
Prefetch "+R_pf);
                        //assert Prefetch instructions and number of success.
diff --git 
a/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
index dc84c75e61..843faa36b9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/linearization/DagLinearizationTest.java
@@ -39,7 +39,7 @@ public class DagLinearizationTest extends AutomatedTestBase {
        private final String testNames[] = {"matrixmult_dag_linearization", 
"csplineCG_dag_linearization",
                "linear_regression_dag_linearization"};
 
-       private final String testConfigs[] = {"breadth-first", "depth-first", 
"incorrect", "min-intermediate"};
+       private final String testConfigs[] = {"breadth-first", "depth-first", 
"min-intermediate", "max-parallelize"};
 
        private final String testDir = "functions/linearization/";
 
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder1.dml 
b/src/test/scripts/functions/async/MaxParallelizeOrder1.dml
new file mode 100644
index 0000000000..218b09cb6e
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder1.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, 
Integer N) return (Matrix[double] beta)
+{
+  A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+  b = t(X) %*% y;
+  beta = solve(A, b);
+}
+
+no_lamda = 10;
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
+
+X = rand(rows=10000, cols=200, seed=42);
+y = rand(rows=10000, cols=1, seed=43);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+  beta = SimlinRegDS(X, y, lamda, N);
+  #beta = lmDS(X=X, y=y, reg=lamda);
+  R[,i] = beta;
+  lamda = lamda + stp;
+  i = i + 1;
+}
+
+R = sum(R);
+write(R, $1, format="text");
+
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder2.dml 
b/src/test/scripts/functions/async/MaxParallelizeOrder2.dml
new file mode 100644
index 0000000000..81e9207105
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder2.dml
@@ -0,0 +1,69 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B, 
Boolean icpt)
+return (Matrix[Double] loss) {
+  if (icpt)
+    X = cbind(X, matrix(1, nrow(X), 1));
+  loss = as.matrix(sum((y - X%*%B)^2));
+}
+
+M = 100000;
+N = 20;
+sp = 1.0;
+no_lamda = 1;
+
+X = rand(rows=M, cols=N, sparsity=sp, seed=42);
+y = rand(rows=M, cols=1, min=0, max=2, seed=42);
+y = ceil(y);
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+Rbeta = matrix(0, rows=ncol(X)+1, cols=no_lamda*2);
+Rloss = matrix(0, rows=no_lamda*2, cols=1);
+i = 1;
+
+
+for (l in 1:no_lamda)
+{
+  beta = l2svm(X=X, Y=y, intercept=FALSE, epsilon=1e-12,
+#      lambda = lamda, maxIterations=10, verbose=FALSE);
+      reg = lamda, verbose=FALSE);
+  Rbeta[1:nrow(beta),i] = beta;
+  Rloss[i,] = l2norm(X, y, beta, FALSE);
+  i = i + 1;
+
+  beta = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12,
+#      lambda = lamda, maxIterations=10, verbose=FALSE);
+      reg = lamda, verbose=FALSE);
+  Rbeta[1:nrow(beta),i] = beta;
+  Rloss[i,] = l2norm(X, y, beta, TRUE);
+  i = i + 1;
+
+  lamda = lamda + stp;
+}
+
+leastLoss = rowIndexMin(t(Rloss));
+bestModel = Rbeta[,as.scalar(leastLoss)];
+
+R = sum(bestModel);
+write(R, $1, format="text");
+
diff --git 
a/src/test/scripts/functions/linearization/SystemDS-config-incorrect.xml 
b/src/test/scripts/functions/linearization/SystemDS-config-max-parallelize.xml
similarity index 90%
rename from 
src/test/scripts/functions/linearization/SystemDS-config-incorrect.xml
rename to 
src/test/scripts/functions/linearization/SystemDS-config-max-parallelize.xml
index 62183138b3..25725397dc 100644
--- a/src/test/scripts/functions/linearization/SystemDS-config-incorrect.xml
+++ 
b/src/test/scripts/functions/linearization/SystemDS-config-max-parallelize.xml
@@ -18,5 +18,5 @@
 -->
 
 <root>
-    
<sysds.compile.linearization>something_incorrect</sysds.compile.linearization>
+    <sysds.compile.linearization>max_parallelize</sysds.compile.linearization>
 </root>


Reply via email to