This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new d0287b1983 [SYSTEMDS-3772] Nary elementwise multiplication
rewrite/runtime
d0287b1983 is described below
commit d0287b1983c4fd915d5106199ff03a0d7183a397
Author: e-strauss <[email protected]>
AuthorDate: Sun Sep 22 12:28:30 2024 +0200
[SYSTEMDS-3772] Nary elementwise multiplication rewrite/runtime
- new runtime operations and rewrite for nary mult similar to add
- fixed bug in nary hop, where output value type was computed wrongfully
in case of scalars with mixed value types
- fixed bug in nary hop, where output is wrongfully set as scalar
- clean up & fixed the aggregate ternary rewrite to also work nary mult
& fixed the n* lineage issue
- increased code coverage for rewrite ternary aggregate: colsum(A^3)
Closes #2105.
---
src/main/java/org/apache/sysds/common/Types.java | 4 +-
.../java/org/apache/sysds/hops/AggUnaryOp.java | 149 +++++++++++---------
src/main/java/org/apache/sysds/hops/NaryOp.java | 4 +-
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 11 +-
.../RewriteAlgebraicSimplificationDynamic.java | 6 +-
src/main/java/org/apache/sysds/lops/Nary.java | 2 +
.../runtime/instructions/CPInstructionParser.java | 1 +
.../runtime/instructions/SPInstructionParser.java | 1 +
.../instructions/cp/BuiltinNaryCPInstruction.java | 5 +
.../cp/MatrixBuiltinNaryCPInstruction.java | 2 +-
.../spark/BuiltinNarySPInstruction.java | 17 +--
.../runtime/lineage/LineageRecomputeUtils.java | 3 +-
.../sysds/runtime/matrix/data/LibMatrixMult.java | 20 ++-
.../sysds/runtime/matrix/data/MatrixBlock.java | 120 ++++++++++++----
.../binary/matrix/ElementwiseAdditionTest.java | 2 +-
.../functions/rewrite/RewriteNaryMultTest.java | 155 +++++++++++++++++++++
.../functions/ternary/TernaryAggregateTest.java | 39 +++++-
.../functions/rewrite/RewriteNaryMultDense1.dml | 32 +++++
.../functions/rewrite/RewriteNaryMultDense2.dml | 35 +++++
.../functions/rewrite/RewriteNaryMultSparse1.dml | 36 +++++
.../functions/rewrite/RewriteNaryMultSparse2.dml | 40 ++++++
.../functions/ternary/TernaryAggregatePow.R | 33 +++++
.../functions/ternary/TernaryAggregatePow.dml | 28 ++++
23 files changed, 632 insertions(+), 113 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index a7397ae54b..6d64cbae54 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -739,10 +739,10 @@ public interface Types {
/** Operations that require a variable number of operands*/
public enum OpOpN {
- PRINTF, CBIND, RBIND, MIN, MAX, PLUS, EVAL, LIST;
+ PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;
public boolean isCellOp() {
- return this == MIN || this == MAX || this == PLUS;
+ return this == MIN || this == MAX || this == PLUS ||
this == MULT;
}
}
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index dcd68b055e..eec86ec15b 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops;
import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.Direction;
@@ -30,6 +31,7 @@ import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.lops.Nary;
import org.apache.sysds.lops.PartialAggregate;
import org.apache.sysds.lops.TernaryAggregate;
import org.apache.sysds.lops.UAggOuterChain;
@@ -38,6 +40,8 @@ import
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import java.util.List;
+
// Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum
public class AggUnaryOp extends MultiThreadedHop
@@ -475,6 +479,17 @@ public class AggUnaryOp extends MultiThreadedHop
}
}
}
+ if (input1.getParent().size() == 1
+ && input1 instanceof NaryOp) { //sum
single consumer
+ NaryOp nop = (NaryOp) input1;
+ if(nop.getOp() == Types.OpOpN.MULT){
+ List<Hop> inputsN = nop.getInput();
+ if(inputsN.size() == 3){
+ ret =
HopRewriteUtils.isEqualSize(inputsN.get(0), inputsN.get(1)) &&
+
HopRewriteUtils.isEqualSize(inputsN.get(1), inputsN.get(2));
+ }
+ }
+ }
}
return ret;
}
@@ -554,83 +569,91 @@ public class AggUnaryOp extends MultiThreadedHop
private Lop constructLopsTernaryAggregateRewrite(ExecType et)
{
- BinaryOp input1 = (BinaryOp)getInput().get(0);
- Hop input11 = input1.getInput().get(0);
- Hop input12 = input1.getInput().get(1);
-
Lop in1 = null, in2 = null, in3 = null;
- boolean handled = false;
-
- if (input1.getOp() == OpOp2.POW) {
- assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) :
"this case can only occur with a power of 3";
- in1 = input11.constructLops();
- in2 = in1;
- in3 = in1;
- handled = true;
- }
- else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT,
OpOp2.POW) ) {
- BinaryOp b11 = (BinaryOp)input11;
- switch( b11.getOp() ) {
- case MULT: // A*B*C case
- in1 = input11.getInput().get(0).constructLops();
- in2 = input11.getInput().get(1).constructLops();
- in3 = input12.constructLops();
- handled = true;
- break;
- case POW: // A*A*B case
- Hop b112 = b11.getInput().get(1);
- if ( !(input12 instanceof BinaryOp &&
((BinaryOp)input12).getOp()==OpOp2.MULT)
- &&
HopRewriteUtils.isLiteralOfValue(b112, 2) ) {
- in1 =
b11.getInput().get(0).constructLops();
- in2 = in1;
- in3 = input12.constructLops();
- handled = true;
- }
- break;
- default: break;
- }
- }
- else if( HopRewriteUtils.isBinary(input12, OpOp2.MULT,
OpOp2.POW) ) {
- BinaryOp b12 = (BinaryOp)input12;
- switch (b12.getOp()) {
- case MULT: // A*B*C case
+ Hop input = getInput().get(0);
+ if(input instanceof BinaryOp) {
+ BinaryOp input1 = (BinaryOp) input;
+ Hop input11 = input1.getInput().get(0);
+ Hop input12 = input1.getInput().get(1);
+
+ boolean handled = false;
+
+ if (input1.getOp() == OpOp2.POW) {
+ assert
(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with
a power of 3";
in1 = input11.constructLops();
- in2 = input12.getInput().get(0).constructLops();
- in3 = input12.getInput().get(1).constructLops();
+ in2 = in1;
+ in3 = in1;
handled = true;
- break;
- case POW: // A*B*B case
- Hop b112 = b12.getInput().get(1);
- if ( HopRewriteUtils.isLiteralOfValue(b112, 2)
) {
- in1 =
b12.getInput().get(0).constructLops();
- in2 = in1;
- in3 = input11.constructLops();
- handled = true;
+ } else if (HopRewriteUtils.isBinary(input11,
OpOp2.MULT, OpOp2.POW)) {
+ BinaryOp b11 = (BinaryOp) input11;
+ switch (b11.getOp()) {
+ case MULT: // A*B*C case
+ in1 =
input11.getInput().get(0).constructLops();
+ in2 =
input11.getInput().get(1).constructLops();
+ in3 = input12.constructLops();
+ handled = true;
+ break;
+ case POW: // A*A*B case
+ Hop b112 =
b11.getInput().get(1);
+ if (!(input12 instanceof
BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT)
+ &&
HopRewriteUtils.isLiteralOfValue(b112, 2)) {
+ in1 =
b11.getInput().get(0).constructLops();
+ in2 = in1;
+ in3 =
input12.constructLops();
+ handled = true;
+ }
+ break;
+ default:
+ break;
+ }
+ } else if (HopRewriteUtils.isBinary(input12,
OpOp2.MULT, OpOp2.POW)) {
+ BinaryOp b12 = (BinaryOp) input12;
+ switch (b12.getOp()) {
+ case MULT: // A*B*C case
+ in1 = input11.constructLops();
+ in2 =
input12.getInput().get(0).constructLops();
+ in3 =
input12.getInput().get(1).constructLops();
+ handled = true;
+ break;
+ case POW: // A*B*B case
+ Hop b112 =
b12.getInput().get(1);
+ if
(HopRewriteUtils.isLiteralOfValue(b112, 2)) {
+ in1 =
b12.getInput().get(0).constructLops();
+ in2 = in1;
+ in3 =
input11.constructLops();
+ handled = true;
+ }
+ break;
+ default:
+ break;
}
- break;
- default: break;
}
- }
- if (!handled) {
- in1 = input11.constructLops();
- in2 = input12.constructLops();
- in3 = new LiteralOp(1).constructLops();
+ if (!handled) {
+ in1 = input11.constructLops();
+ in2 = input12.constructLops();
+ in3 = new LiteralOp(1).constructLops();
+ }
+ } else {
+ NaryOp input1 = (NaryOp) input;
+ in1 = input1.getInput().get(0).constructLops();
+ in2 = input1.getInput().get(1).constructLops();
+ in3 = input1.getInput().get(2).constructLops();
}
- //create new ternary aggregate operator
- int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads
);
+ //create new ternary aggregate operator
+ int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
// The execution type of a unary aggregate instruction should
depend on the execution type of inputs to avoid OOM
// Since we only support matrix-vector and not vector-matrix,
checking the execution type of input1 should suffice.
- ExecType et_input = input1.optFindExecType();
+ ExecType et_input = input.optFindExecType();
// Because ternary aggregate are not supported on GPU
- et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
+ et_input = et_input == ExecType.GPU ? ExecType.CP : et_input;
// If forced ExecType is FED, it means that the federated
planner updated the ExecType and
// execution may fail if ExecType is not FED
et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED
: et_input;
-
- return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
- OpOp2.MULT, _direction, getDataType(), ValueType.FP64,
et_input, k);
+
+ return new TernaryAggregate(in1, in2, in3, AggOp.SUM,
+ OpOp2.MULT, _direction, getDataType(),
ValueType.FP64, et_input, k);
}
@Override
diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java
b/src/main/java/org/apache/sysds/hops/NaryOp.java
index 0a6e93ca14..44fe8ff609 100644
--- a/src/main/java/org/apache/sysds/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/NaryOp.java
@@ -200,7 +200,8 @@ public class NaryOp extends Hop {
HopRewriteUtils.getSumValidInputNnz(dc,
true));
case MIN:
case MAX:
- case PLUS: return new MatrixCharacteristics(
+ case PLUS:
+ case MULT: return new MatrixCharacteristics(
HopRewriteUtils.getMaxInputDim(this, true),
HopRewriteUtils.getMaxInputDim(this, false), -1, -1);
case LIST:
@@ -230,6 +231,7 @@ public class NaryOp extends Hop {
case MIN:
case MAX:
case PLUS:
+ case MULT:
setDim1(getDataType().isScalar() ? 0 :
HopRewriteUtils.getMaxInputDim(this, true));
setDim2(getDataType().isScalar() ? 0 :
HopRewriteUtils.getMaxInputDim(this, false));
break;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 2b84318e35..68167ac3ae 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -745,8 +745,15 @@ public class HopRewriteUtils {
public static NaryOp createNary(OpOpN op, Hop... inputs) {
Hop mainInput = inputs[0];
- NaryOp nop = new NaryOp(mainInput.getName(),
mainInput.getDataType(),
- mainInput.getValueType(), op, inputs);
+ // safe for unordered inputs of Scalars and Matrices
+ // e.g.: S*M*S = M
+ // safe for Scalar with different value type
+ // e.g.: Scalar(Int) * Scalar(FP64) = Scalar(FP64)
+ boolean containsMatrix =
Arrays.stream(inputs).anyMatch(Hop::isMatrix);
+ boolean containsFP64 = Arrays.stream(inputs).anyMatch(h ->
h.getValueType() == ValueType.FP64);
+ DataType dtOut = containsMatrix ? DataType.MATRIX :
mainInput.getDataType();
+ ValueType vtOut = containsFP64? ValueType.FP64 :
mainInput.getValueType();
+ NaryOp nop = new NaryOp(mainInput.getName(), dtOut, vtOut, op,
inputs);
nop.setBlocksize(mainInput.getBlocksize());
copyLineNumbers(mainInput, nop);
nop.refreshSizeInformation();
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 1a4c4ecebd..314d230842 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2801,8 +2801,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
private static Hop foldMultipleMinMaxOperations(Hop hi)
{
- if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX,
OpOp2.PLUS)
- || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX,
OpOpN.PLUS))
+ if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX,
OpOp2.PLUS, OpOp2.MULT)
+ || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX,
OpOpN.PLUS, OpOpN.MULT))
&& hi.getValueType() != ValueType.STRING //exclude
string concat
&& HopRewriteUtils.isNotMatrixVectorBinaryOperation(hi))
{
@@ -2839,7 +2839,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, hi, hnew);
hi = hnew;
- LOG.debug("Applied
foldMultipleMinMaxPlusOperations (line "+hi.getBeginLine()+").");
+ LOG.debug("Applied
foldMultipleMinMaxPlusMultOperations (line "+hi.getBeginLine()+").");
}
else {
converged = true;
diff --git a/src/main/java/org/apache/sysds/lops/Nary.java
b/src/main/java/org/apache/sysds/lops/Nary.java
index 68e7191ad9..950082e47e 100644
--- a/src/main/java/org/apache/sysds/lops/Nary.java
+++ b/src/main/java/org/apache/sysds/lops/Nary.java
@@ -117,6 +117,8 @@ public class Nary extends Lop {
return "n"+operationType.name().toLowerCase();
case PLUS:
return "n+";
+ case MULT:
+ return "n*";
default:
throw new UnsupportedOperationException(
"Nary operation type (" + operationType
+ ") is not defined.");
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 792eace24e..7a5ed3524c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -180,6 +180,7 @@ public class CPInstructionParser extends InstructionParser {
String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
String2CPInstructionType.put( "n+" , CPType.BuiltinNary);
+ String2CPInstructionType.put( "n*" , CPType.BuiltinNary);
String2CPInstructionType.put( "exp" , CPType.Unary);
String2CPInstructionType.put( "abs" , CPType.Unary);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 06e68a63d5..5e4dbaedeb 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -299,6 +299,7 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "nmin", SPType.BuiltinNary);
String2SPInstructionType.put( "nmax", SPType.BuiltinNary);
String2SPInstructionType.put( "n+", SPType.BuiltinNary);
+ String2SPInstructionType.put( "n*", SPType.BuiltinNary);
String2SPInstructionType.put( DataGen.RAND_OPCODE ,
SPType.Rand);
String2SPInstructionType.put( DataGen.SEQ_OPCODE ,
SPType.Rand);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
index 31110423cb..c6b590a61c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.cp;
import org.apache.sysds.common.Types.OpOpN;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -85,6 +86,10 @@ public abstract class BuiltinNaryCPInstruction extends
CPInstruction
return new MatrixBuiltinNaryCPInstruction(
new SimpleOperator(Plus.getPlusFnObject()),
opcode, str, outputOperand, inputOperands);
}
+ else if( opcode.equals("n*") ) {
+ return new MatrixBuiltinNaryCPInstruction(
+ new
SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand,
inputOperands);
+ }
else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) {
return new EvalNaryCPInstruction(null, opcode, str,
outputOperand, inputOperands);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
index d6a0ec53f5..172c15b592 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java
@@ -62,7 +62,7 @@ public class MatrixBuiltinNaryCPInstruction extends
BuiltinNaryCPInstruction imp
outBlock =
((FrameBlock)outBlock).append(frames.get(i), cbind);
}
}
- else if( ArrayUtils.contains(new String[]{"nmin", "nmax",
"n+"}, getOpcode()) ) {
+ else if( ArrayUtils.contains(new String[]{"nmin", "nmax", "n+",
"n*"}, getOpcode()) ) {
outBlock = MatrixBlock.naryOperations(_optr,
matrices.toArray(new MatrixBlock[0]),
scalars.toArray(new ScalarObject[0]), new
MatrixBlock());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
index c3118501ba..8e46a96357 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java
@@ -34,6 +34,7 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -163,7 +164,7 @@ public class BuiltinNarySPInstruction extends SPInstruction
implements LineageTr
return;
}
}
- else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+"},
getOpcode()) ) {
+ else if( ArrayUtils.contains(new
String[]{"nmin","nmax","n+","n*"}, getOpcode()) ) {
//compute output characteristics
dcout = computeMinMaxOutputDataCharacteristics(sec,
inputs);
@@ -179,7 +180,7 @@ public class BuiltinNarySPInstruction extends SPInstruction
implements LineageTr
}
//compute nary min/max (partitioning-preserving)
- out = in.mapValues(new MinMaxAddFunction(getOpcode(),
scalars));
+ out = in.mapValues(new
MinMaxAddMultFunction(getOpcode(), scalars));
}
//set output RDD and add lineage
@@ -278,17 +279,17 @@ public class BuiltinNarySPInstruction extends
SPInstruction implements LineageTr
}
}
- private static class MinMaxAddFunction implements
Function<MatrixBlock[], MatrixBlock> {
+ private static class MinMaxAddMultFunction implements
Function<MatrixBlock[], MatrixBlock> {
private static final long serialVersionUID =
-4227447915387484397L;
private final SimpleOperator _op;
private final ScalarObject[] _scalars;
-
- public MinMaxAddFunction(String opcode, List<ScalarObject>
scalars) {
+
+ public MinMaxAddMultFunction(String opcode, List<ScalarObject>
scalars) {
_scalars = scalars.toArray(new ScalarObject[0]);
- _op = new SimpleOperator(opcode.equals("n+") ?
- Plus.getPlusFnObject() :
-
Builtin.getBuiltinFnObject(opcode.substring(1)));
+ _op = new SimpleOperator(opcode.equals("n+") ?
Plus.getPlusFnObject() :
+ opcode.equals("n*") ?
Multiply.getMultiplyFnObject() :
+
Builtin.getBuiltinFnObject(opcode.substring(1)));
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index 6c691c2b8b..1188d65e30 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -390,7 +390,8 @@ public class LineageRecomputeUtils {
break;
}
case BuiltinNary: {
- String opcode =
item.getOpcode().equals("n+") ? "plus" : item.getOpcode();
+ String opcode =
item.getOpcode().equals("n+") ? "plus" :
+
item.getOpcode().equals("n*") ? "mult" : item.getOpcode();
operands.put(item.getId(), HopRewriteUtils.createNary(
OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands)));
break;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 061c09ea3d..108462c567 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -3844,6 +3844,24 @@ public class LibMatrixMult
c[ ci+6 ] *= aval; c[ ci+7 ] *= aval;
}
}
+
+ public static void vectMultiplyInPlace(final double[] a, double[] c,
int[] cix, final int ai, final int ci, final int len) {
+ final int bn = len%8;
+ //rest, not aligned to 8-blocks
+ for( int j = ci; j < ci+bn; j++ )
+ c[ j ] *= a[ ai+cix[j] ];
+ //unrolled 8-block (for better instruction-level parallelism)
+ for( int j = ci+bn; j < ci+len; j+=8 ) {
+ c[ j+0 ] *= a[ ai+cix[j+0] ];
+ c[ j+1 ] *= a[ ai+cix[j+1] ];
+ c[ j+2 ] *= a[ ai+cix[j+2] ];
+ c[ j+3 ] *= a[ ai+cix[j+3] ];
+ c[ j+4 ] *= a[ ai+cix[j+4] ];
+ c[ j+5 ] *= a[ ai+cix[j+5] ];
+ c[ j+6 ] *= a[ ai+cix[j+6] ];
+ c[ j+7 ] *= a[ ai+cix[j+7] ];
+ }
+ }
//note: public for use by codegen for consistency
public static void vectMultiplyWrite( double[] a, double[] b, double[]
c, int ai, int bi, int ci, final int len )
@@ -3889,7 +3907,7 @@ public class LibMatrixMult
}
}
- private static void vectMultiply( double[] a, double[] c, int ai, int
ci, final int len )
+ public static void vectMultiply(double[] a, double[] c, int ai, int ci,
final int len)
{
final int bn = len%8;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 8e2e7f0665..62f9a1febb 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -3805,32 +3805,38 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
//prepare operator
FunctionObject fn = ((SimpleOperator)op).fn;
boolean plus = fn instanceof Plus;
- Builtin bfn = !plus ? (Builtin)((SimpleOperator)op).fn : null;
-
+ boolean mult = fn instanceof Multiply;
+ boolean minmax = !mult && !plus;
+ Builtin bfn = minmax ? (Builtin) fn : null;
for(int i = 0; i < matrices.length; i++)
if(matrices[i] instanceof CompressedMatrixBlock)
matrices[i] =
CompressedMatrixBlock.getUncompressed(matrices[i], "Nary operation process add
row");
//process all scalars
- double init = plus ? 0 :(bfn.getBuiltinCode() ==
BuiltinCode.MIN) ?
- Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY;
+ double init = plus ? 0 : mult ? 1 : (bfn.getBuiltinCode() ==
BuiltinCode.MIN) ?
+ Double.POSITIVE_INFINITY :
Double.NEGATIVE_INFINITY;
for( ScalarObject so : scalars )
init = fn.execute(init, so.getDoubleValue());
//compute output dimensions and estimate sparsity
final int m = matrices.length > 0 ? matrices[0].rlen : 1;
final int n = matrices.length > 0 ? matrices[0].clen : 1;
+
+ //check for empty multiply with empty matrix
+ if(mult)
+
if(Arrays.stream(matrices).anyMatch(MatrixBlock::isEmptyBlock))
+ return new MatrixBlock(m,n, 0);
+
final long mn = (long) m * n;
- final long nnz = (!plus &&
bfn.getBuiltinCode()==BuiltinCode.MIN && init < 0)
- || (!plus && bfn.getBuiltinCode()==BuiltinCode.MAX &&
init > 0) ? mn :
+ final long nnz = (minmax &&
bfn.getBuiltinCode()==BuiltinCode.MIN && init < 0)
+ || (minmax && bfn.getBuiltinCode()==BuiltinCode.MAX &&
init > 0) ? mn :
+ mult? Arrays.stream(matrices).mapToLong(mb ->
mb.nonZeros).min().orElse(mn) :
Math.min(Arrays.stream(matrices).mapToLong(mb ->
mb.nonZeros).sum(), mn);
- boolean sp = evalSparseFormatInMemory(m, n, nnz);
+ //if multiply and least one sparse input -> output = sparse
+ boolean sp = (mult && Arrays.stream(matrices).anyMatch(mb ->
mb.sparse)) || evalSparseFormatInMemory(m, n, nnz);
//init result matrix
- if( ret == null )
- ret = new MatrixBlock(m, n, sp, nnz);
- else
- ret.reset(m, n, sp, nnz);
+ ret.reset(m, n, sp, nnz);
//main processing
if( matrices.length > 0 ) {
@@ -3839,18 +3845,21 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
mb -> mb.sparse || mb.isEmpty()) ? new int[n] :
null;
if( ret.isInSparseFormat() ) {
double[] tmp = new double[n];
- for(int i = 0; i < m; i++) {
- //reset tmp and compute row output
- Arrays.fill(tmp, init);
- if( plus )
- processAddRow(matrices, tmp, 0,
n, i);
- else
- processMinMaxRow(bfn, matrices,
tmp, 0, n, i, cnt);
- //copy to sparse output
- for(int j = 0; j < n; j++)
- if( tmp[j] != 0 )
- ret.appendValue(i, j,
tmp[j]);
- }
+ for(int i = 0; i < m; i++)
+ if (mult)
+ processMultRowSparse(matrices,
init, n, i, ret);
+ else{
+ //reset tmp and compute row
output
+ Arrays.fill(tmp, init);
+ if( plus )
+ processAddRow(matrices,
tmp, 0, n, i);
+ else
+ processMinMaxRow(bfn,
matrices, tmp, 0, n, i, cnt);
+ //copy to sparse output
+ for(int j = 0; j < n; j++)
+ if( tmp[j] != 0 )
+
ret.appendValue(i, j, tmp[j]);
+ }
}
else {
DenseBlock c = ret.getDenseBlock();
@@ -3860,17 +3869,21 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
Arrays.fill(c.values(i),
c.pos(i), c.pos(i)+n, init);
if( plus )
processAddRow(matrices,
c.values(i), c.pos(i), n, i);
+ else if (mult)
+ processMultRowDense(matrices,
c.values(i), c.pos(i), n, i);
else
processMinMaxRow(bfn, matrices,
c.values(i), c.pos(i), n, i, cnt);
lnnz +=
UtilFunctions.countNonZeros(c.values(i), c.pos(i), n);
}
ret.setNonZeros(lnnz);
+
+ //reevaluate sparsity
+ if(mult && ret.evalSparseFormatInMemory())
+ ret.denseToSparse();
}
}
- else {
+ else
ret.set(0, 0, init);
- }
-
return ret;
}
@@ -3891,6 +3904,61 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
}
}
+
+ private static void processMultRowDense(MatrixBlock[] inputs, double[]
c, int cix, int n, int i) {
+ // if inputs contain sparse -> result == sparse ->
processMultRowSparse
+ for( MatrixBlock in : inputs ) {
+ DenseBlock a = in.getDenseBlock();
+
LibMatrixMult.vectMultiply(in.getDenseBlock().values(i), c, a.pos(i), cix, n);
+ }
+ }
+
+ private static void processMultRowSparse(MatrixBlock[] inputs, double
init, int n, int i, MatrixBlock ret) {
+ SparseBlock[] sparse_inputs = new SparseBlock[inputs.length];
+ int len_sparse_inputs = 0;
+ for( MatrixBlock in : inputs )
+ if (in.isInSparseFormat())
+ sparse_inputs[len_sparse_inputs++] =
in.getSparseBlock();
+
+ int size = sparse_inputs[0].size(i);
+ int[] sparse_indices = new int[size];
+ double[] sparse_values = new double[size];
+ for (int j = 0; j < size; j++){
+ sparse_values[j] =
init*sparse_inputs[0].values(i)[sparse_inputs[0].pos(i) + j];
+ sparse_indices[j] =
sparse_inputs[0].indexes(i)[sparse_inputs[0].pos(i) + j];
+ }
+ for (int k = 1; k < len_sparse_inputs; k++){
+ SparseBlock a = sparse_inputs[k];
+
+ //calculate intersection
+ int aix = a.pos(i);
+ int aix_end = a.pos(i) + a.size(i);
+ int ix_read = 0;
+ int ix_write = 0;
+ while(ix_read < size && aix < aix_end){
+ if(sparse_indices[ix_read] < a.indexes(i)[aix])
+ ix_read += 1;
+ else if(sparse_indices[ix_read] >
a.indexes(i)[aix])
+ aix += 1;
+ else{
+ sparse_indices[ix_write] =
sparse_indices[ix_read];
+ sparse_values[ix_write] =
sparse_values[ix_read]*a.values(i)[aix];
+ aix += 1;
+ ix_read += 1;
+ ix_write += 1;
+ }
+ }
+ size = ix_write;
+ }
+
+ //iterate dense blocks
+ for( MatrixBlock in : inputs )
+ if( !in.isInSparseFormat() )
+
LibMatrixMult.vectMultiplyInPlace(in.denseBlock.values(i),
+ sparse_values, sparse_indices,
in.denseBlock.pos(i), 0, size);
+ for (int j = 0; j < size; j++)
+ ret.appendValue(i, sparse_indices[j], sparse_values[j]);
+ }
private static void processMinMaxRow(Builtin fn, MatrixBlock[] inputs,
double[] c, int cix, int n, int i, int[] cnt) {
if( cnt != null )
diff --git
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
index ac3cf3a041..a4a6b0381d 100644
---
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java
@@ -133,7 +133,7 @@ public class ElementwiseAdditionTest extends
AutomatedTestBase
config.addVariable("cols2", cols2);
loadTestConfiguration(config);
-
+ //this.programArgs
runTest(true,LanguageException.class);
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java
new file mode 100644
index 0000000000..6fab8f63a9
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RewriteNaryMultTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteNaryMultDense1";
+ private static final String TEST_NAME2 = "RewriteNaryMultDense2";
+ private static final String TEST_NAME3 = "RewriteNaryMultSparse1";
+
+ private static final String TEST_NAME4 = "RewriteNaryMultSparse2";
+
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteNaryMultTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}) );
+ addTestConfiguration( TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) );
+ addTestConfiguration( TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) );
+ addTestConfiguration( TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) );
+ }
+
+ @Test
+ public void testExpectExceptionParseNaryCPInstruction(){
+ String opcode = "nxy";
+ try {
+ BuiltinNaryCPInstruction.parseInstruction("CP°" +
opcode +"°A·MATRIX·FP64°B·MATRIX·FP64°C·MATRIX·FP64°_mVar3·MATRIX·FP64");
+ }
+ catch (DMLRuntimeException e){
+ assert e.getMessage().equals("Opcode (" + opcode + ")
not recognized in BuiltinMultipleCPInstruction");
+ }
+ }
+
+ @Test
+ public void testNoRewriteDense1CP() {
+ testRewriteNaryMult(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDense1CP() {
+ testRewriteNaryMult(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDense2CP() {
+ testRewriteNaryMult(TEST_NAME2, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteSparse1CP() {
+ testRewriteNaryMult(TEST_NAME3, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteSparse2CP() {
+ testRewriteNaryMult(TEST_NAME4, true, ExecType.CP);
+ }
+
+ @Test
+ public void testNoRewriteDense1SP() {
+ testRewriteNaryMult(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testRewriteDense1SP() {
+ testRewriteNaryMult(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testRewriteDense2SP() {
+ testRewriteNaryMult(TEST_NAME2, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testRewriteSparse1SP() {
+ testRewriteNaryMult(TEST_NAME3, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testRewriteSparse2SP() {
+ testRewriteNaryMult(TEST_NAME4, true, ExecType.SPARK);
+ }
+
+
+ private void testRewriteNaryMult(String name, boolean rewrites,
ExecType et)
+ {
+ ExecMode oldMode = setExecMode(et);
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(name);
+ loadTestConfiguration(config);
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + name + ".dml";
+ programArgs = new String[]{"-explain","-stats","-args",
output("R") };
+
+ runTest(true, false, null, -1);
+
+ //compare output
+ Double ret = readDMLMatrixFromOutputDir("R").get(new
CellIndex(1,1));
+ if(name.equals(TEST_NAME3))
+ Assert.assertEquals(Double.valueOf(1 + 304 +
10000), ret);
+ else if(name.equals(TEST_NAME4))
+ Assert.assertEquals(Double.valueOf(1), ret,
1e-7);
+ else
+
Assert.assertEquals(Double.valueOf(2*3*4*5*6*1000), ret);
+
+ //check for applied nary plus
+ String prefix = et == ExecType.SPARK ? "sp_" : "";
+ if( rewrites && !name.equals(TEST_NAME2) )
+ Assert.assertEquals(1, Statistics.getCPHeavyHitterCount(prefix
+ "n*"));
+ else
+
Assert.assertTrue(Statistics.getCPHeavyHitterCount(prefix+"*")>=1);
+ }
+ finally {
+ resetExecMode(oldMode);
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
index a7224d88b0..3d7de4983a 100644
---
a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java
@@ -41,10 +41,11 @@ public class TernaryAggregateTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "TernaryAggregateRC";
private final static String TEST_NAME2 = "TernaryAggregateC";
+ private final static String TEST_NAME3 = "TernaryAggregatePow";
private final static String TEST_DIR = "functions/ternary/";
private final static String TEST_CLASS_DIR = TEST_DIR +
TernaryAggregateTest.class.getSimpleName() + "/";
- private final static double eps = 1e-8;
+ private final static double eps = 1e-7;
private final static int rows = 1111;
private final static int cols = 1011;
@@ -56,7 +57,8 @@ public class TernaryAggregateTest extends AutomatedTestBase
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
- addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
}
@Test
@@ -180,6 +182,36 @@ public class TernaryAggregateTest extends AutomatedTestBase
public void testTernaryAggregateCSparseMatrixCPNoRewrite() {
runTernaryAggregateTest(TEST_NAME2, true, false, false,
ExecType.CP);
}
+
+ @Test
+ public void testTernaryAggregateNoRewritePowDenseCP() {
+ runTernaryAggregateTest(TEST_NAME3, false, true, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateNoRewritePowSparseCP() {
+ runTernaryAggregateTest(TEST_NAME3, true, true, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRewritePowDenseCP() {
+ runTernaryAggregateTest(TEST_NAME3, false, true, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRewritePowDenseSP() {
+ runTernaryAggregateTest(TEST_NAME3, false, true, true,
ExecType.SPARK);
+ }
+
+ @Test
+ public void testTernaryAggregateRewritePowSparseCP() {
+ runTernaryAggregateTest(TEST_NAME3, true, true, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testTernaryAggregateRewritePowSparseSP() {
+ runTernaryAggregateTest(TEST_NAME3, true, true, true,
ExecType.SPARK);
+ }
private void runTernaryAggregateTest(String testname, boolean sparse,
boolean vectors, boolean rewrites, ExecType et)
{
@@ -203,10 +235,9 @@ public class TernaryAggregateTest extends AutomatedTestBase
loadTestConfiguration(config);
OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{"-stats","-args",
input("A"), output("R")};
+ programArgs = new String[]{"-explain","-stats","-args",
input("A"), output("R")};
fullRScriptName = HOME + testname + ".R";
rCmd = "Rscript" + " " + fullRScriptName + " " +
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml
b/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml
new file mode 100644
index 0000000000..ce7fdce817
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100; N = 10;
+A = matrix(4, M, N);
+B = matrix(12, M, N);
+C = matrix(15, M, N);
+while(FALSE){}
+
+R = A * B * C
+
+while(FALSE){}
+s = as.matrix(sum(R));
+write(s,$1);
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml
b/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml
new file mode 100644
index 0000000000..36dfc50523
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100; N = 10;
+A = matrix(1, M, N);
+D = matrix(2, M, N);
+B = matrix(3, M, 1);
+E = matrix(4, 1, N);
+C = 5
+F = 6
+while(FALSE){}
+
+R = A * B * C * D * E * F
+
+while(FALSE){}
+s = as.matrix(sum(R));
+write(s,$1);
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml
b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml
new file mode 100644
index 0000000000..043e07741c
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100;
+A = diag(matrix(1, M, 1));
+B = matrix(seq(1,M*M), M, M);
+D = matrix(0, M, 1)
+D[4,1] = 1
+D[1,1] = 1
+D[M,1] = 1
+D = diag(D)
+while(FALSE){}
+
+R = A * B * D
+
+while(FALSE){}
+s = as.matrix(sum(R));
+write(s,$1);
diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml
b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml
new file mode 100644
index 0000000000..07180a33d8
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+M = 100; N = 10;
+A = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1)
+B = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=1.0)
+C = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1)
+D = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1)
+while(FALSE){}
+
+R = A * B * C * D
+
+while(FALSE){}
+R2 = B * C
+while(FALSE){}
+R3 = A * D
+while(FALSE){}
+R4 = R2*R3
+while(FALSE){}
+R[1,1] = R[1,1] + 1
+s = as.matrix(sum(R - R4));
+write(s,$1);
diff --git a/src/test/scripts/functions/ternary/TernaryAggregatePow.R
b/src/test/scripts/functions/ternary/TernaryAggregatePow.R
new file mode 100644
index 0000000000..d0f2227869
--- /dev/null
+++ b/src/test/scripts/functions/ternary/TernaryAggregatePow.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A = as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B = A * 2;
+C = A * 3;
+
+R = t(as.matrix(colSums(A^3)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
diff --git a/src/test/scripts/functions/ternary/TernaryAggregatePow.dml
b/src/test/scripts/functions/ternary/TernaryAggregatePow.dml
new file mode 100644
index 0000000000..e8e2fe5ca9
--- /dev/null
+++ b/src/test/scripts/functions/ternary/TernaryAggregatePow.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+while(FALSE){}
+
+R = colSums(A^3);
+
+write(R, $2);