This is an automated email from the ASF dual-hosted git repository.

janniklinde pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c2d0e5a47c [SYSTEMDS-3941] New Algebraic Rewrites
c2d0e5a47c is described below

commit c2d0e5a47c6a4b22d3979d79f2260e81eb87b9a9
Author: Süleyman Melih Portakal 
<[email protected]>
AuthorDate: Fri Apr 24 17:38:18 2026 +0200

    [SYSTEMDS-3941] New Algebraic Rewrites
    
    Closes #2460.
---
 .../RewriteAlgebraicSimplificationDynamic.java     |  26 ++++++
 .../RewriteAlgebraicSimplificationStatic.java      |  54 +++++++++++
 .../functions/rewrite/RewriteFusedRandTest.java    |   2 +-
 .../RewritePushdownColSumBinaryMultTest.java       | 100 +++++++++++++++++++++
 .../RewritePushdownRowSumBinaryMultTest.java       | 100 +++++++++++++++++++++
 .../RewriteSimplifySumConstantMatrixTest.java      |  86 ++++++++++++++++++
 .../functions/rewrite/RewriteFusedRandLit.dml      |   2 +-
 ...ndLit.dml => RewritePushdownColSumBinaryMult.R} |  19 ++--
 ...Lit.dml => RewritePushdownColSumBinaryMult.dml} |  16 ++--
 ...dLit.dml => RewritePushdownColSumBinaryMult2.R} |  19 ++--
 ...it.dml => RewritePushdownColSumBinaryMult2.dml} |  16 ++--
 ...ndLit.dml => RewritePushdownRowSumBinaryMult.R} |  19 ++--
 ...Lit.dml => RewritePushdownRowSumBinaryMult.dml} |  16 ++--
 ...dLit.dml => RewritePushdownRowSumBinaryMult2.R} |  19 ++--
 ...it.dml => RewritePushdownRowSumBinaryMult2.dml} |  16 ++--
 ...dLit.dml => RewriteSimplifySumConstantMatrix.R} |  15 ++--
 ...it.dml => RewriteSimplifySumConstantMatrix.dml} |  14 +--
 17 files changed, 431 insertions(+), 108 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index eb51348a8e..79b6d8a39b 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -184,6 +184,7 @@ public class RewriteAlgebraicSimplificationDynamic extends 
HopRewriteRule
                        hi = simplifySumDiagToTrace(hi);                  
//e.g., sum(diag(X)) -> trace(X); if col vector
                        hi = simplifyLowerTriExtraction(hop, hi, i);      
//e.g., X * cumsum(diag(matrix(1,nrow(X),1))) -> lower.tri
                        hi = simplifyConstantCumsum(hop, hi, i);          
//e.g., cumsum(matrix(1/n,n,1)) -> seq(1/n, 1, 1/n)
+                       hi = simplifySumConstantMatrix(hop, hi, i);       
//e.g., sum(matrix(a,rows=b,cols=c)) -> a*b*c
                        hi = pushdownBinaryOperationOnDiag(hop, hi, i);   
//e.g., diag(X)*7 -> diag(X*7); if col vector
                        hi = pushdownSumOnAdditiveBinary(hop, hi, i);     
//e.g., sum(A+B) -> sum(A)+sum(B); if dims(A)==dims(B)
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION) {
@@ -1273,6 +1274,31 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                }
                return hi;
        }
+
+       private static Hop simplifySumConstantMatrix(Hop parent, Hop hi, int 
pos) {
+               //pattern: sum(matrix(a, rows=b, cols=c)) -> a*b*c
+               if( HopRewriteUtils.isAggUnaryOp(hi, AggOp.SUM, 
Direction.RowCol)
+                       && 
HopRewriteUtils.isDataGenOpWithConstantValue(hi.getInput(0))
+                       && hi.getInput(0).dimsKnown()
+                       && hi.getInput(0).getDim1() >= 1
+                       && hi.getInput(0).getDim2() >= 1
+                       && hi.getInput(0).getParent().size() == 1 )
+               {
+                       DataGenOp datagen = (DataGenOp) hi.getInput(0);
+                       Hop constVal = datagen.getConstantValue();
+                       Hop rows = new LiteralOp(datagen.getDim1());
+                       Hop cols = new LiteralOp(datagen.getDim2());
+
+                       Hop hnew = HopRewriteUtils.createBinary(
+                               HopRewriteUtils.createBinary(constVal, rows, 
OpOp2.MULT), cols, OpOp2.MULT);
+                       HopRewriteUtils.replaceChildReference(parent, hi, hnew, 
pos);
+                       HopRewriteUtils.cleanupUnreferenced(hi, datagen);
+
+                       hi = hnew;
+                       LOG.debug("Applied simplifySumConstantMatrix (line 
"+hi.getBeginLine()+").");
+               }
+               return hi;
+       }
        
        private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi, 
int pos) 
        {
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 2ae1550257..b014ab7920 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -170,6 +170,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = pushdownDetMultOperation(hop, hi, i);           
//e.g., det(X%*%Y) -> det(X)*det(Y)
                        hi = pushdownDetScalarMatrixMultOperation(hop, hi, i);  
//e.g., det(lambda*X) -> lambda^nrow(X)*det(X)
                        hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lambda*X) -> lambda*sum(X)
+                       hi = pushdownRowSumBinaryMult(hop, hi, i);           
//e.g., rowSums(lambda*X) -> lambda*rowSums(X)
+                       hi = pushdownColSumBinaryMult(hop, hi, i);           
//e.g., colSums(lambda*X) -> lambda*colSums(X)
                        hi = pullupAbs(hop, hi, i);                          
//e.g., abs(X)*abs(Y) --> abs(X*Y)
                        hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
                        hi = simplifyTransposedAppend(hop, hi, i);           
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
@@ -1447,6 +1449,58 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
 
+       private static Hop pushdownRowSumBinaryMult(Hop parent, Hop hi, int pos 
) {
+               //pattern:  rowSums(lamda*X) -> lamda*rowSums(X)
+               if( hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getDirection()==Direction.Row
+                               && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only 
one parent which is the rowSums
+                               && HopRewriteUtils.isBinary(hi.getInput(0), 
OpOp2.MULT, 1)
+                               && 
((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR && 
hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX)
+                               
||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX && 
hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR)))
+               {
+                       Hop operand1 = hi.getInput(0).getInput(0);
+                       Hop operand2 = hi.getInput(0).getInput(1);
+
+                       //check which operand is the Scalar and which is the 
matrix
+                       Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? 
operand1 : operand2;
+                       Hop matrix = (operand1.getDataType()==DataType.MATRIX) 
? operand1 : operand2;
+
+                       AggUnaryOp 
aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Row);
+                       Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, 
OpOp2.MULT);
+
+                       HopRewriteUtils.replaceChildReference(parent, hi, bop, 
pos);
+
+                       LOG.debug("Applied pushdownRowSumBinaryMult (line 
"+hi.getBeginLine()+").");
+                       return bop;
+               }
+               return hi;
+       }
+
+       private static Hop pushdownColSumBinaryMult(Hop parent, Hop hi, int pos 
) {
+               //pattern:  colSums(lamda*X) -> lamda*colSums(X)
+               if( hi instanceof AggUnaryOp && 
((AggUnaryOp)hi).getDirection()==Direction.Col
+                               && ((AggUnaryOp)hi).getOp()==AggOp.SUM // only 
one parent which is the colSums
+                               && HopRewriteUtils.isBinary(hi.getInput(0), 
OpOp2.MULT, 1)
+                               && 
((hi.getInput(0).getInput(0).getDataType()==DataType.SCALAR && 
hi.getInput(0).getInput(1).getDataType()==DataType.MATRIX)
+                               
||(hi.getInput(0).getInput(0).getDataType()==DataType.MATRIX && 
hi.getInput(0).getInput(1).getDataType()==DataType.SCALAR)))
+               {
+                       Hop operand1 = hi.getInput(0).getInput(0);
+                       Hop operand2 = hi.getInput(0).getInput(1);
+
+                       //check which operand is the Scalar and which is the 
matrix
+                       Hop lamda = (operand1.getDataType()==DataType.SCALAR) ? 
operand1 : operand2;
+                       Hop matrix = (operand1.getDataType()==DataType.MATRIX) 
? operand1 : operand2;
+
+                       AggUnaryOp 
aggOp=HopRewriteUtils.createAggUnaryOp(matrix, AggOp.SUM, Direction.Col);
+                       Hop bop = HopRewriteUtils.createBinary(lamda, aggOp, 
OpOp2.MULT);
+
+                       HopRewriteUtils.replaceChildReference(parent, hi, bop, 
pos);
+
+                       LOG.debug("Applied pushdownColSumBinaryMult (line 
"+hi.getBeginLine()+").");
+                       return bop;
+               }
+               return hi;
+       }
+
        private static Hop pullupAbs(Hop parent, Hop hi, int pos ) {
                if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
                                && HopRewriteUtils.isUnary(hi.getInput(0), 
OpOp1.ABS)
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
index ef580848fb..e840accd06 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteFusedRandTest.java
@@ -121,7 +121,7 @@ public class RewriteFusedRandTest extends AutomatedTestBase
                        //compare matrices 
                        Double ret = readDMLMatrixFromOutputDir("R").get(new 
CellIndex(1,1));
                        if( testname.equals(TEST_NAME1) )
-                               Assert.assertEquals("Wrong result", 
Double.valueOf(rows), ret);
+                               Assert.assertEquals("Wrong result", 
Double.valueOf(rows*cols), ret);
                        else if( testname.equals(TEST_NAME2) )
                                Assert.assertEquals("Wrong result", 
Double.valueOf(Math.pow(rows*cols, 2)), ret);
                        else if( testname.equals(TEST_NAME3) )
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java
new file mode 100644
index 0000000000..eb31f12c3d
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownColSumBinaryMultTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+
+public class RewritePushdownColSumBinaryMultTest extends AutomatedTestBase
+{
+       private static final String TEST_NAME1 = 
"RewritePushdownColSumBinaryMult";
+       private static final String TEST_NAME2 = 
"RewritePushdownColSumBinaryMult2";
+
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewritePushdownColSumBinaryMultTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }));
+       }
+
+       @Test
+       public void testPushdownColSumBinaryMultNoRewrite() {
+               testRewritePushdownColSumBinaryMult(TEST_NAME1, false);
+       }
+
+       @Test
+       public void testPushdownColSumBinaryMultRewrite() {
+               testRewritePushdownColSumBinaryMult(TEST_NAME1, true);
+       }
+
+       @Test
+       public void testPushdownColSumBinaryMultNoRewrite2() {
+               testRewritePushdownColSumBinaryMult(TEST_NAME2, false);
+       }
+
+       @Test
+       public void testPushdownColSumBinaryMultRewrite2() {
+               testRewritePushdownColSumBinaryMult(TEST_NAME2, true);
+       }
+
+       private void testRewritePushdownColSumBinaryMult(String testname, 
boolean rewrites) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[] { "-stats", "-args", 
output("R") };
+
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML", 
"R");
+
+                       if(rewrites)
+                               Assert.assertEquals(1, 
Statistics.getCPHeavyHitterCount("n*"));
+                       else
+                               Assert.assertEquals(2, 
Statistics.getCPHeavyHitterCount("*"));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java
new file mode 100644
index 0000000000..cfa18ee335
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownRowSumBinaryMultTest.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+
+public class RewritePushdownRowSumBinaryMultTest extends AutomatedTestBase
+{
+       private static final String TEST_NAME1 = 
"RewritePushdownRowSumBinaryMult";
+       private static final String TEST_NAME2 = 
"RewritePushdownRowSumBinaryMult2";
+
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewritePushdownRowSumBinaryMultTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }));
+       }
+
+       @Test
+       public void testPushdownRowSumBinaryMultNoRewrite() {
+               testRewritePushdownRowSumBinaryMult(TEST_NAME1, false);
+       }
+
+       @Test
+       public void testPushdownRowSumBinaryMultRewrite() {
+               testRewritePushdownRowSumBinaryMult(TEST_NAME1, true);
+       }
+
+       @Test
+       public void testPushdownRowSumBinaryMultNoRewrite2() {
+               testRewritePushdownRowSumBinaryMult(TEST_NAME2, false);
+       }
+
+       @Test
+       public void testPushdownRowSumBinaryMultRewrite2() {
+               testRewritePushdownRowSumBinaryMult(TEST_NAME2, true);
+       }
+
+       private void testRewritePushdownRowSumBinaryMult(String testname, 
boolean rewrites) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[] { "-stats", "-args", 
output("R") };
+
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+                       HashMap<CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
+                       TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "DML", 
"R");
+
+                       if(rewrites)
+                               Assert.assertEquals(1, 
Statistics.getCPHeavyHitterCount("n*"));
+                       else
+                               Assert.assertEquals(2, 
Statistics.getCPHeavyHitterCount("*"));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java
new file mode 100644
index 0000000000..9530740ff8
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifySumConstantMatrixTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class RewriteSimplifySumConstantMatrixTest extends AutomatedTestBase
+{
+       private static final String TEST_NAME = 
"RewriteSimplifySumConstantMatrix";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteSimplifySumConstantMatrixTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "R" }));
+       }
+
+       @Test
+       public void testSimplifySumConstantMatrixNoRewritePositive() {
+               testRewriteSimplifySumConstantMatrix(2.5, 7, 11, false);
+       }
+
+       @Test
+       public void testSimplifySumConstantMatrixRewritePositive() {
+               testRewriteSimplifySumConstantMatrix(2.5, 7, 11, true);
+       }
+
+
+       private void testRewriteSimplifySumConstantMatrix(double value, long 
rows, long cols, boolean rewrites) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+               try {
+                       TestConfiguration config = 
getTestConfiguration(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {
+                               "-stats", "-args",
+                               String.valueOf(value), String.valueOf(rows), 
String.valueOf(cols), output("R")
+                       };
+
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = getRCmd(String.valueOf(value), 
String.valueOf(rows), String.valueOf(cols), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       double actual = readDMLScalarFromOutputDir("R").get(new 
CellIndex(1, 1));
+                       double expected = 
readRScalarFromExpectedDir("R").get(new CellIndex(1, 1));
+                       Assert.assertEquals(expected, actual, 1e-15);
+
+                       if(rewrites)
+                               
Assert.assertFalse(heavyHittersContainsString("rand"));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
index ab00f04772..2e97afdba9 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
@@ -25,5 +25,5 @@ while(FALSE){} #prevent cse
 
 X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
 
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
+R = as.matrix(sum(abs(X1)==abs(X2)));
 write(R, $5);
\ No newline at end of file
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R
index ab00f04772..8813cc5a50 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -7,9 +5,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*colSums(3*X), nrow=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml
index ab00f04772..6494dc45e4 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*colSums(3*X)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R
index ab00f04772..a951ea42c4 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -7,9 +5,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*colSums(X*3), nrow=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml
index ab00f04772..c69492b571 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownColSumBinaryMult2.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*colSums(X*3)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R
index ab00f04772..2b9f7b4b16 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -7,9 +5,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*rowSums(3*X), ncol=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml
index ab00f04772..31f5e0bd1e 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*rowSums(3*X)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R
similarity index 76%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R
index ab00f04772..782e2fa687 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.R
@@ -1,5 +1,3 @@
-#-------------------------------------------------------------
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -7,9 +5,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +17,8 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+args<-commandArgs(TRUE)
+library("Matrix")
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+R=matrix(2*rowSums(X*3), ncol=1)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml
index ab00f04772..d579df7758 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownRowSumBinaryMult2.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,7 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+X=matrix(1, 100, 1) %*% t(seq(1,100))
+while(FALSE){}
+R=2*rowSums(X*3)
+write(R, $1)
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R
index ab00f04772..d6c07b4baf 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.R
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,10 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
+args <- commandArgs(TRUE)
 
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
+value <- as.numeric(args[1])
+rows <- as.integer(args[2])
+cols <- as.integer(args[3])
 
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+write(sum(matrix(value, nrow=rows, ncol=cols)), paste(args[4], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml 
b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml
similarity index 82%
copy from src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
copy to src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml
index ab00f04772..0b54eeb12e 100644
--- a/src/test/scripts/functions/rewrite/RewriteFusedRandLit.dml
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifySumConstantMatrix.dml
@@ -7,9 +7,9 @@
 # to you under the Apache License, Version 2.0 (the
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
-# 
+#
 #   http://www.apache.org/licenses/LICENSE-2.0
-# 
+#
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
 # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,11 +19,5 @@
 #
 #-------------------------------------------------------------
 
-X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-while(FALSE){} #prevent cse
-
-X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
-
-R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
-write(R, $5);
\ No newline at end of file
+R = sum(matrix($1, rows=$2, cols=$3))
+write(R, $4)


Reply via email to