This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new fce8bea2ca [SYSTEMDS-3712] New rewrite for pull-up abs over binary mult fce8bea2ca is described below commit fce8bea2ca3489a272bdd49cd95ebdf0359fb264 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sat Jun 29 20:09:30 2024 +0200 [SYSTEMDS-3712] New rewrite for pull-up abs over binary mult This patch adds a simple new rewrite 'abs(X) * abs(Y) --> abs(X*Y)' as well as its tests. --- .../apache/sysds/hops/rewrite/HopRewriteRule.java | 2 +- .../RewriteAlgebraicSimplificationStatic.java | 20 +++++ .../functions/rewrite/RewritePullupAbsTest.java | 87 ++++++++++++++++++++++ .../scripts/functions/rewrite/RewritePullupAbs.dml | 30 ++++++++ 4 files changed, 138 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java index 10b5e6fb72..e204510b81 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteRule.java @@ -33,7 +33,7 @@ import org.apache.sysds.hops.Hop; public abstract class HopRewriteRule { protected static final Log LOG = LogFactory.getLog(HopRewriteRule.class.getName()); - + /** * Handle a generic (last-level) hop DAG with multiple roots. * diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 8fed2481ed..76691d6480 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -164,6 +164,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = pushdownUnaryAggTransposeOperation(hop, hi, i); //e.g., colSums(t(X)) -> t(rowSums(X)) hi = pushdownCSETransposeScalarOperation(hop, hi, i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X) hi = pushdownSumBinaryMult(hop, hi, i); //e.g., sum(lambda*X) -> lambda*sum(X) + hi = pullupAbs(hop, hi, i); //e.g., abs(X)*abs(Y) --> abs(X*Y) hi = simplifyUnaryPPredOperation(hop, hi, i); //e.g., abs(ppred()) -> ppred(), others: round, ceil, floor hi = simplifyTransposedAppend(hop, hi, i); //e.g., t(cbind(t(A),t(B))) -> rbind(A,B); if(OptimizerUtils.ALLOW_OPERATOR_FUSION) @@ -1122,6 +1123,25 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule return hi; } + private static Hop pullupAbs(Hop parent, Hop hi, int pos ) { + if( HopRewriteUtils.isBinary(hi, OpOp2.MULT) + && HopRewriteUtils.isUnary(hi.getInput(0), OpOp1.ABS) + && hi.getInput(0).getParent().size()==1 + && HopRewriteUtils.isUnary(hi.getInput(1), OpOp1.ABS) + && hi.getInput(1).getParent().size()==1) + { + Hop operand1 = hi.getInput(0).getInput(0); + Hop operand2 = hi.getInput(1).getInput(0); + Hop bop = HopRewriteUtils.createBinary(operand1, operand2, OpOp2.MULT); + Hop uop = HopRewriteUtils.createUnary(bop, OpOp1.ABS); + HopRewriteUtils.replaceChildReference(parent, hi, uop, pos); + + LOG.debug("Applied pullupAbs (line "+hi.getBeginLine()+")."); + return uop; + } + return hi; + } + private static Hop simplifyUnaryPPredOperation( Hop parent, Hop hi, int pos ) { if( hi instanceof UnaryOp && hi.getDataType()==DataType.MATRIX //unaryop diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePullupAbsTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePullupAbsTest.java new file mode 100644 index 0000000000..4fe76b7634 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePullupAbsTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; + +public class RewritePullupAbsTest extends AutomatedTestBase +{ + private static final String TEST_NAME1 = "RewritePullupAbs"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + RewritePullupAbsTest.class.getSimpleName() + "/"; + + private static final int rows = 1000; + private static final int cols = 1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R1"}) ); + } + + @Test + public void testNoRewrite() { + testRewritePullupAbs( TEST_NAME1, false ); + } + + @Test + public void testRewrite() { + testRewritePullupAbs( TEST_NAME1, true ); + } + + private void testRewritePullupAbs( String testname, boolean rewrites ) + { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-stats", "-explain", "-args", + String.valueOf(rows), String.valueOf(cols), output("R1") }; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + //run performance tests + runTest(true, false, null, -1); + + //compare matrices + long expect = Math.round(7*rows); + HashMap<CellIndex, Double> dmlfile1 = readDMLScalarFromOutputDir("R1"); + Assert.assertEquals(expect, dmlfile1.get(new CellIndex(1,1)), 1e-8); + //check rewrite application + int expect2 = rewrites ? 1 : 2; + Assert.assertEquals(expect2, Statistics.getCPHeavyHitterCount("abs")); + } + finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewritePullupAbs.dml b/src/test/scripts/functions/rewrite/RewritePullupAbs.dml new file mode 100644 index 0000000000..e3a10d94ec --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewritePullupAbs.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +X1 = matrix(-1, $1, $2); +X2 = -7 * X1; +while(FALSE){} +R0 = abs(X1) * abs(X2); +while(FALSE){} +R1 = sum(R0) +write(R1, $3); +