This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e11eb24412 [SYSTEMDS-3333] Extended mmchain-optimization w/ additional
rewrites
e11eb24412 is described below
commit e11eb24412e24b3e77043b958ead35afd8be3f74
Author: gogokotsev00 <[email protected]>
AuthorDate: Sat Mar 16 19:45:42 2024 +0100
[SYSTEMDS-3333] Extended mmchain-optimization w/ additional rewrites
Closes #1948.
Co-Authored-By: krutarth <[email protected]>
---
...ewriteMatrixMultChainOptimizationTranspose.java | 511 +++++++++++++++++++++
.../RewriteMatrixMultChainOptTransposeTest.java | 103 +++++
.../rewrite/RewriteMMChainTestTranspose1.R | 35 ++
.../rewrite/RewriteMMChainTestTranspose1.dml | 31 ++
.../rewrite/RewriteMMChainTestTranspose2.R | 34 ++
.../rewrite/RewriteMMChainTestTranspose2.dml | 30 ++
.../rewrite/RewriteMMChainTestTranspose3.R | 34 ++
.../rewrite/RewriteMMChainTestTranspose3.dml | 31 ++
.../rewrite/RewriteMMChainTestTranspose4.R | 34 ++
.../rewrite/RewriteMMChainTestTranspose4.dml | 31 ++
10 files changed, 874 insertions(+)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
new file mode 100644
index 0000000000..56702542a8
--- /dev/null
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
@@ -0,0 +1,511 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+
+import org.apache.commons.lang3.mutable.MutableInt;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.HopsException;
+import org.apache.sysds.runtime.util.CollectionUtils;
+import org.apache.sysds.utils.Explain;
+
+/**
+ * <strong>Rule</strong>: Determine the optimal order of execution for a chain
of
+ * matrix multiplications <br>
+ * <strong>Solution</strong>: Classic Dynamic Programming <br>
+ * <strong>Approach</strong>: Currently, the approach based only on matrix
dimensions <br>
+ * <strong>Goal</strong>: To reduce the number of computations in the run-time
+ * (map-reduce) layer
+ */
+public class RewriteMatrixMultChainOptimizationTranspose extends HopRewriteRule
+{
+ private static final Boolean PUSH_DOWN_TRANSPOSE = true;
+
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state)
+ {
+ if( roots == null )
+ return null;
+
+ // Find the optimal order for the chain whose result is the
current HOP
+ for( Hop h : roots )
+ rule_OptimizeMMChains(h, state);
+
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state)
+ {
+ if( root == null )
+ return null;
+
+ // Find the optimal order for the chain whose result is the
current HOP
+ rule_OptimizeMMChains(root, state);
+
+ return root;
+ }
+
+ /**
+ * rule_OptimizeMMChains(): This method goes through all Hops in the DAG
+ * to find chains that need to be optimized.
+ *
+ * @param hop high-level operator
+ */
+ private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state)
+ {
+ if( !hop.isVisited() ) {
+
+ if (HopRewriteUtils.isMatrixMultiply(hop) &&
!((AggBinaryOp) hop).hasLeftPMInput()) {
+ // Try to find and optimize the chain in which
current Hop is the
+ // last operator
+ prepAndOptimizeMMChain(hop, state);
+ }
+
+ for (Hop hi : hop.getInput())
+ rule_OptimizeMMChains(hi, state);
+
+ hop.setVisited();
+ }
+ }
+
+ /**
+ * optimizeMMChain(): It optimizes the matrix multiplication chain in
which
+ * the last Hop is "this".
+ * <ul><li>Step 1: Identify the chain (mmChain).</li>
+ * <li>Step 2: Clear all links among the Hops that are involved in
mmChain.</li>
+ * <li>Step 3: Find the optimal ordering via dynamic programming.</li>
+ * <li>Step 4: Relink the hops in mmChain.</li></ul>
+ * @param hop high-level operator
+ */
+ private void prepAndOptimizeMMChain( Hop hop, ProgramRewriteStatus
state )
+ {
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("MM Chain Optimization for HOP: (" +
hop.getClass().getSimpleName()
+ + ", " + hop.getHopID() + ", " + hop.getName()
+ ")");
+ }
+
+ // Step 1: Identify the chain (mmChain) & clear all links among
the Hops
+ // that are involved in mmChain.
+
+ // Initialize mmChain with current hop's inputs
+ ArrayList<Hop> mmOperators = new ArrayList<>();
+ mmOperators.add(hop);
+ ArrayList<Hop> mmChain = new ArrayList<>(hop.getInput());
+
+ if (PUSH_DOWN_TRANSPOSE) {
+ checkChainForTransposeAndRewrite(mmChain, hop);
+ }
+
+ int mmChainIndex = 0;
+
+ // Expand each Hop in mmChain to find the entire matrix
multiplication chain
+ while( mmChainIndex < mmChain.size() )
+ {
+ boolean expandable = false;
+
+ Hop h = mmChain.get(mmChainIndex);
+ /*
+ * Check if mmChain[i] is expandable:
+ * 1) It must be MATMULT
+ * 2) It must not have been visited already
+ * (one MATMULT should get expanded only in one
chain)
+ * 3) Its output should not be used in multiple places
+ * (either within chain or outside the chain)
+ */
+
+ if ( HopRewriteUtils.isMatrixMultiply(h) &&
!h.isVisited() )
+ {
+ // check if the output of "h" is used at
multiple places. If yes, it can
+ // not be expanded.
+ expandable = !(h.getParent().size() > 1 ||
inputCount(h.getParent().get(0), h) > 1);
+ if( !expandable )
+ break;
+ }
+
+ h.setVisited();
+
+ if( !expandable ) {
+ mmChainIndex++;
+ }
+ else {
+ ArrayList<Hop> tempList =
mmChain.get(mmChainIndex).getInput();
+ if( tempList.size() != 2 ) {
+ throw new
HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain():
AggBinary must have exactly two inputs.");
+ }
+
+ // add current operator to mmOperators, and its
input nodes to mmChain
+ mmOperators.add(mmChain.get(mmChainIndex));
+ mmChain.set(mmChainIndex, tempList.get(0));
+ mmChain.add(mmChainIndex + 1, tempList.get(1));
+ }
+ }
+
+ // print the MMChain
+ if( LOG.isTraceEnabled() ) {
+ LOG.trace("Identified MM Chain: ");
+ for( Hop h : mmChain ) {
+ logTraceHop(h, 1);
+ }
+ }
+
+ //core mmchain optimization (potentially overridden)
+ if( mmChain.size() != 2 )
+ optimizeMMChain(hop, mmChain, mmOperators, state);
+ }
+
+ protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain,
ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
+ // Step 2: construct dims array
+ double[] dimsArray = new double[mmChain.size() + 1];
+ boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
+
+ if( dimsKnown ) {
+ // Step 3: Clear the links among Hops within the
identified chain
+ clearLinksWithinChain ( hop, mmOperators );
+
+ // Step 4: Find the optimal ordering via dynamic
programming.
+
+ // Invoke Dynamic Programming
+ int size = mmChain.size();
+ int[][] split = mmChainDP(dimsArray, mmChain.size());
+
+ // Step 5: Relink the hops using the optimal ordering
(split[][]) found from DP.
+ LOG.trace("Optimal MM Chain: ");
+ mmChainRelinkHops(mmOperators.get(0), 0, size - 1,
mmChain, mmOperators, new MutableInt(1), split, 1);
+ }
+ }
+
+ /**
+ * mmChainDP(): Core method to perform dynamic programming on a given
array
+ * of matrix dimensions. <br>
+ *
+ * Thomas H. Cormen, Charles E. Leiserson, Ronald L. Rivest, Clifford
Stein
+ * Introduction to Algorithms, Third Edition, MIT Press, page 395.
+ */
+ private static int[][] mmChainDP(double[] dimArray, int size)
+ {
+ double[][] dpMatrix = new double[size][size]; //min cost table
+ int[][] split = new int[size][size]; //min cost index table
+
+ //init minimum costs for chains of length 1
+ for( int i = 0; i < size; i++ ) {
+ Arrays.fill(dpMatrix[i], 0);
+ Arrays.fill(split[i], -1);
+ }
+
+ //compute cost-optimal chains for increasing chain sizes
+ for( int l = 2; l <= size; l++ ) { // chain length
+ for( int i = 0; i < size - l + 1; i++ ) {
+ int j = i + l - 1;
+ // find cost of (i,j)
+ dpMatrix[i][j] = Double.MAX_VALUE;
+ for( int k = i; k <= j - 1; k++ )
+ {
+ //recursive cost computation
+ double cost = dpMatrix[i][k] +
dpMatrix[k + 1][j]
+ + (dimArray[i] * dimArray[k +
1] * dimArray[j + 1]);
+
+ //prune suboptimal
+ if( cost < dpMatrix[i][j] ) {
+ dpMatrix[i][j] = cost;
+ split[i][j] = k;
+ }
+ }
+
+ if( LOG.isTraceEnabled() ){
+ LOG.trace("mmchainopt
[i="+(i+1)+",j="+(j+1)+"]: costs = "+dpMatrix[i][j]+", split =
"+(split[i][j]+1));
+ }
+ }
+ }
+
+ return split;
+ }
+
+ /**
+ * mmChainRelinkHops(): This method gets invoked after finding the
optimal
+ * order (split[][]) from dynamic programming. It relinks the Hops that
are
+ * part of the mmChain.
+ * @param mmChain basic operands in the entire matrix multiplication
chain
+ * @param mmOperators Hops that store the intermediate results in the
chain.
+ * <strong>For example:</strong> A = B %*% (C %*%
D) there will be three
+ * Hops in mmChain (B,C,D), and two Hops in
mmOperators
+ * (one for each * %*%).
+ * @param h high level operator
+ * @param i array index i
+ * @param j array index j
+ * @param opIndex operator index
+ * @param split optimal order
+ * @param level log level
+ */
+ protected final void mmChainRelinkHops(Hop h, int i, int j,
ArrayList<Hop> mmChain,
+ ArrayList<Hop> mmOperators, MutableInt opIndex, int[][] split,
int level)
+ {
+ //NOTE: the opIndex is a MutableInt in order to get the correct
positions
+ //in ragged chains like ((((a, b), c), (D, E), f), e) that
might be given
+ //like that by the original scripts variable assignments
+
+ //single matrix - end of recursion
+ if( i == j ) {
+ logTraceHop(h, level);
+ return;
+ }
+
+ if( LOG.isTraceEnabled() ){
+ String offset = Explain.getIdentation(level);
+ LOG.trace(offset + "(");
+ }
+
+ // Set Input1 for current Hop h
+ if( i == split[i][j] ) {
+ h.getInput().add(mmChain.get(i));
+ mmChain.get(i).getParent().add(h);
+ }
+ else {
+ int ix = opIndex.getValue();
+ opIndex.increment();
+ h.getInput().add(mmOperators.get(ix));
+ mmOperators.get(ix).getParent().add(h);
+ }
+
+ // Set Input2 for current Hop h
+ if( split[i][j] + 1 == j ) {
+ h.getInput().add(mmChain.get(j));
+ mmChain.get(j).getParent().add(h);
+ }
+ else {
+ int ix = opIndex.getValue();
+ opIndex.increment();
+ h.getInput().add(mmOperators.get(ix));
+ mmOperators.get(ix).getParent().add(h);
+ }
+
+ // Find children for both the inputs
+ mmChainRelinkHops(h.getInput(0), i, split[i][j], mmChain,
mmOperators, opIndex, split, level+1);
+ mmChainRelinkHops(h.getInput(1), split[i][j] + 1, j, mmChain,
mmOperators, opIndex, split, level+1);
+
+ // Propagate properties of input hops to current hop h
+ h.refreshSizeInformation();
+
+ if( LOG.isTraceEnabled() ){
+ String offset = Explain.getIdentation(level);
+ LOG.trace(offset + ")");
+ }
+ }
+
+ protected static void clearLinksWithinChain( Hop hop, ArrayList<Hop>
operators )
+ {
+ for( int i=0; i < operators.size(); i++ ) {
+ Hop op = operators.get(i);
+ if( op.getInput().size() != 2 || (i > 0 &&
op.getParent().size() > 1 ) ) {
+ throw new
HopsException(hop.printErrorLocation() +
+ "Unexpected error while applying
optimization on matrix-mult chain. \n");
+ }
+ Hop input1 = op.getInput(0);
+ Hop input2 = op.getInput(1);
+
+ op.getInput().clear();
+ input1.getParent().remove(op);
+ input2.getParent().remove(op);
+ }
+ }
+
+ /**
+ * Obtains all dimension information of the chain and constructs the
dimArray.
+ * If all dimensions are known it returns true; otherwise the mmchain
rewrite
+ * should be ended without modifications.
+ *
+ * @param hop high-level operator
+ * @param chain list of high-level operators
+ * @param dimsArray dimension array
+ * @return true if all dimensions known
+ */
+ protected static boolean getDimsArray( Hop hop, ArrayList<Hop> chain,
double[] dimsArray )
+ {
+ boolean dimsKnown = true;
+
+ // Build the array containing dimensions from all matrices in
the chain
+ // check the dimensions in the matrix chain to insure all
dimensions are known
+ for (Hop value : chain)
+ if (value.getDim1() <= 0 || value.getDim2() <= 0)
+ dimsKnown = false;
+
+ if( dimsKnown ) { //populate dims array if all dims known
+ for( int i = 0; i < chain.size(); i++ ) {
+ if (i == 0) {
+ dimsArray[i] = chain.get(i).getDim1();
+ if (dimsArray[i] <= 0) {
+ throw new
HopsException(hop.printErrorLocation() +
+
"Hops::optimizeMMChain() : Invalid Matrix Dimension: "+ dimsArray[i]);
+ }
+ }
+ else if (chain.get(i - 1).getDim2() !=
chain.get(i).getDim1()) {
+ throw new
HopsException(hop.printErrorLocation() +
+ "Hops::optimizeMMChain() :
Matrix Dimension Mismatch: " +
+ chain.get(i - 1).getDim2()+" !=
"+chain.get(i).getDim1());
+ }
+
+ dimsArray[i + 1] = chain.get(i).getDim2();
+ if( dimsArray[i + 1] <= 0 ) {
+ throw new
HopsException(hop.printErrorLocation() +
+
"Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]);
+ }
+ }
+ }
+
+ return dimsKnown;
+ }
+
+ private static int inputCount( Hop p, Hop h ) {
+ return CollectionUtils.cardinality(h, p.getInput());
+ }
+
+ private static void logTraceHop( Hop hop, int level ) {
+ if( LOG.isTraceEnabled() ) {
+ String offset = Explain.getIdentation(level);
+ LOG.trace(offset+ "Hop " + hop.getName() + "(" +
hop.getClass().getSimpleName()
+ + ", " + hop.getHopID() + ")" + " " +
hop.getDim1() + "x" + hop.getDim2());
+ }
+ }
+
+ /**
+ * Transforms a transpose operator into matrixmult and adjusts
+ * all the respective attributes of the other operators, also creates a
second transpose operator.
+ * Thus, we can achieve larger optimization space for the transformed
chain.<br>
+ * <strong>Idea:</strong> t(A %*% B) -> t(B) %*% t(A)
+ *
+ * @param transposeHop the transpose operator, which contains all
useful data for the transformation
+ * @return the new matrixmult operator
+ */
+ private Hop rewriteChainOnTransposeOperator(Hop transposeHop) {
+ Hop matrixMultHop = transposeHop.getInput(0);
+ Hop firstMatrix = matrixMultHop.getInput(0);
+ Hop secondMatrix = matrixMultHop.getInput(1);
+
+ // Clone transpose operator for the overwritten chain
+ Hop secondTransposeHop = null;
+ try {
+ secondTransposeHop = (Hop) transposeHop.clone();
+ } catch (CloneNotSupportedException ex) {
+ System.err.println("Error on cloning transpose
operator: " + ex.getMessage());
+ }
+ assert secondTransposeHop!= null;
+
+ // Set parent to the other operators accordingly
+ updateParentOfHop(firstMatrix, transposeHop);
+ updateParentOfHop(secondMatrix, secondTransposeHop);
+ updateParentOfHop(transposeHop, matrixMultHop);
+ updateParentOfHop(secondTransposeHop, matrixMultHop);
+
+ // Set input to all operators and update attributes accordingly
+ ArrayList<Hop> inputList = new ArrayList<>();
+ inputList.add(firstMatrix);
+ updateAttributesOfHop(transposeHop, inputList,
firstMatrix.getName());
+
+ inputList.set(0, secondMatrix);
+ updateAttributesOfHop(secondTransposeHop, inputList,
secondMatrix.getName());
+
+ inputList.set(0, secondTransposeHop);
+ inputList.add(transposeHop);
+ updateAttributesOfHop(matrixMultHop, inputList,
firstMatrix.getName());
+
+ return matrixMultHop;
+ }
+
+ private void checkChainForTransposeAndRewrite(ArrayList<Hop> mmChain,
Hop parentOfChain) {
+ int mmChainIndex = 0;
+ while (mmChainIndex < mmChain.size())
+ {
+ Hop currentChainHop = mmChain.get(mmChainIndex);
+
+ // Check if current hop is a transpose operator,
+ // if it has been visited,
+ // and if it has only one input, which is a matrixmult
operator
+ boolean isTransposeOperator =
HopRewriteUtils.isReorg(currentChainHop, Types.ReOrgOp.TRANS);
+
+ if (isTransposeOperator && !currentChainHop.isVisited()
&& currentChainHop.getInput().size() == 1)
+ {
+ Hop transposeOperatorChild =
currentChainHop.getInput(0);
+ if
(HopRewriteUtils.isMatrixMultiply(transposeOperatorChild)
+ &&
hasOnlyTwoReadsAsInput(transposeOperatorChild) &&
transposeOperatorChild.getParent().size() == 1)
+ {
+ int indexInParentInput =
parentOfChain.getInput().indexOf(currentChainHop);
+
+ // Set transpose operator's parent as
new one for matrix multiplication operator
+ Hop matrixMultHop =
rewriteChainOnTransposeOperator(currentChainHop);
+ updateParentOfHop(matrixMultHop,
parentOfChain);
+
+ // Update input of transpose operator's
parent
+
parentOfChain.getInput().set(indexInParentInput, matrixMultHop);
+
+ // Replace transpose operator with the
matrixmult one in the mmchain
+ mmChain.set(mmChainIndex,
matrixMultHop);
+ }
+ }
+ mmChainIndex++;
+ }
+ }
+
+ private void updateParentOfHop(Hop hopToUpdate, Hop parentToSet) {
+ hopToUpdate.getParent().clear();
+ hopToUpdate.getParent().add(parentToSet);
+ }
+
+ /**
+ * Updates input list, dimensions of matrix and text of a given Hop.
+ *
+ * @param hopToUpdate the hop that will be updated
+ * @param inputList new input list that will be set
+ * @param text new text of the operator
+ */
+ private void updateAttributesOfHop(Hop hopToUpdate, ArrayList<Hop>
inputList, String text) {
+ hopToUpdate.getInput().clear();
+
+ for (Hop input : inputList) {
+ hopToUpdate.getInput().add(input);
+ }
+
+ if (HopRewriteUtils.isMatrixMultiply(hopToUpdate)) {
+ // Here we add dimensions of a matrixmult operator
+ hopToUpdate.setDim1(inputList.get(0).getDim1());
+ hopToUpdate.setDim2(inputList.get(1).getDim2());
+ } else {
+ // Here we add dimensions of a transpose operator
+ hopToUpdate.setDim1(inputList.get(0).getDim2());
+ hopToUpdate.setDim2(inputList.get(0).getDim1());
+ }
+
+ //hopToUpdate.setText(String.format("t(%s)", text));
+ }
+
+ private boolean hasOnlyTwoReadsAsInput(Hop transposeOperatorChild) {
+ if (transposeOperatorChild.getInput().size() == 2) {
+ for(Hop hop: transposeOperatorChild.getInput()) {
+ if (!HopRewriteUtils.isData(hop,
Types.OpOpData.TRANSIENTREAD, Types.OpOpData.PERSISTENTREAD))
+ return false;
+ }
+ return true;
+ }
+ return false;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptTransposeTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptTransposeTest.java
new file mode 100644
index 0000000000..d55ad0ddf7
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteMatrixMultChainOptTransposeTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class RewriteMatrixMultChainOptTransposeTest extends AutomatedTestBase
+{
+ //TODO enable experimental mmchain-opt rewrite and debug in detail
+
+ private static final String TEST_NAME = "RewriteMMChainTestTranspose";
+ protected static final int TEST_VARIANTS = 4;
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteMatrixMultChainOptTransposeTest.class.getSimpleName() + "/";
+
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ for( int i=1; i<=TEST_VARIANTS; i++ )
+ addTestConfiguration(TEST_NAME+i, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] {"R"}));
+ }
+
+ @Test
+ public void testSameInputUsedMultipleTimes() {
+ testMMChainWithTransposeOperator(TEST_NAME + "1", 4);
+ }
+
+ @Test
+ public void testTwoMultiplicationsInTransposeOperator() {
+ testMMChainWithTransposeOperator(TEST_NAME + "2", 2);
+ }
+
+ @Test
+ public void testTransposeInTranspose() {
+ testMMChainWithTransposeOperator(TEST_NAME + "3", 4);
+ }
+
+ @Test
+ public void testMMChainFour() {
+ testMMChainWithTransposeOperator(TEST_NAME + "4", 2);
+ }
+
+ private void testMMChainWithTransposeOperator(String testname, int
numberOfTransposeOperators)
+ {
+ ExecMode etOld = setExecMode(ExecMode.SINGLE_NODE);
+
+ try
+ {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[]{ "-explain", "hops",
"-stats",
+ "-args", output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+
+ Assert.assertEquals(numberOfTransposeOperators,
Statistics.getCPHeavyHitterCount(Types.ReOrgOp.TRANS.toString()));
+ }
+ finally {
+ resetExecMode(etOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose1.R
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose1.R
new file mode 100644
index 0000000000..e158bb5a5a
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose1.R
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+a = matrix(1.0, 5, 3)
+b = matrix(2.0, 3, 7)
+x = matrix(3.0, 9, 4)
+y = matrix(2.0, 4, 5)
+
+m1 = t(a %*% b) %*% t(y)
+R = t(b) %*% b %*% m1 %*% t(x)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
diff --git
a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose1.dml
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose1.dml
new file mode 100644
index 0000000000..c6ff97ac8d
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose1.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = matrix(1.0, 5, 3)
+b = matrix(2.0, 3, 7)
+x = matrix(3.0, 9, 4)
+y = matrix(2.0, 4, 5)
+
+while(FALSE){}
+m = t(a %*% b) %*% t(y)
+R = t(b) %*% b %*% m %*% t(x)
+
+write(R, $1);
diff --git a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose2.R
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose2.R
new file mode 100644
index 0000000000..b6b77d633d
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose2.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+a = matrix(1.0, 5, 3)
+b = matrix(2.0, 3, 7)
+x = matrix(3.0, 9, 4)
+y = matrix(2.0, 4, 5)
+
+R = t(a %*% b) %*% t(x %*% y)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
diff --git
a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose2.dml
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose2.dml
new file mode 100644
index 0000000000..938e1445c3
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose2.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = matrix(1.0, 5, 3)
+b = matrix(2.0, 3, 7)
+x = matrix(3.0, 9, 4)
+y = matrix(2.0, 4, 5)
+
+while(FALSE){}
+R = t(a %*% b) %*% t(x %*% y)
+
+write(R, $1);
diff --git a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose3.R
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose3.R
new file mode 100644
index 0000000000..59b800ec95
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose3.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+a = matrix(1.0, 5, 5)
+b = matrix(2.0, 5, 9)
+x = matrix(3.0, 9, 4)
+y = matrix(2.0, 4, 3)
+
+R = t(a %*% a) %*% t(t(x %*% y) %*% t(b))
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
diff --git
a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose3.dml
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose3.dml
new file mode 100644
index 0000000000..d4cb4a84dc
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose3.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = matrix(1.0, 5, 5)
+b = matrix(2.0, 5, 9)
+x = matrix(3.0, 9, 4)
+y = matrix(2.0, 4, 3)
+
+while(FALSE){}
+
+R = t(a %*% a) %*% t(t(x %*% y) %*% t(b))
+
+write(R, $1);
diff --git a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose4.R
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose4.R
new file mode 100644
index 0000000000..db71f31944
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose4.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+a = matrix(1.0, 5, 5)
+b = matrix(2.0, 5, 9)
+x = matrix(3.0, 9, 4)
+y = matrix(1.0, 5, 4)
+
+R = t(a) %*% b %*% x %*% t(a %*% y)
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
diff --git
a/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose4.dml
b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose4.dml
new file mode 100644
index 0000000000..87af217c3b
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteMMChainTestTranspose4.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+a = matrix(1.0, 5, 5)
+b = matrix(2.0, 5, 9)
+x = matrix(3.0, 9, 4)
+y = matrix(1.0, 5, 4)
+
+while(FALSE){}
+
+R = t(a) %*% b %*% x %*% t(a %*% y)
+
+write(R, $1);