[SYSTEMML-1990] Generalized ctable rewrites (seq-table, const inputs)

This patch generalized the existing rewrite for table(seq(),X,...) to
rexpand(X,...) to handle cases with unknown dimensions, including common
scenarios with column indexing on X. Additionally, this patch also
introduces a new rewrite for table with constant matrix inputs (i.e.,
table(X, matrix(7)) -> table(X,7)).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c9614324
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c9614324
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c9614324

Branch: refs/heads/master
Commit: c96143248349b6c68253ef9b3777afd5e5ed62f2
Parents: d696862
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Nov 9 16:31:58 2017 -0800
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Thu Nov 9 22:08:02 2017 -0800

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 27 ++++++-
 .../RewriteAlgebraicSimplificationDynamic.java  | 11 ++-
 .../RewriteAlgebraicSimplificationStatic.java   | 22 +++++-
 .../misc/RewriteCTableToRExpandTest.java        | 83 ++++++++++++++------
 .../RewriteCTableToRExpandLeftUnknownPos.dml    | 28 +++++++
 .../RewriteCTableToRExpandRightUnknownPos.dml   | 28 +++++++
 6 files changed, 167 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 28b2189..66f4fc7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -965,6 +965,15 @@ public class HopRewriteUtils
                        || isLiteralOfValue(hop.getInput().get(1), val));
        }
        
+       public static boolean isTernary(Hop hop, OpOp3 type) {
+               return hop instanceof TernaryOp && 
((TernaryOp)hop).getOp()==type;
+       }
+       
+       public static boolean isTernary(Hop hop, OpOp3... types) {
+               return ( hop instanceof TernaryOp 
+                       && ArrayUtils.contains(types, ((TernaryOp) 
hop).getOp()));
+       }
+       
        public static boolean containsInput(Hop current, Hop probe) {
                return rContainsInput(current, probe, new HashSet<Long>());     
        }
@@ -1052,6 +1061,15 @@ public class HopRewriteUtils
                return true;
        }
        
+       public static boolean isColumnRightIndexing(Hop hop) {
+               return hop instanceof IndexingOp
+                       && ((IndexingOp) hop).isColLowerEqualsUpper()
+                       && ((hop.dimsKnown() && hop.getDim1() == 
hop.getInput().get(0).getDim1())
+                       || (isLiteralOfValue(hop.getInput().get(1), 1) 
+                               && isUnary(hop.getInput().get(2), OpOp1.NROW) 
+                               && 
hop.getInput().get(2).getInput().get(0)==hop.getInput().get(0)));
+       }
+       
        public static boolean isFullColumnIndexing(LeftIndexingOp hop) {
                return hop.isColLowerEqualsUpper()
                        && isLiteralOfValue(hop.getInput().get(2), 1)
@@ -1112,9 +1130,7 @@ public class HopRewriteUtils
                        Hop to = 
dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO));
                        Hop incr = 
dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR));
                        return isLiteralOfValue(from, 1) && 
isLiteralOfValue(incr, 1)
-                               && (isLiteralOfValue(to, 
row?input.getDim1():input.getDim2())
-                                       || (to instanceof UnaryOp && 
((UnaryOp)to).getOp()==(row?
-                                               OpOp1.NROW:OpOp1.NCOL) && 
to.getInput().get(0)==input));
+                               && isSizeExpressionOf(to, input, row);
                }
                return false;
        }
@@ -1149,6 +1165,11 @@ public class HopRewriteUtils
                throw new HopsException("Failed to retrieve 'to' argument from 
basic 1-N sequence.");
        }
        
+       public static boolean isSizeExpressionOf(Hop size, Hop input, boolean 
row) {
+               return (input.dimsKnown() && isLiteralOfValue(size, 
row?input.getDim1():input.getDim2()))
+                       || ((row ? isUnary(size, OpOp1.NROW) : isUnary(size, 
OpOp1.NCOL)) && (size.getInput().get(0)==input 
+                       || (isColumnRightIndexing(input) && 
size.getInput().get(0)==input.getInput().get(0))));
+       }
        
        public static boolean hasOnlyWriteParents( Hop hop, boolean 
inclTransient, boolean inclPersistent )
        {

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 0fa1aed..e07f97c 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2540,15 +2540,14 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                //pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, 
max=m, dir=row, ignore=false, cast=true)
                //note: this rewrite supports both left/right sequence 
                if(    hi instanceof TernaryOp && hi.getInput().size()==5 
//table without weights 
-                       && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) //i.e., weight of 1
-                       && hi.getInput().get(3) instanceof LiteralOp && 
hi.getInput().get(4) instanceof LiteralOp)
+                       && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) ) //i.e., weight of 1
                {
                        Hop first = hi.getInput().get(0);
                        Hop second = hi.getInput().get(1);
                        
                        //pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1)
-                       if( HopRewriteUtils.isBasic1NSequence(first, second, 
true) && second.dimsKnown() 
-                               && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), second.getDim1()) )
+                       if( HopRewriteUtils.isBasic1NSequence(first, second, 
true) 
+                               && 
HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(3), second, true) )
                        {
                                //setup input parameter hops
                                HashMap<String,Hop> args = new HashMap<>();
@@ -2568,8 +2567,8 @@ public class RewriteAlgebraicSimplificationDynamic 
extends HopRewriteRule
                                LOG.debug("Applied simplifyTableSeqExpand1 
(line "+hi.getBeginLine()+")");      
                        }
                        //pattern b: table(v, seq(1,nrow(v)), m, nrow(v))
-                       else if( HopRewriteUtils.isBasic1NSequence(second, 
first, true) && first.dimsKnown() 
-                               && 
HopRewriteUtils.isLiteralOfValue(hi.getInput().get(4), first.getDim1()) )
+                       else if( HopRewriteUtils.isBasic1NSequence(second, 
first, true)
+                               && 
HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(4), first, true) )
                        {
                                //setup input parameter hops
                                HashMap<String,Hop> args = new HashMap<>();

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 4c68fe2..cbfb527 100644
--- 
a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ 
b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -152,6 +152,7 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        hi = foldMultipleAppendOperations(hi);               
//e.g., cbind(X,cbind(Y,Z)) -> cbind(X,Y,Z)
                        hi = simplifyBinaryToUnaryOperation(hop, hi, i);     
//e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X)
                        hi = canonicalizeMatrixMultScalarAdd(hi);            
//e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) 
+                       hi = simplifyCTableWithConstMatrixInputs(hi);        
//e.g., table(X, matrix(1,...)) -> table(X, 1)
                        hi = simplifyReverseOperation(hop, hi, i);           
//e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
                        if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
                                hi = simplifyMultiBinaryToBinaryOperation(hi);  
     //e.g., 1-X*Y -> X 1-* Y
@@ -664,13 +665,32 @@ public class RewriteAlgebraicSimplificationStatic extends 
HopRewriteRule
                        {
                                bop.setOp(OpOp2.PLUS);
                                HopRewriteUtils.replaceChildReference(bop,  
right,
-                                               
HopRewriteUtils.createBinaryMinus(right), 1);                           
+                                               
HopRewriteUtils.createBinaryMinus(right), 1);
                                LOG.debug("Applied 
canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+").");
                        }
                }
                
                return hi;
        }
+       
+       private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) 
+               throws HopsException
+       {
+               //pattern: table(X, matrix(1,...), matrix(7, ...)) -> table(X, 
1, 7)
+               if( HopRewriteUtils.isTernary(hi, OpOp3.CTABLE) ) {
+                       //note: the first input always expected to be a matrix
+                       for( int i=1; i<hi.getInput().size(); i++ ) {
+                               Hop inCurr = hi.getInput().get(i);
+                               if( 
HopRewriteUtils.isDataGenOpWithConstantValue(inCurr) ) {
+                                       Hop inNew = 
((DataGenOp)inCurr).getInput(DataExpression.RAND_MIN);
+                                       
HopRewriteUtils.replaceChildReference(hi, inCurr, inNew, i);
+                                       LOG.debug("Applied 
simplifyCTableWithConstMatrixInputs"
+                                               + i + " (line 
"+hi.getBeginLine()+").");
+                               }
+                       }
+               }
+               return hi;
+       }
 
        /**
         * NOTE: this would be by definition a dynamic rewrite; however, we 
apply it as a static

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
index b42a978..838fbb1 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
@@ -22,6 +22,7 @@ package org.apache.sysml.test.integration.functions.misc;
 import org.junit.Test;
 
 import org.junit.Assert;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
@@ -33,6 +34,8 @@ public class RewriteCTableToRExpandTest extends 
AutomatedTestBase
        private static final String TEST_NAME2 = 
"RewriteCTableToRExpandRightPos"; 
        private static final String TEST_NAME3 = 
"RewriteCTableToRExpandLeftNeg"; 
        private static final String TEST_NAME4 = 
"RewriteCTableToRExpandRightNeg"; 
+       private static final String TEST_NAME5 = 
"RewriteCTableToRExpandLeftUnknownPos";
+       private static final String TEST_NAME6 = 
"RewriteCTableToRExpandRightUnknownPos";
        
        private static final String TEST_DIR = "functions/misc/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RewriteCTableToRExpandTest.class.getSimpleName() + "/";
@@ -52,6 +55,8 @@ public class RewriteCTableToRExpandTest extends 
AutomatedTestBase
                addTestConfiguration( TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
                addTestConfiguration( TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME5, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
+               addTestConfiguration( TEST_NAME6, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
        }
 
        @Test
@@ -94,6 +99,25 @@ public class RewriteCTableToRExpandTest extends 
AutomatedTestBase
                testRewriteCTableRExpand( TEST_NAME4, CropType.PAD );
        }
        
+       @Test
+       public void testRewriteCTableRExpandLeftUnknownDenseCrop()  {
+               testRewriteCTableRExpand( TEST_NAME5, CropType.CROP );
+       }
+       
+       @Test
+       public void testRewriteCTableRExpandLeftUnknownDensePad()  {
+               testRewriteCTableRExpand( TEST_NAME5, CropType.PAD );
+       }
+       
+       @Test
+       public void testRewriteCTableRExpandRightUnknownDenseCrop()  {
+               testRewriteCTableRExpand( TEST_NAME6, CropType.CROP );
+       }
+       
+       @Test
+       public void testRewriteCTableRExpandRightUnknownDensePad()  {
+               testRewriteCTableRExpand( TEST_NAME6, CropType.PAD );
+       }
        
        private void testRewriteCTableRExpand( String testname, CropType type )
        {       
@@ -101,30 +125,45 @@ public class RewriteCTableToRExpandTest extends 
AutomatedTestBase
                loadTestConfiguration(config);
 
                int outDim = maxVal + ((type==CropType.CROP) ? -7 : 7);
+               boolean unknownTests = ( testname.equals(TEST_NAME5) || 
testname.equals(TEST_NAME6) );
+                       
                
-               String HOME = SCRIPT_DIR + TEST_DIR;
-               fullDMLScriptName = HOME + testname + ".dml";
-               programArgs = new String[]{ "-stats","-args", 
-                       input("A"), String.valueOf(outDim), output("R") };
-               
-               fullRScriptName = HOME + testname + ".R";
-               rCmd = getRCmd(inputDir(), String.valueOf(outDim), 
expectedDir());                      
-
-               double[][] A = getRandomMatrix(rows, 1, 1, 10, 1.0, 7);
-               writeInputMatrixWithMTD("A", A, false);
-               
-               //run performance tests
-               runTest(true, false, null, -1); 
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               if( unknownTests )
+                       rtplatform = RUNTIME_PLATFORM.SINGLE_NODE;
                
-               //compare output meta data
-               boolean left = (testname.equals(TEST_NAME1) || 
testname.equals(TEST_NAME3));
-               boolean pos = (testname.equals(TEST_NAME1) || 
testname.equals(TEST_NAME2));
-               int rrows = (left && pos) ? rows : outDim;
-               int rcols = (!left && pos) ? rows : outDim;
-               checkDMLMetaDataFile("R", new MatrixCharacteristics(rrows, 
rcols, 1, 1));
-               
-               //check for applied rewrite
-               Assert.assertEquals(Boolean.valueOf(testname.equals(TEST_NAME1) 
|| testname.equals(TEST_NAME2)),
+               try 
+               {
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{ 
"-explain","-stats","-args", 
+                               input("A"), String.valueOf(outDim), output("R") 
};
+                       
+                       fullRScriptName = HOME + testname + ".R";
+                       rCmd = getRCmd(inputDir(), String.valueOf(outDim), 
expectedDir());
+       
+                       double[][] A = getRandomMatrix(rows, 1, 1, 10, 1.0, 7);
+                       writeInputMatrixWithMTD("A", A, false);
+                       
+                       //run performance tests
+                       runTest(true, false, null, -1); 
+                       
+                       //compare output meta data
+                       boolean left = (testname.equals(TEST_NAME1) || 
testname.equals(TEST_NAME3) 
+                               || testname.equals(TEST_NAME5) || 
testname.equals(TEST_NAME6));
+                       boolean pos = (testname.equals(TEST_NAME1) || 
testname.equals(TEST_NAME2));
+                       int rrows = (left && pos) ? rows : outDim;
+                       int rcols = (!left && pos) ? rows : outDim;
+                       if( !unknownTests )
+                               checkDMLMetaDataFile("R", new 
MatrixCharacteristics(rrows, rcols, 1, 1));
+                       
+                       //check for applied rewrite
+                       
Assert.assertEquals(Boolean.valueOf(testname.equals(TEST_NAME1) 
+                               || testname.equals(TEST_NAME2) || unknownTests),
                                
Boolean.valueOf(heavyHittersContainsSubString("rexpand")));
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml 
b/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml
new file mode 100644
index 0000000..4b07462
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+T = matrix(1, nrow(A), 2);
+A2 = rand(rows=sum(T)/2, cols=100, min=1, max=10);
+R = table(seq(1,nrow(A2)), A2[,1], nrow(A2), $2);
+
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml
----------------------------------------------------------------------
diff --git 
a/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml 
b/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml
new file mode 100644
index 0000000..68d2860
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+T = matrix(1, nrow(A), 2);
+A2 = rand(rows=sum(T)/2, cols=100, min=1, max=10);
+R = table(A2[,1], seq(1,nrow(A2)), $2, nrow(A2));
+
+write(R, $3);

Reply via email to