This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 51b53c397e [SYSTEMDS-3836] Fix distributive binary ops rewrite for 
broadcasting
51b53c397e is described below

commit 51b53c397e8edd9e35cb4cb6c6dbc44c78f53a10
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 15 14:47:04 2025 +0100

    [SYSTEMDS-3836] Fix distributive binary ops rewrite for broadcasting
    
    The distributive binary operation rewrite, transforms the pattern
    X-X*Y into (1-Y)*X but was so far not aware of broadcasting semantics.
    If Y is a row or column vector but X a matrix, the rewrite yields
    mismatching dimension exceptions during runtime. We now simply rewrite
    the pattern to X*(1-Y) if Y is indeed a vector and X is not. Always
    rewriting the pattern to the latter cause the mmchain rewrite to
    no longer trigger (which is crucial for many end-to-end algorithms).
    
    The tests have, however, also shown that for the multiLogReg test
    we are not compiling mmchain (independent of this rewrite change)
    something that needs fixing before the release.
---
 .../rewrite/RewriteAlgebraicSimplificationStatic.java     | 15 +++++++++++----
 .../RewriteSimplifyDistributiveBinaryOperationTest.java   |  7 ++++++-
 .../rewrite/RewriteSimplifyDistributiveBinaryOperation.R  |  2 ++
 .../RewriteSimplifyDistributiveBinaryOperation.dml        |  3 +++
 4 files changed, 22 insertions(+), 5 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index de14b5c5ec..5d867bf0ff 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -886,10 +886,14 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                X = right;
                                                Y = ( right == leftC1 ) ? 
leftC2 : leftC1;
                                        }
-                                       if( X != null ){ //rewrite 'binary +/-' 
+                                       if( X != null && Y.dimsKnown() ){ 
//rewrite 'binary +/-' 
                                                LiteralOp literal = new 
LiteralOp(1);
                                                BinaryOp plus = 
HopRewriteUtils.createBinary(Y, literal, bop.getOp());
-                                               BinaryOp mult = 
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
+                                               
+                                               BinaryOp mult = 
(plus.getDim1()==1 || plus.getDim2() == 1)
+                                                               && 
(X.getDim1()>1 && X.getDim2()>1) ?
+                                                       
HopRewriteUtils.createBinary(X, plus, OpOp2.MULT) :
+                                                       
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
                                                
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(hi, left);
                                                hi = mult;
@@ -908,10 +912,13 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                                                X = left;
                                                Y = ( left == rightC1 ) ? 
rightC2 : rightC1;
                                        }
-                                       if( X != null ){ //rewrite '+/- binary'
+                                       if( X != null && Y.dimsKnown() ){ 
//rewrite '+/- binary'
                                                LiteralOp literal = new 
LiteralOp(1);
                                                BinaryOp plus = 
HopRewriteUtils.createBinary(literal, Y, bop.getOp());
-                                               BinaryOp mult = 
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
+                                               BinaryOp mult = 
(plus.getDim1()==1 || plus.getDim2() == 1) 
+                                                               && 
(X.getDim1()>1 && X.getDim2()>1) ?
+                                                       
HopRewriteUtils.createBinary(X, plus, OpOp2.MULT) :
+                                                       
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
                                                
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
                                                
HopRewriteUtils.cleanupUnreferenced(hi, right);
                                                hi = mult;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
index f130404989..795db7421b 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
@@ -86,6 +86,11 @@ public class RewriteSimplifyDistributiveBinaryOperationTest 
extends AutomatedTes
        public void testDistrBinaryOpMultAddRewrite() {
                testSimplifyDistributiveBinaryOperation(4, true);    //pattern: 
(Y*X+X) -> (Y+1)*X
        }
+       
+       @Test
+       public void testDistrBinaryOpMultMinusVectorRewrite() {
+               testSimplifyDistributiveBinaryOperation(5, true);    //pattern: 
(X*Y-X) -> (Y+1)*X, Y vector
+       }
 
        private void testSimplifyDistributiveBinaryOperation(int ID, boolean 
rewrites) {
                boolean oldFlag1 = 
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -104,7 +109,7 @@ public class RewriteSimplifyDistributiveBinaryOperationTest 
extends AutomatedTes
 
                        //create matrices
                        double[][] X = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 3);
-                       double[][] Y = getRandomMatrix(rows, cols, -1, 1, 
0.60d, 5);
+                       double[][] Y = getRandomMatrix(rows, ID==5?1:cols, -1, 
1, 0.60d, 5);
                        writeInputMatrixWithMTD("X", X, true);
                        writeInputMatrixWithMTD("Y", Y, true);
 
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
 
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
index 0c544da5f9..1d21848b06 100644
--- 
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
+++ 
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
@@ -44,6 +44,8 @@ if( type == 1 ) {
     R = (X+Y*X)
 } else if( type == 4 ) {
     R = (Y*X+X)
+} else if( type == 5 ) {
+    R = (X*(Y%*%matrix(1,1,ncol(X)))-X) * 1
 }
 
 
diff --git 
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
 
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
index 943b2f444b..bd93104d24 100644
--- 
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
+++ 
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
@@ -38,6 +38,9 @@ else if( type == 3 ) {
 else if( type == 4 ) {
     R = (Y*X+X) * 1
 }
+else if( type == 5 ) {
+    R = (X*Y-X) * 1
+}
 
 
 # Write the result matrix R

Reply via email to