This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 427b6d4  Fix shape inference pass (#14153)
427b6d4 is described below

commit 427b6d47b7a02987edc80c8118da10a14544e664
Author: Przemyslaw Tredak <ptre...@gmail.com>
AuthorDate: Tue Mar 5 15:26:16 2019 -0800

    Fix shape inference pass (#14153)
    
    * Fix InferShape pass
    
    * nnvm::TShape -> mxnet::TShape in InferShapeAttr
    
    * More nnvm->mxnet namespace changes
    
    * And more nnvm -> mxnet
    
    * Retrigger CI
---
 src/executor/infer_graph_attr_pass.cc | 76 ++++++++++++++++++++++++++++++-----
 tests/python/unittest/test_symbol.py  | 13 ++++++
 2 files changed, 80 insertions(+), 9 deletions(-)

diff --git a/src/executor/infer_graph_attr_pass.cc 
b/src/executor/infer_graph_attr_pass.cc
index af8094a..6a7fde6 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -68,6 +68,26 @@ bool ApplyOpInferAttr<int, FInferStorageType>(const 
nnvm::Graph& g,
  * shape/type inference functions'. The nnvm InferAttr will be deprecated
  * in the future. Please use interfaces InferShape, InferType, and 
InferStorageType
  * to call this function.
+ *
+ * \param ret graph used for attribute inference
+ * \param emmpty_val empty value of the attribute
+ * \param infer_name name of the function used for attribute inference
+ * \param input_name name of the attribute in the graph used to store the
+ *                   input data for attribute inference
+ * \param attr_key_name name of the attribute used for inference for variable 
nodes
+ * \param attr_name name of the inferred attribute
+ * \param unknown_name name of the attribute storing number of entries
+ *                     impossible to infer
+ * \param fis_none function returning true for not fully inferred values
+ * \param fdefault default function used for inference if the node does not
+ *                 provide its own implementation.
+ * \param bwd_identity_assign whether the attributes of forward NDArray and 
backward
+ *                            NDArray have to be the same. False only for 
storage
+ *                            type inference
+ * \param dispatch_mode_name name of the dispatch mode attribute on the node. 
Used for
+ *                           storage type inference
+ * \param default_mode_val default value of the dispatch mode attribute on the 
node. Used
+ *                         for storage type inference
  */
 template<typename AttrType, typename FInferType, typename IsNone, typename 
FDefault>
 nnvm::Graph InferAttr(nnvm::Graph &&ret,
@@ -322,23 +342,49 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret,
   return ret;
 }
 
-template<typename IsNone, typename FDefault>
+/*!\brief
+ * This is a version of the InferAttr function specifically for shape 
inference.
+ *
+ * \param ret graph used for attribute inference
+ * \param emmpty_val empty value of the attribute
+ * \param infer_name name of the function used for attribute inference
+ * \param input_name name of the attribute in the graph used to store the
+ *                   input data for attribute inference
+ * \param attr_key_name name of the attribute used for inference for variable 
nodes
+ * \param attr_name name of the inferred attribute
+ * \param unknown_name name of the attribute storing number of entries
+ *                     impossible to infer
+ * \param fis_none function returning true for not fully inferred values
+ * \param fnum_unknown function returning how many elements are unknown in
+ *                     partially inferred value of the attribute
+ * \param fdefault default function used for inference if the node does not
+ *                 provide its own implementation.
+ * \param bwd_identity_assign whether the attributes of forward NDArray and 
backward
+ *                            NDArray have to be the same. False only for 
storage
+ *                            type inference
+ * \param dispatch_mode_name name of the dispatch mode attribute on the node. 
Used for
+ *                           storage type inference
+ * \param default_mode_val default value of the dispatch mode attribute on the 
node. Used
+ *                         for storage type inference
+ */
+template<typename IsNone, typename FDefault, typename FNumUnknown>
 nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
-                           const nnvm::TShape empty_val,
+                           const mxnet::TShape empty_val,
                            const char* infer_name,
                            const char* input_name,
                            const char* attr_key_name,
                            const char* attr_name,
                            const char* unknown_name,
                            IsNone fis_none,
+                           FNumUnknown fnum_unknown,
                            FDefault fdefault,
                            bool bwd_identity_assign,
                            const char* dispatch_mode_name,
                            const DispatchMode default_mode_val = 
DispatchMode::kUndefined) {
   using nnvm::IndexedGraph;
   using nnvm::Op;
-  using AttrType = nnvm::TShape;
-  using FInferType = nnvm::FInferShape;
+  using AttrType = mxnet::TShape;
+  using FInferType = mxnet::FInferShape;
   using AttrVector = std::vector<AttrType>;
   using NodeAttrVector = std::vector<DispatchMode>;
   using dmlc::any;
@@ -548,12 +594,12 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
   };
 
   size_t last_num_unknown;
-  size_t num_unknown_dispatch_mode = dispatch_mode_name ? node_end - 
node_start : 0;
-  size_t num_unknown_entry_attr = entry_end - entry_start;
-  size_t num_unknown = num_unknown_entry_attr + num_unknown_dispatch_mode;
+  size_t num_unknown = static_cast<size_t>(-1);  // Infinity
+
   int i = 0;
   do {
     if (i % 2 == 0) {
+      // forward inference
       for (uint32_t nid = node_start; nid < node_end; ++nid) {
         infer_step(nid, false);
       }
@@ -567,7 +613,7 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
     num_unknown = 0;
     for (size_t j = entry_start; j < entry_end; ++j) {
       if (fis_none(rshape[j])) {
-        ++num_unknown;
+        num_unknown += fnum_unknown(rshape[j]);
       }
     }
     if (dispatch_mode_name) {
@@ -598,11 +644,23 @@ nnvm::Graph InferShape(nnvm::Graph&& graph,
   if (shape_attr_key.length() != 0) {
     graph.attrs["shape_attr_key"] = std::make_shared<any>(shape_attr_key);
   }
-  return InferAttr<mxnet::TShape, mxnet::FInferShape>(
+  return InferShapeAttr(
       std::move(graph), mxnet::TShape(),
       "FInferShape", "shape_inputs", "shape_attr_key",
       "shape", "shape_num_unknown_nodes",
       [](const mxnet::TShape& s) { return s.ndim() == 0 || s.Size() == 0; },
+      [](const mxnet::TShape& s) {
+        if (s.ndim() == 0) {  // TODO(reminisce): Usage of ndim
+          return static_cast<size_t>(1);
+        }
+        size_t ret = 0;
+        for (const auto& val : s) {
+          if (val == 0) {
+            ++ret;
+          }
+        }
+        return ret;
+      },
       nullptr, true, nullptr);
 }
 
diff --git a/tests/python/unittest/test_symbol.py 
b/tests/python/unittest/test_symbol.py
index ac4564b..b290ff3 100644
--- a/tests/python/unittest/test_symbol.py
+++ b/tests/python/unittest/test_symbol.py
@@ -157,6 +157,19 @@ def test_symbol_infer_shape():
     assert arg_shapes['x2h_weight'] == (num_hidden, num_dim)
     assert arg_shapes['h2h_weight'] == (num_hidden, num_hidden)
 
+    # Partial shape inference with some unknown dimensions
+    data_shape = (1, 0, 0, 0)
+    data = mx.sym.Variable('data', shape=data_shape)
+    weight = mx.sym.Variable('weight')
+    cdata = mx.sym.cast(data, dtype='float16')
+    cweight = mx.sym.cast(weight, dtype='float16')
+    test = mx.sym.Convolution(data=cdata, weight=cweight, pad=(3, 3), 
num_filter=64, stride=(2, 2), no_bias=True, kernel=(7, 7))
+
+    arg, _, _ = test.infer_shape_partial()
+    arg_shapes = dict(zip(test.list_arguments(), arg))
+    assert arg_shapes['data'] == data_shape
+    assert arg_shapes['weight'] == (64, 0, 7, 7)
+
 
 def test_symbol_infer_shape_var():
     "Test specifying shape information when constructing a variable"

Reply via email to