Repository: incubator-systemml
Updated Branches:
  refs/heads/master 29c307c9a -> 76f3ca5d3


[HOTFIX] Bugfix for metadata of conv2d_* and maxpool_* operations


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/76f3ca5d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/76f3ca5d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/76f3ca5d

Branch: refs/heads/master
Commit: 76f3ca5d39e492fc3075c4bd8240ec5339647001
Parents: 29c307c
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Wed May 3 21:02:14 2017 -0800
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Wed May 3 22:02:14 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/hops/ConvolutionOp.java    | 49 ++++++++++++++------
 .../org/apache/sysml/parser/DMLTranslator.java  | 10 +++-
 2 files changed, 43 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/76f3ca5d/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java 
b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
index a18aada..cb67d65 100644
--- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
+++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java
@@ -242,30 +242,43 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                switch(op) 
                {
                        case MAX_POOLING: {
-                               ret[0] = getInput().get(0)._dim1;
+                               // input
+                               long N = getInput().get(0)._dim1;
+                               ret[0] = N;
                                ret[1] = getExtractedVal(params.C, params.P, 
params.Q);
                                ret[2] = -1;
                                break;
                        }
                        case DIRECT_CONV2D: {
-                               ret[0] = getInput().get(0)._dim1;
-                               ret[1] = 
getExtractedVal(getInput().get(1)._dim1, params.P, params.Q);
+                               // input, filter
+                               long N = getInput().get(0)._dim1;
+                               ret[0] = N;
+                               ret[1] = getExtractedVal(params.K, params.P, 
params.Q);
                                ret[2] = -1;
                                break;
                        }
                        case DIRECT_CONV2D_BACKWARD_FILTER: {
-                               ret[0] = getInput().get(1)._dim1;
-                               ret[1] = getInput().get(1)._dim2;
+                               // input, dout  
+                               ret[0] = params.K;
+                               ret[1] = getExtractedVal(params.C, params.R, 
params.S);
                                ret[2] = -1;
                                break;
                        }
-                       case MAX_POOLING_BACKWARD:
-                       case DIRECT_CONV2D_BACKWARD_DATA: {
+                       case MAX_POOLING_BACKWARD: {
+                               // input, dout
                                ret[0] = getInput().get(0)._dim1;
                                ret[1] = getInput().get(0)._dim2;
                                ret[2] = -1;
                                break;
                        }
+                       case DIRECT_CONV2D_BACKWARD_DATA: {
+                               // filter, dout
+                               long N = getInput().get(1)._dim1;
+                               ret[0] = N;
+                               ret[1] = getExtractedVal(params.C, params.H, 
params.W);
+                               ret[2] = -1;
+                               break;
+                       }
                        default:
                                throw new RuntimeException("Unsupported op:" + 
op.name());
                }
@@ -390,13 +403,16 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                {
                        case MAX_POOLING:
                        {       
-                               _dim1 = getInput().get(0)._dim1;
+                               // input
+                               long N = getInput().get(0)._dim1;
+                               _dim1 = N;
                                _dim2 = getExtractedVal(params.C, params.P, 
params.Q);
                                _nnz = -1; // cannot infer stats
                                break;
                        }
                        case MAX_POOLING_BACKWARD:
                        {
+                               // input, dout
                                _dim1 = getInput().get(0)._dim1;
                                _dim2 = getInput().get(0)._dim2;
                                _nnz = -1;
@@ -404,22 +420,27 @@ public class ConvolutionOp extends Hop  implements 
MultiThreadedHop
                        }
                        case DIRECT_CONV2D:
                        {
-                               _dim1 = getInput().get(0)._dim1;
-                               _dim2 = 
getExtractedVal(getInput().get(1)._dim1, params.P, params.Q);
+                               // input, filter
+                               long N = getInput().get(0)._dim1;
+                               _dim1 = N;
+                               _dim2 = getExtractedVal(params.K, params.P, 
params.Q);
                                _nnz = -1; // cannot infer stats
                                break;
                        }
                        case DIRECT_CONV2D_BACKWARD_DATA:
                        {
-                               _dim1 = getInput().get(0)._dim1;
-                               _dim2 = getInput().get(0)._dim2;
+                               // filter, dout
+                               long N = getInput().get(1)._dim1;
+                               _dim1 = N;
+                               _dim2 = getExtractedVal(params.C, params.H, 
params.W);
                                _nnz = -1; // cannot infer stats
                                break;
                        }
                        case DIRECT_CONV2D_BACKWARD_FILTER:
                        {
-                               _dim1 = getInput().get(1)._dim1;
-                               _dim2 = getInput().get(1)._dim2;
+                               // input, dout  
+                               _dim1 = params.K;
+                               _dim2 = getExtractedVal(params.C, params.R, 
params.S);
                                _nnz = -1; // cannot infer stats
                                break;
                        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/76f3ca5d/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index 3373e98..9f63038 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 
+import org.antlr.v4.parse.ANTLRParser.option_return;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -2152,7 +2153,7 @@ public class DMLTranslator
                if (target == null) {
                        target = createTarget(source);
                }
-
+               
                // Construct the hop based on the type of Builtin function
                switch (source.getOpCode()) {
 
@@ -2785,7 +2786,12 @@ public class DMLTranslator
                        throw new ParseException("Unsupported builtin function 
type: "+source.getOpCode());
                }
                
-               setIdentifierParams(currBuiltinOp, source.getOutput());
+               if( !(source.getOpCode() == BuiltinFunctionOp.CONV2D || 
source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA ||
+                               source.getOpCode() == 
BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || source.getOpCode() == 
BuiltinFunctionOp.MAX_POOL ||
+                               source.getOpCode() == 
BuiltinFunctionOp.MAX_POOL_BACKWARD) ) {
+                       // Since the dimension of output doesnot match that of 
input variable for these operations
+                       setIdentifierParams(currBuiltinOp, source.getOutput());
+               }
                currBuiltinOp.setAllPositions(source.getBeginLine(), 
source.getBeginColumn(), source.getEndLine(), source.getEndColumn());
                return currBuiltinOp;
        }

Reply via email to