Repository: systemml
Updated Branches:
  refs/heads/master dbb0b48a9 -> 8dcb487e4


[SYSTEMML-1812] Improved codegen candidate exploration algorithm

This patch makes two minor improvements to the codegen candidate
exploration algorithm for simplification and slightly better
performance. The performance improvements are due to iterating over
distinct templates and avoiding unnecessary object creation.
 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/41849701
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/41849701
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/41849701

Branch: refs/heads/master
Commit: 418497019b461b3f3de4c0b453cee76fc9b73d80
Parents: dbb0b48
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Wed Jul 26 19:24:07 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Wed Jul 26 21:48:59 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       | 39 +++++++-------------
 .../hops/codegen/template/CPlanMemoTable.java   | 35 +++++++++++-------
 .../hops/codegen/template/TemplateBase.java     | 25 +++++++++----
 .../hops/codegen/template/TemplateUtils.java    |  4 +-
 .../functions/codegen/rowAggPattern30.dml       |  2 -
 5 files changed, 56 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java 
b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index 8ab2240..dadb318 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -506,23 +506,14 @@ public class SpoofCompiler
                
                //open initial operator plans, if possible
                for( TemplateBase tpl : TemplateUtils.TEMPLATES )
-                       if( tpl.open(hop) ) {
-                               MemoTableEntrySet P = new 
MemoTableEntrySet(hop, tpl.getType(), false);
-                               memo.addAll(hop, enumPlans(hop, -1, P, tpl, 
memo));
-                       }
+                       if( tpl.open(hop) )
+                               memo.addAll(hop, enumPlans(hop, null, tpl, 
memo));
                
                //fuse and merge operator plans
-               for( Hop c : hop.getInput() ) {
-                       if( memo.contains(c.getHopID()) )
-                               for( MemoTableEntry me : 
memo.getDistinct(c.getHopID()) ) {
-                                       TemplateBase tpl = 
TemplateUtils.createTemplate(me.type, me.closed);
-                                       if( tpl.fuse(hop, c) ) {
-                                               int pos = 
hop.getInput().indexOf(c);
-                                               MemoTableEntrySet P = new 
MemoTableEntrySet(hop, tpl.getType(), pos, c.getHopID(), tpl.isClosed());
-                                               memo.addAll(hop, enumPlans(hop, 
pos, P, tpl, memo));
-                                       }
-                               }       
-               }
+               for( Hop c : hop.getInput() )
+                       for( TemplateBase tpl : 
memo.getDistinctTemplates(c.getHopID()) )
+                               if( tpl.fuse(hop, c) )
+                                       memo.addAll(hop, enumPlans(hop, c, tpl, 
memo));
                
                //close operator plans, if required
                if( memo.contains(hop.getHopID()) ) {
@@ -546,16 +537,14 @@ public class SpoofCompiler
                memo.addHop(hop);
        }
        
-       private static MemoTableEntrySet enumPlans(Hop hop, int pos, 
MemoTableEntrySet P, TemplateBase tpl, CPlanMemoTable memo) {
-               for(int k=0; k<hop.getInput().size(); k++)
-                       if( k != pos ) {
-                               Hop input2 = hop.getInput().get(k);
-                               if( memo.contains(input2.getHopID(), true, 
tpl.getType(), TemplateType.CellTpl) 
-                                       && tpl.merge(hop, input2) )
-                                       P.crossProduct(k, -1L, 
input2.getHopID());
-                               else
-                                       P.crossProduct(k, -1L);
-                       }
+       private static MemoTableEntrySet enumPlans(Hop hop, Hop c, TemplateBase 
tpl, CPlanMemoTable memo) {
+               MemoTableEntrySet P = new MemoTableEntrySet(hop, c, tpl);
+               for(int k=0; k<hop.getInput().size(); k++) {
+                       Hop input2 = hop.getInput().get(k);
+                       if( input2 != c && tpl.merge(hop, input2) 
+                               && memo.contains(input2.getHopID(), true, 
tpl.getType(), TemplateType.CellTpl))
+                               P.crossProduct(k, -1L, input2.getHopID());
+               }
                return P;
        }
        

http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
index 074d29c..edbcdf9 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java
@@ -219,9 +219,16 @@ public class CPlanMemoTable
        }
        
        public List<MemoTableEntry> getDistinct(long hopID) {
+               return _plans.get(hopID).stream()
+                       .distinct().collect(Collectors.toList());
+       }
+       
+       public List<TemplateBase> getDistinctTemplates(long hopID) {
+               if(!contains(hopID))
+                       return Collections.emptyList();
                //return distinct entries wrt type and closed attributes
                return _plans.get(hopID).stream()
-                       .map(p -> new 
MemoTableEntry(p.type,-1,-1,-1,p.size,p.closed))
+                       .map(p -> TemplateUtils.createTemplate(p.type, 
p.closed))
                        .distinct().collect(Collectors.toList());
        }
        
@@ -327,11 +334,14 @@ public class CPlanMemoTable
                                && !(!isPlanRef(1) && that.isPlanRef(1))
                                && !(!isPlanRef(2) && that.isPlanRef(2)));
                }
-               
                @Override
                public int hashCode() {
-                       return Arrays.hashCode(
-                               new long[]{(long)type.ordinal(), input1, 
input2, input3});
+                       int h = UtilFunctions.intHashCode(type.ordinal(), 
Long.hashCode(input1));
+                       h = UtilFunctions.intHashCode(h, Long.hashCode(input2));
+                       h = UtilFunctions.intHashCode(h, Long.hashCode(input3));
+                       h = UtilFunctions.intHashCode(h, size);
+                       h = UtilFunctions.intHashCode(h, 
Boolean.hashCode(closed));
+                       return h;
                }
                @Override
                public boolean equals(Object obj) {
@@ -339,7 +349,8 @@ public class CPlanMemoTable
                                return false;
                        MemoTableEntry that = (MemoTableEntry)obj;
                        return type == that.type && input1 == that.input1
-                               && input2 == that.input2 && input3 == 
that.input3;
+                               && input2 == that.input2 && input3 == 
that.input3
+                               && size == that.size && closed == that.closed;
                }
                @Override
                public String toString() {
@@ -360,18 +371,16 @@ public class CPlanMemoTable
        {
                public ArrayList<MemoTableEntry> plans = new 
ArrayList<MemoTableEntry>();
                
-               public MemoTableEntrySet(Hop hop, TemplateType type, boolean 
close) {
-                       int size = (hop instanceof IndexingOp) ? 1 : 
hop.getInput().size();
-                       plans.add(new MemoTableEntry(type, -1, -1, -1, size, 
close));
-               }
-               
-               public MemoTableEntrySet(Hop hop, TemplateType type, int pos, 
long hopID, boolean close) {
+               public MemoTableEntrySet(Hop hop, Hop c, TemplateBase tpl) {
+                       int pos = (c != null) ? hop.getInput().indexOf(c) : -1;
                        int size = (hop instanceof IndexingOp) ? 1 : 
hop.getInput().size();
-                       plans.add(new MemoTableEntry(type, (pos==0)?hopID:-1, 
-                                       (pos==1)?hopID:-1, (pos==2)?hopID:-1, 
size));
+                       plans.add(new MemoTableEntry(tpl.getType(), 
(pos==0)?c.getHopID():-1,
+                               (pos==1)?c.getHopID():-1, 
(pos==2)?c.getHopID():-1, size, tpl.isClosed()));
                }
                
                public void crossProduct(int pos, Long... refs) {
+                       if( refs.length==1 && refs[0] == -1 )
+                               return; //unmodified plan set
                        ArrayList<MemoTableEntry> tmp = new 
ArrayList<MemoTableEntry>();
                        for( MemoTableEntry me : plans )
                                for( Long ref : refs )

http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
index f0fe3fa..a4f2b91 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java
@@ -22,6 +22,7 @@ package org.apache.sysml.hops.codegen.template;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysml.runtime.matrix.data.Pair;
+import org.apache.sysml.runtime.util.UtilFunctions;
 
 public abstract class TemplateBase 
 {      
@@ -43,7 +44,7 @@ public abstract class TemplateBase
        }
 
        protected final TemplateType _type;
-       protected boolean _closed = false;
+       protected final boolean _closed;
        
        protected TemplateBase(TemplateType type) {
                this(type, false);
@@ -62,6 +63,21 @@ public abstract class TemplateBase
                return _closed;
        }
        
+       @Override
+       public int hashCode() {
+               return UtilFunctions.intHashCode(
+                       _type.ordinal(), Boolean.hashCode(_closed));
+       }
+       
+       @Override
+       public boolean equals(Object obj) {
+               if( !(obj instanceof TemplateBase) )
+                       return false;
+               TemplateBase that = (TemplateBase)obj;
+               return _type == that._type 
+                       && _closed == that._closed;
+       }
+       
        /////////////////////////////////////////////
        // Open-Fuse-Merge-Close interface 
        // (for candidate generation and exploration)
@@ -106,13 +122,6 @@ public abstract class TemplateBase
         */
        public abstract CloseType close(Hop hop);
        
-       /**
-        * Mark the template as closed either invalid or valid.
-        */
-       public void close() {
-               _closed = true;
-       }
-       
        /////////////////////////////////////////////
        // CPlan construction interface
        // (for plan creation of selected candidates)

http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 402f9fe..b461c5e 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -54,7 +54,9 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 
 public class TemplateUtils 
 {
-       public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new 
TemplateRow(), new TemplateCell(), new TemplateOuterProduct()};
+       public static final TemplateBase[] TEMPLATES = new TemplateBase[]{
+               new TemplateRow(), new TemplateCell(), new 
TemplateOuterProduct()};
+       //note: multiagg not included because it's a composite template
        
        public static boolean isVector(Hop hop) {
                return (hop.getDataType() == DataType.MATRIX 

http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/test/scripts/functions/codegen/rowAggPattern30.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern30.dml 
b/src/test/scripts/functions/codegen/rowAggPattern30.dml
index a130b82..60e5e54 100644
--- a/src/test/scripts/functions/codegen/rowAggPattern30.dml
+++ b/src/test/scripts/functions/codegen/rowAggPattern30.dml
@@ -29,6 +29,4 @@ if(1==1){}
 Q = P[,1:K] * (X %*% ssX_V);
 R = t(X) %*% (Q - P[,1:K] * rowSums(Q));
 
-print(max(R));
-
 write(R, $1)

Reply via email to