This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 51b53c397e [SYSTEMDS-3836] Fix distributive binary ops rewrite for
broadcasting
51b53c397e is described below
commit 51b53c397e8edd9e35cb4cb6c6dbc44c78f53a10
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Feb 15 14:47:04 2025 +0100
[SYSTEMDS-3836] Fix distributive binary ops rewrite for broadcasting
The distributive binary operation rewrite, transforms the pattern
X-X*Y into (1-Y)*X but was so far not aware of broadcasting semantics.
If Y is a row or column vector but X a matrix, the rewrite yields
mismatching dimension exceptions during runtime. We now simply rewrite
the pattern to X*(1-Y) if Y is indeed a vector and X is not. Always
rewriting the pattern to the latter cause the mmchain rewrite to
no longer trigger (which is crucial for many end-to-end algorithms).
The tests have, however, also shown that for the multiLogReg test
we are not compiling mmchain (independent of this rewrite change)
something that needs fixing before the release.
---
.../rewrite/RewriteAlgebraicSimplificationStatic.java | 15 +++++++++++----
.../RewriteSimplifyDistributiveBinaryOperationTest.java | 7 ++++++-
.../rewrite/RewriteSimplifyDistributiveBinaryOperation.R | 2 ++
.../RewriteSimplifyDistributiveBinaryOperation.dml | 3 +++
4 files changed, 22 insertions(+), 5 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index de14b5c5ec..5d867bf0ff 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -886,10 +886,14 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
X = right;
Y = ( right == leftC1 ) ?
leftC2 : leftC1;
}
- if( X != null ){ //rewrite 'binary +/-'
+ if( X != null && Y.dimsKnown() ){
//rewrite 'binary +/-'
LiteralOp literal = new
LiteralOp(1);
BinaryOp plus =
HopRewriteUtils.createBinary(Y, literal, bop.getOp());
- BinaryOp mult =
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
+
+ BinaryOp mult =
(plus.getDim1()==1 || plus.getDim2() == 1)
+ &&
(X.getDim1()>1 && X.getDim2()>1) ?
+
HopRewriteUtils.createBinary(X, plus, OpOp2.MULT) :
+
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, left);
hi = mult;
@@ -908,10 +912,13 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
X = left;
Y = ( left == rightC1 ) ?
rightC2 : rightC1;
}
- if( X != null ){ //rewrite '+/- binary'
+ if( X != null && Y.dimsKnown() ){
//rewrite '+/- binary'
LiteralOp literal = new
LiteralOp(1);
BinaryOp plus =
HopRewriteUtils.createBinary(literal, Y, bop.getOp());
- BinaryOp mult =
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
+ BinaryOp mult =
(plus.getDim1()==1 || plus.getDim2() == 1)
+ &&
(X.getDim1()>1 && X.getDim2()>1) ?
+
HopRewriteUtils.createBinary(X, plus, OpOp2.MULT) :
+
HopRewriteUtils.createBinary(plus, X, OpOp2.MULT);
HopRewriteUtils.replaceChildReference(parent, hi, mult, pos);
HopRewriteUtils.cleanupUnreferenced(hi, right);
hi = mult;
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
index f130404989..795db7421b 100644
---
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyDistributiveBinaryOperationTest.java
@@ -86,6 +86,11 @@ public class RewriteSimplifyDistributiveBinaryOperationTest
extends AutomatedTes
public void testDistrBinaryOpMultAddRewrite() {
testSimplifyDistributiveBinaryOperation(4, true); //pattern:
(Y*X+X) -> (Y+1)*X
}
+
+ @Test
+ public void testDistrBinaryOpMultMinusVectorRewrite() {
+ testSimplifyDistributiveBinaryOperation(5, true); //pattern:
(X*Y-X) -> (Y+1)*X, Y vector
+ }
private void testSimplifyDistributiveBinaryOperation(int ID, boolean
rewrites) {
boolean oldFlag1 =
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -104,7 +109,7 @@ public class RewriteSimplifyDistributiveBinaryOperationTest
extends AutomatedTes
//create matrices
double[][] X = getRandomMatrix(rows, cols, -1, 1,
0.60d, 3);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1,
0.60d, 5);
+ double[][] Y = getRandomMatrix(rows, ID==5?1:cols, -1,
1, 0.60d, 5);
writeInputMatrixWithMTD("X", X, true);
writeInputMatrixWithMTD("Y", Y, true);
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
index 0c544da5f9..1d21848b06 100644
---
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
+++
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.R
@@ -44,6 +44,8 @@ if( type == 1 ) {
R = (X+Y*X)
} else if( type == 4 ) {
R = (Y*X+X)
+} else if( type == 5 ) {
+ R = (X*(Y%*%matrix(1,1,ncol(X)))-X) * 1
}
diff --git
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
index 943b2f444b..bd93104d24 100644
---
a/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
+++
b/src/test/scripts/functions/rewrite/RewriteSimplifyDistributiveBinaryOperation.dml
@@ -38,6 +38,9 @@ else if( type == 3 ) {
else if( type == 4 ) {
R = (Y*X+X) * 1
}
+else if( type == 5 ) {
+ R = (X*Y-X) * 1
+}
# Write the result matrix R