[SYSTEMML-1369] New loop vectorization rewrite for indexed copies As explained in https://issues.apache.org/jira/browse/SYSTEMML-1369, this patch introduces an auto loop vectorization rewrite for indexed copies. For example, we now rewrite the following loop
parfor (i in 1:ncol(labels)) topics[id, i] = labels[1, i]; to a simple left indexing of topics[id, 1:ncol(labels)] = labels[1, 1:ncol(labels)]. This applies to for and parfor loops. Furthermore, this patch also fixes size update issues of the existing loop vectorization rewrites vectorizeElementwiseBinary and vectorizeElementwiseUnary. On a scenario of an 1K x 1K dense topics matrix, an outer loop over id, and regular for loops, the runtime improved from 33s (3438s without update-in-place) to 0.2s. Note that on large out-of-core matrices, the improvements are even larger because we cannot apply update-in-place there. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ed3a1588 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ed3a1588 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ed3a1588 Branch: refs/heads/master Commit: ed3a1588287e22ab3e8a3e9e971b505a9157ba59 Parents: cd6685a Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Mar 2 23:01:28 2017 -0800 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Thu Mar 2 23:01:28 2017 -0800 ---------------------------------------------------------------------- .../rewrite/RewriteForLoopVectorization.java | 177 ++++++++++++------- 1 file changed, 114 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ed3a1588/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java index 991dedd..273436e 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteForLoopVectorization.java @@ -76,10 +76,20 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule || csb instanceof IfStatementBlock || csb instanceof ForStatementBlock ) ) { - //auto vectorization pattern - sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar); //e.g., for(i){s = s + as.scalar(X[i,2])} + //AUTO VECTORIZATION PATTERNS + //Note: unnecessary row or column indexing then later removed via hop rewrites + + //e.g., for(i in a:b){s = s + as.scalar(X[i,2])} -> s = sum(X[a:b,2]) + sb = vectorizeScalarAggregate(sb, csb, from, to, incr, iterVar); + + //e.g., for(i in a:b){X[i,2] = Y[i,1] + Z[i,3]} -> X[a:b,2] = Y[a:b,1] + Z[a:b,3]; sb = vectorizeElementwiseBinary(sb, csb, from, to, incr, iterVar); + + //e.g., for(i in a:b){X[i,2] = abs(Y[i,1])} -> X[a:b,2] = abs(Y[a:b,1]); sb = vectorizeElementwiseUnary(sb, csb, from, to, incr, iterVar); + + //e.g., for(i in a:b){X[7,i] = Y[1,i]} -> X[7,a:b] = Y[1,a:b]; + sb = vectorizeIndexedCopy(sb, csb, from, to, incr, iterVar); } } } @@ -91,19 +101,6 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule return ret; } - /** - * Note: unnecessary row or column indexing then later removed via - * dynamic rewrites - * - * @param sb statement block? - * @param csb statement boock? - * @param from high-level operator? - * @param to high-level operator? - * @param increment high-level operator? - * @param itervar ? - * @return statement block - * @throws HopsException if HopsException occurs - */ private StatementBlock vectorizeScalarAggregate( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) throws HopsException { @@ -206,19 +203,6 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule return ret; } - /** - * Note: unnecessary row or column indexing then later removed via - * dynamic rewrites - * - * @param sb ? - * @param csb ? - * @param from ? - * @param to ? - * @param increment ? - * @param itervar ? - * @return statement block - * @throws HopsException if HopsException occurs - */ private StatementBlock vectorizeElementwiseBinary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) throws HopsException { @@ -291,8 +275,9 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule HopRewriteUtils.replaceChildReference(rix0, rix0.getInput().get(index2-1), to, index2-1); HopRewriteUtils.replaceChildReference(rix1, rix1.getInput().get(index1-1), from, index1-1); HopRewriteUtils.replaceChildReference(rix1, rix1.getInput().get(index2-1), to, index2-1); + updateLeftAndRightIndexingSizes(rowIx, lix, rix0, rix1); bop.refreshSizeInformation(); - lix.refreshSizeInformation(); + lix.refreshSizeInformation(); //after bop update ret = csb; //ret.liveIn().removeVariable(itervar); @@ -302,19 +287,6 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule return ret; } - /** - * Note: unnecessary row or column indexing then later removed via - * dynamic rewrites - * - * @param sb ? - * @param csb ? - * @param from ? - * @param to ? - * @param increment ? - * @param itervar ? - * @return statement block - * @throws HopsException if HopsException occurs - */ private StatementBlock vectorizeElementwiseUnary( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) throws HopsException { @@ -342,30 +314,16 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule && lixrhs.getInput().get(0) instanceof IndexingOp && lixrhs.getInput().get(0).getInput().get(0) instanceof DataOp ) { - IndexingOp rix = (IndexingOp) lixrhs.getInput().get(0); - //check for rowwise - if( lix.getRowLowerEqualsUpper() && rix.getRowLowerEqualsUpper() - && lix.getInput().get(2).getName().equals(itervar) - && rix.getInput().get(1).getName().equals(itervar) ) - { - apply = true; - rowIx = true; - } - //check for colwise - if( lix.getColLowerEqualsUpper() && rix.getColLowerEqualsUpper() - && lix.getInput().get(4).getName().equals(itervar) - && rix.getInput().get(3).getName().equals(itervar) ) - { - apply = true; - rowIx = false; - } + boolean[] tmp = checkLeftAndRightIndexing(lix, + (IndexingOp) lixrhs.getInput().get(0), itervar); + apply = tmp[0]; + rowIx = tmp[1]; } } } //apply rewrite if possible - if( apply ) - { + if( apply ) { Hop root = csb.get_hops().get(0); LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0); UnaryOp uop = (UnaryOp) lix.getInput().get(1); @@ -378,14 +336,107 @@ public class RewriteForLoopVectorization extends StatementBlockRewriteRule //modify right indexing HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1-1), from, index1-1); HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2-1), to, index2-1); + updateLeftAndRightIndexingSizes(rowIx, lix, rix); uop.refreshSizeInformation(); - lix.refreshSizeInformation(); + lix.refreshSizeInformation(); //after uop update ret = csb; - //ret.liveIn().removeVariable(itervar); LOG.debug("Applied vectorizeElementwiseUnaryForLoop."); } return ret; } + + private StatementBlock vectorizeIndexedCopy( StatementBlock sb, StatementBlock csb, Hop from, Hop to, Hop increment, String itervar ) + throws HopsException + { + StatementBlock ret = sb; + + //check supported increment values + if( !(increment instanceof LiteralOp && ((LiteralOp)increment).getDoubleValue()==1.0) ) { + return ret; + } + + //check for applicability + boolean apply = false; + boolean rowIx = false; //row or col + if( csb.get_hops()!=null && csb.get_hops().size()==1 ) + { + Hop root = csb.get_hops().get(0); + + if( root.getDataType()==DataType.MATRIX && root.getInput().get(0) instanceof LeftIndexingOp ) + { + LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0); + Hop lixlhs = lix.getInput().get(0); + Hop lixrhs = lix.getInput().get(1); + + if( lixlhs instanceof DataOp && lixrhs instanceof IndexingOp + && lixrhs.getInput().get(0) instanceof DataOp ) + { + boolean[] tmp = checkLeftAndRightIndexing(lix, (IndexingOp)lixrhs, itervar); + apply = tmp[0]; + rowIx = tmp[1]; + } + } + } + + //apply rewrite if possible + if( apply ) { + Hop root = csb.get_hops().get(0); + LeftIndexingOp lix = (LeftIndexingOp) root.getInput().get(0); + IndexingOp rix = (IndexingOp) lix.getInput().get(1); + int index1 = rowIx ? 2 : 4; + int index2 = rowIx ? 3 : 5; + //modify left indexing bounds + HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index1), from, index1); + HopRewriteUtils.replaceChildReference(lix, lix.getInput().get(index2), to, index2); + //modify right indexing + HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index1-1), from, index1-1); + HopRewriteUtils.replaceChildReference(rix, rix.getInput().get(index2-1), to, index2-1); + updateLeftAndRightIndexingSizes(rowIx, lix, rix); + + ret = csb; + LOG.debug("Applied vectorizeIndexedCopy."); + } + + return ret; + } + + private static boolean[] checkLeftAndRightIndexing(LeftIndexingOp lix, IndexingOp rix, String itervar) { + boolean[] ret = new boolean[2]; //apply, rowIx + + //check for rowwise + if( lix.getRowLowerEqualsUpper() && rix.getRowLowerEqualsUpper() + && lix.getInput().get(2).getName().equals(itervar) + && rix.getInput().get(1).getName().equals(itervar) ) { + ret[0] = true; + ret[1] = true; + } + //check for colwise + if( lix.getColLowerEqualsUpper() && rix.getColLowerEqualsUpper() + && lix.getInput().get(4).getName().equals(itervar) + && rix.getInput().get(3).getName().equals(itervar) ) { + ret[0] = true; + ret[1] = false; + } + + return ret; + } + + private static void updateLeftAndRightIndexingSizes(boolean rowIx, LeftIndexingOp lix, IndexingOp... rix) { + //unset special flags + if( rowIx ) { + lix.setRowLowerEqualsUpper(false); + for( IndexingOp rixi : rix ) + rixi.setRowLowerEqualsUpper(false); + } + else { + lix.setColLowerEqualsUpper(false); + for( IndexingOp rixi : rix ) + rixi.setColLowerEqualsUpper(false); + } + for( IndexingOp rixi : rix ) + rixi.refreshSizeInformation(); + lix.refreshSizeInformation(); + } }