This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new fce8bea2ca [SYSTEMDS-3712] New rewrite for pull-up abs over binary mult
fce8bea2ca is described below

commit fce8bea2ca3489a272bdd49cd95ebdf0359fb264
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sat Jun 29 20:09:30 2024 +0200

    [SYSTEMDS-3712] New rewrite for pull-up abs over binary mult
    
    This patch adds a simple new rewrite 'abs(X) * abs(Y) --> abs(X*Y)'
    as well as its tests.
---
 .../apache/sysds/hops/rewrite/HopRewriteRule.java  |  2 +-
 .../RewriteAlgebraicSimplificationStatic.java      | 20 +++++
 .../functions/rewrite/RewritePullupAbsTest.java    | 87 ++++++++++++++++++++++
 .../scripts/functions/rewrite/RewritePullupAbs.dml | 30 ++++++++
 4 files changed, 138 insertions(+), 1 deletion(-)

diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java
index 10b5e6fb72..e204510b81 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java
@@ -33,7 +33,7 @@ import org.apache.sysds.hops.Hop;
 public abstract class HopRewriteRule 
 {
        protected static final Log LOG = 
LogFactory.getLog(HopRewriteRule.class.getName());
-               
+       
        /**
         * Handle a generic (last-level) hop DAG with multiple roots.
         * 
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 8fed2481ed..76691d6480 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -164,6 +164,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = pushdownUnaryAggTransposeOperation(hop, hi, i); 
//e.g., colSums(t(X)) -> t(rowSums(X))
                        hi = pushdownCSETransposeScalarOperation(hop, hi, 
i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
                        hi = pushdownSumBinaryMult(hop, hi, i);              
//e.g., sum(lambda*X) -> lambda*sum(X)
+                       hi = pullupAbs(hop, hi, i);                          
//e.g., abs(X)*abs(Y) --> abs(X*Y)
                        hi = simplifyUnaryPPredOperation(hop, hi, i);        
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
                        hi = simplifyTransposedAppend(hop, hi, i);           
//e.g., t(cbind(t(A),t(B))) -> rbind(A,B);
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
@@ -1122,6 +1123,25 @@ public class RewriteAlgebraicSimplificationStatic 
extends HopRewriteRule
                return hi;
        }
 
+       private static Hop pullupAbs(Hop parent, Hop hi, int pos ) {
+               if( HopRewriteUtils.isBinary(hi, OpOp2.MULT)
+                       && HopRewriteUtils.isUnary(hi.getInput(0), OpOp1.ABS) 
+                       && hi.getInput(0).getParent().size()==1
+                       && HopRewriteUtils.isUnary(hi.getInput(1), OpOp1.ABS) 
+                       && hi.getInput(1).getParent().size()==1)
+               {
+                       Hop operand1 = hi.getInput(0).getInput(0);
+                       Hop operand2 = hi.getInput(1).getInput(0);
+                       Hop bop = HopRewriteUtils.createBinary(operand1, 
operand2, OpOp2.MULT);
+                       Hop uop = HopRewriteUtils.createUnary(bop, OpOp1.ABS);
+                       HopRewriteUtils.replaceChildReference(parent, hi, uop, 
pos);
+                       
+                       LOG.debug("Applied pullupAbs (line 
"+hi.getBeginLine()+").");
+                       return uop;
+               }
+               return hi;
+       }
+       
        private static Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int 
pos )
        {
                if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX  
//unaryop
diff --git 
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePullupAbsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePullupAbsTest.java
new file mode 100644
index 0000000000..4fe76b7634
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePullupAbsTest.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+public class RewritePullupAbsTest extends AutomatedTestBase 
+{
+       private static final String TEST_NAME1 = "RewritePullupAbs";
+       private static final String TEST_DIR = "functions/rewrite/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RewritePullupAbsTest.class.getSimpleName() + "/";
+       
+       private static final int rows = 1000;
+       private static final int cols = 1;
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R1"}) );
+       }
+
+       @Test
+       public void testNoRewrite() {
+               testRewritePullupAbs( TEST_NAME1, false );
+       }
+       
+       @Test
+       public void testRewrite() {
+               testRewritePullupAbs( TEST_NAME1, true );
+       }
+       
+       private void testRewritePullupAbs( String testname, boolean rewrites )
+       {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               
+               try {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ "-stats", "-explain", 
"-args",
+                               String.valueOf(rows), String.valueOf(cols), 
output("R1") };
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = 
rewrites;
+
+                       //run performance tests
+                       runTest(true, false, null, -1);
+                       
+                       //compare matrices 
+                       long expect = Math.round(7*rows);
+                       HashMap<CellIndex, Double> dmlfile1 = 
readDMLScalarFromOutputDir("R1");
+                       Assert.assertEquals(expect, dmlfile1.get(new 
CellIndex(1,1)), 1e-8);
+                       //check rewrite application
+                       int expect2 = rewrites ? 1 : 2;
+                       Assert.assertEquals(expect2, 
Statistics.getCPHeavyHitterCount("abs"));
+               }
+               finally {
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/rewrite/RewritePullupAbs.dml 
b/src/test/scripts/functions/rewrite/RewritePullupAbs.dml
new file mode 100644
index 0000000000..e3a10d94ec
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewritePullupAbs.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X1 = matrix(-1, $1, $2);
+X2 = -7 * X1;
+while(FALSE){}
+R0 = abs(X1) * abs(X2);
+while(FALSE){}
+R1 = sum(R0)
+write(R1, $3);
+

Reply via email to