This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 7de36573fc [SYSTEMDS-3785] Fix rewrite test for simplify bushy binary 
ops
7de36573fc is described below

commit 7de36573fc6e6e22957145cbb40cfc402d5978f8
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Oct 24 19:54:48 2024 +0200

    [SYSTEMDS-3785] Fix rewrite test for simplify bushy binary ops
    
    This patch resolves a remaining FIXME after improved rewrite code
    coverage by fixing the expressions and other rewrite configs so the
    test actually triggers the existing rewrite.
---
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  1 +
 .../apache/sysds/hops/rewrite/ProgramRewriter.java |  3 +-
 .../RewriteAlgebraicSimplificationStatic.java      |  4 +--
 .../RewriteSimplifyBushyBinaryOperationTest.java   | 38 ++++++++++++----------
 4 files changed, 25 insertions(+), 21 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index de8e7809ca..6338ff7a70 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -195,6 +195,7 @@ public class OptimizerUtils
         * all sum-product related rewrites.
         */
        public static boolean ALLOW_SUM_PRODUCT_REWRITES = true;
+       public static boolean ALLOW_SUM_PRODUCT_REWRITES2 = true;
 
        /**
         * Enables additional mmchain optimizations. in the future, this might 
be merged with
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java 
b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
index cd440a6bcf..03633d06a8 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java
@@ -126,7 +126,8 @@ public class ProgramRewriter{
                        }
                        if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES) {
                                _dagRuleSet.add( new 
RewriteMatrixMultChainOptimization()         ); //dependency: cse
-                               _dagRuleSet.add( new 
RewriteElementwiseMultChainOptimization()    ); //dependency: cse
+                               if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 )
+                                       _dagRuleSet.add( new 
RewriteElementwiseMultChainOptimization()); //dependency: cse
                        }
                        if(OptimizerUtils.ALLOW_ADVANCED_MMCHAIN_REWRITES){
                                _dagRuleSet.add( new 
RewriteMatrixMultChainOptimizationTranspose()      ); //dependency: cse
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 76691d6480..a18a2b7466 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -855,8 +855,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
        }
        
        /**
-        * (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
-        * (X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
+        * t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v)
+        * t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
         * 
         * Note: Restriction ba() at leaf and root instead of data at leaf to 
not reorganize too
         * eagerly, which would loose additional rewrite potential. This 
rewrite has two goals
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
index 105dfa8cbc..fb1bcc3630 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyBushyBinaryOperationTest.java
@@ -25,6 +25,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
 import org.junit.Test;
 
 import java.util.HashMap;
@@ -37,7 +38,7 @@ public class RewriteSimplifyBushyBinaryOperationTest extends 
AutomatedTestBase {
                TEST_DIR + 
RewriteSimplifyBushyBinaryOperationTest.class.getSimpleName() + "/";
 
        private static final int rows = 500;
-       private static final int cols = 500;
+       private static final int cols = 100;
        private static final double eps = Math.pow(10, -10);
 
        @Override
@@ -46,28 +47,28 @@ public class RewriteSimplifyBushyBinaryOperationTest 
extends AutomatedTestBase {
                addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
        }
 
+       //pattern: t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%((X*Y)*(Z%*%v))
        @Test
        public void testBushyBinaryOperationMultNoRewrite() {
                testSimplifyBushyBinaryOperation(1, false);
        }
 
        @Test
-       public void testBushyBinaryOperationMultRewrite() {     //pattern: 
(X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
+       public void testBushyBinaryOperationMultRewrite() { 
                testSimplifyBushyBinaryOperation(1, true);
        }
 
+       //pattern: t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
        @Test
        public void testBushyBinaryOperationAddNoRewrite() {
                testSimplifyBushyBinaryOperation(2, false);
        }
 
        @Test
-       public void testBushyBinaryOperationAddtRewrite() {     //pattern: 
(X+(Y+(Z%*%v))) -> (X+Y)+(Z%*%v)
+       public void testBushyBinaryOperationAddtRewrite() { 
                testSimplifyBushyBinaryOperation(2, true);
        }
 
-
-
        private void testSimplifyBushyBinaryOperation(int ID, boolean rewrites) 
{
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
                try {
@@ -76,19 +77,21 @@ public class RewriteSimplifyBushyBinaryOperationTest 
extends AutomatedTestBase {
 
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[] {"-stats", "-args", 
input("X"), input("Y"), input("Z"), input("v"), String.valueOf(ID), 
output("R")};
+                       programArgs = new String[] {"-stats", "-explain", 
"-args", 
+                               input("X"), input("Y"), input("Z"), input("v"), 
String.valueOf(ID), output("R")};
                        fullRScriptName = HOME + TEST_NAME + ".R";
                        rCmd = getRCmd(inputDir(), String.valueOf(ID), 
expectedDir());
 
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
-                       //OptimizerUtils.ALLOW_OPERATOR_FUSION = rewrites;
-                       //OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-
+                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = false; 
//disable nary mult
+                       OptimizerUtils.ALLOW_OPERATOR_FUSION = false; //disable 
emult reordering
+                       //TODO improved phase ordering 
+                       
                        //create matrices
-                       double[][] X = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 3);
-                       double[][] Y = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 5);
+                       double[][] X = getRandomMatrix(rows, 1, -1, 1, 0.60d, 
3);
+                       double[][] Y = getRandomMatrix(rows, 1, -1, 1, 0.60d, 
5);
                        double[][] Z = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 6);
-                       double[][] v = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 8);
+                       double[][] v = getRandomMatrix(cols, 1, -1, 1, 0.60d, 
8);
                        writeInputMatrixWithMTD("X", X, true);
                        writeInputMatrixWithMTD("Y", Y, true);
                        writeInputMatrixWithMTD("Z", Z, true);
@@ -101,15 +104,14 @@ public class RewriteSimplifyBushyBinaryOperationTest 
extends AutomatedTestBase {
                        HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
                        HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("R");
                        TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
-
-                       /**
-                        * The rewrite in RewriteAlgebraicSimplificationStatic 
is not entered. Hence, we fail
-                        * the assertions for this rewrite so that we can 
revisit this issue later.
-                        */
-                       //FIXME
+               
+                       if( ID == 1 && rewrites ) //check mmchain, enabled by 
bushy join 
+                               
Assert.assertTrue(heavyHittersContainsString("mmchain"));
                }
                finally {
                        OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+                       OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+                       OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES2 = true;
                        Recompiler.reinitRecompiler();
                }
        }

Reply via email to