This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 956a686c60 [SYSTEMDS-3473] Push down rmvars for asynchronous 
instructions
956a686c60 is described below

commit 956a686c60afdcce95ac6f5398f181dc66e62ba2
Author: Arnab Phani <[email protected]>
AuthorDate: Thu Dec 1 22:52:24 2022 +0100

    [SYSTEMDS-3473] Push down rmvars for asynchronous instructions
    
    This patch repositions the rmvar instructions for the inputs to an
    asynchronous instruction after the consumers of the asynchronous
    instruction. This change allows keeping the inputs to an asynchronous
    operator alive until get() is called for the future object.
    Moreover, this patch adds more tests for the new operator ordering
    and fixes minor bugs.
    
    Closes #1745
---
 src/main/java/org/apache/sysds/lops/Lop.java       |  14 +++
 .../java/org/apache/sysds/lops/compile/Dag.java    |  26 +++-
 .../lops/compile/linearization/ILinearize.java     |  20 ++--
 .../instructions/spark/CpmmSPInstruction.java      | 133 +++++++++++++++++----
 .../instructions/spark/MapmmSPInstruction.java     |  50 +++++++-
 .../functions/async/MaxParallelizeOrderTest.java   |  15 ++-
 .../functions/async/MaxParallelizeOrder3.dml       |  36 ++++++
 .../functions/async/MaxParallelizeOrder4.dml       |  37 ++++++
 8 files changed, 288 insertions(+), 43 deletions(-)

diff --git a/src/main/java/org/apache/sysds/lops/Lop.java 
b/src/main/java/org/apache/sysds/lops/Lop.java
index 3f1cdfe8f6..3064d2b39a 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -150,6 +150,12 @@ public abstract class Lop
        protected OutputParameters outParams = null;
 
        protected LopProperties lps = null;
+
+       /**
+        * Indicates if this lop is a candidate for asynchronous execution.
+        * Examples include spark unary aggregate, mapmm, prefetch
+        */
+       protected boolean _asynchronous = false;
        
 
        /**
@@ -365,6 +371,14 @@ public abstract class Lop
                return consumerCount;
        }
 
+       public void setAsynchronous(boolean isAsync) {
+               _asynchronous = isAsync;
+       }
+
+       public boolean isAsynchronousOp() {
+               return _asynchronous;
+       }
+
        /**
         * Method to have Lops print their state. This is for debugging 
purposes.
         */
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java 
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 77325fb297..ade809aea6 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -422,16 +422,27 @@ public class Dag<N extends Lop>
         * @param delteInst list of instructions
         */
        private static void processConsumersForInputs(Lop node, 
List<Instruction> inst, List<Instruction> delteInst) {
+               // The asynchronous instructions execute lazily. The inputs to 
an asynchronous instruction
+               // must live till the outputs of the async. instruction are 
consumed (i.e. future.get is called)
+               if (node.isAsynchronousOp())
+                       return;
+
                // reduce the consumer count for all input lops
-               // if the count becomes zero, then then variable associated w/ 
input can be removed
+               // if the count becomes zero, then variable associated w/ input 
can be removed
                for(Lop in : node.getInputs() )
                        processConsumers(in, inst, delteInst, null);
        }
        
        private static void processConsumers(Lop node, List<Instruction> inst, 
List<Instruction> deleteInst, Lop locationInfo) {
                // reduce the consumer count for all input lops
-               // if the count becomes zero, then then variable associated w/ 
input can be removed
+               // if the count becomes zero, then variable associated w/ input 
can be removed
+
                if ( node.removeConsumer() == 0 ) {
+                       // The inputs to the asynchronous input can be safely 
removed at this point as
+                       // the outputs of the asynchronous instruction are 
consumed.
+                       if (node.isAsynchronousOp())
+                               processConsumerIfAsync(node, inst, deleteInst);
+
                        if ( node.isDataExecLocation() && 
((Data)node).isLiteral() ) {
                                return;
                        }
@@ -450,6 +461,17 @@ public class Dag<N extends Lop>
                        excludeRemoveInstruction(label, deleteInst);
                }
        }
+
+       // Generate rmvar instructions for the inputs of an asynchronous 
instruction.
+       private static void processConsumerIfAsync(Lop node, List<Instruction> 
inst, List<Instruction> deleteInst) {
+               if (!node.isAsynchronousOp())
+                       return;
+
+               // Temporarily disable the _asynchronous flag to generate 
rmvars for the inputs
+               node.setAsynchronous(false);
+               processConsumersForInputs(node, inst, deleteInst);
+               node.setAsynchronous(true);
+       }
        
        /**
         * Method to generate instructions that are executed in Control 
Program. At
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
index e78530b33b..7eee970e2b 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ b/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
@@ -183,7 +183,7 @@ public interface ILinearize {
        private static List<Lop> doMaxParallelizeSort(List<Lop> v)
        {
                List<Lop> final_v = null;
-               if (v.stream().anyMatch(ILinearize::isSparkAction)) {
+               if (v.stream().anyMatch(ILinearize::isSparkTriggeringOp)) {
                        // Step 1: Collect the Spark roots and #Spark 
instructions in each subDAG
                        Map<Long, Integer> sparkOpCount = new HashMap<>();
                        List<Lop> roots = v.stream().filter(l -> 
l.getOutputs().isEmpty()).collect(Collectors.toList());
@@ -221,7 +221,7 @@ public interface ILinearize {
                if (sparkOpCount.containsKey(root.getID())) //visited before
                        return sparkOpCount.get(root.getID());
 
-               // Sum Spark operators in the child DAGs
+               // Aggregate #Spark operators in the child DAGs
                int total = 0;
                for (Lop input : root.getInputs())
                        total += collectSparkRoots(input, sparkOpCount, 
sparkRoots);
@@ -230,9 +230,11 @@ public interface ILinearize {
                total = root.isExecSpark() ? total + 1 : total;
                sparkOpCount.put(root.getID(), total);
 
-               // Triggering point: Spark operator with all CP consumers
-               if (isSparkAction(root) && root.isAllOutputsCP())
+               // Triggering point: Spark action/operator with all CP consumers
+               if (isSparkTriggeringOp(root)) {
                        sparkRoots.add(root);
+                       root.setAsynchronous(true); //candidate for async. 
execution
+               }
 
                return total;
        }
@@ -262,11 +264,11 @@ public interface ILinearize {
                root.setVisited();
        }
 
-       private static boolean isSparkAction(Lop lop) {
+       private static boolean isSparkTriggeringOp(Lop lop) {
                return lop.isExecSpark() && (lop.getAggType() == 
SparkAggType.SINGLE_BLOCK
                        || lop.getDataType() == DataType.SCALAR || lop 
instanceof MapMultChain
                        || lop instanceof PickByCount || lop instanceof MMZip 
|| lop instanceof CentralMoment
-                       || lop instanceof CoVariance || lop instanceof MMTSJ);
+                       || lop instanceof CoVariance || lop instanceof MMTSJ || 
lop.isAllOutputsCP());
        }
 
        private static List<Lop> addPrefetchLop(List<Lop> nodes) {
@@ -276,11 +278,12 @@ public interface ILinearize {
                for (Lop l : nodes) {
                        nodesWithPrefetch.add(l);
                        if (isPrefetchNeeded(l)) {
-                               //TODO: No prefetch if the parent is placed 
right after the spark OP
-                               //or push the parent further to increase 
parallelism
                                List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
                                //Construct a Prefetch lop that takes this 
Spark node as a input
                                UnaryCP prefetch = new UnaryCP(l, 
OpOp1.PREFETCH, l.getDataType(), l.getValueType(), ExecType.CP);
+                               prefetch.setAsynchronous(true);
+                               //Reset asynchronous flag for the input if 
already set (e.g. mapmm -> prefetch)
+                               l.setAsynchronous(false);
                                for (Lop outCP : oldOuts) {
                                        //Rewire l -> outCP to l -> Prefetch -> 
outCP
                                        prefetch.addOutput(outCP);
@@ -304,6 +307,7 @@ public interface ILinearize {
                                List<Lop> oldOuts = new 
ArrayList<>(l.getOutputs());
                                //Construct a Broadcast lop that takes this 
Spark node as an input
                                UnaryCP bc = new UnaryCP(l, OpOp1.BROADCAST, 
l.getDataType(), l.getValueType(), ExecType.CP);
+                               bc.setAsynchronous(true);
                                //FIXME: Wire Broadcast only with the necessary 
outputs
                                for (Lop outCP : oldOuts) {
                                        //Rewire l -> outCP to l -> Broadcast 
-> outCP
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
index 0cbd4acfe0..653596806d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/CpmmSPInstruction.java
@@ -23,6 +23,7 @@ import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -43,8 +44,13 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import scala.Tuple2;
 
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
 /**
  * Cpmm: cross-product matrix multiplication operation (distributed matrix 
multiply
  * by join over common dimension and subsequent aggregation of partial 
results).
@@ -96,19 +102,31 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                }
                
                if( SparkUtils.isHashPartitioned(in1) //ZIPMM-like CPMM
-                       && mc1.getNumRowBlocks()==1 && mc2.getCols()==1 ) {
-                       //note: if the major input is hash-partitioned and it's 
a matrix-vector
-                       //multiply, avoid the index mapping to preserve the 
partitioning similar
-                       //to a ZIPMM but with different transpose 
characteristics
-                       JavaRDD<MatrixBlock> out = in1
-                               .join(in2.mapToPair(new ReorgMapFunction("r'")))
-                               .values().map(new Cpmm2MultiplyFunction())
-                               .filter(new FilterNonEmptyBlocksFunction2());
-                       MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
-                       
-                       //put output block into symbol table (no lineage 
because single block)
-                       //this also includes implicit maintenance of matrix 
characteristics
-                       sec.setMatrixOutput(output.getName(), out2);
+                       && mc1.getNumRowBlocks()==1 && mc2.getCols()==1 )
+               //note: if the major input is hash-partitioned and it's a 
matrix-vector
+               //multiply, avoid the index mapping to preserve the 
partitioning similar
+               //to a ZIPMM but with different transpose characteristics
+               {
+                       if (ConfigurationManager.isMaxPrallelizeEnabled()) {
+                               try {
+                                       
if(CommonThreadPool.triggerRemoteOPsPool == null)
+                                               
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
+                                       CpmmMatrixVectorTask task = new 
CpmmMatrixVectorTask(in1, in2);
+                                       Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                       sec.setMatrixOutput(output.getName(), 
future_out);
+                               }
+                               catch(Exception ex) {
+                                       throw new DMLRuntimeException(ex);
+                               }
+                       }
+                       else {
+                               JavaRDD<MatrixBlock> out = 
in1.join(in2.mapToPair(new ReorgMapFunction("r'"))).values().map(new 
Cpmm2MultiplyFunction()).filter(new FilterNonEmptyBlocksFunction2());
+                               MatrixBlock out2 = 
RDDAggregateUtils.sumStable(out);
+
+                               //put output block into symbol table (no 
lineage because single block)
+                               //this also includes implicit maintenance of 
matrix characteristics
+                               sec.setMatrixOutput(output.getName(), out2);
+                       }
                }
                else //GENERAL CPMM
                {
@@ -119,21 +137,39 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                        //process core cpmm matrix multiply 
                        JavaPairRDD<Long, IndexedMatrixValue> tmp1 = 
in1.mapToPair(new CpmmIndexFunction(true));
                        JavaPairRDD<Long, IndexedMatrixValue> tmp2 = 
in2.mapToPair(new CpmmIndexFunction(false));
-                       JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1
-                               .join(tmp2, numPartJoin)                // join 
over common dimension
-                               .mapToPair(new CpmmMultiplyFunction()); // 
compute block multiplications
-                       
+
                        //process cpmm aggregation and handle outputs
-                       if( _aggtype == SparkAggType.SINGLE_BLOCK ) {
-                               //prune empty blocks and aggregate all results
-                               out = out.filter(new 
FilterNonEmptyBlocksFunction());
-                               MatrixBlock out2 = 
RDDAggregateUtils.sumStable(out);
-                               
-                               //put output block into symbol table (no 
lineage because single block)
-                               //this also includes implicit maintenance of 
matrix characteristics
-                               sec.setMatrixOutput(output.getName(), out2);
+                       if( _aggtype == SparkAggType.SINGLE_BLOCK )
+                       {
+                               if 
(ConfigurationManager.isMaxPrallelizeEnabled()) {
+                                       try {
+                                               
if(CommonThreadPool.triggerRemoteOPsPool == null)
+                                                       
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
+                                               CpmmMatrixMatrixTask task = new 
CpmmMatrixMatrixTask(in1, in2, numPartJoin);
+                                               Future<MatrixBlock> future_out 
= CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                               
sec.setMatrixOutput(output.getName(), future_out);
+                                       }
+                                       catch(Exception ex) { throw new 
DMLRuntimeException(ex); }
+                               }
+                               else {
+                                       JavaPairRDD<MatrixIndexes, MatrixBlock> 
out = tmp1
+                                               .join(tmp2, numPartJoin)        
        // join over common dimension
+                                               .mapToPair(new 
CpmmMultiplyFunction()); // compute block multiplications
+                                       //prune empty blocks and aggregate all 
results
+                                       out = out.filter(new 
FilterNonEmptyBlocksFunction());
+                                       MatrixBlock out2 = 
RDDAggregateUtils.sumStable(out);
+
+                                       //put output block into symbol table 
(no lineage because single block)
+                                       //this also includes implicit 
maintenance of matrix characteristics
+                                       sec.setMatrixOutput(output.getName(), 
out2);
+                               }
+
                        }
-                       else { //DEFAULT: MULTI_BLOCK
+                       else
+                       { //DEFAULT: MULTI_BLOCK
+                               JavaPairRDD<MatrixIndexes,MatrixBlock> out = 
tmp1
+                                       .join(tmp2, numPartJoin)                
// join over common dimension
+                                       .mapToPair(new CpmmMultiplyFunction()); 
// compute block multiplications
                                if( !_outputEmptyBlocks || 
mc1.isNoEmptyBlocks() || mc2.isNoEmptyBlocks() )
                                        out = out.filter(new 
FilterNonEmptyBlocksFunction());
                                out = RDDAggregateUtils.sumByKeyStable(out, 
false);
@@ -234,4 +270,49 @@ public class CpmmSPInstruction extends 
AggregateBinarySPInstruction {
                        return OperationsOnMatrixValues.matMult(in1, in2, new 
MatrixBlock(), _op);
                }
        }
+
+       private static class CpmmMatrixVectorTask implements 
Callable<MatrixBlock>
+       {
+               JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
+               JavaPairRDD<MatrixIndexes, MatrixBlock> _in2;
+
+               CpmmMatrixVectorTask(JavaPairRDD<MatrixIndexes, MatrixBlock> 
in1, JavaPairRDD<MatrixIndexes, MatrixBlock> in2) {
+                       _in1 = in1;
+                       _in2 = in2;
+               }
+               @Override
+               public MatrixBlock call() {
+                               JavaRDD<MatrixBlock> out = _in1
+                               .join(_in2.mapToPair(new 
ReorgMapFunction("r'")))
+                               .values().map(new Cpmm2MultiplyFunction())
+                               .filter(new FilterNonEmptyBlocksFunction2());
+                       return RDDAggregateUtils.sumStable(out);
+               }
+       }
+
+       private static class CpmmMatrixMatrixTask implements 
Callable<MatrixBlock>
+       {
+               JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
+               JavaPairRDD<MatrixIndexes, MatrixBlock> _in2;
+               int _numPartJoin;
+
+               CpmmMatrixMatrixTask(JavaPairRDD<MatrixIndexes, MatrixBlock> 
in1, JavaPairRDD<MatrixIndexes, MatrixBlock> in2, int nPartJoin) {
+                       _in1 = in1;
+                       _in2 = in2;
+                       _numPartJoin = nPartJoin;
+               }
+               @Override
+               public MatrixBlock call() {
+                       //process core cpmm matrix multiply
+                       JavaPairRDD<Long, IndexedMatrixValue> tmp1 = 
_in1.mapToPair(new CpmmIndexFunction(true));
+                       JavaPairRDD<Long, IndexedMatrixValue> tmp2 = 
_in2.mapToPair(new CpmmIndexFunction(false));
+                       JavaPairRDD<MatrixIndexes,MatrixBlock> out = tmp1
+                               .join(tmp2, _numPartJoin)                // 
join over common dimension
+                               .mapToPair(new CpmmMultiplyFunction()); // 
compute block multiplications
+
+                       //prune empty blocks and aggregate all results
+                       out = out.filter(new FilterNonEmptyBlocksFunction());
+                       return RDDAggregateUtils.sumStable(out);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
index 3a1a6c27d9..29f28b604e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MapmmSPInstruction.java
@@ -21,6 +21,9 @@ package org.apache.sysds.runtime.instructions.spark;
 
 
 import java.util.Iterator;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
 import java.util.stream.IntStream;
 
 import org.apache.commons.logging.Log;
@@ -30,6 +33,7 @@ import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.MapMult;
@@ -54,6 +58,7 @@ import 
org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import scala.Tuple2;
 
 public class MapmmSPInstruction extends AggregateBinarySPInstruction {
@@ -135,12 +140,24 @@ public class MapmmSPInstruction extends 
AggregateBinarySPInstruction {
                //execute mapmm and aggregation if necessary and put output 
into symbol table
                if( _aggtype == SparkAggType.SINGLE_BLOCK )
                {
-                       JavaRDD<MatrixBlock> out = in1.map(new 
RDDMapMMFunction2(type, in2));
-                       MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
-                       
-                       //put output block into symbol table (no lineage 
because single block)
-                       //this also includes implicit maintenance of matrix 
characteristics
-                       sec.setMatrixOutput(output.getName(), out2);
+                       if (ConfigurationManager.isMaxPrallelizeEnabled()) {
+                               try {
+                                       
if(CommonThreadPool.triggerRemoteOPsPool == null)
+                                               
CommonThreadPool.triggerRemoteOPsPool = Executors.newCachedThreadPool();
+                                       RDDMapmmTask task = new  
RDDMapmmTask(in1, in2, type);
+                                       Future<MatrixBlock> future_out = 
CommonThreadPool.triggerRemoteOPsPool.submit(task);
+                                       sec.setMatrixOutput(output.getName(), 
future_out);
+                               }
+                               catch(Exception ex) { throw new 
DMLRuntimeException(ex); }
+                       }
+                       else {
+                               JavaRDD<MatrixBlock> out = in1.map(new 
RDDMapMMFunction2(type, in2));
+                               MatrixBlock out2 = 
RDDAggregateUtils.sumStable(out);
+
+                               //put output block into symbol table (no 
lineage because single block)
+                               //this also includes implicit maintenance of 
matrix characteristics
+                               sec.setMatrixOutput(output.getName(), out2);
+                       }
                }
                else //MULTI_BLOCK or NONE
                {
@@ -443,4 +460,25 @@ public class MapmmSPInstruction extends 
AggregateBinarySPInstruction {
                        }
                }
        }
+
+       private static class RDDMapmmTask implements Callable<MatrixBlock>
+       {
+               JavaPairRDD<MatrixIndexes, MatrixBlock> _in1;
+               PartitionedBroadcast<MatrixBlock> _in2;
+               MapMult.CacheType _type;
+
+               RDDMapmmTask(JavaPairRDD<MatrixIndexes, MatrixBlock> in1, 
PartitionedBroadcast<MatrixBlock> in2, MapMult.CacheType type) {
+                       _in1 = in1;
+                       _in2 = in2;
+                       _type = type;
+               }
+
+               @Override
+               public MatrixBlock call() {
+                       //execute mapmm and aggregation if necessary and put 
output into symbol table
+                       JavaRDD<MatrixBlock> out = _in1.map(new 
RDDMapMMFunction2(_type, _in2));
+                       MatrixBlock out2 = RDDAggregateUtils.sumStable(out);
+                       return out2;
+               }
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
 
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
index be011925d2..ee89824c64 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/async/MaxParallelizeOrderTest.java
@@ -37,7 +37,7 @@ public class MaxParallelizeOrderTest extends 
AutomatedTestBase {
 
        protected static final String TEST_DIR = "functions/async/";
        protected static final String TEST_NAME = "MaxParallelizeOrder";
-       protected static final int TEST_VARIANTS = 2;
+       protected static final int TEST_VARIANTS = 4;
        protected static String TEST_CLASS_DIR = TEST_DIR + 
MaxParallelizeOrderTest.class.getSimpleName() + "/";
 
        @Override
@@ -57,6 +57,16 @@ public class MaxParallelizeOrderTest extends 
AutomatedTestBase {
                runTest(TEST_NAME+"2");
        }
 
+       @Test
+       public void testSparkAction() {
+               runTest(TEST_NAME+"3");
+       }
+
+       @Test
+       public void testSparkTransformations() {
+               runTest(TEST_NAME+"4");
+       }
+
        public void runTest(String testname) {
                boolean old_simplification = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                boolean old_sum_product = 
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
@@ -85,10 +95,13 @@ public class MaxParallelizeOrderTest extends 
AutomatedTestBase {
 
                        OptimizerUtils.ASYNC_PREFETCH_SPARK = true;
                        OptimizerUtils.MAX_PARALLELIZE_ORDER = true;
+                       if (testname.equalsIgnoreCase(TEST_NAME+"4"))
+                               OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE 
= false;
                        runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
                        HashMap<MatrixValue.CellIndex, Double> R_mp = 
readDMLScalarFromOutputDir("R");
                        OptimizerUtils.ASYNC_PREFETCH_SPARK = false;
                        OptimizerUtils.MAX_PARALLELIZE_ORDER = false;
+                       OptimizerUtils.ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true;
 
                        //compare matrices
                        boolean matchVal = TestUtils.compareMatrices(R, R_mp, 
1e-6, "Origin", "withPrefetch");
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder3.dml 
b/src/test/scripts/functions/async/MaxParallelizeOrder3.dml
new file mode 100644
index 0000000000..c26dee67c5
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder3.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+
+# CP instructions
+v = ((v + v) * 1 - v) / (1+1);
+v = ((v + v) * 2 - v) / (2+1);
+
+# Spark transformation operations 
+sp1 = X + ceil(X);
+sp2 = sp1 %*% v; #output fits in local
+
+# CP binary triggers the DAG of SP operations
+# if transitive spark exec type is off
+cp = sp2 + sum(v);
+R = sum(cp);
+write(R, $1, format="text");
diff --git a/src/test/scripts/functions/async/MaxParallelizeOrder4.dml 
b/src/test/scripts/functions/async/MaxParallelizeOrder4.dml
new file mode 100644
index 0000000000..2a7f7001ac
--- /dev/null
+++ b/src/test/scripts/functions/async/MaxParallelizeOrder4.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+X = rand(rows=10000, cols=200, seed=42); #sp_rand
+v = rand(rows=200, cols=1, seed=42); #cp_rand
+v2 = rand(rows=200, cols=1, seed=43); #cp_rand
+
+# CP instructions
+v = ((v + v) * 1 - v) / (1+1);
+v = ((v + v) * 2 - v) / (2+1);
+
+# Spark transformation operations 
+sp1 = X + ceil(X);
+sp2 = sp1 %*% v2; #output fits in local
+
+# CP binary triggers the DAG of SP operations
+# if transitive spark exec type is off
+cp = sp2 + sum(v);
+R = sum(cp);
+write(R, $1, format="text");

Reply via email to