This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new d0287b1983 [SYSTEMDS-3772] Nary elementwise multiplication rewrite/runtime d0287b1983 is described below commit d0287b1983c4fd915d5106199ff03a0d7183a397 Author: e-strauss <lathan...@gmx.de> AuthorDate: Sun Sep 22 12:28:30 2024 +0200 [SYSTEMDS-3772] Nary elementwise multiplication rewrite/runtime - new runtime operations and rewrite for nary mult similar to add - fixed bug in nary hop, where output value type was computed wrongfully in case of scalars with mixed value types - fixed bug in nary hop, where output is wrongfully set as scalar - clean up & fixed the aggregate ternary rewrite to also work nary mult & fixed the n* lineage issue - increased code coverage for rewrite ternary aggregate: colsum(A^3) Closes #2105. --- src/main/java/org/apache/sysds/common/Types.java | 4 +- .../java/org/apache/sysds/hops/AggUnaryOp.java | 149 +++++++++++--------- src/main/java/org/apache/sysds/hops/NaryOp.java | 4 +- .../apache/sysds/hops/rewrite/HopRewriteUtils.java | 11 +- .../RewriteAlgebraicSimplificationDynamic.java | 6 +- src/main/java/org/apache/sysds/lops/Nary.java | 2 + .../runtime/instructions/CPInstructionParser.java | 1 + .../runtime/instructions/SPInstructionParser.java | 1 + .../instructions/cp/BuiltinNaryCPInstruction.java | 5 + .../cp/MatrixBuiltinNaryCPInstruction.java | 2 +- .../spark/BuiltinNarySPInstruction.java | 17 +-- .../runtime/lineage/LineageRecomputeUtils.java | 3 +- .../sysds/runtime/matrix/data/LibMatrixMult.java | 20 ++- .../sysds/runtime/matrix/data/MatrixBlock.java | 120 ++++++++++++---- .../binary/matrix/ElementwiseAdditionTest.java | 2 +- .../functions/rewrite/RewriteNaryMultTest.java | 155 +++++++++++++++++++++ .../functions/ternary/TernaryAggregateTest.java | 39 +++++- .../functions/rewrite/RewriteNaryMultDense1.dml | 32 +++++ .../functions/rewrite/RewriteNaryMultDense2.dml | 35 +++++ .../functions/rewrite/RewriteNaryMultSparse1.dml | 36 +++++ .../functions/rewrite/RewriteNaryMultSparse2.dml | 40 ++++++ .../functions/ternary/TernaryAggregatePow.R | 33 +++++ .../functions/ternary/TernaryAggregatePow.dml | 28 ++++ 23 files changed, 632 insertions(+), 113 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index a7397ae54b..6d64cbae54 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -739,10 +739,10 @@ public interface Types { /** Operations that require a variable number of operands*/ public enum OpOpN { - PRINTF, CBIND, RBIND, MIN, MAX, PLUS, EVAL, LIST; + PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST; public boolean isCellOp() { - return this == MIN || this == MAX || this == PLUS; + return this == MIN || this == MAX || this == PLUS || this == MULT; } } diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java index dcd68b055e..eec86ec15b 100644 --- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java @@ -20,6 +20,7 @@ package org.apache.sysds.hops; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.AggOp; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.Direction; @@ -30,6 +31,7 @@ import org.apache.sysds.hops.AggBinaryOp.SparkAggType; import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.lops.Lop; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.lops.Nary; import org.apache.sysds.lops.PartialAggregate; import org.apache.sysds.lops.TernaryAggregate; import org.apache.sysds.lops.UAggOuterChain; @@ -38,6 +40,8 @@ import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import java.util.List; + // Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum public class AggUnaryOp extends MultiThreadedHop @@ -475,6 +479,17 @@ public class AggUnaryOp extends MultiThreadedHop } } } + if (input1.getParent().size() == 1 + && input1 instanceof NaryOp) { //sum single consumer + NaryOp nop = (NaryOp) input1; + if(nop.getOp() == Types.OpOpN.MULT){ + List<Hop> inputsN = nop.getInput(); + if(inputsN.size() == 3){ + ret = HopRewriteUtils.isEqualSize(inputsN.get(0), inputsN.get(1)) && + HopRewriteUtils.isEqualSize(inputsN.get(1), inputsN.get(2)); + } + } + } } return ret; } @@ -554,83 +569,91 @@ public class AggUnaryOp extends MultiThreadedHop private Lop constructLopsTernaryAggregateRewrite(ExecType et) { - BinaryOp input1 = (BinaryOp)getInput().get(0); - Hop input11 = input1.getInput().get(0); - Hop input12 = input1.getInput().get(1); - Lop in1 = null, in2 = null, in3 = null; - boolean handled = false; - - if (input1.getOp() == OpOp2.POW) { - assert(HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3"; - in1 = input11.constructLops(); - in2 = in1; - in3 = in1; - handled = true; - } - else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, OpOp2.POW) ) { - BinaryOp b11 = (BinaryOp)input11; - switch( b11.getOp() ) { - case MULT: // A*B*C case - in1 = input11.getInput().get(0).constructLops(); - in2 = input11.getInput().get(1).constructLops(); - in3 = input12.constructLops(); - handled = true; - break; - case POW: // A*A*B case - Hop b112 = b11.getInput().get(1); - if ( !(input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT) - && HopRewriteUtils.isLiteralOfValue(b112, 2) ) { - in1 = b11.getInput().get(0).constructLops(); - in2 = in1; - in3 = input12.constructLops(); - handled = true; - } - break; - default: break; - } - } - else if( HopRewriteUtils.isBinary(input12, OpOp2.MULT, OpOp2.POW) ) { - BinaryOp b12 = (BinaryOp)input12; - switch (b12.getOp()) { - case MULT: // A*B*C case + Hop input = getInput().get(0); + if(input instanceof BinaryOp) { + BinaryOp input1 = (BinaryOp) input; + Hop input11 = input1.getInput().get(0); + Hop input12 = input1.getInput().get(1); + + boolean handled = false; + + if (input1.getOp() == OpOp2.POW) { + assert (HopRewriteUtils.isLiteralOfValue(input12, 3)) : "this case can only occur with a power of 3"; in1 = input11.constructLops(); - in2 = input12.getInput().get(0).constructLops(); - in3 = input12.getInput().get(1).constructLops(); + in2 = in1; + in3 = in1; handled = true; - break; - case POW: // A*B*B case - Hop b112 = b12.getInput().get(1); - if ( HopRewriteUtils.isLiteralOfValue(b112, 2) ) { - in1 = b12.getInput().get(0).constructLops(); - in2 = in1; - in3 = input11.constructLops(); - handled = true; + } else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, OpOp2.POW)) { + BinaryOp b11 = (BinaryOp) input11; + switch (b11.getOp()) { + case MULT: // A*B*C case + in1 = input11.getInput().get(0).constructLops(); + in2 = input11.getInput().get(1).constructLops(); + in3 = input12.constructLops(); + handled = true; + break; + case POW: // A*A*B case + Hop b112 = b11.getInput().get(1); + if (!(input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) + && HopRewriteUtils.isLiteralOfValue(b112, 2)) { + in1 = b11.getInput().get(0).constructLops(); + in2 = in1; + in3 = input12.constructLops(); + handled = true; + } + break; + default: + break; + } + } else if (HopRewriteUtils.isBinary(input12, OpOp2.MULT, OpOp2.POW)) { + BinaryOp b12 = (BinaryOp) input12; + switch (b12.getOp()) { + case MULT: // A*B*C case + in1 = input11.constructLops(); + in2 = input12.getInput().get(0).constructLops(); + in3 = input12.getInput().get(1).constructLops(); + handled = true; + break; + case POW: // A*B*B case + Hop b112 = b12.getInput().get(1); + if (HopRewriteUtils.isLiteralOfValue(b112, 2)) { + in1 = b12.getInput().get(0).constructLops(); + in2 = in1; + in3 = input11.constructLops(); + handled = true; + } + break; + default: + break; } - break; - default: break; } - } - if (!handled) { - in1 = input11.constructLops(); - in2 = input12.constructLops(); - in3 = new LiteralOp(1).constructLops(); + if (!handled) { + in1 = input11.constructLops(); + in2 = input12.constructLops(); + in3 = new LiteralOp(1).constructLops(); + } + } else { + NaryOp input1 = (NaryOp) input; + in1 = input1.getInput().get(0).constructLops(); + in2 = input1.getInput().get(1).constructLops(); + in3 = input1.getInput().get(2).constructLops(); } - //create new ternary aggregate operator - int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ); + //create new ternary aggregate operator + int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); // The execution type of a unary aggregate instruction should depend on the execution type of inputs to avoid OOM // Since we only support matrix-vector and not vector-matrix, checking the execution type of input1 should suffice. - ExecType et_input = input1.optFindExecType(); + ExecType et_input = input.optFindExecType(); // Because ternary aggregate are not supported on GPU - et_input = et_input == ExecType.GPU ? ExecType.CP : et_input; + et_input = et_input == ExecType.GPU ? ExecType.CP : et_input; // If forced ExecType is FED, it means that the federated planner updated the ExecType and // execution may fail if ExecType is not FED et_input = (getForcedExecType() == ExecType.FED) ? ExecType.FED : et_input; - - return new TernaryAggregate(in1, in2, in3, AggOp.SUM, - OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k); + + return new TernaryAggregate(in1, in2, in3, AggOp.SUM, + OpOp2.MULT, _direction, getDataType(), ValueType.FP64, et_input, k); } @Override diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java b/src/main/java/org/apache/sysds/hops/NaryOp.java index 0a6e93ca14..44fe8ff609 100644 --- a/src/main/java/org/apache/sysds/hops/NaryOp.java +++ b/src/main/java/org/apache/sysds/hops/NaryOp.java @@ -200,7 +200,8 @@ public class NaryOp extends Hop { HopRewriteUtils.getSumValidInputNnz(dc, true)); case MIN: case MAX: - case PLUS: return new MatrixCharacteristics( + case PLUS: + case MULT: return new MatrixCharacteristics( HopRewriteUtils.getMaxInputDim(this, true), HopRewriteUtils.getMaxInputDim(this, false), -1, -1); case LIST: @@ -230,6 +231,7 @@ public class NaryOp extends Hop { case MIN: case MAX: case PLUS: + case MULT: setDim1(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, true)); setDim2(getDataType().isScalar() ? 0 : HopRewriteUtils.getMaxInputDim(this, false)); break; diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java index 2b84318e35..68167ac3ae 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java @@ -745,8 +745,15 @@ public class HopRewriteUtils { public static NaryOp createNary(OpOpN op, Hop... inputs) { Hop mainInput = inputs[0]; - NaryOp nop = new NaryOp(mainInput.getName(), mainInput.getDataType(), - mainInput.getValueType(), op, inputs); + // safe for unordered inputs of Scalars and Matrices + // e.g.: S*M*S = M + // safe for Scalar with different value type + // e.g.: Scalar(Int) * Scalar(FP64) = Scalar(FP64) + boolean containsMatrix = Arrays.stream(inputs).anyMatch(Hop::isMatrix); + boolean containsFP64 = Arrays.stream(inputs).anyMatch(h -> h.getValueType() == ValueType.FP64); + DataType dtOut = containsMatrix ? DataType.MATRIX : mainInput.getDataType(); + ValueType vtOut = containsFP64? ValueType.FP64 : mainInput.getValueType(); + NaryOp nop = new NaryOp(mainInput.getName(), dtOut, vtOut, op, inputs); nop.setBlocksize(mainInput.getBlocksize()); copyLineNumbers(mainInput, nop); nop.refreshSizeInformation(); diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 1a4c4ecebd..314d230842 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2801,8 +2801,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule private static Hop foldMultipleMinMaxOperations(Hop hi) { - if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, OpOp2.PLUS) - || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS)) + if( (HopRewriteUtils.isBinary(hi, OpOp2.MIN, OpOp2.MAX, OpOp2.PLUS, OpOp2.MULT) + || HopRewriteUtils.isNary(hi, OpOpN.MIN, OpOpN.MAX, OpOpN.PLUS, OpOpN.MULT)) && hi.getValueType() != ValueType.STRING //exclude string concat && HopRewriteUtils.isNotMatrixVectorBinaryOperation(hi)) { @@ -2839,7 +2839,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule for( Hop p : parents ) HopRewriteUtils.replaceChildReference(p, hi, hnew); hi = hnew; - LOG.debug("Applied foldMultipleMinMaxPlusOperations (line "+hi.getBeginLine()+")."); + LOG.debug("Applied foldMultipleMinMaxPlusMultOperations (line "+hi.getBeginLine()+")."); } else { converged = true; diff --git a/src/main/java/org/apache/sysds/lops/Nary.java b/src/main/java/org/apache/sysds/lops/Nary.java index 68e7191ad9..950082e47e 100644 --- a/src/main/java/org/apache/sysds/lops/Nary.java +++ b/src/main/java/org/apache/sysds/lops/Nary.java @@ -117,6 +117,8 @@ public class Nary extends Lop { return "n"+operationType.name().toLowerCase(); case PLUS: return "n+"; + case MULT: + return "n*"; default: throw new UnsupportedOperationException( "Nary operation type (" + operationType + ") is not defined."); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index 792eace24e..7a5ed3524c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -180,6 +180,7 @@ public class CPInstructionParser extends InstructionParser { String2CPInstructionType.put( "nmax", CPType.BuiltinNary); String2CPInstructionType.put( "nmin", CPType.BuiltinNary); String2CPInstructionType.put( "n+" , CPType.BuiltinNary); + String2CPInstructionType.put( "n*" , CPType.BuiltinNary); String2CPInstructionType.put( "exp" , CPType.Unary); String2CPInstructionType.put( "abs" , CPType.Unary); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index 06e68a63d5..5e4dbaedeb 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -299,6 +299,7 @@ public class SPInstructionParser extends InstructionParser String2SPInstructionType.put( "nmin", SPType.BuiltinNary); String2SPInstructionType.put( "nmax", SPType.BuiltinNary); String2SPInstructionType.put( "n+", SPType.BuiltinNary); + String2SPInstructionType.put( "n*", SPType.BuiltinNary); String2SPInstructionType.put( DataGen.RAND_OPCODE , SPType.Rand); String2SPInstructionType.put( DataGen.SEQ_OPCODE , SPType.Rand); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java index 31110423cb..c6b590a61c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java @@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.cp; import org.apache.sysds.common.Types.OpOpN; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -85,6 +86,10 @@ public abstract class BuiltinNaryCPInstruction extends CPInstruction return new MatrixBuiltinNaryCPInstruction( new SimpleOperator(Plus.getPlusFnObject()), opcode, str, outputOperand, inputOperands); } + else if( opcode.equals("n*") ) { + return new MatrixBuiltinNaryCPInstruction( + new SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand, inputOperands); + } else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) { return new EvalNaryCPInstruction(null, opcode, str, outputOperand, inputOperands); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java index d6a0ec53f5..172c15b592 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MatrixBuiltinNaryCPInstruction.java @@ -62,7 +62,7 @@ public class MatrixBuiltinNaryCPInstruction extends BuiltinNaryCPInstruction imp outBlock = ((FrameBlock)outBlock).append(frames.get(i), cbind); } } - else if( ArrayUtils.contains(new String[]{"nmin", "nmax", "n+"}, getOpcode()) ) { + else if( ArrayUtils.contains(new String[]{"nmin", "nmax", "n+", "n*"}, getOpcode()) ) { outBlock = MatrixBlock.naryOperations(_optr, matrices.toArray(new MatrixBlock[0]), scalars.toArray(new ScalarObject[0]), new MatrixBlock()); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java index c3118501ba..8e46a96357 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/BuiltinNarySPInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.functionobjects.Builtin; +import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; @@ -163,7 +164,7 @@ public class BuiltinNarySPInstruction extends SPInstruction implements LineageTr return; } } - else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+"}, getOpcode()) ) { + else if( ArrayUtils.contains(new String[]{"nmin","nmax","n+","n*"}, getOpcode()) ) { //compute output characteristics dcout = computeMinMaxOutputDataCharacteristics(sec, inputs); @@ -179,7 +180,7 @@ public class BuiltinNarySPInstruction extends SPInstruction implements LineageTr } //compute nary min/max (partitioning-preserving) - out = in.mapValues(new MinMaxAddFunction(getOpcode(), scalars)); + out = in.mapValues(new MinMaxAddMultFunction(getOpcode(), scalars)); } //set output RDD and add lineage @@ -278,17 +279,17 @@ public class BuiltinNarySPInstruction extends SPInstruction implements LineageTr } } - private static class MinMaxAddFunction implements Function<MatrixBlock[], MatrixBlock> { + private static class MinMaxAddMultFunction implements Function<MatrixBlock[], MatrixBlock> { private static final long serialVersionUID = -4227447915387484397L; private final SimpleOperator _op; private final ScalarObject[] _scalars; - - public MinMaxAddFunction(String opcode, List<ScalarObject> scalars) { + + public MinMaxAddMultFunction(String opcode, List<ScalarObject> scalars) { _scalars = scalars.toArray(new ScalarObject[0]); - _op = new SimpleOperator(opcode.equals("n+") ? - Plus.getPlusFnObject() : - Builtin.getBuiltinFnObject(opcode.substring(1))); + _op = new SimpleOperator(opcode.equals("n+") ? Plus.getPlusFnObject() : + opcode.equals("n*") ? Multiply.getMultiplyFnObject() : + Builtin.getBuiltinFnObject(opcode.substring(1))); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java index 6c691c2b8b..1188d65e30 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java @@ -390,7 +390,8 @@ public class LineageRecomputeUtils { break; } case BuiltinNary: { - String opcode = item.getOpcode().equals("n+") ? "plus" : item.getOpcode(); + String opcode = item.getOpcode().equals("n+") ? "plus" : + item.getOpcode().equals("n*") ? "mult" : item.getOpcode(); operands.put(item.getId(), HopRewriteUtils.createNary( OpOpN.valueOf(opcode.toUpperCase()), createNaryInputs(item, operands))); break; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index 061c09ea3d..108462c567 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -3844,6 +3844,24 @@ public class LibMatrixMult c[ ci+6 ] *= aval; c[ ci+7 ] *= aval; } } + + public static void vectMultiplyInPlace(final double[] a, double[] c, int[] cix, final int ai, final int ci, final int len) { + final int bn = len%8; + //rest, not aligned to 8-blocks + for( int j = ci; j < ci+bn; j++ ) + c[ j ] *= a[ ai+cix[j] ]; + //unrolled 8-block (for better instruction-level parallelism) + for( int j = ci+bn; j < ci+len; j+=8 ) { + c[ j+0 ] *= a[ ai+cix[j+0] ]; + c[ j+1 ] *= a[ ai+cix[j+1] ]; + c[ j+2 ] *= a[ ai+cix[j+2] ]; + c[ j+3 ] *= a[ ai+cix[j+3] ]; + c[ j+4 ] *= a[ ai+cix[j+4] ]; + c[ j+5 ] *= a[ ai+cix[j+5] ]; + c[ j+6 ] *= a[ ai+cix[j+6] ]; + c[ j+7 ] *= a[ ai+cix[j+7] ]; + } + } //note: public for use by codegen for consistency public static void vectMultiplyWrite( double[] a, double[] b, double[] c, int ai, int bi, int ci, final int len ) @@ -3889,7 +3907,7 @@ public class LibMatrixMult } } - private static void vectMultiply( double[] a, double[] c, int ai, int ci, final int len ) + public static void vectMultiply(double[] a, double[] c, int ai, int ci, final int len) { final int bn = len%8; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 8e2e7f0665..62f9a1febb 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -3805,32 +3805,38 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, //prepare operator FunctionObject fn = ((SimpleOperator)op).fn; boolean plus = fn instanceof Plus; - Builtin bfn = !plus ? (Builtin)((SimpleOperator)op).fn : null; - + boolean mult = fn instanceof Multiply; + boolean minmax = !mult && !plus; + Builtin bfn = minmax ? (Builtin) fn : null; for(int i = 0; i < matrices.length; i++) if(matrices[i] instanceof CompressedMatrixBlock) matrices[i] = CompressedMatrixBlock.getUncompressed(matrices[i], "Nary operation process add row"); //process all scalars - double init = plus ? 0 :(bfn.getBuiltinCode() == BuiltinCode.MIN) ? - Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; + double init = plus ? 0 : mult ? 1 : (bfn.getBuiltinCode() == BuiltinCode.MIN) ? + Double.POSITIVE_INFINITY : Double.NEGATIVE_INFINITY; for( ScalarObject so : scalars ) init = fn.execute(init, so.getDoubleValue()); //compute output dimensions and estimate sparsity final int m = matrices.length > 0 ? matrices[0].rlen : 1; final int n = matrices.length > 0 ? matrices[0].clen : 1; + + //check for empty multiply with empty matrix + if(mult) + if(Arrays.stream(matrices).anyMatch(MatrixBlock::isEmptyBlock)) + return new MatrixBlock(m,n, 0); + final long mn = (long) m * n; - final long nnz = (!plus && bfn.getBuiltinCode()==BuiltinCode.MIN && init < 0) - || (!plus && bfn.getBuiltinCode()==BuiltinCode.MAX && init > 0) ? mn : + final long nnz = (minmax && bfn.getBuiltinCode()==BuiltinCode.MIN && init < 0) + || (minmax && bfn.getBuiltinCode()==BuiltinCode.MAX && init > 0) ? mn : + mult? Arrays.stream(matrices).mapToLong(mb -> mb.nonZeros).min().orElse(mn) : Math.min(Arrays.stream(matrices).mapToLong(mb -> mb.nonZeros).sum(), mn); - boolean sp = evalSparseFormatInMemory(m, n, nnz); + //if multiply and least one sparse input -> output = sparse + boolean sp = (mult && Arrays.stream(matrices).anyMatch(mb -> mb.sparse)) || evalSparseFormatInMemory(m, n, nnz); //init result matrix - if( ret == null ) - ret = new MatrixBlock(m, n, sp, nnz); - else - ret.reset(m, n, sp, nnz); + ret.reset(m, n, sp, nnz); //main processing if( matrices.length > 0 ) { @@ -3839,18 +3845,21 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, mb -> mb.sparse || mb.isEmpty()) ? new int[n] : null; if( ret.isInSparseFormat() ) { double[] tmp = new double[n]; - for(int i = 0; i < m; i++) { - //reset tmp and compute row output - Arrays.fill(tmp, init); - if( plus ) - processAddRow(matrices, tmp, 0, n, i); - else - processMinMaxRow(bfn, matrices, tmp, 0, n, i, cnt); - //copy to sparse output - for(int j = 0; j < n; j++) - if( tmp[j] != 0 ) - ret.appendValue(i, j, tmp[j]); - } + for(int i = 0; i < m; i++) + if (mult) + processMultRowSparse(matrices, init, n, i, ret); + else{ + //reset tmp and compute row output + Arrays.fill(tmp, init); + if( plus ) + processAddRow(matrices, tmp, 0, n, i); + else + processMinMaxRow(bfn, matrices, tmp, 0, n, i, cnt); + //copy to sparse output + for(int j = 0; j < n; j++) + if( tmp[j] != 0 ) + ret.appendValue(i, j, tmp[j]); + } } else { DenseBlock c = ret.getDenseBlock(); @@ -3860,17 +3869,21 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, Arrays.fill(c.values(i), c.pos(i), c.pos(i)+n, init); if( plus ) processAddRow(matrices, c.values(i), c.pos(i), n, i); + else if (mult) + processMultRowDense(matrices, c.values(i), c.pos(i), n, i); else processMinMaxRow(bfn, matrices, c.values(i), c.pos(i), n, i, cnt); lnnz += UtilFunctions.countNonZeros(c.values(i), c.pos(i), n); } ret.setNonZeros(lnnz); + + //reevaluate sparsity + if(mult && ret.evalSparseFormatInMemory()) + ret.denseToSparse(); } } - else { + else ret.set(0, 0, init); - } - return ret; } @@ -3891,6 +3904,61 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, } } } + + private static void processMultRowDense(MatrixBlock[] inputs, double[] c, int cix, int n, int i) { + // if inputs contain sparse -> result == sparse -> processMultRowSparse + for( MatrixBlock in : inputs ) { + DenseBlock a = in.getDenseBlock(); + LibMatrixMult.vectMultiply(in.getDenseBlock().values(i), c, a.pos(i), cix, n); + } + } + + private static void processMultRowSparse(MatrixBlock[] inputs, double init, int n, int i, MatrixBlock ret) { + SparseBlock[] sparse_inputs = new SparseBlock[inputs.length]; + int len_sparse_inputs = 0; + for( MatrixBlock in : inputs ) + if (in.isInSparseFormat()) + sparse_inputs[len_sparse_inputs++] = in.getSparseBlock(); + + int size = sparse_inputs[0].size(i); + int[] sparse_indices = new int[size]; + double[] sparse_values = new double[size]; + for (int j = 0; j < size; j++){ + sparse_values[j] = init*sparse_inputs[0].values(i)[sparse_inputs[0].pos(i) + j]; + sparse_indices[j] = sparse_inputs[0].indexes(i)[sparse_inputs[0].pos(i) + j]; + } + for (int k = 1; k < len_sparse_inputs; k++){ + SparseBlock a = sparse_inputs[k]; + + //calculate intersection + int aix = a.pos(i); + int aix_end = a.pos(i) + a.size(i); + int ix_read = 0; + int ix_write = 0; + while(ix_read < size && aix < aix_end){ + if(sparse_indices[ix_read] < a.indexes(i)[aix]) + ix_read += 1; + else if(sparse_indices[ix_read] > a.indexes(i)[aix]) + aix += 1; + else{ + sparse_indices[ix_write] = sparse_indices[ix_read]; + sparse_values[ix_write] = sparse_values[ix_read]*a.values(i)[aix]; + aix += 1; + ix_read += 1; + ix_write += 1; + } + } + size = ix_write; + } + + //iterate dense blocks + for( MatrixBlock in : inputs ) + if( !in.isInSparseFormat() ) + LibMatrixMult.vectMultiplyInPlace(in.denseBlock.values(i), + sparse_values, sparse_indices, in.denseBlock.pos(i), 0, size); + for (int j = 0; j < size; j++) + ret.appendValue(i, sparse_indices[j], sparse_values[j]); + } private static void processMinMaxRow(Builtin fn, MatrixBlock[] inputs, double[] c, int cix, int n, int i, int[] cnt) { if( cnt != null ) diff --git a/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java b/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java index ac3cf3a041..a4a6b0381d 100644 --- a/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java +++ b/src/test/java/org/apache/sysds/test/functions/binary/matrix/ElementwiseAdditionTest.java @@ -133,7 +133,7 @@ public class ElementwiseAdditionTest extends AutomatedTestBase config.addVariable("cols2", cols2); loadTestConfiguration(config); - + //this.programArgs runTest(true,LanguageException.class); } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java new file mode 100644 index 0000000000..6fab8f63a9 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteNaryMultTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; +import org.junit.Test; + +public class RewriteNaryMultTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewriteNaryMultDense1"; + private static final String TEST_NAME2 = "RewriteNaryMultDense2"; + private static final String TEST_NAME3 = "RewriteNaryMultSparse1"; + + private static final String TEST_NAME4 = "RewriteNaryMultSparse2"; + + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewriteNaryMultTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}) ); + addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) ); + addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) ); + addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}) ); + } + + @Test + public void testExpectExceptionParseNaryCPInstruction(){ + String opcode = "nxy"; + try { + BuiltinNaryCPInstruction.parseInstruction("CP°" + opcode +"°A·MATRIX·FP64°B·MATRIX·FP64°C·MATRIX·FP64°_mVar3·MATRIX·FP64"); + } + catch (DMLRuntimeException e){ + assert e.getMessage().equals("Opcode (" + opcode + ") not recognized in BuiltinMultipleCPInstruction"); + } + } + + @Test + public void testNoRewriteDense1CP() { + testRewriteNaryMult(TEST_NAME1, false, ExecType.CP); + } + + @Test + public void testRewriteDense1CP() { + testRewriteNaryMult(TEST_NAME1, true, ExecType.CP); + } + + @Test + public void testRewriteDense2CP() { + testRewriteNaryMult(TEST_NAME2, true, ExecType.CP); + } + + @Test + public void testRewriteSparse1CP() { + testRewriteNaryMult(TEST_NAME3, true, ExecType.CP); + } + + @Test + public void testRewriteSparse2CP() { + testRewriteNaryMult(TEST_NAME4, true, ExecType.CP); + } + + @Test + public void testNoRewriteDense1SP() { + testRewriteNaryMult(TEST_NAME1, false, ExecType.SPARK); + } + + @Test + public void testRewriteDense1SP() { + testRewriteNaryMult(TEST_NAME1, true, ExecType.SPARK); + } + + @Test + public void testRewriteDense2SP() { + testRewriteNaryMult(TEST_NAME2, true, ExecType.SPARK); + } + + @Test + public void testRewriteSparse1SP() { + testRewriteNaryMult(TEST_NAME3, true, ExecType.SPARK); + } + + @Test + public void testRewriteSparse2SP() { + testRewriteNaryMult(TEST_NAME4, true, ExecType.SPARK); + } + + + private void testRewriteNaryMult(String name, boolean rewrites, ExecType et) + { + ExecMode oldMode = setExecMode(et); + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try + { + TestConfiguration config = getTestConfiguration(name); + loadTestConfiguration(config); + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + name + ".dml"; + programArgs = new String[]{"-explain","-stats","-args", output("R") }; + + runTest(true, false, null, -1); + + //compare output + Double ret = readDMLMatrixFromOutputDir("R").get(new CellIndex(1,1)); + if(name.equals(TEST_NAME3)) + Assert.assertEquals(Double.valueOf(1 + 304 + 10000), ret); + else if(name.equals(TEST_NAME4)) + Assert.assertEquals(Double.valueOf(1), ret, 1e-7); + else + Assert.assertEquals(Double.valueOf(2*3*4*5*6*1000), ret); + + //check for applied nary plus + String prefix = et == ExecType.SPARK ? "sp_" : ""; + if( rewrites && !name.equals(TEST_NAME2) ) + Assert.assertEquals(1, Statistics.getCPHeavyHitterCount(prefix + "n*")); + else + Assert.assertTrue(Statistics.getCPHeavyHitterCount(prefix+"*")>=1); + } + finally { + resetExecMode(oldMode); + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java index a7224d88b0..3d7de4983a 100644 --- a/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ternary/TernaryAggregateTest.java @@ -41,10 +41,11 @@ public class TernaryAggregateTest extends AutomatedTestBase { private final static String TEST_NAME1 = "TernaryAggregateRC"; private final static String TEST_NAME2 = "TernaryAggregateC"; + private final static String TEST_NAME3 = "TernaryAggregatePow"; private final static String TEST_DIR = "functions/ternary/"; private final static String TEST_CLASS_DIR = TEST_DIR + TernaryAggregateTest.class.getSimpleName() + "/"; - private final static double eps = 1e-8; + private final static double eps = 1e-7; private final static int rows = 1111; private final static int cols = 1011; @@ -56,7 +57,8 @@ public class TernaryAggregateTest extends AutomatedTestBase public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); - addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); } @Test @@ -180,6 +182,36 @@ public class TernaryAggregateTest extends AutomatedTestBase public void testTernaryAggregateCSparseMatrixCPNoRewrite() { runTernaryAggregateTest(TEST_NAME2, true, false, false, ExecType.CP); } + + @Test + public void testTernaryAggregateNoRewritePowDenseCP() { + runTernaryAggregateTest(TEST_NAME3, false, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateNoRewritePowSparseCP() { + runTernaryAggregateTest(TEST_NAME3, true, true, false, ExecType.CP); + } + + @Test + public void testTernaryAggregateRewritePowDenseCP() { + runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRewritePowDenseSP() { + runTernaryAggregateTest(TEST_NAME3, false, true, true, ExecType.SPARK); + } + + @Test + public void testTernaryAggregateRewritePowSparseCP() { + runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.CP); + } + + @Test + public void testTernaryAggregateRewritePowSparseSP() { + runTernaryAggregateTest(TEST_NAME3, true, true, true, ExecType.SPARK); + } private void runTernaryAggregateTest(String testname, boolean sparse, boolean vectors, boolean rewrites, ExecType et) { @@ -203,10 +235,9 @@ public class TernaryAggregateTest extends AutomatedTestBase loadTestConfiguration(config); OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites; - String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{"-stats","-args", input("A"), output("R")}; + programArgs = new String[]{"-explain","-stats","-args", input("A"), output("R")}; fullRScriptName = HOME + testname + ".R"; rCmd = "Rscript" + " " + fullRScriptName + " " + diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml b/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml new file mode 100644 index 0000000000..ce7fdce817 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteNaryMultDense1.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +M = 100; N = 10; +A = matrix(4, M, N); +B = matrix(12, M, N); +C = matrix(15, M, N); +while(FALSE){} + +R = A * B * C + +while(FALSE){} +s = as.matrix(sum(R)); +write(s,$1); diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml b/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml new file mode 100644 index 0000000000..36dfc50523 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteNaryMultDense2.dml @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +M = 100; N = 10; +A = matrix(1, M, N); +D = matrix(2, M, N); +B = matrix(3, M, 1); +E = matrix(4, 1, N); +C = 5 +F = 6 +while(FALSE){} + +R = A * B * C * D * E * F + +while(FALSE){} +s = as.matrix(sum(R)); +write(s,$1); diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml new file mode 100644 index 0000000000..043e07741c --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse1.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +M = 100; +A = diag(matrix(1, M, 1)); +B = matrix(seq(1,M*M), M, M); +D = matrix(0, M, 1) +D[4,1] = 1 +D[1,1] = 1 +D[M,1] = 1 +D = diag(D) +while(FALSE){} + +R = A * B * D + +while(FALSE){} +s = as.matrix(sum(R)); +write(s,$1); diff --git a/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml new file mode 100644 index 0000000000..07180a33d8 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteNaryMultSparse2.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +M = 100; N = 10; +A = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1) +B = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=1.0) +C = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1) +D = rand(rows=M, cols=N, min=0, max=1, pdf="uniform", sparsity=0.1) +while(FALSE){} + +R = A * B * C * D + +while(FALSE){} +R2 = B * C +while(FALSE){} +R3 = A * D +while(FALSE){} +R4 = R2*R3 +while(FALSE){} +R[1,1] = R[1,1] + 1 +s = as.matrix(sum(R - R4)); +write(s,$1); diff --git a/src/test/scripts/functions/ternary/TernaryAggregatePow.R b/src/test/scripts/functions/ternary/TernaryAggregatePow.R new file mode 100644 index 0000000000..d0f2227869 --- /dev/null +++ b/src/test/scripts/functions/ternary/TernaryAggregatePow.R @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) + +library("Matrix") + +A = as.matrix(readMM(paste(args[1], "A.mtx", sep=""))) +B = A * 2; +C = A * 3; + +R = t(as.matrix(colSums(A^3))); + +writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); diff --git a/src/test/scripts/functions/ternary/TernaryAggregatePow.dml b/src/test/scripts/functions/ternary/TernaryAggregatePow.dml new file mode 100644 index 0000000000..e8e2fe5ca9 --- /dev/null +++ b/src/test/scripts/functions/ternary/TernaryAggregatePow.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); + +while(FALSE){} + +R = colSums(A^3); + +write(R, $2);