[SYSTEMML-1702] Fix robustness codegen row-wise (unknowns, scalar/vect) Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ddcb9e01 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ddcb9e01 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ddcb9e01
Branch: refs/heads/master Commit: ddcb9e0190989ea1837af8dd44676d52497c3e15 Parents: 1e6639c Author: Matthias Boehm <mboe...@gmail.com> Authored: Thu Jun 15 23:57:44 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Jun 16 10:01:58 2017 -0700 ---------------------------------------------------------------------- .../sysml/hops/codegen/template/TemplateCell.java | 2 +- .../sysml/hops/codegen/template/TemplateRow.java | 15 ++++++++++----- .../sysml/hops/codegen/template/TemplateUtils.java | 5 +++++ 3 files changed, 16 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/ddcb9e01/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java index d6dcdf6..91c61c2 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java @@ -75,7 +75,7 @@ public class TemplateCell extends TemplateBase @Override public boolean open(Hop hop) { - return isValidOperation(hop) + return hop.dimsKnown() && isValidOperation(hop) || (hop instanceof IndexingOp && (((IndexingOp)hop) .isColLowerEqualsUpper() || hop.getDim2()==1)); } http://git-wip-us.apache.org/repos/asf/systemml/blob/ddcb9e01/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java index a0b4572..8364fba 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java @@ -73,9 +73,9 @@ public class TemplateRow extends TemplateBase @Override public boolean open(Hop hop) { - return (hop instanceof BinaryOp && hop.getInput().get(0).getDim2()>1 + return (hop instanceof BinaryOp && hop.dimsKnown() && hop.getInput().get(0).getDim2()>1 && hop.getInput().get(1).getDim2()==1 && TemplateCell.isValidOperation(hop)) - || (hop instanceof AggBinaryOp && hop.getDim2()==1 + || (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1) || (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol && hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1 @@ -164,7 +164,7 @@ public class TemplateRow extends TemplateBase MemoTableEntry me = memo.getBest(hop.getHopID(), TemplateType.RowTpl); for( int i=0; i<hop.getInput().size(); i++ ) { Hop c = hop.getInput().get(i); - if( me.isPlanRef(i) ) + if( me!=null && me.isPlanRef(i) ) rConstructCplan(c, memo, tmp, inHops, inHops2, compileLiterals); else { CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals); @@ -258,7 +258,8 @@ public class TemplateRow extends TemplateBase CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID()); // if one input is a matrix then we need to do vector by scalar operations - if(hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1 ) + if( (hop.getInput().get(0).getDim1() > 1 && hop.getInput().get(0).getDim2() > 1) + || (hop.getInput().get(1).getDim1() > 1 && hop.getInput().get(1).getDim2() > 1)) { if( HopRewriteUtils.isBinary(hop, SUPPORTED_VECT_BINARY) ) { if( TemplateUtils.isMatrix(cdata1) && TemplateUtils.isMatrix(cdata2) ) { @@ -267,6 +268,8 @@ public class TemplateRow extends TemplateBase } else { String opname = "VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR"; + if( TemplateUtils.isColVector(cdata1) ) + cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); if( TemplateUtils.isColVector(cdata2) ) cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(opname)); @@ -281,7 +284,9 @@ public class TemplateRow extends TemplateBase String primitiveOpName = ((BinaryOp)hop).getOp().toString(); if( TemplateUtils.isColVector(cdata1) ) cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R); - if( TemplateUtils.isColVector(cdata2) ) + if( TemplateUtils.isColVector(cdata2) //vector or vector can be inferred from lhs + || (TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData + && hop.getInput().get(1).getDataType().isMatrix())) cdata2 = new CNodeUnary(cdata2, UnaryType.LOOKUP_R); out = new CNodeBinary(cdata1, cdata2, BinType.valueOf(primitiveOpName)); } http://git-wip-us.apache.org/repos/asf/systemml/blob/ddcb9e01/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java index 89a02da..fbaaab6 100644 --- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java +++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java @@ -65,6 +65,11 @@ public class TemplateUtils || hop.getDim1() == 1 && hop.getDim2() != 1 ) ); } + public static boolean isColVector(Hop hop) { + return (hop.getDataType() == DataType.MATRIX + && hop.getDim1() != 1 && hop.getDim2() == 1 ); + } + public static boolean isColVector(CNode hop) { return (hop.getDataType() == DataType.MATRIX && hop.getNumRows() != 1 && hop.getNumCols() == 1);