Repository: systemml Updated Branches: refs/heads/master db13eac1a -> 50dafa038
[SYSTEMML-1678] Fix rewrite 'fuse axpy binary ops' for outer products This patch fixes the dynamic simplification rewrite fuseAxpyBinaryOperationChain to not trigger on outer products of vectors and adds related negative test cases. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/50dafa03 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/50dafa03 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/50dafa03 Branch: refs/heads/master Commit: 50dafa038ff3282f327260f2d413bdfd907bfe04 Parents: db13eac Author: Matthias Boehm <mboe...@gmail.com> Authored: Mon Jun 26 22:33:23 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Mon Jun 26 22:33:23 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/ipa/FunctionCallSizeInfo.java | 15 ++++++++++ .../RewriteAlgebraicSimplificationDynamic.java | 2 +- .../misc/RewriteFuseBinaryOpChainTest.java | 24 +++++++++++++--- .../misc/RewriteFuseBinaryOpChainTest4.R | 30 ++++++++++++++++++++ .../misc/RewriteFuseBinaryOpChainTest4.dml | 29 +++++++++++++++++++ 5 files changed, 95 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java index fb668b5..9f76e32 100644 --- a/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java +++ b/src/main/java/org/apache/sysml/hops/ipa/FunctionCallSizeInfo.java @@ -345,6 +345,21 @@ public class FunctionCallSizeInfo sb.append("\n"); } + sb.append("Valid #non-zeros for propagation: \n"); + for( Entry<String, Set<Integer>> e : _fcandSafeNNZ.entrySet() ) { + sb.append("--"); + sb.append(e.getKey()); + sb.append(": "); + for( Integer pos : e.getValue() ) { + sb.append(pos); + sb.append(":"); + sb.append(_fgraph.getFunctionCalls(e.getKey()) + .get(0).getInput().get(pos).getName()); + sb.append(" "); + } + sb.append("\n"); + } + return sb.toString(); } } http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 91c5972..9681e44 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2131,7 +2131,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule private Hop fuseAxpyBinaryOperationChain(Hop parent, Hop hi, int pos) { //patterns: (a) X + s*Y -> X +* sY, (b) s*Y+X -> X +* sY, (c) X - s*Y -> X -* sY - if( hi instanceof BinaryOp + if( hi instanceof BinaryOp && !((BinaryOp) hi).isOuterVectorOperator() && (((BinaryOp)hi).getOp()==OpOp2.PLUS || ((BinaryOp)hi).getOp()==OpOp2.MINUS) ) { BinaryOp bop = (BinaryOp) hi; http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java index 4c21587..f1d2a6a 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFuseBinaryOpChainTest.java @@ -32,7 +32,6 @@ import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; import org.apache.sysml.test.utils.TestUtils; -import org.apache.sysml.utils.Statistics; /** * Regression test for function recompile-once issue with literal replacement. @@ -43,7 +42,8 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase private static final String TEST_NAME1 = "RewriteFuseBinaryOpChainTest1"; //+* (X+s*Y) private static final String TEST_NAME2 = "RewriteFuseBinaryOpChainTest2"; //-* (X-s*Y) private static final String TEST_NAME3 = "RewriteFuseBinaryOpChainTest3"; //+* (s*Y+X) - + private static final String TEST_NAME4 = "RewriteFuseBinaryOpChainTest4"; //outer(X, s*Y, "+") not applied + private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFuseBinaryOpChainTest.class.getSimpleName() + "/"; @@ -55,6 +55,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) ); addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); } @Test @@ -147,6 +148,18 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase testFuseBinaryChain( TEST_NAME3, true, ExecType.MR ); } + //negative tests + + @Test + public void testOuterBinaryPlusNoRewriteCP() { + testFuseBinaryChain( TEST_NAME4, false, ExecType.CP ); + } + + @Test + public void testOuterBinaryPlusRewriteCP() { + testFuseBinaryChain( TEST_NAME4, true, ExecType.CP); + } + /** * * @param testname @@ -182,7 +195,7 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase fullRScriptName = HOME + testname + ".R"; rCmd = getRCmd(inputDir(), expectedDir()); - runTest(true, false, null, -1); + runTest(true, false, null, -1); runRScript(true); //compare matrices @@ -199,7 +212,10 @@ public class RewriteFuseBinaryOpChainTest extends AutomatedTestBase prefix = Instruction.SP_INST_PREFIX; String opcode = (testname.equals(TEST_NAME1)||testname.equals(TEST_NAME3)) ? prefix+"+*" : prefix+"-*"; - Assert.assertTrue("Rewrite not applied.",Statistics.getCPHeavyHitterOpCodes().contains(opcode)); + if( testname.equals(TEST_NAME4) ) + Assert.assertFalse("Rewrite applied.", heavyHittersContainsSubString(opcode)); + else + Assert.assertTrue("Rewrite not applied.", heavyHittersContainsSubString(opcode)); } } finally http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R new file mode 100644 index 0000000..7e9a392 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.R @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args<-commandArgs(TRUE) +options(digits=22) +library("Matrix") + +X = matrix(1, 10, 1); +Y = matrix(2, 1, 10); +lambda = 7; + +S = outer(as.vector(X), as.vector(lambda*Y), "+"); + +writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); http://git-wip-us.apache.org/repos/asf/systemml/blob/50dafa03/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml new file mode 100644 index 0000000..0599f02 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteFuseBinaryOpChainTest4.dml @@ -0,0 +1,29 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix(1, 10, 1); +Y = matrix(2, 1, 10); +lambda = 7; +if(1==1){} + +S = outer(X, lambda*Y, "+"); + +write(S,$1);