This is an automated email from the ASF dual-hosted git repository.
janniklinde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c2d0e5a47c [SYSTEMDS-3941] New Algebraic Rewrites
c2d0e5a47c is described below
commit c2d0e5a47c6a4b22d3979d79f2260e81eb87b9a9
Author: Süleyman Melih Portakal
<[email protected]>
AuthorDate: Fri Apr 24 17:38:18 2026 +0200
[SYSTEMDS-3941] New Algebraic Rewrites
Closes #2460.
---
.../RewriteAlgebraicSimplificationDynamic.java | 26 ++++++
.../RewriteAlgebraicSimplificationStatic.java | 54 +++++++++++
.../functions/rewrite/RewriteFusedRandTest.java | 2 +-
.../RewritePushdownColSumBinaryMultTest.java | 100 +++++++++++++++++++++
.../RewritePushdownRowSumBinaryMultTest.java | 100 +++++++++++++++++++++
.../RewriteSimplifySumConstantMatrixTest.java | 86 ++++++++++++++++++
.../functions/rewrite/RewriteFusedRandLit.dml | 2 +-
...ndLit.dml => RewritePushdownColSumBinaryMult.R} | 19 ++--
...Lit.dml => RewritePushdownColSumBinaryMult.dml} | 16 ++--
...dLit.dml => RewritePushdownColSumBinaryMult2.R} | 19 ++--
...it.dml => RewritePushdownColSumBinaryMult2.dml} | 16 ++--
...ndLit.dml => RewritePushdownRowSumBinaryMult.R} | 19 ++--
...Lit.dml => RewritePushdownRowSumBinaryMult.dml} | 16 ++--
...dLit.dml => RewritePushdownRowSumBinaryMult2.R} | 19 ++--
...it.dml => RewritePushdownRowSumBinaryMult2.dml} | 16 ++--
...dLit.dml => RewriteSimplifySumConstantMatrix.R} | 15 ++--
...it.dml => RewriteSimplifySumConstantMatrix.dml} | 14 +--
17 files changed, 431 insertions(+), 108 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index eb51348a8e..79b6d8a39b 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -184,6 +184,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
hi = simplifySumDiagToTrace(hi);
//e.g., sum(diag(X)) -> trace(X); if col vector
hi = simplifyLowerTriExtraction(hop, hi, i);
//e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
hi = simplifyConstantCumsum(hop, hi, i);
//e.g., cumsum(matrix(1/n,n,1)) -> seq(1/n, 1, 1/n)
+ hi = simplifySumConstantMatrix(hop, hi, i);
//e.g., sum(matrix(a,rows=b,cols=c)) -> a*b*c
hi = pushdownBinaryOperationOnDiag(hop, hi, i);
//e.g., diag(X)*7 -> diag(X*7); if col vector
hi = pushdownSumOnAdditiveBinary(hop, hi, i);
//e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1273,6 +1274,31 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
}
return hi;
}
+
+ private static Hop simplifySumConstantMatrix(Hop parent, Hop hi, int
pos) {
+ //pattern: sum(matrix(a, rows=b, cols=c)) -> a*b*c
+ if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM,
Direction.RowCol)
+ &&
HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(0))
+ && hi.getInput(0).dimsKnown()
+ && hi.getInput(0).getDim1() >= 1
+ && hi.getInput(0).getDim2() >= 1
+ && hi.getInput(0).getParent().size() == 1 )
+ {
+ DataGenOp datagen = (DataGenOp) hi.getInput(0);
+ Hop constVal = datagen.getConstantValue();
+ Hop rows = new LiteralOp(datagen.getDim1());
+ Hop cols = new LiteralOp(datagen.getDim2());
+
+ Hop hnew = HopRewriteUtils.createBinary(
+ HopRewriteUtils.createBinary(constVal, rows,
OpOp2.MULT), cols, OpOp2.MULT);
+ HopRewriteUtils.replaceChildReference(parent, hi, hnew,
pos);
+ HopRewriteUtils.cleanupUnreferenced(hi, datagen);
+
+ hi = hnew;
+ LOG.debug("Applied simplifySumConstantMatrix (line
"+hi.getBeginLine()+").");
+ }
+ return hi;
+ }
private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi,
int pos)
{
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 2ae1550257..b014ab7920 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -170,6 +170,8 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = pushdownDetMultOperation(hop, hi, i);
//e.g., det(X%*%Y) -> det(X)*det(Y)
hi = pushdownDetScalarMatrixMultOperation(hop, hi, i);
//e.g., det(lambda*X) -> lambda^nrow(X)*det(X)
hi = pushdownSumBinaryMult(hop, hi, i);
//e.g., sum(lambda*X) -> lambda*sum(X)
+ hi = pushdownRowSumBinaryMult(hop, hi, i);
//e.g., rowSums(lambda*X) -> lambda*rowSums(X)
+ hi = pushdownColSumBinaryMult(hop, hi, i);
//e.g., colSums(lambda*X) -> lambda*colSums(X)
hi = pullupAbs(hop, hi, i);
//e.g., abs(X)*abs(Y) --> abs(X*Y)
hi = simplifyUnaryPPredOperation(hop, hi, i);
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
hi = simplifyTransposedAppend(hop, hi, i);
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
@@ -1447,6 +1449,58 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
return hi;
}
+ private static Hop pushdownRowSumBinaryMult(Hop parent, Hop hi, int pos
) {
+ //pattern: rowSums(lamda*X) -> lamda*rowSums(X)
+ if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.Row
+ && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only
one parent which is the rowSums
+ && HopRewriteUtils.isBinary(hi.getInput(0),
OpOp2.MULT, 1)
+ &&
((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR &&
hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX)
+
||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX &&
hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR)))
+ {
+ Hop operand1 = hi.getInput(0).getInput(0);
+ Hop operand2 = hi.getInput(0).getInput(1);
+
+ //check which operand is the Scalar and which is the
matrix
+ Hop lamda = (operand1.getDataType()==DataType.SCALAR) ?
operand1 : operand2;
+ Hop matrix = (operand1.getDataType()==DataType.MATRIX)
? operand1 : operand2;
+
+ AggUnaryOp
aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Row);
+ Hop bop = HopRewriteUtils.createBinary(lamda, aggOp,
OpOp2.MULT);
+
+ HopRewriteUtils.replaceChildReference(parent, hi, bop,
pos);
+
+ LOG.debug("Applied pushdownRowSumBinaryMult (line
"+hi.getBeginLine()+").");
+ return bop;
+ }
+ return hi;
+ }
+
+ private static Hop pushdownColSumBinaryMult(Hop parent, Hop hi, int pos
) {
+ //pattern: colSums(lamda*X) -> lamda*colSums(X)
+ if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.Col
+ && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only
one parent which is the colSums
+ && HopRewriteUtils.isBinary(hi.getInput(0),
OpOp2.MULT, 1)
+ &&
((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR &&
hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX)
+
||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX &&
hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR)))
+ {
+ Hop operand1 = hi.getInput(0).getInput(0);
+ Hop operand2 = hi.getInput(0).getInput(1);
+
+ //check which operand is the Scalar and which is the
matrix
+ Hop lamda = (operand1.getDataType()==DataType.SCALAR) ?
operand1 : operand2;
+ Hop matrix = (operand1.getDataType()==DataType.MATRIX)
? operand1 : operand2;
+
+ AggUnaryOp
aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Col);
+ Hop bop = HopRewriteUtils.createBinary(lamda, aggOp,
OpOp2.MULT);
+
+ HopRewriteUtils.replaceChildReference(parent, hi, bop,
pos);
+
+ LOG.debug("Applied pushdownColSumBinaryMult (line
"+hi.getBeginLine()+").");
+ return bop;
+ }
+ return hi;
+ }
+
private static Hop pullupAbs(Hop parent, Hop hi, int pos ) {
if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
&& HopRewriteUtils.isUnary(hi.getInput(0),
OpOp1.ABS)
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
index ef580848fb..e840accd06 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
@@ -121,7 +121,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase
//compare matrices
Double ret = readDMLMatrixFromOutputDir("R").get(new
CellIndex(1,1));
if( testname.equals(TEST_NAME1) )
- Assert.assertEquals("Wrong result",
Double.valueOf(rows), ret);
+ Assert.assertEquals("Wrong result",
Double.valueOf(rows*cols), ret);
else if( testname.equals(TEST_NAME2) )
Assert.assertEquals("Wrong result",
Double.valueOf(Math.pow(rows*cols, 2)), ret);
else if( testname.equals(TEST_NAME3) )
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java
new file mode 100644
index 0000000000..eb31f12c3d
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+
+public class RewritePushdownColSumBinaryMultTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 =
"RewritePushdownColSumBinaryMult";
+ private static final String TEST_NAME2 =
"RewritePushdownColSumBinaryMult2";
+
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewritePushdownColSumBinaryMultTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }));
+ }
+
+ @Test
+ public void testPushdownColSumBinaryMultNoRewrite() {
+ testRewritePushdownColSumBinaryMult(TEST_NAME1, false);
+ }
+
+ @Test
+ public void testPushdownColSumBinaryMultRewrite() {
+ testRewritePushdownColSumBinaryMult(TEST_NAME1, true);
+ }
+
+ @Test
+ public void testPushdownColSumBinaryMultNoRewrite2() {
+ testRewritePushdownColSumBinaryMult(TEST_NAME2, false);
+ }
+
+ @Test
+ public void testPushdownColSumBinaryMultRewrite2() {
+ testRewritePushdownColSumBinaryMult(TEST_NAME2, true);
+ }
+
+ private void testRewritePushdownColSumBinaryMult(String testname,
boolean rewrites) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+ try {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-stats", "-args",
output("R") };
+
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML",
"R");
+
+ if(rewrites)
+ Assert.assertEquals(1,
Statistics.getCPHeavyHitterCount("n*"));
+ else
+ Assert.assertEquals(2,
Statistics.getCPHeavyHitterCount("*"));
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java
new file mode 100644
index 0000000000..cfa18ee335
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+
+public class RewritePushdownRowSumBinaryMultTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 =
"RewritePushdownRowSumBinaryMult";
+ private static final String TEST_NAME2 =
"RewritePushdownRowSumBinaryMult2";
+
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewritePushdownRowSumBinaryMultTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }));
+ }
+
+ @Test
+ public void testPushdownRowSumBinaryMultNoRewrite() {
+ testRewritePushdownRowSumBinaryMult(TEST_NAME1, false);
+ }
+
+ @Test
+ public void testPushdownRowSumBinaryMultRewrite() {
+ testRewritePushdownRowSumBinaryMult(TEST_NAME1, true);
+ }
+
+ @Test
+ public void testPushdownRowSumBinaryMultNoRewrite2() {
+ testRewritePushdownRowSumBinaryMult(TEST_NAME2, false);
+ }
+
+ @Test
+ public void testPushdownRowSumBinaryMultRewrite2() {
+ testRewritePushdownRowSumBinaryMult(TEST_NAME2, true);
+ }
+
+ private void testRewritePushdownRowSumBinaryMult(String testname,
boolean rewrites) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+ try {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-stats", "-args",
output("R") };
+
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML",
"R");
+
+ if(rewrites)
+ Assert.assertEquals(1,
Statistics.getCPHeavyHitterCount("n*"));
+ else
+ Assert.assertEquals(2,
Statistics.getCPHeavyHitterCount("*"));
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java
new file mode 100644
index 0000000000..9530740ff8
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class RewriteSimplifySumConstantMatrixTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME =
"RewriteSimplifySumConstantMatrix";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSimplifySumConstantMatrixTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }));
+ }
+
+ @Test
+ public void testSimplifySumConstantMatrixNoRewritePositive() {
+ testRewriteSimplifySumConstantMatrix(2.5, 7, 11, false);
+ }
+
+ @Test
+ public void testSimplifySumConstantMatrixRewritePositive() {
+ testRewriteSimplifySumConstantMatrix(2.5, 7, 11, true);
+ }
+
+
+ private void testRewriteSimplifySumConstantMatrix(double value, long
rows, long cols, boolean rewrites) {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+ try {
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {
+ "-stats", "-args",
+ String.valueOf(value), String.valueOf(rows),
String.valueOf(cols), output("R")
+ };
+
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = getRCmd(String.valueOf(value),
String.valueOf(rows), String.valueOf(cols), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ double actual = readDMLScalarFromOutputDir("R").get(new
CellIndex(1, 1));
+ double expected =
readRScalarFromExpectedDir("R").get(new CellIndex(1, 1));
+ Assert.assertEquals(expected, actual, 1e-15);
+
+ if(rewrites)
+
Assert.assertFalse(heavyHittersContainsString("rand"));
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
index ab00f04772..2e97afdba9 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
@@ -25,5 +25,5 @@ while(FALSE){} #prevent cse
X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
+R = as.matrix(sum(abs(X1)==abs(X2)));
write(R, $5);
\ No newline at end of file
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R
index ab00f04772..8813cc5a50 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -7,9 +5,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*colSums(3*X), nrow=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml
index ab00f04772..6494dc45e4 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*colSums(3*X)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R
index ab00f04772..a951ea42c4 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -7,9 +5,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*colSums(X*3), nrow=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml
index ab00f04772..c69492b571 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*colSums(X*3)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R
index ab00f04772..2b9f7b4b16 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -7,9 +5,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*rowSums(3*X), ncol=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml
index ab00f04772..31f5e0bd1e 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*rowSums(3*X)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R
index ab00f04772..782e2fa687 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -7,9 +5,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*rowSums(X*3), ncol=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml
index ab00f04772..d579df7758 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*rowSums(X*3)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R
index ab00f04772..d6c07b4baf 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,10 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
+args <- commandArgs(TRUE)
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
+value <- as.numeric(args[1])
+rows <- as.integer(args[2])
+cols <- as.integer(args[3])
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+write(sum(matrix(value, nrow=rows, ncol=cols)), paste(args[4], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml
index ab00f04772..0b54eeb12e 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml
@@ -7,9 +7,9 @@
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
-#
+#
# http://www.apache.org/licenses/LICENSE-2.0
-#
+#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,5 @@
#
#-------------------------------------------------------------
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+R = sum(matrix($1, rows=$2, cols=$3))
+write(R, $4)