[SYSTEMML-1990] Generalized ctable rewrites (seq-table, const inputs) This patch generalized the existing rewrite for table(seq(),X,...) to rexpand(X,...) to handle cases with unknown dimensions, including common scenarios with column indexing on X. Additionally, this patch also introduces a new rewrite for table with constant matrix inputs (i.e., table(X, matrix(7)) -> table(X,7)).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c9614324 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c9614324 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c9614324 Branch: refs/heads/master Commit: c96143248349b6c68253ef9b3777afd5e5ed62f2 Parents: d696862 Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Nov 9 16:31:58 2017 -0800 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu Nov 9 22:08:02 2017 -0800 ---------------------------------------------------------------------- .../sysml/hops/rewrite/HopRewriteUtils.java | 27 ++++++- .../RewriteAlgebraicSimplificationDynamic.java | 11 ++- .../RewriteAlgebraicSimplificationStatic.java | 22 +++++- .../misc/RewriteCTableToRExpandTest.java | 83 ++++++++++++++------ .../RewriteCTableToRExpandLeftUnknownPos.dml | 28 +++++++ .../RewriteCTableToRExpandRightUnknownPos.dml | 28 +++++++ 6 files changed, 167 insertions(+), 32 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java index 28b2189..66f4fc7 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java @@ -965,6 +965,15 @@ public class HopRewriteUtils || isLiteralOfValue(hop.getInput().get(1), val)); } + public static boolean isTernary(Hop hop, OpOp3 type) { + return hop instanceof TernaryOp && ((TernaryOp)hop).getOp()==type; + } + + public static boolean isTernary(Hop hop, OpOp3... types) { + return ( hop instanceof TernaryOp + && ArrayUtils.contains(types, ((TernaryOp) hop).getOp())); + } + public static boolean containsInput(Hop current, Hop probe) { return rContainsInput(current, probe, new HashSet<Long>()); } @@ -1052,6 +1061,15 @@ public class HopRewriteUtils return true; } + public static boolean isColumnRightIndexing(Hop hop) { + return hop instanceof IndexingOp + && ((IndexingOp) hop).isColLowerEqualsUpper() + && ((hop.dimsKnown() && hop.getDim1() == hop.getInput().get(0).getDim1()) + || (isLiteralOfValue(hop.getInput().get(1), 1) + && isUnary(hop.getInput().get(2), OpOp1.NROW) + && hop.getInput().get(2).getInput().get(0)==hop.getInput().get(0))); + } + public static boolean isFullColumnIndexing(LeftIndexingOp hop) { return hop.isColLowerEqualsUpper() && isLiteralOfValue(hop.getInput().get(2), 1) @@ -1112,9 +1130,7 @@ public class HopRewriteUtils Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO)); Hop incr = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR)); return isLiteralOfValue(from, 1) && isLiteralOfValue(incr, 1) - && (isLiteralOfValue(to, row?input.getDim1():input.getDim2()) - || (to instanceof UnaryOp && ((UnaryOp)to).getOp()==(row? - OpOp1.NROW:OpOp1.NCOL) && to.getInput().get(0)==input)); + && isSizeExpressionOf(to, input, row); } return false; } @@ -1149,6 +1165,11 @@ public class HopRewriteUtils throw new HopsException("Failed to retrieve 'to' argument from basic 1-N sequence."); } + public static boolean isSizeExpressionOf(Hop size, Hop input, boolean row) { + return (input.dimsKnown() && isLiteralOfValue(size, row?input.getDim1():input.getDim2())) + || ((row ? isUnary(size, OpOp1.NROW) : isUnary(size, OpOp1.NCOL)) && (size.getInput().get(0)==input + || (isColumnRightIndexing(input) && size.getInput().get(0)==input.getInput().get(0)))); + } public static boolean hasOnlyWriteParents( Hop hop, boolean inclTransient, boolean inclPersistent ) { http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index 0fa1aed..e07f97c 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -2540,15 +2540,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule //pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true) //note: this rewrite supports both left/right sequence if( hi instanceof TernaryOp && hi.getInput().size()==5 //table without weights - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) //i.e., weight of 1 - && hi.getInput().get(3) instanceof LiteralOp && hi.getInput().get(4) instanceof LiteralOp) + && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) ) //i.e., weight of 1 { Hop first = hi.getInput().get(0); Hop second = hi.getInput().get(1); //pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1) - if( HopRewriteUtils.isBasic1NSequence(first, second, true) && second.dimsKnown() - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), second.getDim1()) ) + if( HopRewriteUtils.isBasic1NSequence(first, second, true) + && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(3), second, true) ) { //setup input parameter hops HashMap<String,Hop> args = new HashMap<>(); @@ -2568,8 +2567,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule LOG.debug("Applied simplifyTableSeqExpand1 (line "+hi.getBeginLine()+")"); } //pattern b: table(v, seq(1,nrow(v)), m, nrow(v)) - else if( HopRewriteUtils.isBasic1NSequence(second, first, true) && first.dimsKnown() - && HopRewriteUtils.isLiteralOfValue(hi.getInput().get(4), first.getDim1()) ) + else if( HopRewriteUtils.isBasic1NSequence(second, first, true) + && HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(4), first, true) ) { //setup input parameter hops HashMap<String,Hop> args = new HashMap<>(); http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index 4c68fe2..cbfb527 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -152,6 +152,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule hi = foldMultipleAppendOperations(hi); //e.g., cbind(X,cbind(Y,Z)) -> cbind(X,Y,Z) hi = simplifyBinaryToUnaryOperation(hop, hi, i); //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X) hi = canonicalizeMatrixMultScalarAdd(hi); //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) + hi = simplifyCTableWithConstMatrixInputs(hi); //e.g., table(X, matrix(1,...)) -> table(X, 1) hi = simplifyReverseOperation(hop, hi, i); //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X) if(OptimizerUtils.ALLOW_OPERATOR_FUSION) hi = simplifyMultiBinaryToBinaryOperation(hi); //e.g., 1-X*Y -> X 1-* Y @@ -664,13 +665,32 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule { bop.setOp(OpOp2.PLUS); HopRewriteUtils.replaceChildReference(bop, right, - HopRewriteUtils.createBinaryMinus(right), 1); + HopRewriteUtils.createBinaryMinus(right), 1); LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+")."); } } return hi; } + + private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) + throws HopsException + { + //pattern: table(X, matrix(1,...), matrix(7, ...)) -> table(X, 1, 7) + if( HopRewriteUtils.isTernary(hi, OpOp3.CTABLE) ) { + //note: the first input always expected to be a matrix + for( int i=1; i<hi.getInput().size(); i++ ) { + Hop inCurr = hi.getInput().get(i); + if( HopRewriteUtils.isDataGenOpWithConstantValue(inCurr) ) { + Hop inNew = ((DataGenOp)inCurr).getInput(DataExpression.RAND_MIN); + HopRewriteUtils.replaceChildReference(hi, inCurr, inNew, i); + LOG.debug("Applied simplifyCTableWithConstMatrixInputs" + + i + " (line "+hi.getBeginLine()+")."); + } + } + } + return hi; + } /** * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java index b42a978..838fbb1 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java @@ -22,6 +22,7 @@ package org.apache.sysml.test.integration.functions.misc; import org.junit.Test; import org.junit.Assert; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.runtime.matrix.MatrixCharacteristics; import org.apache.sysml.test.integration.AutomatedTestBase; import org.apache.sysml.test.integration.TestConfiguration; @@ -33,6 +34,8 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase private static final String TEST_NAME2 = "RewriteCTableToRExpandRightPos"; private static final String TEST_NAME3 = "RewriteCTableToRExpandLeftNeg"; private static final String TEST_NAME4 = "RewriteCTableToRExpandRightNeg"; + private static final String TEST_NAME5 = "RewriteCTableToRExpandLeftUnknownPos"; + private static final String TEST_NAME6 = "RewriteCTableToRExpandRightUnknownPos"; private static final String TEST_DIR = "functions/misc/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewriteCTableToRExpandTest.class.getSimpleName() + "/"; @@ -52,6 +55,8 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) ); addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) ); addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) ); + addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) ); } @Test @@ -94,6 +99,25 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase testRewriteCTableRExpand( TEST_NAME4, CropType.PAD ); } + @Test + public void testRewriteCTableRExpandLeftUnknownDenseCrop() { + testRewriteCTableRExpand( TEST_NAME5, CropType.CROP ); + } + + @Test + public void testRewriteCTableRExpandLeftUnknownDensePad() { + testRewriteCTableRExpand( TEST_NAME5, CropType.PAD ); + } + + @Test + public void testRewriteCTableRExpandRightUnknownDenseCrop() { + testRewriteCTableRExpand( TEST_NAME6, CropType.CROP ); + } + + @Test + public void testRewriteCTableRExpandRightUnknownDensePad() { + testRewriteCTableRExpand( TEST_NAME6, CropType.PAD ); + } private void testRewriteCTableRExpand( String testname, CropType type ) { @@ -101,30 +125,45 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase loadTestConfiguration(config); int outDim = maxVal + ((type==CropType.CROP) ? -7 : 7); + boolean unknownTests = ( testname.equals(TEST_NAME5) || testname.equals(TEST_NAME6) ); + - String HOME = SCRIPT_DIR + TEST_DIR; - fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{ "-stats","-args", - input("A"), String.valueOf(outDim), output("R") }; - - fullRScriptName = HOME + testname + ".R"; - rCmd = getRCmd(inputDir(), String.valueOf(outDim), expectedDir()); - - double[][] A = getRandomMatrix(rows, 1, 1, 10, 1.0, 7); - writeInputMatrixWithMTD("A", A, false); - - //run performance tests - runTest(true, false, null, -1); + RUNTIME_PLATFORM platformOld = rtplatform; + if( unknownTests ) + rtplatform = RUNTIME_PLATFORM.SINGLE_NODE; - //compare output meta data - boolean left = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME3)); - boolean pos = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME2)); - int rrows = (left && pos) ? rows : outDim; - int rcols = (!left && pos) ? rows : outDim; - checkDMLMetaDataFile("R", new MatrixCharacteristics(rrows, rcols, 1, 1)); - - //check for applied rewrite - Assert.assertEquals(Boolean.valueOf(testname.equals(TEST_NAME1) || testname.equals(TEST_NAME2)), + try + { + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-explain","-stats","-args", + input("A"), String.valueOf(outDim), output("R") }; + + fullRScriptName = HOME + testname + ".R"; + rCmd = getRCmd(inputDir(), String.valueOf(outDim), expectedDir()); + + double[][] A = getRandomMatrix(rows, 1, 1, 10, 1.0, 7); + writeInputMatrixWithMTD("A", A, false); + + //run performance tests + runTest(true, false, null, -1); + + //compare output meta data + boolean left = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME3) + || testname.equals(TEST_NAME5) || testname.equals(TEST_NAME6)); + boolean pos = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME2)); + int rrows = (left && pos) ? rows : outDim; + int rcols = (!left && pos) ? rows : outDim; + if( !unknownTests ) + checkDMLMetaDataFile("R", new MatrixCharacteristics(rrows, rcols, 1, 1)); + + //check for applied rewrite + Assert.assertEquals(Boolean.valueOf(testname.equals(TEST_NAME1) + || testname.equals(TEST_NAME2) || unknownTests), Boolean.valueOf(heavyHittersContainsSubString("rexpand"))); + } + finally { + rtplatform = platformOld; + } } } http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml b/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml new file mode 100644 index 0000000..4b07462 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); + +T = matrix(1, nrow(A), 2); +A2 = rand(rows=sum(T)/2, cols=100, min=1, max=10); +R = table(seq(1,nrow(A2)), A2[,1], nrow(A2), $2); + +write(R, $3); http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml b/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml new file mode 100644 index 0000000..68d2860 --- /dev/null +++ b/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A = read($1); + +T = matrix(1, nrow(A), 2); +A2 = rand(rows=sum(T)/2, cols=100, min=1, max=10); +R = table(A2[,1], seq(1,nrow(A2)), $2, nrow(A2)); + +write(R, $3);