Repository: systemml Updated Branches: refs/heads/master dbb0b48a9 -> 8dcb487e4
[SYSTEMML-1812] Improved codegen candidate exploration algorithm This patch makes two minor improvements to the codegen candidate exploration algorithm for simplification and slightly better performance. The performance improvements are due to iterating over distinct templates and avoiding unnecessary object creation. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/41849701 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/41849701 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/41849701 Branch: refs/heads/master Commit: 418497019b461b3f3de4c0b453cee76fc9b73d80 Parents: dbb0b48 Author: Matthias Boehm <mboe...@gmail.com> Authored: Wed Jul 26 19:24:07 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Wed Jul 26 21:48:59 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/SpoofCompiler.java | 39 +++++++------------- .../hops/codegen/template/CPlanMemoTable.java | 35 +++++++++++------- .../hops/codegen/template/TemplateBase.java | 25 +++++++++---- .../hops/codegen/template/TemplateUtils.java | 4 +- .../functions/codegen/rowAggPattern30.dml | 2 - 5 files changed, 56 insertions(+), 49 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java index 8ab2240..dadb318 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java +++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java @@ -506,23 +506,14 @@ public class SpoofCompiler //open initial operator plans, if possible for( TemplateBase tpl : TemplateUtils.TEMPLATES ) - if( tpl.open(hop) ) { - MemoTableEntrySet P = new MemoTableEntrySet(hop, tpl.getType(), false); - memo.addAll(hop, enumPlans(hop, -1, P, tpl, memo)); - } + if( tpl.open(hop) ) + memo.addAll(hop, enumPlans(hop, null, tpl, memo)); //fuse and merge operator plans - for( Hop c : hop.getInput() ) { - if( memo.contains(c.getHopID()) ) - for( MemoTableEntry me : memo.getDistinct(c.getHopID()) ) { - TemplateBase tpl = TemplateUtils.createTemplate(me.type, me.closed); - if( tpl.fuse(hop, c) ) { - int pos = hop.getInput().indexOf(c); - MemoTableEntrySet P = new MemoTableEntrySet(hop, tpl.getType(), pos, c.getHopID(), tpl.isClosed()); - memo.addAll(hop, enumPlans(hop, pos, P, tpl, memo)); - } - } - } + for( Hop c : hop.getInput() ) + for( TemplateBase tpl : memo.getDistinctTemplates(c.getHopID()) ) + if( tpl.fuse(hop, c) ) + memo.addAll(hop, enumPlans(hop, c, tpl, memo)); //close operator plans, if required if( memo.contains(hop.getHopID()) ) { @@ -546,16 +537,14 @@ public class SpoofCompiler memo.addHop(hop); } - private static MemoTableEntrySet enumPlans(Hop hop, int pos, MemoTableEntrySet P, TemplateBase tpl, CPlanMemoTable memo) { - for(int k=0; k<hop.getInput().size(); k++) - if( k != pos ) { - Hop input2 = hop.getInput().get(k); - if( memo.contains(input2.getHopID(), true, tpl.getType(), TemplateType.CellTpl) - && tpl.merge(hop, input2) ) - P.crossProduct(k, -1L, input2.getHopID()); - else - P.crossProduct(k, -1L); - } + private static MemoTableEntrySet enumPlans(Hop hop, Hop c, TemplateBase tpl, CPlanMemoTable memo) { + MemoTableEntrySet P = new MemoTableEntrySet(hop, c, tpl); + for(int k=0; k<hop.getInput().size(); k++) { + Hop input2 = hop.getInput().get(k); + if( input2 != c && tpl.merge(hop, input2) + && memo.contains(input2.getHopID(), true, tpl.getType(), TemplateType.CellTpl)) + P.crossProduct(k, -1L, input2.getHopID()); + } return P; } http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java index 074d29c..edbcdf9 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/CPlanMemoTable.java @@ -219,9 +219,16 @@ public class CPlanMemoTable } public List<MemoTableEntry> getDistinct(long hopID) { + return _plans.get(hopID).stream() + .distinct().collect(Collectors.toList()); + } + + public List<TemplateBase> getDistinctTemplates(long hopID) { + if(!contains(hopID)) + return Collections.emptyList(); //return distinct entries wrt type and closed attributes return _plans.get(hopID).stream() - .map(p -> new MemoTableEntry(p.type,-1,-1,-1,p.size,p.closed)) + .map(p -> TemplateUtils.createTemplate(p.type, p.closed)) .distinct().collect(Collectors.toList()); } @@ -327,11 +334,14 @@ public class CPlanMemoTable && !(!isPlanRef(1) && that.isPlanRef(1)) && !(!isPlanRef(2) && that.isPlanRef(2))); } - @Override public int hashCode() { - return Arrays.hashCode( - new long[]{(long)type.ordinal(), input1, input2, input3}); + int h = UtilFunctions.intHashCode(type.ordinal(), Long.hashCode(input1)); + h = UtilFunctions.intHashCode(h, Long.hashCode(input2)); + h = UtilFunctions.intHashCode(h, Long.hashCode(input3)); + h = UtilFunctions.intHashCode(h, size); + h = UtilFunctions.intHashCode(h, Boolean.hashCode(closed)); + return h; } @Override public boolean equals(Object obj) { @@ -339,7 +349,8 @@ public class CPlanMemoTable return false; MemoTableEntry that = (MemoTableEntry)obj; return type == that.type && input1 == that.input1 - && input2 == that.input2 && input3 == that.input3; + && input2 == that.input2 && input3 == that.input3 + && size == that.size && closed == that.closed; } @Override public String toString() { @@ -360,18 +371,16 @@ public class CPlanMemoTable { public ArrayList<MemoTableEntry> plans = new ArrayList<MemoTableEntry>(); - public MemoTableEntrySet(Hop hop, TemplateType type, boolean close) { - int size = (hop instanceof IndexingOp) ? 1 : hop.getInput().size(); - plans.add(new MemoTableEntry(type, -1, -1, -1, size, close)); - } - - public MemoTableEntrySet(Hop hop, TemplateType type, int pos, long hopID, boolean close) { + public MemoTableEntrySet(Hop hop, Hop c, TemplateBase tpl) { + int pos = (c != null) ? hop.getInput().indexOf(c) : -1; int size = (hop instanceof IndexingOp) ? 1 : hop.getInput().size(); - plans.add(new MemoTableEntry(type, (pos==0)?hopID:-1, - (pos==1)?hopID:-1, (pos==2)?hopID:-1, size)); + plans.add(new MemoTableEntry(tpl.getType(), (pos==0)?c.getHopID():-1, + (pos==1)?c.getHopID():-1, (pos==2)?c.getHopID():-1, size, tpl.isClosed())); } public void crossProduct(int pos, Long... refs) { + if( refs.length==1 && refs[0] == -1 ) + return; //unmodified plan set ArrayList<MemoTableEntry> tmp = new ArrayList<MemoTableEntry>(); for( MemoTableEntry me : plans ) for( Long ref : refs ) http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java index f0fe3fa..a4f2b91 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateBase.java @@ -22,6 +22,7 @@ package org.apache.sysml.hops.codegen.template; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.codegen.cplan.CNodeTpl; import org.apache.sysml.runtime.matrix.data.Pair; +import org.apache.sysml.runtime.util.UtilFunctions; public abstract class TemplateBase { @@ -43,7 +44,7 @@ public abstract class TemplateBase } protected final TemplateType _type; - protected boolean _closed = false; + protected final boolean _closed; protected TemplateBase(TemplateType type) { this(type, false); @@ -62,6 +63,21 @@ public abstract class TemplateBase return _closed; } + @Override + public int hashCode() { + return UtilFunctions.intHashCode( + _type.ordinal(), Boolean.hashCode(_closed)); + } + + @Override + public boolean equals(Object obj) { + if( !(obj instanceof TemplateBase) ) + return false; + TemplateBase that = (TemplateBase)obj; + return _type == that._type + && _closed == that._closed; + } + ///////////////////////////////////////////// // Open-Fuse-Merge-Close interface // (for candidate generation and exploration) @@ -106,13 +122,6 @@ public abstract class TemplateBase */ public abstract CloseType close(Hop hop); - /** - * Mark the template as closed either invalid or valid. - */ - public void close() { - _closed = true; - } - ///////////////////////////////////////////// // CPlan construction interface // (for plan creation of selected candidates) http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 402f9fe..b461c5e 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -54,7 +54,9 @@ import org.apache.sysml.runtime.util.UtilFunctions; public class TemplateUtils { - public static final TemplateBase[] TEMPLATES = new TemplateBase[]{new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()}; + public static final TemplateBase[] TEMPLATES = new TemplateBase[]{ + new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()}; + //note: multiagg not included because it's a composite template public static boolean isVector(Hop hop) { return (hop.getDataType() == DataType.MATRIX http://git-wip-us.apache.org/repos/asf/systemml/blob/41849701/src/test/scripts/functions/codegen/rowAggPattern30.dml ---------------------------------------------------------------------- diff --git a/src/test/scripts/functions/codegen/rowAggPattern30.dml b/src/test/scripts/functions/codegen/rowAggPattern30.dml index a130b82..60e5e54 100644 --- a/src/test/scripts/functions/codegen/rowAggPattern30.dml +++ b/src/test/scripts/functions/codegen/rowAggPattern30.dml @@ -29,6 +29,4 @@ if(1==1){} Q = P[,1:K] * (X %*% ssX_V); R = t(X) %*% (Q - P[,1:K] * rowSums(Q)); -print(max(R)); - write(R, $1)