This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new d0287b1983 [SYSTEMDS-3772] Nary elementwise multiplication 
rewrite/runtime
d0287b1983 is described below

commit d0287b1983c4fd915d5106199ff03a0d7183a397
Author: e-strauss <lathan...@gmx.de>
AuthorDate: Sun Sep 22 12:28:30 2024 +0200

    [SYSTEMDS-3772] Nary elementwise multiplication rewrite/runtime
    
    - new runtime operations and rewrite for nary mult similar to add
    - fixed bug in nary hop, where output value type was computed wrongfully
    in case of scalars with mixed value types
    - fixed bug in nary hop, where output is wrongfully set as scalar
    - clean up & fixed the aggregate ternary rewrite to also work nary mult
    & fixed the n* lineage issue
    - increased code coverage for rewrite ternary aggregate: colsum(A^3)
    
    Closes #2105.
---
 src/main/java/org/apache/sysds/common/Types.java   |   4 +-
 .../java/org/apache/sysds/hops/AggUnaryOp.java     | 149 +++++++++++---------
 src/main/java/org/apache/sysds/hops/NaryOp.java    |   4 +-
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |  11 +-
 .../RewriteAlgebraicSimplificationDynamic.java     |   6 +-
 src/main/java/org/apache/sysds/lops/Nary.java      |   2 +
 .../runtime/instructions/CPInstructionParser.java  |   1 +
 .../runtime/instructions/SPInstructionParser.java  |   1 +
 .../instructions/cp/BuiltinNaryCPInstruction.java  |   5 +
 .../cp/MatrixBuiltinNaryCPInstruction.java         |   2 +-
 .../spark/BuiltinNarySPInstruction.java            |  17 +--
 .../runtime/lineage/LineageRecomputeUtils.java     |   3 +-
 .../sysds/runtime/matrix/data/LibMatrixMult.java   |  20 ++-
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 120 ++++++++++++----
 .../binary/matrix/ElementwiseAdditionTest.java     |   2 +-
 .../functions/rewrite/RewriteNaryMultTest.java     | 155 +++++++++++++++++++++
 .../functions/ternary/TernaryAggregateTest.java    |  39 +++++-
 .../functions/rewrite/RewriteNaryMultDense1.dml    |  32 +++++
 .../functions/rewrite/RewriteNaryMultDense2.dml    |  35 +++++
 .../functions/rewrite/RewriteNaryMultSparse1.dml   |  36 +++++
 .../functions/rewrite/RewriteNaryMultSparse2.dml   |  40 ++++++
 .../functions/ternary/TernaryAggregatePow.R        |  33 +++++
 .../functions/ternary/TernaryAggregatePow.dml      |  28 ++++
 23 files changed, 632 insertions(+), 113 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index a7397ae54b..6d64cbae54 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -739,10 +739,10 @@ public interface Types {
        
        /** Operations that require a variable number of operands*/
        public enum OpOpN {
-               PRINTF, CBIND, RBIND, MIN, MAX, PLUS, EVAL, LIST;
+               PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;
                
                public boolean isCellOp() {
-                       return this == MIN || this == MAX || this == PLUS;
+                       return this == MIN || this == MAX || this == PLUS || 
this == MULT;
                }
        }
        
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index dcd68b055e..eec86ec15b 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.hops;
 
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.AggOp;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.Direction;
@@ -30,6 +31,7 @@ import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.Nary;
 import org.apache.sysds.lops.PartialAggregate;
 import org.apache.sysds.lops.TernaryAggregate;
 import org.apache.sysds.lops.UAggOuterChain;
@@ -38,6 +40,8 @@ import 
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
+import java.util.List;
+
 // Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum
 
 public class AggUnaryOp extends MultiThreadedHop
@@ -475,6 +479,17 @@ public class AggUnaryOp extends MultiThreadedHop
                                        }
                                }
                        }
+                       if (input1.getParent().size() == 1
+                                       && input1 instanceof NaryOp) { //sum 
single consumer
+                               NaryOp nop = (NaryOp) input1;
+                               if(nop.getOp() == Types.OpOpN.MULT){
+                                       List<Hop> inputsN = nop.getInput();
+                                       if(inputsN.size() == 3){
+                                               ret = 
HopRewriteUtils.isEqualSize(inputsN.get(0), inputsN.get(1)) &&
+                                                               
HopRewriteUtils.isEqualSize(inputsN.get(1), inputsN.get(2));
+                                       }
+                               }
+                       }
                }
                return ret;
        }
@@ -554,83 +569,91 @@ public class AggUnaryOp extends MultiThreadedHop
 
        private Lop constructLopsTernaryAggregateRewrite(ExecType et) 
        {
-               BinaryOp input1 = (BinaryOp)getInput().get(0);
-               Hop input11 = input1.getInput().get(0);
-               Hop input12 = input1.getInput().get(1);
-               
                Lop in1 = null, in2 = null, in3 = null;
-               boolean handled = false;
-
-               if (input1.getOp() == OpOp2.POW) {
-                       assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : 
"this case can only occur with a power of 3";
-                       in1 = input11.constructLops();
-                       in2 = in1;
-                       in3 = in1;
-                       handled = true;
-               }
-               else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, 
OpOp2.POW) ) {
-                       BinaryOp b11 = (BinaryOp)input11;
-                       switch( b11.getOp() ) {
-                       case MULT: // A*B*C case
-                               in1 = input11.getInput().get(0).constructLops();
-                               in2 = input11.getInput().get(1).constructLops();
-                               in3 = input12.constructLops();
-                               handled = true;
-                               break;
-                       case POW: // A*A*B case
-                               Hop b112 = b11.getInput().get(1);
-                               if ( !(input12 instanceof BinaryOp && 
((BinaryOp)input12).getOp()==OpOp2.MULT)
-                                               && 
HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
-                                       in1 = 
b11.getInput().get(0).constructLops();
-                                       in2 = in1;
-                                       in3 = input12.constructLops();
-                                       handled = true;
-                               }
-                               break;
-                       default: break;
-                       }
-               }
-               else if( HopRewriteUtils.isBinary(input12, OpOp2.MULT, 
OpOp2.POW) ) {
-                       BinaryOp b12 = (BinaryOp)input12;
-                       switch (b12.getOp()) {
-                       case MULT: // A*B*C case
+               Hop input = getInput().get(0);
+               if(input instanceof BinaryOp) {
+                       BinaryOp input1 = (BinaryOp) input;
+                       Hop input11 = input1.getInput().get(0);
+                       Hop input12 = input1.getInput().get(1);
+
+                       boolean handled = false;
+
+                       if (input1.getOp() == OpOp2.POW) {
+                               assert 
(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with 
a power of 3";
                                in1 = input11.constructLops();
-                               in2 = input12.getInput().get(0).constructLops();
-                               in3 = input12.getInput().get(1).constructLops();
+                               in2 = in1;
+                               in3 = in1;
                                handled = true;
-                               break;
-                       case POW: // A*B*B case
-                               Hop b112 = b12.getInput().get(1);
-                               if ( HopRewriteUtils.isLiteralOfValue(b112, 2) 
) {
-                                       in1 = 
b12.getInput().get(0).constructLops();
-                                       in2 = in1;
-                                       in3 = input11.constructLops();
-                                       handled = true;
+                       } else if (HopRewriteUtils.isBinary(input11, 
OpOp2.MULT, OpOp2.POW)) {
+                               BinaryOp b11 = (BinaryOp) input11;
+                               switch (b11.getOp()) {
+                                       case MULT: // A*B*C case
+                                               in1 = 
input11.getInput().get(0).constructLops();
+                                               in2 = 
input11.getInput().get(1).constructLops();
+                                               in3 = input12.constructLops();
+                                               handled = true;
+                                               break;
+                                       case POW: // A*A*B case
+                                               Hop b112 = 
b11.getInput().get(1);
+                                               if (!(input12 instanceof 
BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT)
+                                                               && 
HopRewriteUtils.isLiteralOfValue(b112, 2)) {
+                                                       in1 = 
b11.getInput().get(0).constructLops();
+                                                       in2 = in1;
+                                                       in3 = 
input12.constructLops();
+                                                       handled = true;
+                                               }
+                                               break;
+                                       default:
+                                               break;
+                               }
+                       } else if (HopRewriteUtils.isBinary(input12, 
OpOp2.MULT, OpOp2.POW)) {
+                               BinaryOp b12 = (BinaryOp) input12;
+                               switch (b12.getOp()) {
+                                       case MULT: // A*B*C case
+                                               in1 = input11.constructLops();
+                                               in2 = 
input12.getInput().get(0).constructLops();
+                                               in3 = 
input12.getInput().get(1).constructLops();
+                                               handled = true;
+                                               break;
+                                       case POW: // A*B*B case
+                                               Hop b112 = 
b12.getInput().get(1);
+                                               if 
(HopRewriteUtils.isLiteralOfValue(b112, 2)) {
+                                                       in1 = 
b12.getInput().get(0).constructLops();
+                                                       in2 = in1;
+                                                       in3 = 
input11.constructLops();
+                                                       handled = true;
+                                               }
+                                               break;
+                                       default:
+                                               break;
                                }
-                               break;
-                       default: break;
                        }
-               }
 
-               if (!handled) {
-                       in1 = input11.constructLops();
-                       in2 = input12.constructLops();
-                       in3 = new LiteralOp(1).constructLops();
+                       if (!handled) {
+                               in1 = input11.constructLops();
+                               in2 = input12.constructLops();
+                               in3 = new LiteralOp(1).constructLops();
+                       }
+               } else {
+                       NaryOp input1 = (NaryOp) input;
+                       in1 = input1.getInput().get(0).constructLops();
+                       in2 = input1.getInput().get(1).constructLops();
+                       in3 = input1.getInput().get(2).constructLops();
                }
 
-               //create new ternary aggregate operator 
-               int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads 
);
+               //create new ternary aggregate operator
+               int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
                // The execution type of a unary aggregate instruction should 
depend on the execution type of inputs to avoid OOM
                // Since we only support matrix-vector and not vector-matrix, 
checking the execution type of input1 should suffice.
-               ExecType et_input = input1.optFindExecType();
+               ExecType et_input = input.optFindExecType();
                // Because ternary aggregate are not supported on GPU
-               et_input = et_input == ExecType.GPU ? ExecType.CP :  et_input;
+               et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
                // If forced ExecType is FED, it means that the federated 
planner updated the ExecType and
                // execution may fail if ExecType is not FED
                et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED 
: et_input;
-               
-               return new TernaryAggregate(in1, in2, in3, AggOp.SUM, 
-                       OpOp2.MULT, _direction, getDataType(), ValueType.FP64, 
et_input, k);
+
+               return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
+                               OpOp2.MULT, _direction, getDataType(), 
ValueType.FP64, et_input, k);
        }
        
        @Override
diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java 
b/src/main/java/org/apache/sysds/hops/NaryOp.java
index 0a6e93ca14..44fe8ff609 100644
--- a/src/main/java/org/apache/sysds/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/NaryOp.java
@@ -200,7 +200,8 @@ public class NaryOp extends Hop {
                                        HopRewriteUtils.getSumValidInputNnz(dc, 
true));
                                case MIN:
                                case MAX:
-                               case PLUS: return new MatrixCharacteristics(
+                               case PLUS:
+                               case MULT: return new MatrixCharacteristics(
                                                
HopRewriteUtils.getMaxInputDim(this, true),
                                                
HopRewriteUtils.getMaxInputDim(this, false), -1, -1);
                                case LIST:
@@ -230,6 +231,7 @@ public class NaryOp extends Hop {
                        case MIN:
                        case MAX:
                        case PLUS:
+                       case MULT:
                                setDim1(getDataType().isScalar() ? 0 : 
HopRewriteUtils.getMaxInputDim(this, true));
                                setDim2(getDataType().isScalar() ? 0 : 
HopRewriteUtils.getMaxInputDim(this, false));
                                break;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 2b84318e35..68167ac3ae 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -745,8 +745,15 @@ public class HopRewriteUtils {
        
        public static NaryOp createNary(OpOpN op, Hop... inputs) {
                Hop mainInput = inputs[0];
-               NaryOp nop = new NaryOp(mainInput.getName(), 
mainInput.getDataType(),
-                       mainInput.getValueType(), op, inputs);
+               // safe for unordered inputs of Scalars and Matrices
+               // e.g.: S*M*S = M
+               // safe for Scalar with different value type
+               // e.g.: Scalar(Int) * Scalar(FP64) = Scalar(FP64)
+               boolean containsMatrix = 
Arrays.stream(inputs).anyMatch(Hop::isMatrix);
+               boolean containsFP64 = Arrays.stream(inputs).anyMatch(h -> 
h.getValueType() == ValueType.FP64);
+               DataType dtOut = containsMatrix ? DataType.MATRIX : 
mainInput.getDataType();
+               ValueType vtOut = containsFP64? ValueType.FP64 : 
mainInput.getValueType();
+               NaryOp nop = new NaryOp(mainInput.getName(), dtOut, vtOut, op, 
inputs);
                nop.setBlocksize(mainInput.getBlocksize());
                copyLineNumbers(mainInput, nop);
                nop.refreshSizeInformation();
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 1a4c4ecebd..314d230842 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2801,8 +2801,8 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
        
        private static Hop foldMultipleMinMaxOperations(Hop hi) 
        {
-               if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, 
OpOp2.PLUS) 
-                       || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, 
OpOpN.PLUS))
+               if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, 
OpOp2.PLUS, OpOp2.MULT)
+                       || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, 
OpOpN.PLUS, OpOpN.MULT))
                        && hi.getValueType() != ValueType.STRING //exclude 
string concat
                        && HopRewriteUtils.isNotMatrixVectorBinaryOperation(hi))
                {
@@ -2839,7 +2839,7 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                        for( Hop p : parents )
                                                
HopRewriteUtils.replaceChildReference(p, hi, hnew);
                                        hi = hnew;
-                                       LOG.debug("Applied 
foldMultipleMinMaxPlusOperations (line "+hi.getBeginLine()+").");
+                                       LOG.debug("Applied 
foldMultipleMinMaxPlusMultOperations (line "+hi.getBeginLine()+").");
                                }
                                else {
                                        converged = true;
diff --git a/src/main/java/org/apache/sysds/lops/Nary.java 
b/src/main/java/org/apache/sysds/lops/Nary.java
index 68e7191ad9..950082e47e 100644
--- a/src/main/java/org/apache/sysds/lops/Nary.java
+++ b/src/main/java/org/apache/sysds/lops/Nary.java
@@ -117,6 +117,8 @@ public class Nary extends Lop {
                                return "n"+operationType.name().toLowerCase();
                        case PLUS:
                                return "n+";
+                       case MULT:
+                               return "n*";
                        default:
                                throw new UnsupportedOperationException(
                                        "Nary operation type (" + operationType 
+ ") is not defined.");
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 792eace24e..7a5ed3524c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -180,6 +180,7 @@ public class CPInstructionParser extends InstructionParser {
                String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
                String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
                String2CPInstructionType.put( "n+"  , CPType.BuiltinNary);
+               String2CPInstructionType.put( "n*"  , CPType.BuiltinNary);
 
                String2CPInstructionType.put( "exp"   , CPType.Unary);
                String2CPInstructionType.put( "abs"   , CPType.Unary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 06e68a63d5..5e4dbaedeb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -299,6 +299,7 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "nmin",  SPType.BuiltinNary);
                String2SPInstructionType.put( "nmax",  SPType.BuiltinNary);
                String2SPInstructionType.put( "n+",  SPType.BuiltinNary);
+               String2SPInstructionType.put( "n*",  SPType.BuiltinNary);
 
                String2SPInstructionType.put( DataGen.RAND_OPCODE  , 
SPType.Rand);
                String2SPInstructionType.put( DataGen.SEQ_OPCODE   , 
SPType.Rand);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
index 31110423cb..c6b590a61c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.cp;
 import org.apache.sysds.common.Types.OpOpN;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -85,6 +86,10 @@ public abstract class BuiltinNaryCPInstruction extends 
CPInstruction
                        return new MatrixBuiltinNaryCPInstruction(
                                new SimpleOperator(Plus.getPlusFnObject()), 
opcode, str, outputOperand, inputOperands);
                }
+               else if( opcode.equals("n*") ) {
+                       return new MatrixBuiltinNaryCPInstruction(
+                                       new 
SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand, 
inputOperands);
+               }
                else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) {
                        return new EvalNaryCPInstruction(null, opcode, str, 
outputOperand, inputOperands);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
index d6a0ec53f5..172c15b592 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
@@ -62,7 +62,7 @@ public class MatrixBuiltinNaryCPInstruction extends 
BuiltinNaryCPInstruction imp
                                        outBlock = 
((FrameBlock)outBlock).append(frames.get(i), cbind);
                        }
                }
-               else if( ArrayUtils.contains(new String[]{"nmin", "nmax", 
"n+"}, getOpcode()) ) {
+               else if( ArrayUtils.contains(new String[]{"nmin", "nmax", "n+", 
"n*"}, getOpcode()) ) {
                        outBlock = MatrixBlock.naryOperations(_optr, 
matrices.toArray(new MatrixBlock[0]),
                                scalars.toArray(new ScalarObject[0]), new 
MatrixBlock());
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
index c3118501ba..8e46a96357 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
@@ -34,6 +34,7 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Multiply;
 import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -163,7 +164,7 @@ public class BuiltinNarySPInstruction extends SPInstruction 
implements LineageTr
                                return;
                        }
                }
-               else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+"}, 
getOpcode()) ) {
+               else if( ArrayUtils.contains(new 
String[]{"nmin","nmax","n+","n*"}, getOpcode()) ) {
                        //compute output characteristics
                        dcout = computeMinMaxOutputDataCharacteristics(sec, 
inputs);
                        
@@ -179,7 +180,7 @@ public class BuiltinNarySPInstruction extends SPInstruction 
implements LineageTr
                        }
                        
                        //compute nary min/max (partitioning-preserving)
-                       out = in.mapValues(new MinMaxAddFunction(getOpcode(), 
scalars));
+                       out = in.mapValues(new 
MinMaxAddMultFunction(getOpcode(), scalars));
                }
                
                //set output RDD and add lineage
@@ -278,17 +279,17 @@ public class BuiltinNarySPInstruction extends 
SPInstruction implements LineageTr
                }
        }
        
-       private static class MinMaxAddFunction implements 
Function<MatrixBlock[], MatrixBlock> {
+       private static class MinMaxAddMultFunction implements 
Function<MatrixBlock[], MatrixBlock> {
                private static final long serialVersionUID = 
-4227447915387484397L;
                
                private final SimpleOperator _op;
                private final ScalarObject[] _scalars;
-               
-               public MinMaxAddFunction(String opcode, List<ScalarObject> 
scalars) {
+
+               public MinMaxAddMultFunction(String opcode, List<ScalarObject> 
scalars) {
                        _scalars = scalars.toArray(new ScalarObject[0]);
-                       _op = new SimpleOperator(opcode.equals("n+") ?
-                               Plus.getPlusFnObject() :
-                               
Builtin.getBuiltinFnObject(opcode.substring(1)));
+                       _op = new SimpleOperator(opcode.equals("n+") ? 
Plus.getPlusFnObject() :
+                                       opcode.equals("n*") ? 
Multiply.getMultiplyFnObject() :
+                                                       
Builtin.getBuiltinFnObject(opcode.substring(1)));
                }
                
                @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index 6c691c2b8b..1188d65e30 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -390,7 +390,8 @@ public class LineageRecomputeUtils {
                                                        break;
                                                }
                                                case BuiltinNary: {
-                                                       String opcode = 
item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
+                                                       String opcode = 
item.getOpcode().equals("n+") ? "plus" :
+                                                                       
item.getOpcode().equals("n*") ? "mult" : item.getOpcode();
                                                        
operands.put(item.getId(), HopRewriteUtils.createNary(
                                                                
OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
                                                        break;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 061c09ea3d..108462c567 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -3844,6 +3844,24 @@ public class LibMatrixMult
                        c[ ci+6 ] *= aval; c[ ci+7 ] *= aval;
                }
        }
+       
+       public static void vectMultiplyInPlace(final double[] a, double[] c, 
int[] cix, final int ai, final int ci, final int len) {
+               final int bn = len%8;
+               //rest, not aligned to 8-blocks
+               for( int j = ci; j < ci+bn; j++ )
+                       c[ j ] *= a[ ai+cix[j] ];
+               //unrolled 8-block (for better instruction-level parallelism)
+               for( int j = ci+bn; j < ci+len; j+=8 ) {
+                       c[ j+0 ] *= a[ ai+cix[j+0] ];
+                       c[ j+1 ] *= a[ ai+cix[j+1] ];
+                       c[ j+2 ] *= a[ ai+cix[j+2] ];
+                       c[ j+3 ] *= a[ ai+cix[j+3] ];
+                       c[ j+4 ] *= a[ ai+cix[j+4] ];
+                       c[ j+5 ] *= a[ ai+cix[j+5] ];
+                       c[ j+6 ] *= a[ ai+cix[j+6] ];
+                       c[ j+7 ] *= a[ ai+cix[j+7] ];
+               }
+       }
 
        //note: public for use by codegen for consistency
        public static void vectMultiplyWrite( double[] a, double[] b, double[] 
c, int ai, int bi, int ci, final int len )
@@ -3889,7 +3907,7 @@ public class LibMatrixMult
                }
        }
 
-       private static void vectMultiply( double[] a, double[] c, int ai, int 
ci, final int len )
+       public static void vectMultiply(double[] a, double[] c, int ai, int ci, 
final int len)
        {
                final int bn = len%8;
                
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 8e2e7f0665..62f9a1febb 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -3805,32 +3805,38 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                //prepare operator
                FunctionObject fn = ((SimpleOperator)op).fn;
                boolean plus = fn instanceof Plus;
-               Builtin bfn = !plus ? (Builtin)((SimpleOperator)op).fn : null;
-               
+               boolean mult = fn instanceof Multiply;
+               boolean minmax = !mult && !plus;
+               Builtin bfn = minmax ? (Builtin) fn : null;
                for(int i = 0; i < matrices.length; i++)
                        if(matrices[i] instanceof CompressedMatrixBlock)
                                matrices[i] = 
CompressedMatrixBlock.getUncompressed(matrices[i], "Nary operation process add 
row");
 
                //process all scalars
-               double init = plus ? 0 :(bfn.getBuiltinCode() == 
BuiltinCode.MIN) ?
-                       Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
+               double init = plus ? 0 : mult ? 1 : (bfn.getBuiltinCode() == 
BuiltinCode.MIN) ?
+                               Double.POSITIVE_INFINITY : 
Double.NEGATIVE_INFINITY;
                for( ScalarObject so : scalars )
                        init = fn.execute(init, so.getDoubleValue());
 
                //compute output dimensions and estimate sparsity
                final int m = matrices.length > 0 ? matrices[0].rlen : 1;
                final int n = matrices.length > 0 ? matrices[0].clen : 1;
+
+               //check for empty multiply with empty matrix
+               if(mult)
+                       
if(Arrays.stream(matrices).anyMatch(MatrixBlock::isEmptyBlock))
+                               return new MatrixBlock(m,n, 0);
+
                final long mn = (long) m * n;
-               final long nnz = (!plus && 
bfn.getBuiltinCode()==BuiltinCode.MIN && init < 0)
-                       || (!plus && bfn.getBuiltinCode()==BuiltinCode.MAX && 
init > 0) ? mn :
+               final long nnz = (minmax && 
bfn.getBuiltinCode()==BuiltinCode.MIN && init < 0)
+                       || (minmax && bfn.getBuiltinCode()==BuiltinCode.MAX && 
init > 0) ? mn :
+                               mult? Arrays.stream(matrices).mapToLong(mb -> 
mb.nonZeros).min().orElse(mn) :
                        Math.min(Arrays.stream(matrices).mapToLong(mb -> 
mb.nonZeros).sum(), mn);
-               boolean sp = evalSparseFormatInMemory(m, n, nnz);
+               //if multiply and least one sparse input -> output = sparse
+               boolean sp = (mult && Arrays.stream(matrices).anyMatch(mb -> 
mb.sparse)) || evalSparseFormatInMemory(m, n, nnz);
                
                //init result matrix 
-               if( ret == null )
-                       ret = new MatrixBlock(m, n, sp, nnz);
-               else
-                       ret.reset(m, n, sp, nnz);
+               ret.reset(m, n, sp, nnz);
                
                //main processing
                if( matrices.length > 0 ) {
@@ -3839,18 +3845,21 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                                mb -> mb.sparse || mb.isEmpty()) ? new int[n] : 
null;
                        if( ret.isInSparseFormat() ) {
                                double[] tmp = new double[n];
-                               for(int i = 0; i < m; i++) {
-                                       //reset tmp and compute row output
-                                       Arrays.fill(tmp, init);
-                                       if( plus )
-                                               processAddRow(matrices, tmp, 0, 
n, i);
-                                       else
-                                               processMinMaxRow(bfn, matrices, 
tmp, 0, n, i, cnt);
-                                       //copy to sparse output
-                                       for(int j = 0; j < n; j++)
-                                               if( tmp[j] != 0 )
-                                                       ret.appendValue(i, j, 
tmp[j]);
-                               }
+                               for(int i = 0; i < m; i++)
+                                       if (mult)
+                                               processMultRowSparse(matrices, 
init, n, i, ret);
+                                       else{
+                                               //reset tmp and compute row 
output
+                                               Arrays.fill(tmp, init);
+                                               if( plus )
+                                                       processAddRow(matrices, 
tmp, 0, n, i);
+                                               else
+                                                       processMinMaxRow(bfn, 
matrices, tmp, 0, n, i, cnt);
+                                               //copy to sparse output
+                                               for(int j = 0; j < n; j++)
+                                                       if( tmp[j] != 0 )
+                                                               
ret.appendValue(i, j, tmp[j]);
+                                       }
                        }
                        else {
                                DenseBlock c = ret.getDenseBlock();
@@ -3860,17 +3869,21 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                                                Arrays.fill(c.values(i), 
c.pos(i), c.pos(i)+n, init);
                                        if( plus )
                                                processAddRow(matrices, 
c.values(i), c.pos(i), n, i);
+                                       else if (mult)
+                                               processMultRowDense(matrices, 
c.values(i), c.pos(i), n, i);
                                        else
                                                processMinMaxRow(bfn, matrices, 
c.values(i), c.pos(i), n, i, cnt);
                                        lnnz += 
UtilFunctions.countNonZeros(c.values(i), c.pos(i), n);
                                }
                                ret.setNonZeros(lnnz);
+
+                               //reevaluate sparsity
+                               if(mult && ret.evalSparseFormatInMemory())
+                                       ret.denseToSparse();
                        }
                }
-               else {
+               else
                        ret.set(0, 0, init);
-               }
-               
                return ret;
        }
        
@@ -3891,6 +3904,61 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        }
                }
        }
+
+       private static void processMultRowDense(MatrixBlock[] inputs, double[] 
c, int cix, int n, int i) {
+               // if inputs contain sparse -> result == sparse -> 
processMultRowSparse
+               for( MatrixBlock in : inputs ) {
+                       DenseBlock a = in.getDenseBlock();
+                       
LibMatrixMult.vectMultiply(in.getDenseBlock().values(i), c, a.pos(i), cix, n);
+               }
+       }
+
+       private static void processMultRowSparse(MatrixBlock[] inputs, double 
init, int n, int i, MatrixBlock ret) {
+               SparseBlock[] sparse_inputs = new SparseBlock[inputs.length];
+               int len_sparse_inputs = 0;
+               for( MatrixBlock in : inputs )
+                       if (in.isInSparseFormat())
+                               sparse_inputs[len_sparse_inputs++] = 
in.getSparseBlock();
+
+               int size = sparse_inputs[0].size(i);
+               int[] sparse_indices = new int[size];
+               double[] sparse_values = new double[size];
+               for (int j = 0; j < size; j++){
+                       sparse_values[j] = 
init*sparse_inputs[0].values(i)[sparse_inputs[0].pos(i) + j];
+                       sparse_indices[j] = 
sparse_inputs[0].indexes(i)[sparse_inputs[0].pos(i) + j];
+               }
+               for (int k = 1; k < len_sparse_inputs; k++){
+                       SparseBlock a = sparse_inputs[k];
+
+                       //calculate intersection
+                       int aix = a.pos(i);
+                       int aix_end = a.pos(i) + a.size(i);
+                       int ix_read = 0;
+                       int ix_write = 0;
+                       while(ix_read < size && aix < aix_end){
+                               if(sparse_indices[ix_read] < a.indexes(i)[aix])
+                                       ix_read += 1;
+                               else if(sparse_indices[ix_read] > 
a.indexes(i)[aix])
+                                       aix += 1;
+                               else{
+                                       sparse_indices[ix_write] = 
sparse_indices[ix_read];
+                                       sparse_values[ix_write] = 
sparse_values[ix_read]*a.values(i)[aix];
+                                       aix += 1;
+                                       ix_read += 1;
+                                       ix_write += 1;
+                               }
+                       }
+                       size = ix_write;
+               }
+
+               //iterate dense blocks
+               for( MatrixBlock in : inputs )
+                       if( !in.isInSparseFormat() )
+                               
LibMatrixMult.vectMultiplyInPlace(in.denseBlock.values(i),
+                                       sparse_values, sparse_indices, 
in.denseBlock.pos(i), 0, size);
+               for (int j = 0; j < size; j++)
+                       ret.appendValue(i, sparse_indices[j], sparse_values[j]);
+       }
        
        private static void processMinMaxRow(Builtin fn, MatrixBlock[] inputs, 
double[] c, int cix, int n, int i, int[] cnt) {
                if( cnt != null )
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
index ac3cf3a041..a4a6b0381d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
@@ -133,7 +133,7 @@ public class ElementwiseAdditionTest extends 
AutomatedTestBase
                config.addVariable("cols2", cols2);
                
                loadTestConfiguration(config);
-               
+               //this.programArgs
                runTest(true,LanguageException.class);
        }
        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java
new file mode 100644
index 0000000000..6fab8f63a9
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteNaryMultTest extends AutomatedTestBase
+{
+       private static final String TEST_NAME1 = "RewriteNaryMultDense1";
+       private static final String TEST_NAME2 = "RewriteNaryMultDense2";
+       private static final String TEST_NAME3 = "RewriteNaryMultSparse1";
+
+       private static final String TEST_NAME4 = "RewriteNaryMultSparse2";
+       
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteNaryMultTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}) );
+               addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) );
+               addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) );
+               addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) );
+       }
+
+       @Test
+       public void testExpectExceptionParseNaryCPInstruction(){
+               String opcode = "nxy";
+               try {
+                       BuiltinNaryCPInstruction.parseInstruction("CP°" + 
opcode +"°A·MATRIX·FP64°B·MATRIX·FP64°C·MATRIX·FP64°_mVar3·MATRIX·FP64");
+               }
+               catch (DMLRuntimeException e){
+                       assert e.getMessage().equals("Opcode (" + opcode + ") 
not recognized in BuiltinMultipleCPInstruction");
+               }
+       }
+
+       @Test
+       public void testNoRewriteDense1CP() {
+               testRewriteNaryMult(TEST_NAME1, false, ExecType.CP);
+       }
+
+       @Test
+       public void testRewriteDense1CP() {
+               testRewriteNaryMult(TEST_NAME1, true, ExecType.CP);
+       }
+
+       @Test
+       public void testRewriteDense2CP() {
+               testRewriteNaryMult(TEST_NAME2, true, ExecType.CP);
+       }
+
+       @Test
+       public void testRewriteSparse1CP() {
+               testRewriteNaryMult(TEST_NAME3, true, ExecType.CP);
+       }
+
+       @Test
+       public void testRewriteSparse2CP() {
+               testRewriteNaryMult(TEST_NAME4, true, ExecType.CP);
+       }
+
+       @Test
+       public void testNoRewriteDense1SP() {
+               testRewriteNaryMult(TEST_NAME1, false, ExecType.SPARK);
+       }
+
+       @Test
+       public void testRewriteDense1SP() {
+               testRewriteNaryMult(TEST_NAME1, true, ExecType.SPARK);
+       }
+
+       @Test
+       public void testRewriteDense2SP() {
+               testRewriteNaryMult(TEST_NAME2, true, ExecType.SPARK);
+       }
+
+       @Test
+       public void testRewriteSparse1SP() {
+               testRewriteNaryMult(TEST_NAME3, true, ExecType.SPARK);
+       }
+
+       @Test
+       public void testRewriteSparse2SP() {
+               testRewriteNaryMult(TEST_NAME4, true, ExecType.SPARK);
+       }
+
+
+       private void testRewriteNaryMult(String name, boolean rewrites, 
ExecType et)
+       {
+               ExecMode oldMode = setExecMode(et);
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try
+               {
+                       TestConfiguration config = getTestConfiguration(name);
+                       loadTestConfiguration(config);
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + name + ".dml";
+                       programArgs = new String[]{"-explain","-stats","-args", 
output("R") };
+
+                       runTest(true, false, null, -1); 
+                       
+                       //compare output
+                       Double ret = readDMLMatrixFromOutputDir("R").get(new 
CellIndex(1,1));
+                       if(name.equals(TEST_NAME3))
+                               Assert.assertEquals(Double.valueOf(1 + 304 + 
10000), ret);
+                       else if(name.equals(TEST_NAME4))
+                               Assert.assertEquals(Double.valueOf(1), ret, 
1e-7);
+                       else
+                               
Assert.assertEquals(Double.valueOf(2*3*4*5*6*1000), ret);
+                       
+                       //check for applied nary plus
+                       String prefix = et == ExecType.SPARK ? "sp_" : "";
+                       if( rewrites && !name.equals(TEST_NAME2) )
+                Assert.assertEquals(1, Statistics.getCPHeavyHitterCount(prefix 
+ "n*"));
+                       else
+                               
Assert.assertTrue(Statistics.getCPHeavyHitterCount(prefix+"*")>=1);
+               }
+               finally {
+                       resetExecMode(oldMode);
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
 
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
index a7224d88b0..3d7de4983a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
@@ -41,10 +41,11 @@ public class TernaryAggregateTest extends AutomatedTestBase
 {
        private final static String TEST_NAME1 = "TernaryAggregateRC";
        private final static String TEST_NAME2 = "TernaryAggregateC";
+       private final static String TEST_NAME3 = "TernaryAggregatePow";
        
        private final static String TEST_DIR = "functions/ternary/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
TernaryAggregateTest.class.getSimpleName() + "/";
-       private final static double eps = 1e-8;
+       private final static double eps = 1e-7;
        
        private final static int rows = 1111;
        private final static int cols = 1011;
@@ -56,7 +57,8 @@ public class TernaryAggregateTest extends AutomatedTestBase
        public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); 
-               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); 
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
        }
 
        @Test
@@ -180,6 +182,36 @@ public class TernaryAggregateTest extends AutomatedTestBase
        public void testTernaryAggregateCSparseMatrixCPNoRewrite() {
                runTernaryAggregateTest(TEST_NAME2, true, false, false, 
ExecType.CP);
        }
+
+       @Test
+       public void testTernaryAggregateNoRewritePowDenseCP() {
+               runTernaryAggregateTest(TEST_NAME3, false, true, false, 
ExecType.CP);
+       }
+
+       @Test
+       public void testTernaryAggregateNoRewritePowSparseCP() {
+               runTernaryAggregateTest(TEST_NAME3, true, true, false, 
ExecType.CP);
+       }
+
+       @Test
+       public void testTernaryAggregateRewritePowDenseCP() {
+               runTernaryAggregateTest(TEST_NAME3, false, true, true, 
ExecType.CP);
+       }
+
+       @Test
+       public void testTernaryAggregateRewritePowDenseSP() {
+               runTernaryAggregateTest(TEST_NAME3, false, true, true, 
ExecType.SPARK);
+       }
+
+       @Test
+       public void testTernaryAggregateRewritePowSparseCP() {
+               runTernaryAggregateTest(TEST_NAME3, true, true, true, 
ExecType.CP);
+       }
+
+       @Test
+       public void testTernaryAggregateRewritePowSparseSP() {
+               runTernaryAggregateTest(TEST_NAME3, true, true, true, 
ExecType.SPARK);
+       }
        
        private void runTernaryAggregateTest(String testname, boolean sparse, 
boolean vectors, boolean rewrites, ExecType et)
        {
@@ -203,10 +235,9 @@ public class TernaryAggregateTest extends AutomatedTestBase
                        loadTestConfiguration(config);
                        
                        OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-                       
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{"-stats","-args", 
input("A"), output("R")};
+                       programArgs = new String[]{"-explain","-stats","-args", 
input("A"), output("R")};
                        
                        fullRScriptName = HOME + testname + ".R";
                        rCmd = "Rscript" + " " + fullRScriptName + " " + 
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml 
b/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml
new file mode 100644
index 0000000000..ce7fdce817
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100; N = 10;
+A = matrix(4, M, N);
+B = matrix(12, M, N);
+C = matrix(15, M, N);
+while(FALSE){}
+
+R = A * B * C
+
+while(FALSE){}
+s = as.matrix(sum(R));
+write(s,$1);
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml 
b/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml
new file mode 100644
index 0000000000..36dfc50523
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100; N = 10;
+A = matrix(1, M, N);
+D = matrix(2, M, N);
+B = matrix(3, M, 1);
+E = matrix(4, 1, N);
+C = 5
+F = 6
+while(FALSE){}
+
+R = A * B * C * D * E * F
+
+while(FALSE){}
+s = as.matrix(sum(R));
+write(s,$1);
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml 
b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml
new file mode 100644
index 0000000000..043e07741c
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100;
+A = diag(matrix(1, M, 1));
+B = matrix(seq(1,M*M), M, M);
+D = matrix(0, M, 1)
+D[4,1] = 1
+D[1,1] = 1
+D[M,1] = 1
+D = diag(D)
+while(FALSE){}
+
+R = A * B * D
+
+while(FALSE){}
+s = as.matrix(sum(R));
+write(s,$1);
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml 
b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml
new file mode 100644
index 0000000000..07180a33d8
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100; N = 10;
+A = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1)
+B = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=1.0)
+C = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1)
+D = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1)
+while(FALSE){}
+
+R = A * B * C * D
+
+while(FALSE){}
+R2 = B * C
+while(FALSE){}
+R3 = A * D
+while(FALSE){}
+R4 = R2*R3
+while(FALSE){}
+R[1,1] = R[1,1] + 1
+s = as.matrix(sum(R - R4));
+write(s,$1);
diff --git a/src/test/scripts/functions/ternary/TernaryAggregatePow.R 
b/src/test/scripts/functions/ternary/TernaryAggregatePow.R
new file mode 100644
index 0000000000..d0f2227869
--- /dev/null
+++ b/src/test/scripts/functions/ternary/TernaryAggregatePow.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+C = A * 3;
+
+R = t(as.matrix(colSums(A^3)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 
diff --git a/src/test/scripts/functions/ternary/TernaryAggregatePow.dml 
b/src/test/scripts/functions/ternary/TernaryAggregatePow.dml
new file mode 100644
index 0000000000..e8e2fe5ca9
--- /dev/null
+++ b/src/test/scripts/functions/ternary/TernaryAggregatePow.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+while(FALSE){}
+
+R = colSums(A^3);
+
+write(R, $2);

Reply via email to