[SYSTEMML-1702] Fix robustness codegen row-wise (unknowns, scalar/vect)

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ddcb9e01
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ddcb9e01
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ddcb9e01

Branch: refs/heads/master
Commit: ddcb9e0190989ea1837af8dd44676d52497c3e15
Parents: 1e6639c
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Thu Jun 15 23:57:44 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Jun 16 10:01:58 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/template/TemplateCell.java    |  2 +-
 .../sysml/hops/codegen/template/TemplateRow.java     | 15 ++++++++++-----
 .../sysml/hops/codegen/template/TemplateUtils.java   |  5 +++++
 3 files changed, 16 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ddcb9e01/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index d6dcdf6..91c61c2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -75,7 +75,7 @@ public class TemplateCell extends TemplateBase
 
        @Override
        public boolean open(Hop hop) {
-               return isValidOperation(hop)
+               return hop.dimsKnown() && isValidOperation(hop)
                        || (hop instanceof IndexingOp && (((IndexingOp)hop)
                                .isColLowerEqualsUpper() || hop.getDim2()==1));
        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ddcb9e01/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index a0b4572..8364fba 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -73,9 +73,9 @@ public class TemplateRow extends TemplateBase
        
        @Override
        public boolean open(Hop hop) {
-               return (hop instanceof BinaryOp && 
hop.getInput().get(0).getDim2()>1 
+               return (hop instanceof BinaryOp && hop.dimsKnown() && 
hop.getInput().get(0).getDim2()>1 
                                && hop.getInput().get(1).getDim2()==1 && 
TemplateCell.isValidOperation(hop)) 
-                       || (hop instanceof AggBinaryOp && hop.getDim2()==1
+                       || (hop instanceof AggBinaryOp && hop.dimsKnown() && 
hop.getDim2()==1
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1)
                        || (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getDirection()!=Direction.RowCol 
                                && hop.getInput().get(0).getDim1()>1 && 
hop.getInput().get(0).getDim2()>1
@@ -164,7 +164,7 @@ public class TemplateRow extends TemplateBase
                MemoTableEntry me = memo.getBest(hop.getHopID(), 
TemplateType.RowTpl);
                for( int i=0; i<hop.getInput().size(); i++ ) {
                        Hop c = hop.getInput().get(i);
-                       if( me.isPlanRef(i) )
+                       if( me!=null && me.isPlanRef(i) )
                                rConstructCplan(c, memo, tmp, inHops, inHops2, 
compileLiterals);
                        else {
                                CNodeData cdata = 
TemplateUtils.createCNodeData(c, compileLiterals);    
@@ -258,7 +258,8 @@ public class TemplateRow extends TemplateBase
                        CNode cdata2 = 
tmp.get(hop.getInput().get(1).getHopID());
                        
                        // if one input is a matrix then we need to do vector 
by scalar operations
-                       if(hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1 )
+                       if( (hop.getInput().get(0).getDim1() > 1 && 
hop.getInput().get(0).getDim2() > 1)
+                               || (hop.getInput().get(1).getDim1() > 1 && 
hop.getInput().get(1).getDim2() > 1))
                        {
                                if( HopRewriteUtils.isBinary(hop, 
SUPPORTED_VECT_BINARY) ) {
                                        if( TemplateUtils.isMatrix(cdata1) && 
TemplateUtils.isMatrix(cdata2) ) {
@@ -267,6 +268,8 @@ public class TemplateRow extends TemplateBase
                                        }
                                        else {
                                                String opname = 
"VECT_"+((BinaryOp)hop).getOp().name()+"_SCALAR";
+                                               if( 
TemplateUtils.isColVector(cdata1) )
+                                                       cdata1 = new 
CNodeUnary(cdata1, UnaryType.LOOKUP_R);
                                                if( 
TemplateUtils.isColVector(cdata2) )
                                                        cdata2 = new 
CNodeUnary(cdata2, UnaryType.LOOKUP_R);
                                                out = new CNodeBinary(cdata1, 
cdata2, BinType.valueOf(opname));
@@ -281,7 +284,9 @@ public class TemplateRow extends TemplateBase
                                String primitiveOpName = 
((BinaryOp)hop).getOp().toString();
                                if( TemplateUtils.isColVector(cdata1) )
                                        cdata1 = new CNodeUnary(cdata1, 
UnaryType.LOOKUP_R);
-                               if( TemplateUtils.isColVector(cdata2) )
+                               if( TemplateUtils.isColVector(cdata2) //vector 
or vector can be inferred from lhs
+                                       || 
(TemplateUtils.isColVector(hop.getInput().get(0)) && cdata2 instanceof CNodeData
+                                               && 
hop.getInput().get(1).getDataType().isMatrix()))
                                        cdata2 = new CNodeUnary(cdata2, 
UnaryType.LOOKUP_R);
                                out = new CNodeBinary(cdata1, cdata2, 
BinType.valueOf(primitiveOpName));        
                        }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ddcb9e01/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java 
b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 89a02da..fbaaab6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -65,6 +65,11 @@ public class TemplateUtils
                          || hop.getDim1() == 1 && hop.getDim2() != 1 ) );
        }
        
+       public static boolean isColVector(Hop hop) {
+               return (hop.getDataType() == DataType.MATRIX 
+                       && hop.getDim1() != 1 && hop.getDim2() == 1 );
+       }
+       
        public static boolean isColVector(CNode hop) {
                return (hop.getDataType() == DataType.MATRIX 
                        && hop.getNumRows() != 1 && hop.getNumCols() == 1);

Reply via email to