This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f3b638a48f [SYSTEMDS-3812] Improved rewrites pushdow-sum and rm-reorg
f3b638a48f is described below

commit f3b638a48f843dca3ac9b963100e2146c88e7751
Author: aarna <[email protected]>
AuthorDate: Fri Jan 10 15:46:57 2025 +0100

    [SYSTEMDS-3812] Improved rewrites pushdow-sum and rm-reorg
    
    Closes #2176.
---
 .../RewriteAlgebraicSimplificationDynamic.java     | 106 ++++++++++++++-------
 .../functions/aggregate/PushdownSumBinaryTest.java |  36 ++-----
 .../rewrite/RewritePushdownSumBinaryMult.java      |   5 -
 .../rewrite/RewritePushdownSumOnBinaryTest.java    |  84 +++++++++++-----
 .../rewrite/RewritePushdownSumOnBinary.dml         |  27 ++++--
 5 files changed, 160 insertions(+), 98 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index d73f8489b6..ddb2252f51 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -381,30 +381,28 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                
                return hi;
        }
-       
-       private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, 
int pos)
-       {
-               if( hi instanceof ReorgOp ) 
-               {
+
+       private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, 
int pos) {
+               if( hi instanceof ReorgOp ) {
                        ReorgOp rop = (ReorgOp) hi;
-                       Hop input = hi.getInput(0); 
+                       Hop input = hi.getInput(0);
                        boolean apply = false;
-                       
-                       //equal dims of reshape input and output -> no need for 
reshape because 
+
+                       //equal dims of reshape input and output -> no need for 
reshape because
                        //byrow always refers to both input/output and hence 
gives the same result
                        apply |= (rop.getOp()==ReOrgOp.RESHAPE && 
HopRewriteUtils.isEqualSize(hi, input));
-                       
-                       //1x1 dimensions of transpose/reshape -> no need for 
reorg      
-                       apply |= ((rop.getOp()==ReOrgOp.TRANS || 
rop.getOp()==ReOrgOp.RESHAPE) 
-                                       && rop.getDim1()==1 && 
rop.getDim2()==1);
-                       
+
+                       //1x1 dimensions of transpose/reshape/roll -> no need 
for reorg
+                       apply |= ((rop.getOp()==ReOrgOp.TRANS || 
rop.getOp()==ReOrgOp.RESHAPE
+                                       || rop.getOp()==ReOrgOp.ROLL) && 
rop.getDim1()==1 && rop.getDim2()==1);
+
                        if( apply ) {
                                HopRewriteUtils.replaceChildReference(parent, 
hi, input, pos);
                                hi = input;
                                LOG.debug("Applied removeUnnecessaryReorg.");
                        }
                }
-               
+
                return hi;
        }
        
@@ -1356,44 +1354,78 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
         * @param pos position
         * @return high-level operator
         */
-       private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int 
pos) 
+       private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int 
pos)
        {
                //all patterns headed by full sum over binary operation
                if(    hi instanceof AggUnaryOp //full sum root over binaryop
-                       && ((AggUnaryOp)hi).getDirection()==Direction.RowCol
-                       && ((AggUnaryOp)hi).getOp() == AggOp.SUM 
-                       && hi.getInput(0) instanceof BinaryOp   
-                       && hi.getInput(0).getParent().size()==1 ) //single 
parent
+                               && 
((AggUnaryOp)hi).getDirection()==Direction.RowCol
+                               && ((AggUnaryOp)hi).getOp() == AggOp.SUM
+                               && hi.getInput(0) instanceof BinaryOp
+                               && hi.getInput(0).getParent().size()==1 ) 
//single parent
                {
                        BinaryOp bop = (BinaryOp) hi.getInput(0);
                        Hop left = bop.getInput(0);
                        Hop right = bop.getInput(1);
-                       
-                       if( HopRewriteUtils.isEqualSize(left, right)  //dims(A) 
== dims(B)
-                               && left.getDataType() == DataType.MATRIX
-                               && right.getDataType() == DataType.MATRIX )
+
+                       if( left.getDataType() == DataType.MATRIX
+                                       && right.getDataType() == 
DataType.MATRIX )
                        {
                                OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS 
//pattern a: sum(A+B)->sum(A)+sum(B)
                                                || bop.getOp() == OpOp2.MINUS ) 
    //pattern b: sum(A-B)->sum(A)-sum(B)
                                                ? bop.getOp() : null;
-                               
+
                                if( applyOp != null ) {
-                                       //create new subdag sum(A) bop sum(B)
-                                       AggUnaryOp sum1 = 
HopRewriteUtils.createSum(left);
-                                       AggUnaryOp sum2 = 
HopRewriteUtils.createSum(right);
-                                       BinaryOp newBin = 
HopRewriteUtils.createBinary(sum1, sum2, applyOp);
-
-                                       //rewire new subdag
-                                       
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
-                                       HopRewriteUtils.cleanupUnreferenced(hi, 
bop);
-                                       
-                                       hi = newBin;
-                                       
-                                       LOG.debug("Applied 
pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
+                                       if (HopRewriteUtils.isEqualSize(left, 
right)) {
+                                               //create new subdag sum(A) bop 
sum(B) for equal-sized matrices
+                                               AggUnaryOp sum1 = 
HopRewriteUtils.createSum(left);
+                                               AggUnaryOp sum2 = 
HopRewriteUtils.createSum(right);
+                                               BinaryOp newBin = 
HopRewriteUtils.createBinary(sum1, sum2, applyOp);
+                                               //rewire new subdag
+                                               
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
+                                               
HopRewriteUtils.cleanupUnreferenced(hi, bop);
+
+                                               hi = newBin;
+
+                                               LOG.debug("Applied 
pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
+                                       }
+                                       // Check if right operand is a vector 
(has dimension of 1 in either rows or columns)
+                                       else if (right.getDim1() == 1 || 
right.getDim2() == 1) {
+                                               AggUnaryOp sum1 = 
HopRewriteUtils.createSum(left);
+                                               AggUnaryOp sum2 = 
HopRewriteUtils.createSum(right);
+
+                                               // Row vector case (1 x n)
+                                               if (right.getDim1() == 1) {
+                                                       // Create nrow(A) 
operation using dimensions
+                                                       LiteralOp nRows = new 
LiteralOp(left.getDim1());
+                                                       BinaryOp scaledSum = 
HopRewriteUtils.createBinary(nRows, sum2, OpOp2.MULT);
+                                                       BinaryOp newBin = 
HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
+                                                       //rewire new subdag
+                                                       
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
+                                                       
HopRewriteUtils.cleanupUnreferenced(hi, bop);
+
+                                                       hi = newBin;
+
+                                                       LOG.debug("Applied 
pushdownSumOnAdditiveBinary with row vector (line "+hi.getBeginLine()+").");
+                                               }
+                                               // Column vector case (n x 1)
+                                               else if (right.getDim2() == 1) {
+                                                       // Create ncol(A) 
operation using dimensions
+                                                       LiteralOp nCols = new 
LiteralOp(left.getDim2());
+                                                       BinaryOp scaledSum = 
HopRewriteUtils.createBinary(nCols, sum2, OpOp2.MULT);
+                                                       BinaryOp newBin = 
HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
+                                                       //rewire new subdag
+                                                       
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
+                                                       
HopRewriteUtils.cleanupUnreferenced(hi, bop);
+
+                                                       hi = newBin;
+
+                                                       LOG.debug("Applied 
pushdownSumOnAdditiveBinary with column vector (line "+hi.getBeginLine()+").");
+                                               }
+                                       }
                                }
                        }
                }
-       
+
                return hi;
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
 
b/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
index d4ac5fc6dc..3e7286c274 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/aggregate/PushdownSumBinaryTest.java
@@ -25,10 +25,8 @@ import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -89,39 +87,24 @@ public class PushdownSumBinaryTest extends AutomatedTestBase
        }
        
        @Test
-       public void testPushDownSumPlusNoRewriteSP() {
+       public void testPushDownSumPlusBroadcastSP() {
                runPushdownSumOnBinaryTest(TEST_NAME1, false, ExecType.SPARK);
        }
        
        @Test
-       public void testPushDownSumMinusNoRewriteSP() {
+       public void testPushDownSumMinusBroadcastSP() {
                runPushdownSumOnBinaryTest(TEST_NAME2, false, ExecType.SPARK);
        }
-               
-       /**
-        * 
-        * @param testname
-        * @param type
-        * @param sparse
-        * @param instType
-        */
+       
        private void runPushdownSumOnBinaryTest( String testname, boolean 
equiDims, ExecType instType) 
        {
                //rtplatform for MR
-               ExecMode platformOld = rtplatform;
-               switch( instType ){
-                       case SPARK: rtplatform = ExecMode.SPARK; break;
-                       default: rtplatform = ExecMode.HYBRID; break;
-               }
-       
-               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
-               if( rtplatform == ExecMode.SPARK )
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               ExecMode platformOld = setExecMode(instType);
                        
                try
                {
                        //determine script and function name
-                       String TEST_NAME = testname;                    
+                       String TEST_NAME = testname;
                        String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME 
+ "_" + String.valueOf(equiDims) + "/" : "";
                        
                        TestConfiguration config = 
getTestConfiguration(TEST_NAME);
@@ -150,13 +133,10 @@ public class PushdownSumBinaryTest extends 
AutomatedTestBase
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
                        
                        String lopcode = TEST_NAME.equals(TEST_NAME1) ? "+" : 
"-";
-                       String opcode = equiDims ? lopcode : 
Instruction.SP_INST_PREFIX+"map"+lopcode;
-                       Assert.assertTrue("Non-applied rewrite", 
Statistics.getCPHeavyHitterOpCodes().contains(opcode));        
+                       Assert.assertTrue("Non-applied rewrite", 
Statistics.getCPHeavyHitterOpCodes().contains(lopcode));
                }
-               finally
-               {
-                       rtplatform = platformOld;
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               finally {
+                       resetExecMode(platformOld);
                }
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
index 60ce24f105..cb135e21c8 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumBinaryMult.java
@@ -68,11 +68,6 @@ public class RewritePushdownSumBinaryMult extends 
AutomatedTestBase
                testRewritePushdownSumBinaryMult( TEST_NAME2, true );
        }
        
-       /**
-        * 
-        * @param testname
-        * @param rewrites
-        */
        private void testRewritePushdownSumBinaryMult( String testname, boolean 
rewrites )
        {       
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
index 9391af719a..d9459b03a9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java
@@ -29,54 +29,94 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
-public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase 
+public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase
 {
        private static final String TEST_NAME1 = "RewritePushdownSumOnBinary";
        private static final String TEST_DIR = "functions/rewrite/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewritePushdownSumOnBinaryTest.class.getSimpleName() + "/";
-       
+
        private static final int rows = 1000;
        private static final int cols = 1;
-       
+       private static final double eps = 1e-8;
+
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R1", "R2" }) );
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,
+                               new String[] { "R1", "R2", "R3", "R4" }));
+       }
+
+       @Test
+       public void testRewritePushdownSumOnBinaryNoRewrite() {
+               testRewritePushdownSumOnBinary(TEST_NAME1, false);
+       }
+
+       @Test
+       public void testRewritePushdownSumOnBinary() {
+               testRewritePushdownSumOnBinary(TEST_NAME1, true);
        }
 
        @Test
-       public void testRewritePushdownSumOnBinaryNoRewrite()  {
-               testRewritePushdownSumOnBinary( TEST_NAME1, false );
+       public void testRewritePushdownSumOnBinaryRowVector() {
+               testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, true);
        }
-       
+
        @Test
-       public void testRewritePushdownSumOnBinary()  {
-               testRewritePushdownSumOnBinary( TEST_NAME1, true );
+       public void testRewritePushdownSumOnBinaryColVector() {
+               testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, false);
        }
-       
-       private void testRewritePushdownSumOnBinary( String testname, boolean 
rewrites )
-       {       
+
+       private void testRewritePushdownSumOnBinary(String testname, boolean 
rewrites) {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
-               
+
                try {
                        TestConfiguration config = 
getTestConfiguration(testname);
                        loadTestConfiguration(config);
-                       
+
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{ "-args", 
String.valueOf(rows), 
-                                       String.valueOf(cols), output("R1"), 
output("R2") };
+
+                       programArgs = new String[]{ "-args", 
String.valueOf(rows),
+                                       String.valueOf(cols), output("R1"), 
output("R2"),
+                                       String.valueOf(rows), 
String.valueOf(cols) };  // Assuming row and col vectors
+
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
 
-                       //run performance tests
+                       // Run performance tests
                        runTest(true, false, null, -1);
-                       
-                       //compare matrices 
-                       long expect = Math.round(0.5*rows);
+
+                       // Compare matrices
+                       long expect = Math.round(0.5 * rows);
                        HashMap<CellIndex, Double> dmlfile1 = 
readDMLScalarFromOutputDir("R1");
-                       Assert.assertEquals(expect, dmlfile1.get(new 
CellIndex(1,1)), expect*0.01);
+                       Assert.assertEquals(expect, dmlfile1.get(new 
CellIndex(1, 1)), eps);
                        HashMap<CellIndex, Double> dmlfile2 = 
readDMLScalarFromOutputDir("R2");
-                       Assert.assertEquals(expect, dmlfile2.get(new 
CellIndex(1,1)), expect*0.01);
+                       Assert.assertEquals(expect, dmlfile2.get(new 
CellIndex(1, 1)), eps);
+               } finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+
+
+       private void testRewritePushdownSumOnBinaryVector(String testname, 
boolean rewrites, boolean isRow) {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-args", 
String.valueOf(rows),
+                                       String.valueOf(cols), output("R3"), 
output("R4"),
+                                       String.valueOf(isRow ? 1 : rows), 
String.valueOf(isRow ? cols : 1) };
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       runTest(true, false, null, -1);
+
+                       long expect = Math.round(500); // Expected value for 
0.5 + 0.5
+                       HashMap<CellIndex, Double> dmlfile3 = 
readDMLScalarFromOutputDir("R3");
+                       Assert.assertEquals(expect, dmlfile3.get(new 
CellIndex(1,1)), eps);
+                       HashMap<CellIndex, Double> dmlfile4 = 
readDMLScalarFromOutputDir("R4");
+                       Assert.assertEquals(expect, dmlfile4.get(new 
CellIndex(1,1)), eps);
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
diff --git a/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml 
b/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
index d48ac0aad8..0d1b812397 100644
--- a/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
+++ b/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
@@ -19,15 +19,30 @@
 #
 #-------------------------------------------------------------
 
-A = rand(rows=$1, cols=$2, seed=1);
-B = rand(rows=$1, cols=$2, seed=2);
-C = rand(rows=$1, cols=$2, seed=3);
-D = rand(rows=$1, cols=$2, seed=4);
+# Required parameters
+A = matrix(0.5, rows=$1, cols=$2);
+B = matrix(0.5, rows=$1, cols=$2);
+C = matrix(0.5, rows=$1, cols=$2);
+D = matrix(0.5, rows=$1, cols=$2);
 
+# Set defaults for optional parameters
+rowsV = ifdef($5, 0)
+colsV = ifdef($6, 0)
+
+# Original matrix tests
 r1 = sum(A*B + C*D);
 r2 = r1;
 
-print("r1="+r1+", r2="+r2);
+# Vector tests
+if (rowsV != 0 & colsV != 0) {
+    V = matrix(0.5, rows=rowsV, cols=colsV);
+    r3 = sum(A + V);
+    r4 = r3;
+}
+
 write(r1, $3);
 write(r2, $4);
-
+if (rowsV != 0 & colsV != 0) {
+    write(r3, $5);
+    write(r4, $6);
+}
\ No newline at end of file

Reply via email to