This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4b2d83e915 [SYSTEMDS-3861] Fix redundant transposes due to multi-level
rewrites
4b2d83e915 is described below
commit 4b2d83e915c40b7433580cddfec68aa8c440ba05
Author: aarna <[email protected]>
AuthorDate: Fri Apr 18 12:43:04 2025 +0200
[SYSTEMDS-3861] Fix redundant transposes due to multi-level rewrites
Closes #2249.
---
.../java/org/apache/sysds/hops/AggBinaryOp.java | 487 ++++++++++-----------
.../hops/fedplanner/FederatedMemoTablePrinter.java | 19 +
.../functions/rewrite/RewriteTransposeTest.java | 86 ++++
.../functions/rewrite/RewriteTransposeCase1.R | 32 ++
.../functions/rewrite/RewriteTransposeCase1.dml | 27 ++
.../functions/rewrite/RewriteTransposeCase2.R | 32 ++
.../functions/rewrite/RewriteTransposeCase2.dml | 28 ++
.../functions/rewrite/RewriteTransposeCase3.R | 33 ++
.../functions/rewrite/RewriteTransposeCase3.dml | 28 ++
9 files changed, 519 insertions(+), 253 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 2cf651f189..5f9c6b41b3 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -43,6 +43,7 @@ import org.apache.sysds.lops.MatMultCP;
import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.lops.PMapMult;
import org.apache.sysds.lops.Transform;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -65,7 +66,7 @@ public class AggBinaryOp extends MultiThreadedHop {
public static final double MAPMULT_MEM_MULTIPLIER = 1.0;
public static MMultMethod FORCED_MMULT_METHOD = null;
- public enum MMultMethod {
+ public enum MMultMethod {
CPMM, //cross-product matrix multiplication (mr)
RMM, //replication matrix multiplication (mr)
MAPMM_L, //map-side matrix-matrix multiplication using
distributed cache (mr/sp)
@@ -78,27 +79,27 @@ public class AggBinaryOp extends MultiThreadedHop {
ZIPMM, //zip matrix multiplication (sp)
MM //in-memory matrix multiplication (cp)
}
-
- public enum SparkAggType{
+
+ public enum SparkAggType {
NONE,
SINGLE_BLOCK,
MULTI_BLOCK,
}
-
+
private OpOp2 innerOp;
private AggOp outerOp;
private MMultMethod _method = null;
-
+
//hints set by previous to operator selection
private boolean _hasLeftPMInput = false; //left input is permutation
matrix
-
+
private AggBinaryOp() {
//default constructor for clone
}
-
+
public AggBinaryOp(String l, DataType dt, ValueType vt, OpOp2 innOp,
- AggOp outOp, Hop in1, Hop in2) {
+ AggOp outOp, Hop in1, Hop in2) {
super(l, dt, vt);
innerOp = innOp;
outerOp = outOp;
@@ -106,7 +107,7 @@ public class AggBinaryOp extends MultiThreadedHop {
getInput().add(1, in2);
in1.getParent().add(this);
in2.getParent().add(this);
-
+
//compute unknown dims and nnz
refreshSizeInformation();
}
@@ -114,30 +115,30 @@ public class AggBinaryOp extends MultiThreadedHop {
public void setHasLeftPMInput(boolean flag) {
_hasLeftPMInput = flag;
}
-
- public boolean hasLeftPMInput(){
+
+ public boolean hasLeftPMInput() {
return _hasLeftPMInput;
}
- public MMultMethod getMMultMethod(){
+ public MMultMethod getMMultMethod() {
return _method;
}
-
+
@Override
public boolean isGPUEnabled() {
- if(!DMLScript.USE_ACCELERATOR)
+ if (!DMLScript.USE_ACCELERATOR)
return false;
-
+
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
//matrix mult operation selection part 2 (specific pattern)
MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern
ChainType chain = checkMapMultChain(); //determine mmchain
pattern
-
- _method = optFindMMultMethodCP ( input1.getDim1(),
input1.getDim2(),
- input2.getDim1(), input2.getDim2(), mmtsj, chain,
_hasLeftPMInput );
- switch( _method ){
- case TSMM:
+
+ _method = optFindMMultMethodCP(input1.getDim1(),
input1.getDim2(),
+ input2.getDim1(), input2.getDim2(), mmtsj,
chain, _hasLeftPMInput);
+ switch (_method) {
+ case TSMM:
//return false; // TODO: Disabling any fused
transa optimization in 1.0 release.
return true;
case MAPMM_CHAIN:
@@ -150,50 +151,47 @@ public class AggBinaryOp extends MultiThreadedHop {
throw new RuntimeException("Unsupported
method:" + _method);
}
}
-
+
/**
* NOTE: overestimated mem in case of transpose-identity matmult, but
3/2 at worst
- * and existing mem estimate advantageous in terms of consistency
hops/lops,
- * and some special cases internally materialize the transpose
for better cache locality
+ * and existing mem estimate advantageous in terms of consistency
hops/lops,
+ * and some special cases internally materialize the transpose for
better cache locality
*/
@Override
- public Lop constructLops()
- {
+ public Lop constructLops() {
//return already created lops
- if( getLops() != null )
+ if (getLops() != null)
return getLops();
-
+
//construct matrix mult lops (currently only supported
aggbinary)
- if ( isMatrixMultiply() )
- {
+ if (isMatrixMultiply()) {
Hop input1 = getInput().get(0);
Hop input2 = getInput().get(1);
-
+
//matrix mult operation selection part 1 (CP vs MR vs
Spark)
ExecType et = optFindExecType();
-
+
//matrix mult operation selection part 2 (specific
pattern)
MMTSJType mmtsj = checkTransposeSelf(); //determine
tsmm pattern
ChainType chain = checkMapMultChain(); //determine
mmchain pattern
- if(mmtsj == MMTSJType.LEFT &&
input2.isCompressedOutput()){
+ if (mmtsj == MMTSJType.LEFT &&
input2.isCompressedOutput()) {
// if tsmm and input is compressed. (using
input2, since input1 is transposed and therefore not compressed.)
et = ExecType.CP;
}
- if( et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED )
- {
+ if (et == ExecType.CP || et == ExecType.GPU || et ==
ExecType.FED) {
//matrix mult operation selection part 3 (CP
type)
- _method = optFindMMultMethodCP (
input1.getDim1(), input1.getDim2(),
- input2.getDim1(),
input2.getDim2(), mmtsj, chain, _hasLeftPMInput );
-
+ _method =
optFindMMultMethodCP(input1.getDim1(), input1.getDim2(),
+ input2.getDim1(),
input2.getDim2(), mmtsj, chain, _hasLeftPMInput);
+
//dispatch CP lops construction
- switch( _method ){
- case TSMM:
- constructCPLopsTSMM( mmtsj, et
);
+ switch (_method) {
+ case TSMM:
+ constructCPLopsTSMM(mmtsj, et);
break;
case MAPMM_CHAIN:
- constructCPLopsMMChain( chain );
+ constructCPLopsMMChain(chain);
break;
case PMM:
constructCPLopsPMM();
@@ -204,53 +202,49 @@ public class AggBinaryOp extends MultiThreadedHop {
default:
throw new
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" +
_method + ") while constructing CP lops.");
}
- }
- else if( et == ExecType.SPARK )
- {
+ } else if (et == ExecType.SPARK) {
//matrix mult operation selection part 3 (SPARK
type)
boolean tmmRewrite =
HopRewriteUtils.isTransposeOperation(input1);
- _method = optFindMMultMethodSpark (
+ _method = optFindMMultMethodSpark(
input1.getDim1(),
input1.getDim2(), input1.getBlocksize(), input1.getNnz(),
input2.getDim1(),
input2.getDim2(), input2.getBlocksize(), input2.getNnz(),
- mmtsj, chain, _hasLeftPMInput,
tmmRewrite );
+ mmtsj, chain, _hasLeftPMInput,
tmmRewrite);
//dispatch SPARK lops construction
- switch( _method )
- {
+ switch (_method) {
case TSMM:
- case TSMM2:
- constructSparkLopsTSMM( mmtsj,
_method==MMultMethod.TSMM2 );
+ case TSMM2:
+ constructSparkLopsTSMM(mmtsj,
_method == MMultMethod.TSMM2);
break;
case MAPMM_L:
case MAPMM_R:
- constructSparkLopsMapMM(
_method );
+
constructSparkLopsMapMM(_method);
break;
case MAPMM_CHAIN:
- constructSparkLopsMapMMChain(
chain );
+
constructSparkLopsMapMMChain(chain);
break;
case PMAPMM:
constructSparkLopsPMapMM();
break;
- case CPMM:
+ case CPMM:
constructSparkLopsCPMM();
break;
- case RMM:
+ case RMM:
constructSparkLopsRMM();
break;
case PMM:
- constructSparkLopsPMM();
+ constructSparkLopsPMM();
break;
case ZIPMM:
constructSparkLopsZIPMM();
break;
-
+
default:
- throw new
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" +
_method + ") while constructing SPARK lops.");
+ throw new
HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" +
_method + ") while constructing SPARK lops.");
}
}
- }
- else
+ } else
throw new HopsException(this.printErrorLocation() +
"Invalid operation in AggBinary Hop, aggBin(" + innerOp + "," + outerOp + ")
while constructing lops.");
-
+
//add reblock/checkpoint lops if necessary
constructAndSetLopsDataFlowProperties();
@@ -260,30 +254,28 @@ public class AggBinaryOp extends MultiThreadedHop {
@Override
public String getOpString() {
//ba - binary aggregate, for consistency with runtime
- return "ba(" + outerOp.toString() + innerOp.toString()+")";
+ return "ba(" + outerOp.toString() + innerOp.toString() + ")";
}
-
+
@Override
- public void computeMemEstimate(MemoTable memo)
- {
+ public void computeMemEstimate(MemoTable memo) {
//extension of default compute memory estimate in order to
//account for smaller tsmm memory requirements.
super.computeMemEstimate(memo);
-
+
//tsmm left is guaranteed to require only X but not t(X), while
//tsmm right might have additional requirements to transpose X
if sparse
//NOTE: as a heuristic this correction is only applied if not a
column vector because
//most other vector operations require memory for at least two
vectors (we aim for
//consistency in order to prevent anomalies in parfor opt
leading to small degree of par)
MMTSJType mmtsj = checkTransposeSelf();
- if( mmtsj.isLeft() && getInput().get(1).dimsKnown() &&
getInput().get(1).getDim2()>1 ) {
+ if (mmtsj.isLeft() && getInput().get(1).dimsKnown() &&
getInput().get(1).getDim2() > 1) {
_memEstimate = _memEstimate -
getInput().get(0)._outputMemEstimate;
}
}
@Override
- protected double computeOutputMemEstimate( long dim1, long dim2, long
nnz )
- {
+ protected double computeOutputMemEstimate(long dim1, long dim2, long
nnz) {
//NOTES:
// * The estimate for transpose-self is the same as for normal
matrix multiplications
// because (1) this decouples the decision of TSMM over
default MM and (2) some cases
@@ -314,10 +306,9 @@ public class AggBinaryOp extends MultiThreadedHop {
return ret;
}
-
+
@Override
- protected double computeIntermediateMemEstimate( long dim1, long dim2,
long nnz )
- {
+ protected double computeIntermediateMemEstimate(long dim1, long dim2,
long nnz) {
double ret = 0;
if (isGPUEnabled()) {
@@ -327,277 +318,254 @@ public class AggBinaryOp extends MultiThreadedHop {
double in2Sparsity =
OptimizerUtils.getSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz());
boolean in1Sparse = in1Sparsity <
MatrixBlock.SPARSITY_TURN_POINT;
boolean in2Sparse = in2Sparsity <
MatrixBlock.SPARSITY_TURN_POINT;
- if(in1Sparse && !in2Sparse) {
+ if (in1Sparse && !in2Sparse) {
// Only in sparse-dense cases, we need
additional memory budget for GPU
ret +=
OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
}
}
//account for potential final dense-sparse transformation
(worst-case sparse representation)
- if( dim2 >= 2 && nnz != 0 ) //vectors always dense
+ if (dim2 >= 2 && nnz != 0) //vectors always dense
ret += MatrixBlock.estimateSizeSparseInMemory(dim1,
dim2,
- MatrixBlock.SPARSITY_TURN_POINT -
UtilFunctions.DOUBLE_EPS);
-
+ MatrixBlock.SPARSITY_TURN_POINT -
UtilFunctions.DOUBLE_EPS);
+
return ret;
}
-
+
@Override
- protected DataCharacteristics inferOutputCharacteristics( MemoTable
memo )
- {
+ protected DataCharacteristics inferOutputCharacteristics(MemoTable
memo) {
DataCharacteristics[] dc = memo.getAllInputStats(getInput());
DataCharacteristics ret = null;
- if( dc[0].rowsKnown() && dc[1].colsKnown() ) {
+ if (dc[0].rowsKnown() && dc[1].colsKnown()) {
ret = new MatrixCharacteristics(dc[0].getRows(),
dc[1].getCols());
- double sp1 = (dc[0].getNonZeros()>0) ?
OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(),
dc[0].getNonZeros()) : 1.0;
- double sp2 = (dc[1].getNonZeros()>0) ?
OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(),
dc[1].getNonZeros()) : 1.0;
- ret.setNonZeros((long)(ret.getLength() *
OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(),
ret.getCols(), true)));
+ double sp1 = (dc[0].getNonZeros() > 0) ?
OptimizerUtils.getSparsity(dc[0].getRows(), dc[0].getCols(),
dc[0].getNonZeros()) : 1.0;
+ double sp2 = (dc[1].getNonZeros() > 0) ?
OptimizerUtils.getSparsity(dc[1].getRows(), dc[1].getCols(),
dc[1].getNonZeros()) : 1.0;
+ ret.setNonZeros((long) (ret.getLength() *
OptimizerUtils.getMatMultSparsity(sp1, sp2, ret.getRows(), dc[0].getCols(),
ret.getCols(), true)));
}
return ret;
}
-
+
public boolean isMatrixMultiply() {
- return ( this.innerOp == OpOp2.MULT && this.outerOp ==
AggOp.SUM );
+ return (this.innerOp == OpOp2.MULT && this.outerOp ==
AggOp.SUM);
}
-
+
private boolean isOuterProduct() {
- return ( getInput().get(0).isVector() &&
getInput().get(1).isVector() )
- && ( getInput().get(0).getDim1() == 1 &&
getInput().get(0).getDim1() > 1
- && getInput().get(1).getDim1() > 1 &&
getInput().get(1).getDim2() == 1 );
+ return (getInput().get(0).isVector() &&
getInput().get(1).isVector())
+ && (getInput().get(0).getDim1() == 1 &&
getInput().get(0).getDim1() > 1
+ && getInput().get(1).getDim1() > 1 &&
getInput().get(1).getDim2() == 1);
}
-
+
@Override
public boolean isMultiThreadedOpType() {
return isMatrixMultiply();
}
-
+
@Override
- public boolean allowsAllExecTypes()
- {
+ public boolean allowsAllExecTypes() {
return true;
}
-
+
@Override
- protected ExecType optFindExecType(boolean transitive)
- {
+ protected ExecType optFindExecType(boolean transitive) {
checkAndSetForcedPlatform();
-
- if( _etypeForced != null ) {
+
+ if (_etypeForced != null) {
setExecType(_etypeForced);
- }
- else
- {
- if ( OptimizerUtils.isMemoryBasedOptLevel() ) {
+ } else {
+ if (OptimizerUtils.isMemoryBasedOptLevel()) {
setExecType(findExecTypeByMemEstimate());
}
// choose CP if the dimensions of both inputs are below
Hops.CPThreshold
// OR if it is vector-vector inner product
- else if ( (getInput().get(0).areDimsBelowThreshold() &&
getInput().get(1).areDimsBelowThreshold())
- ||
(getInput().get(0).isVector() && getInput().get(1).isVector() &&
!isOuterProduct()) )
- {
+ else if ((getInput().get(0).areDimsBelowThreshold() &&
getInput().get(1).areDimsBelowThreshold())
+ || (getInput().get(0).isVector() &&
getInput().get(1).isVector() && !isOuterProduct())) {
setExecType(ExecType.CP);
- }
- else
- {
+ } else {
setExecType(ExecType.SPARK);
}
-
+
//check for valid CP mmchain, send invalid memory
requirements to remote
- if( _etype == ExecType.CP
- && checkMapMultChain() != ChainType.NONE
- && OptimizerUtils.getLocalMemBudget() <
-
getInput().get(0).getInput().get(0).getOutputMemEstimate() ) {
+ if (_etype == ExecType.CP
+ && checkMapMultChain() != ChainType.NONE
+ && OptimizerUtils.getLocalMemBudget() <
+
getInput().get(0).getInput().get(0).getOutputMemEstimate()) {
setExecType(ExecType.SPARK);
}
-
+
//check for valid CP dimensions and matrix size
checkAndSetInvalidCPDimsAndSize();
}
-
+
//spark-specific decision refinement (execute binary aggregate
w/ left or right spark input and
//single parent also in spark because it's likely cheap and
reduces data transfer)
MMTSJType mmtsj = checkTransposeSelf(); //determine tsmm pattern
- if( transitive && _etype == ExecType.CP && _etypeForced !=
ExecType.CP
- && ((!mmtsj.isLeft() &&
isApplicableForTransitiveSparkExecType(true))
- || ( !mmtsj.isRight() &&
isApplicableForTransitiveSparkExecType(false))) )
- {
+ if (transitive && _etype == ExecType.CP && _etypeForced !=
ExecType.CP
+ && ((!mmtsj.isLeft() &&
isApplicableForTransitiveSparkExecType(true))
+ || (!mmtsj.isRight() &&
isApplicableForTransitiveSparkExecType(false)))) {
//pull binary aggregate into spark
setExecType(ExecType.SPARK);
}
//mark for recompile (forever)
setRequiresRecompileIfNecessary();
-
+
return _etype;
}
-
- private boolean isApplicableForTransitiveSparkExecType(boolean left)
- {
+
+ private boolean isApplicableForTransitiveSparkExecType(boolean left) {
int index = left ? 0 : 1;
- return !(getInput(index) instanceof DataOp &&
((DataOp)getInput(index)).requiresCheckpoint())
- &&
(!HopRewriteUtils.isTransposeOperation(getInput(index))
+ return !(getInput(index) instanceof DataOp && ((DataOp)
getInput(index)).requiresCheckpoint())
+ &&
(!HopRewriteUtils.isTransposeOperation(getInput(index))
|| (left &&
!isLeftTransposeRewriteApplicable(true)))
- && getInput(index).getParent().size()==1 //bagg is only
parent
- && !getInput(index).areDimsBelowThreshold()
- && (getInput(index).optFindExecType() == ExecType.SPARK
- || (getInput(index) instanceof DataOp &&
((DataOp)getInput(index)).hasOnlyRDD()))
- &&
getInput(index).getOutputMemEstimate()>getOutputMemEstimate();
+ && getInput(index).getParent().size() == 1
//bagg is only parent
+ && !getInput(index).areDimsBelowThreshold()
+ && (getInput(index).optFindExecType() ==
ExecType.SPARK
+ || (getInput(index) instanceof DataOp &&
((DataOp) getInput(index)).hasOnlyRDD()))
+ && getInput(index).getOutputMemEstimate() >
getOutputMemEstimate();
}
-
+
/**
* TSMM: Determine if XtX pattern applies for this aggbinary and if yes
- * which type.
- *
+ * which type.
+ *
* @return MMTSJType
*/
- public MMTSJType checkTransposeSelf()
- {
+ public MMTSJType checkTransposeSelf() {
MMTSJType ret = MMTSJType.NONE;
-
+
Hop in1 = getInput().get(0);
Hop in2 = getInput().get(1);
-
- if( HopRewriteUtils.isTransposeOperation(in1)
- && in1.getInput().get(0) == in2 )
- {
+
+ if (HopRewriteUtils.isTransposeOperation(in1)
+ && in1.getInput().get(0) == in2) {
ret = MMTSJType.LEFT;
}
-
- if( HopRewriteUtils.isTransposeOperation(in2)
- && in2.getInput().get(0) == in1 )
- {
+
+ if (HopRewriteUtils.isTransposeOperation(in2)
+ && in2.getInput().get(0) == in1) {
ret = MMTSJType.RIGHT;
}
-
+
return ret;
}
/**
- * MapMultChain: Determine if XtwXv/XtXv pattern applies for this
aggbinary
- * and if yes which type.
- *
+ * MapMultChain: Determine if XtwXv/XtXv pattern applies for this
aggbinary
+ * and if yes which type.
+ *
* @return ChainType
*/
- public ChainType checkMapMultChain()
- {
+ public ChainType checkMapMultChain() {
ChainType chainType = ChainType.NONE;
-
+
Hop in1 = getInput().get(0);
Hop in2 = getInput().get(1);
-
+
//check for transpose left input (both chain types)
- if( HopRewriteUtils.isTransposeOperation(in1) )
- {
+ if (HopRewriteUtils.isTransposeOperation(in1)) {
Hop X = in1.getInput().get(0);
-
+
//check mapmultchain patterns
//t(X)%*%(w*(X%*%v))
- if( in2 instanceof BinaryOp &&
((BinaryOp)in2).getOp()==OpOp2.MULT )
- {
+ if (in2 instanceof BinaryOp && ((BinaryOp) in2).getOp()
== OpOp2.MULT) {
Hop in3b = in2.getInput().get(1);
- if( in3b instanceof AggBinaryOp )
- {
+ if (in3b instanceof AggBinaryOp) {
Hop in4 = in3b.getInput().get(0);
- if( X == in4 ) //common input
+ if (X == in4) //common input
chainType = ChainType.XtwXv;
}
}
//t(X)%*%((X%*%v)-y)
- else if( in2 instanceof BinaryOp &&
((BinaryOp)in2).getOp()==OpOp2.MINUS )
- {
+ else if (in2 instanceof BinaryOp && ((BinaryOp)
in2).getOp() == OpOp2.MINUS) {
Hop in3a = in2.getInput().get(0);
- Hop in3b = in2.getInput().get(1);
- if( in3a instanceof AggBinaryOp &&
in3b.getDataType()==DataType.MATRIX )
- {
+ Hop in3b = in2.getInput().get(1);
+ if (in3a instanceof AggBinaryOp &&
in3b.getDataType() == DataType.MATRIX) {
Hop in4 = in3a.getInput().get(0);
- if( X == in4 ) //common input
+ if (X == in4) //common input
chainType = ChainType.XtXvy;
}
}
//t(X)%*%(X%*%v)
- else if( in2 instanceof AggBinaryOp )
- {
+ else if (in2 instanceof AggBinaryOp) {
Hop in3 = in2.getInput().get(0);
- if( X == in3 ) //common input
+ if (X == in3) //common input
chainType = ChainType.XtXv;
}
}
-
+
return chainType;
}
-
+
//////////////////////////
// CP Lops generation
- /////////////////////////
-
- private void constructCPLopsTSMM( MMTSJType mmtsj, ExecType et ) {
+
+ /// //////////////////////
+
+ private void constructCPLopsTSMM(MMTSJType mmtsj, ExecType et) {
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
- Lop matmultCP = new
MMTSJ(getInput().get(mmtsj.isLeft()?1:0).constructLops(),
- getDataType(), getValueType(), et, mmtsj, false, k);
+ Lop matmultCP = new MMTSJ(getInput().get(mmtsj.isLeft() ? 1 :
0).constructLops(),
+ getDataType(), getValueType(), et, mmtsj,
false, k);
matmultCP.getOutputParameters().setDimensions(getDim1(),
getDim2(), getBlocksize(), getNnz());
- setLineNumbers( matmultCP );
+ setLineNumbers(matmultCP);
setLops(matmultCP);
}
- private void constructCPLopsMMChain( ChainType chain )
- {
+ private void constructCPLopsMMChain(ChainType chain) {
MapMultChain mapmmchain = null;
- if( chain == ChainType.XtXv ) {
+ if (chain == ChainType.XtXv) {
Hop hX = getInput().get(0).getInput().get(0);
Hop hv = getInput().get(1).getInput().get(1);
- mapmmchain = new MapMultChain( hX.constructLops(),
hv.constructLops(), getDataType(), getValueType(), ExecType.CP);
- }
- else { //ChainType.XtwXv / ChainType.XtwXvy
+ mapmmchain = new MapMultChain(hX.constructLops(),
hv.constructLops(), getDataType(), getValueType(), ExecType.CP);
+ } else { //ChainType.XtwXv / ChainType.XtwXvy
int wix = (chain == ChainType.XtwXv) ? 0 : 1;
int vix = (chain == ChainType.XtwXv) ? 1 : 0;
Hop hX = getInput().get(0).getInput().get(0);
Hop hw = getInput().get(1).getInput().get(wix);
Hop hv =
getInput().get(1).getInput().get(vix).getInput().get(1);
- mapmmchain = new MapMultChain( hX.constructLops(),
hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(),
ExecType.CP);
+ mapmmchain = new MapMultChain(hX.constructLops(),
hv.constructLops(), hw.constructLops(), chain, getDataType(), getValueType(),
ExecType.CP);
}
-
+
//set degree of parallelism
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
- mapmmchain.setNumThreads( k );
-
+ mapmmchain.setNumThreads(k);
+
//set basic lop properties
setOutputDimensions(mapmmchain);
setLineNumbers(mapmmchain);
setLops(mapmmchain);
}
-
+
/**
* NOTE: exists for consistency since removeEmtpy might be scheduled to
MR
- * but matrix mult on small output might be scheduled to CP. Hence, we
+ * but matrix mult on small output might be scheduled to CP. Hence, we
* need to handle directly passed selection vectors in CP as well.
*/
- private void constructCPLopsPMM()
- {
+ private void constructCPLopsPMM() {
Hop pmInput = getInput().get(0);
Hop rightInput = getInput().get(1);
-
+
Hop nrow = HopRewriteUtils.createValueHop(pmInput, true); //NROW
nrow.setBlocksize(0);
nrow.setForcedExecType(ExecType.CP);
HopRewriteUtils.copyLineNumbers(this, nrow);
Lop lnrow = nrow.constructLops();
-
+
PMMJ pmm = new PMMJ(pmInput.constructLops(),
rightInput.constructLops(), lnrow, getDataType(), getValueType(), false, false,
ExecType.CP);
-
+
//set degree of parallelism
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
pmm.setNumThreads(k);
-
+
pmm.getOutputParameters().setDimensions(getDim1(), getDim2(),
getBlocksize(), getNnz());
setLineNumbers(pmm);
-
+
setLops(pmm);
-
+
HopRewriteUtils.removeChildReference(pmInput, nrow);
}
- private void constructCPLopsMM(ExecType et)
- {
+ private void constructCPLopsMM(ExecType et) {
Lop matmultCP = null;
String cla =
ConfigurationManager.getDMLConfig().getTextValue("sysds.compressed.linalg");
if (et == ExecType.GPU) {
@@ -610,72 +578,85 @@ public class AggBinaryOp extends MultiThreadedHop {
boolean leftTrans = false; //
HopRewriteUtils.isTransposeOperation(h1);
boolean rightTrans = false; //
HopRewriteUtils.isTransposeOperation(h2);
Lop left = !leftTrans ? h1.constructLops() :
- h1.getInput().get(0).constructLops();
+ h1.getInput().get(0).constructLops();
Lop right = !rightTrans ? h2.constructLops() :
- h2.getInput().get(0).constructLops();
+ h2.getInput().get(0).constructLops();
matmultCP = new MatMultCP(left, right, getDataType(),
getValueType(), et, leftTrans, rightTrans);
setOutputDimensions(matmultCP);
- }
- else if (cla.equals("true") || cla.equals("cost")){
+ } else if (cla.equals("true") || cla.equals("cost")) {
Hop h1 = getInput().get(0);
Hop h2 = getInput().get(1);
int k =
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
boolean leftTrans =
HopRewriteUtils.isTransposeOperation(h1);
- boolean rightTrans =
HopRewriteUtils.isTransposeOperation(h2);
+ boolean rightTrans =
HopRewriteUtils.isTransposeOperation(h2);
Lop left = !leftTrans ? h1.constructLops() :
- h1.getInput().get(0).constructLops();
+ h1.getInput().get(0).constructLops();
Lop right = !rightTrans ? h2.constructLops() :
- h2.getInput().get(0).constructLops();
+ h2.getInput().get(0).constructLops();
matmultCP = new MatMultCP(left, right, getDataType(),
getValueType(), et, k, leftTrans, rightTrans);
- }
- else {
- if( isLeftTransposeRewriteApplicable(true) ) {
+ } else {
+ if (isLeftTransposeRewriteApplicable(true)) {
matmultCP =
constructCPLopsMMWithLeftTransposeRewrite(et);
- }
- else {
+ } else {
int k =
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
matmultCP = new
MatMultCP(getInput().get(0).constructLops(),
- getInput().get(1).constructLops(),
getDataType(), getValueType(), et, k);
+
getInput().get(1).constructLops(), getDataType(), getValueType(), et, k);
updateLopFedOut(matmultCP);
}
setOutputDimensions(matmultCP);
}
-
+
setLineNumbers(matmultCP);
setLops(matmultCP);
}
- private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et)
- {
- Hop X = getInput().get(0).getInput().get(0); //guaranteed to
exists
+ private Lop constructCPLopsMMWithLeftTransposeRewrite(ExecType et) {
+ Hop X = getInput().get(0).getInput().get(0); // guaranteed to
exist
Hop Y = getInput().get(1);
int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
-
+
+ //Check if X is already a transpose operation
+ boolean isXTransposed = X instanceof ReorgOp &&
((ReorgOp)X).getOp() == ReOrgOp.TRANS;
+ Hop actualX = isXTransposed ? X.getInput().get(0) : X;
+
+ //Check if Y is a transpose operation
+ boolean isYTransposed = Y instanceof ReorgOp &&
((ReorgOp)Y).getOp() == ReOrgOp.TRANS;
+ Hop actualY = isYTransposed ? Y.getInput().get(0) : Y;
+
+ //Handle Y or actualY for transpose
+ Lop yLop = isYTransposed ? actualY.constructLops() :
Y.constructLops();
+ ExecType inputReorgExecType = (Y.hasFederatedOutput()) ?
ExecType.FED : ExecType.CP;
+
//right vector transpose
- Lop lY = Y.constructLops();
- ExecType inputReorgExecType = ( Y.hasFederatedOutput() ) ?
ExecType.FED : ExecType.CP;
- Lop tY = (lY instanceof Transform &&
((Transform)lY).getOp()==ReOrgOp.TRANS ) ?
- lY.getInputs().get(0) : //if input is already a
transpose, avoid redundant transpose ops
- new Transform(lY, ReOrgOp.TRANS, getDataType(),
getValueType(), inputReorgExecType, k);
- tY.getOutputParameters().setDimensions(Y.getDim2(),
Y.getDim1(), getBlocksize(), Y.getNnz());
+ Lop tY = (yLop instanceof Transform &&
((Transform)yLop).getOp() == ReOrgOp.TRANS) ?
+ yLop.getInputs().get(0) : //if input is already
a transpose, avoid redundant transpose ops
+ new Transform(yLop, ReOrgOp.TRANS,
getDataType(), getValueType(), inputReorgExecType, k);
+
+ //Set dimensions for tY
+ long tYRows = isYTransposed ? actualY.getDim1() : Y.getDim2();
+ long tYCols = isYTransposed ? actualY.getDim2() : Y.getDim1();
+ tY.getOutputParameters().setDimensions(tYRows, tYCols,
getBlocksize(), Y.getNnz());
setLineNumbers(tY);
if (Y.hasFederatedOutput())
updateLopFedOut(tY);
-
+
+ //Construct X lops for matrix multiplication
+ Lop xLop = isXTransposed ? actualX.constructLops() :
X.constructLops();
+
//matrix mult
- Lop mult = new MatMultCP(tY, X.constructLops(), getDataType(),
getValueType(), et, k); //CP or FED
- mult.getOutputParameters().setDimensions(Y.getDim2(),
X.getDim2(), getBlocksize(), getNnz());
+ Lop mult = new MatMultCP(tY, xLop, getDataType(),
getValueType(), et, k);
+ mult.getOutputParameters().setDimensions(tYRows, isXTransposed
? actualX.getDim1() : X.getDim2(), getBlocksize(), getNnz());
mult.setFederatedOutput(_federatedOutput);
setLineNumbers(mult);
//result transpose (dimensions set outside)
- ExecType outTransposeExecType = ( _federatedOutput ==
FederatedOutput.FOUT ) ?
- ExecType.FED : ExecType.CP;
+ ExecType outTransposeExecType = (_federatedOutput ==
FederatedOutput.FOUT) ?
+ ExecType.FED : ExecType.CP;
Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(),
getValueType(), outTransposeExecType, k);
return out;
}
-
+
//////////////////////////
// Spark Lops generation
/////////////////////////
@@ -718,25 +699,25 @@ public class AggBinaryOp extends MultiThreadedHop {
{
Hop X = getInput().get(0).getInput().get(0); //guaranteed to
exists
Hop Y = getInput().get(1);
-
+
//right vector transpose
Lop tY = new Transform(Y.constructLops(), ReOrgOp.TRANS,
getDataType(), getValueType(), ExecType.CP);
tY.getOutputParameters().setDimensions(Y.getDim2(),
Y.getDim1(), getBlocksize(), Y.getNnz());
setLineNumbers(tY);
-
+
//matrix mult spark
- boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R);
+ boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R);
SparkAggType aggtype = getSparkMMAggregationType(needAgg);
- _outputEmptyBlocks =
!OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
-
- Lop mult = new MapMult( tY, X.constructLops(), getDataType(),
getValueType(),
- false, false, _outputEmptyBlocks,
aggtype);
+ _outputEmptyBlocks =
!OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
+
+ Lop mult = new MapMult( tY, X.constructLops(), getDataType(),
getValueType(),
+ false, false, _outputEmptyBlocks,
aggtype);
mult.getOutputParameters().setDimensions(Y.getDim2(),
X.getDim2(), getBlocksize(), getNnz());
setLineNumbers(mult);
-
+
//result transpose (dimensions set outside)
Lop out = new Transform(mult, ReOrgOp.TRANS, getDataType(),
getValueType(), ExecType.CP);
-
+
return out;
}
@@ -892,13 +873,13 @@ public class AggBinaryOp extends MultiThreadedHop {
setLineNumbers( zipmm );
setLops(zipmm);
}
-
+
/**
* Determines if the rewrite t(X)%*%Y -> t(t(Y)%*%X) is applicable
* and cost effective. Whenever X is a wide matrix and Y is a vector
* this has huge impact, because the transpose of X would dominate
* the entire operation costs.
- *
+ *
* @param CP true if CP
* @return true if left transpose rewrite applicable
*/
@@ -910,38 +891,38 @@ public class AggBinaryOp extends MultiThreadedHop {
{
return false;
}
-
+
boolean ret = false;
Hop h1 = getInput().get(0);
Hop h2 = getInput().get(1);
-
+
//check for known dimensions and cost for t(X) vs t(v) + t(tvX)
//(for both CP/MR, we explicitly check that new transposes fit
in memory,
//even a ba in CP does not imply that both transposes can be
executed in CP)
- if( CP ) //in-memory ba
+ if( CP ) //in-memory ba
{
if( HopRewriteUtils.isTransposeOperation(h1) )
{
long m = h1.getDim1();
long cd = h1.getDim2();
long n = h2.getDim2();
-
+
//check for known dimensions (necessary
condition for subsequent checks)
- ret = (m>0 && cd>0 && n>0);
-
- //check operation memory with changed transpose
(this is important if we have
+ ret = (m>0 && cd>0 && n>0);
+
+ //check operation memory with changed transpose
(this is important if we have
//e.g., t(X) %*% v, where X is sparse and tX
fits in memory but X does not
double memX =
h1.getInput().get(0).getOutputMemEstimate();
double memtv =
OptimizerUtils.estimateSizeExactSparsity(n, cd, 1.0);
double memtXv =
OptimizerUtils.estimateSizeExactSparsity(n, m, 1.0);
double newMemEstimate = memtv + memX + memtXv;
ret &= ( newMemEstimate <
OptimizerUtils.getLocalMemBudget() );
-
+
//check for cost benefit of t(X) vs t(v) +
t(tvX) and memory of additional transpose ops
ret &= ( m*cd > (cd*n + m*n) &&
- 2 *
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) <
OptimizerUtils.getLocalMemBudget() &&
- 2 *
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) <
OptimizerUtils.getLocalMemBudget() );
-
+ 2 *
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) <
OptimizerUtils.getLocalMemBudget() &&
+ 2 *
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) <
OptimizerUtils.getLocalMemBudget() );
+
//update operation memory estimate (e.g., for
parfor optimizer)
if( ret )
_memEstimate = newMemEstimate;
@@ -955,14 +936,14 @@ public class AggBinaryOp extends MultiThreadedHop {
long n = h2.getDim2();
//note: output size constraint for mapmult
already checked by optfindmmultmethod
if( m>0 && cd>0 && n>0 && (m*cd > (cd*n + m*n))
&&
- 2 *
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) <
OptimizerUtils.getLocalMemBudget() &&
- 2 *
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) <
OptimizerUtils.getLocalMemBudget() )
+ 2 *
OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) <
OptimizerUtils.getLocalMemBudget() &&
+ 2 *
OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0) <
OptimizerUtils.getLocalMemBudget() )
{
ret = true;
}
}
}
-
+
return ret;
}
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
index 2841256607..05e8d171b7 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTablePrinter.java
@@ -1,3 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
package org.apache.sysds.hops.fedplanner;
import org.apache.commons.lang3.tuple.Pair;
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java
new file mode 100644
index 0000000000..ac28b12caf
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteTransposeTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+import java.util.HashMap;
+
+public class RewriteTransposeTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "RewriteTransposeCase1"; //
t(X)%*%Y
+ private final static String TEST_NAME2 = "RewriteTransposeCase2"; //
X=t(A); t(X)%*%Y
+ private final static String TEST_NAME3 = "RewriteTransposeCase3"; //
Y=t(A); t(X)%*%Y
+
+ private final static String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteTransposeTest.class.getSimpleName() + "/";
+
+ private static final double eps = 1e-9;
+
+ @Override
+ public void setUp() {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION=false;
+
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}));
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[]{"R"}));
+ }
+
+ @Test
+ public void testTransposeRewrite1CP() {
+ runTransposeRewriteTest(TEST_NAME1, false);
+ }
+
+ @Test
+ public void testTransposeRewrite2CP() {
+ runTransposeRewriteTest(TEST_NAME2, true);
+ }
+
+ @Test
+ public void testTransposeRewrite3CP() {
+ runTransposeRewriteTest(TEST_NAME3, false);
+ }
+
+ private void runTransposeRewriteTest(String testname, boolean
expectedMerge) {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+
+ programArgs = new String[]{"-explain", "-stats", "-args",
output("R")};
+
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(expectedDir());
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ HashMap<MatrixValue.CellIndex, Double> dmlOutput =
readDMLMatrixFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> rOutput =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlOutput, rOutput, eps, "Stat-DML",
"Stat-R");
+
+ Assert.assertTrue(Statistics.getCPHeavyHitterCount("r'") <= 2);
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R
b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R
new file mode 100644
index 0000000000..5b0e19dca2
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("matrixStats")
+
+X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE)
+Y <- matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE)
+
+R <- t(t(Y)%*%X)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml
b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml
new file mode 100644
index 0000000000..83cfb65dc6
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase1.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1, 20), rows=4, cols=5);
+Y = matrix(seq(1, 12), rows=4, cols=3);
+
+R = t(X)%*%Y;
+
+write(R, $1);
\ No newline at end of file
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R
b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R
new file mode 100644
index 0000000000..fea8c26669
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("matrixStats")
+A = matrix(seq(1, 20), nrow=5, ncol=4, byrow=TRUE)
+Y = matrix(seq(1, 12), nrow=4, ncol=3, byrow=TRUE)
+X = t(A)
+
+R <- t(t(Y)%*%X)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml
b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml
new file mode 100644
index 0000000000..cb9332423b
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = matrix(seq(1, 20), rows=5, cols=4);
+Y = matrix(seq(1, 12), rows=4, cols=3);
+X = t(A);
+
+R = t(X) %*% Y;
+
+write(R, $1);
\ No newline at end of file
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R
b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R
new file mode 100644
index 0000000000..2bdd22f674
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("matrixStats")
+
+X <- matrix(seq(1, 20), nrow=4, ncol=5, byrow=TRUE)
+A <- matrix(seq(1, 12), nrow=3, ncol=4, byrow=TRUE)
+Y <- t(A)
+
+R <- t(t(Y)%*%X)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[1], "R", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml
b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml
new file mode 100644
index 0000000000..2e26920aed
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteTransposeCase3.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1, 20), rows=4, cols=5);
+A = matrix(seq(1, 12), rows=3, cols=4);
+Y = t(A);
+
+R = t(X) %*% Y;
+
+write(R, $1);
\ No newline at end of file