This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 2972d6df5f2453e091343b59708343a4c562f185
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon May 20 19:22:42 2024 +0200

    [MINOR] Fix simplification rewrite binary ops (robustness for strings)
---
 .../rewrite/RewriteAlgebraicSimplificationStatic.java     | 15 ++++++++++-----
 1 file changed, 10 insertions(+), 5 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index f1065ea832..8fed2481ed 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -279,7 +279,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        Hop right = bop.getInput().get(1);
                        //X/1 or X*1 -> X 
                        if(    left.getDataType()==DataType.MATRIX 
-                               && right instanceof LiteralOp && 
((LiteralOp)right).getDoubleValue()==1.0 )
+                               && right instanceof LiteralOp && 
right.getValueType().isNumeric()
+                               && ((LiteralOp)right).getDoubleValue()==1.0 )
                        {
                                if( bop.getOp()==OpOp2.DIV || 
bop.getOp()==OpOp2.MULT )
                                {
@@ -291,7 +292,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        }
                        //X-0 -> X 
                        else if(    left.getDataType()==DataType.MATRIX 
-                                       && right instanceof LiteralOp && 
((LiteralOp)right).getDoubleValue()==0.0 )
+                                       && right instanceof LiteralOp && 
right.getValueType().isNumeric()
+                                       && 
((LiteralOp)right).getDoubleValue()==0.0 )
                        {
                                if( bop.getOp()==OpOp2.MINUS )
                                {
@@ -303,7 +305,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        }
                        //1*X -> X
                        else if(   right.getDataType()==DataType.MATRIX 
-                                       && left instanceof LiteralOp && 
((LiteralOp)left).getDoubleValue()==1.0 )
+                                       && left instanceof LiteralOp && 
left.getValueType().isNumeric()
+                                       && 
((LiteralOp)left).getDoubleValue()==1.0 )
                        {
                                if( bop.getOp()==OpOp2.MULT )
                                {
@@ -317,7 +320,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        //note: this rewrite is necessary since the new antlr 
parser always converts 
                        //-X to -1*X due to mechanical reasons
                        else if(   right.getDataType()==DataType.MATRIX 
-                                       && left instanceof LiteralOp && 
((LiteralOp)left).getDoubleValue()==-1.0 )
+                                       && left instanceof LiteralOp && 
left.getValueType().isNumeric()
+                                       && 
((LiteralOp)left).getDoubleValue()==-1.0 )
                        {
                                if( bop.getOp()==OpOp2.MULT )
                                {
@@ -330,7 +334,8 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        }
                        //X*-1 -> -X (see comment above)
                        else if(   left.getDataType()==DataType.MATRIX 
-                                       && right instanceof LiteralOp && 
((LiteralOp)right).getDoubleValue()==-1.0 )
+                                       && right instanceof LiteralOp && 
right.getValueType().isNumeric()
+                                       && 
((LiteralOp)right).getDoubleValue()==-1.0 )
                        {
                                if( bop.getOp()==OpOp2.MULT )
                                {

Reply via email to