Repository: incubator-systemml Updated Branches: refs/heads/master 29c307c9a -> 76f3ca5d3
[HOTFIX] Bugfix for metadata of conv2d_* and maxpool_* operations Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/76f3ca5d Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/76f3ca5d Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/76f3ca5d Branch: refs/heads/master Commit: 76f3ca5d39e492fc3075c4bd8240ec5339647001 Parents: 29c307c Author: Niketan Pansare <npan...@us.ibm.com> Authored: Wed May 3 21:02:14 2017 -0800 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Wed May 3 22:02:14 2017 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/hops/ConvolutionOp.java | 49 ++++++++++++++------ .../org/apache/sysml/parser/DMLTranslator.java | 10 +++- 2 files changed, 43 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/76f3ca5d/src/main/java/org/apache/sysml/hops/ConvolutionOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java index a18aada..cb67d65 100644 --- a/src/main/java/org/apache/sysml/hops/ConvolutionOp.java +++ b/src/main/java/org/apache/sysml/hops/ConvolutionOp.java @@ -242,30 +242,43 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop switch(op) { case MAX_POOLING: { - ret[0] = getInput().get(0)._dim1; + // input + long N = getInput().get(0)._dim1; + ret[0] = N; ret[1] = getExtractedVal(params.C, params.P, params.Q); ret[2] = -1; break; } case DIRECT_CONV2D: { - ret[0] = getInput().get(0)._dim1; - ret[1] = getExtractedVal(getInput().get(1)._dim1, params.P, params.Q); + // input, filter + long N = getInput().get(0)._dim1; + ret[0] = N; + ret[1] = getExtractedVal(params.K, params.P, params.Q); ret[2] = -1; break; } case DIRECT_CONV2D_BACKWARD_FILTER: { - ret[0] = getInput().get(1)._dim1; - ret[1] = getInput().get(1)._dim2; + // input, dout + ret[0] = params.K; + ret[1] = getExtractedVal(params.C, params.R, params.S); ret[2] = -1; break; } - case MAX_POOLING_BACKWARD: - case DIRECT_CONV2D_BACKWARD_DATA: { + case MAX_POOLING_BACKWARD: { + // input, dout ret[0] = getInput().get(0)._dim1; ret[1] = getInput().get(0)._dim2; ret[2] = -1; break; } + case DIRECT_CONV2D_BACKWARD_DATA: { + // filter, dout + long N = getInput().get(1)._dim1; + ret[0] = N; + ret[1] = getExtractedVal(params.C, params.H, params.W); + ret[2] = -1; + break; + } default: throw new RuntimeException("Unsupported op:" + op.name()); } @@ -390,13 +403,16 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop { case MAX_POOLING: { - _dim1 = getInput().get(0)._dim1; + // input + long N = getInput().get(0)._dim1; + _dim1 = N; _dim2 = getExtractedVal(params.C, params.P, params.Q); _nnz = -1; // cannot infer stats break; } case MAX_POOLING_BACKWARD: { + // input, dout _dim1 = getInput().get(0)._dim1; _dim2 = getInput().get(0)._dim2; _nnz = -1; @@ -404,22 +420,27 @@ public class ConvolutionOp extends Hop implements MultiThreadedHop } case DIRECT_CONV2D: { - _dim1 = getInput().get(0)._dim1; - _dim2 = getExtractedVal(getInput().get(1)._dim1, params.P, params.Q); + // input, filter + long N = getInput().get(0)._dim1; + _dim1 = N; + _dim2 = getExtractedVal(params.K, params.P, params.Q); _nnz = -1; // cannot infer stats break; } case DIRECT_CONV2D_BACKWARD_DATA: { - _dim1 = getInput().get(0)._dim1; - _dim2 = getInput().get(0)._dim2; + // filter, dout + long N = getInput().get(1)._dim1; + _dim1 = N; + _dim2 = getExtractedVal(params.C, params.H, params.W); _nnz = -1; // cannot infer stats break; } case DIRECT_CONV2D_BACKWARD_FILTER: { - _dim1 = getInput().get(1)._dim1; - _dim2 = getInput().get(1)._dim2; + // input, dout + _dim1 = params.K; + _dim2 = getExtractedVal(params.C, params.R, params.S); _nnz = -1; // cannot infer stats break; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/76f3ca5d/src/main/java/org/apache/sysml/parser/DMLTranslator.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java index 3373e98..9f63038 100644 --- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; +import org.antlr.v4.parse.ANTLRParser.option_return; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.conf.ConfigurationManager; @@ -2152,7 +2153,7 @@ public class DMLTranslator if (target == null) { target = createTarget(source); } - + // Construct the hop based on the type of Builtin function switch (source.getOpCode()) { @@ -2785,7 +2786,12 @@ public class DMLTranslator throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } - setIdentifierParams(currBuiltinOp, source.getOutput()); + if( !(source.getOpCode() == BuiltinFunctionOp.CONV2D || source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_DATA || + source.getOpCode() == BuiltinFunctionOp.CONV2D_BACKWARD_FILTER || source.getOpCode() == BuiltinFunctionOp.MAX_POOL || + source.getOpCode() == BuiltinFunctionOp.MAX_POOL_BACKWARD) ) { + // Since the dimension of output doesnot match that of input variable for these operations + setIdentifierParams(currBuiltinOp, source.getOutput()); + } currBuiltinOp.setAllPositions(source.getBeginLine(), source.getBeginColumn(), source.getEndLine(), source.getEndColumn()); return currBuiltinOp; }