simplifyDotProductSum shall not interfere with tak+* Added conditions to the dynamic algebraic rewrite simplifyDotProductSum that do not apply the optimization for (A^2)*B or B*(A^2), since TernaryAggregate handles these.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/a5846bbb Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/a5846bbb Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/a5846bbb Branch: refs/heads/master Commit: a5846bbb383c655189963bffefed1c0db4ffcc89 Parents: edbac3b Author: Dylan Hutchison <dhutc...@cs.washington.edu> Authored: Fri Jun 9 23:58:16 2017 -0700 Committer: Dylan Hutchison <dhutc...@cs.washington.edu> Committed: Sun Jun 18 17:43:30 2017 -0700 ---------------------------------------------------------------------- .../hops/rewrite/RewriteAlgebraicSimplificationDynamic.java | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/a5846bbb/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index ad80c05..166af2f 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2050,7 +2050,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule else if( HopRewriteUtils.isBinary(hi2, OpOp2.MULT, 1) //no other consumer than sum && hi2.getInput().get(0).getDim2()==1 && hi2.getInput().get(1).getDim2()==1 && !HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.MULT) - && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) ) + && !HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.MULT) + && !(HopRewriteUtils.isBinary(hi2.getInput().get(0), OpOp2.POW) // do not rewrite (A^2)*B + && hi2.getInput().get(0).getInput().get(1) instanceof LiteralOp // let tak+* handle it + && ((LiteralOp)hi2.getInput().get(0).getInput().get(1)).getLongValue() == 2) + && !(HopRewriteUtils.isBinary(hi2.getInput().get(1), OpOp2.POW) // do not rewrite B*(A^2) + && hi2.getInput().get(1).getInput().get(1) instanceof LiteralOp // let tak+* handle it + && ((LiteralOp)hi2.getInput().get(1).getInput().get(1)).getLongValue() == 2) + ) { baLeft = hi2.getInput().get(0); baRight = hi2.getInput().get(1);