This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7641759  [MXNET-309] [ONNX-MXNet] Model Metadata API (#10512)
7641759 is described below

commit 76417594e56a85ec0cc9412b9dd2c7e2ab581d8b
Author: Anirudh <anirudhk...@gmail.com>
AuthorDate: Wed May 16 20:36:54 2018 -0700

    [MXNET-309] [ONNX-MXNet] Model Metadata API (#10512)
    
    * metadata api
    
    * pylint changes
    
    * move logic to import_onnx
    
    * test fix
    
    * doc API
    
    * rerun CI.
    
    * fix comments
    
    * docs fix
---
 docs/api/python/contrib/onnx.md                   |  6 +++--
 docs/tutorials/onnx/inference_on_onnx_model.md    | 19 ++++++++++----
 example/onnx/super_resolution.py                  |  6 ++---
 python/mxnet/contrib/onnx/__init__.py             |  3 +--
 python/mxnet/contrib/onnx/_import/import_model.py | 30 +++++++++++++++++++++++
 python/mxnet/contrib/onnx/_import/import_onnx.py  | 23 +++++++++++++++++
 tests/python-pytest/onnx/onnx_test.py             | 27 +++++++++++++++-----
 7 files changed, 95 insertions(+), 19 deletions(-)

diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md
index 44aabaf..6fb546f 100644
--- a/docs/api/python/contrib/onnx.md
+++ b/docs/api/python/contrib/onnx.md
@@ -13,7 +13,7 @@ With ONNX format support for MXNet, developers can build and 
train models with a
 ```
 
 ### Installation Instructions
-- To use this module developers need to **install ONNX**, which requires 
protobuf compiler to be installed separately. Please follow the [instructions 
to install ONNX and its 
dependencies](https://github.com/onnx/onnx#installation). Once installed, you 
can go through the tutorials on how to use this module.
+- To use this module developers need to **install ONNX**, which requires the 
protobuf compiler to be installed separately. Please follow the [instructions 
to install ONNX and its 
dependencies](https://github.com/onnx/onnx#installation). **MXNet currently 
supports ONNX v1.1.1**. Once installed, you can go through the tutorials on how 
to use this module.
 
 
 This document describes all the ONNX-MXNet APIs.
@@ -23,6 +23,7 @@ This document describes all the ONNX-MXNet APIs.
     :nosignatures:
 
     mxnet.contrib.onnx.import_model
+    mxnet.contrib.onnx.get_model_metadata
 ```
 
 ## ONNX Tutorials
@@ -43,7 +44,8 @@ This document describes all the ONNX-MXNet APIs.
 ```eval_rst
 
 .. automodule:: mxnet.contrib.onnx
-    :members: import_model 
+    :members: import_model
+    :members: get_model_metadata
 
 ```
 
diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md 
b/docs/tutorials/onnx/inference_on_onnx_model.md
index f342dad..3d4072a 100644
--- a/docs/tutorials/onnx/inference_on_onnx_model.md
+++ b/docs/tutorials/onnx/inference_on_onnx_model.md
@@ -104,17 +104,26 @@ We pick a context, GPU if available, otherwise CPU
 ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
 ```
 
-We obtain the data names of the inputs to the model, by listing all the inputs 
to the symbol graph and excluding the argument and auxiliary parameters from 
that list:
+We obtain the data names of the inputs to the model by using the model 
metadata API: 
 
 ```python
-data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in 
aux_params]
-print(data_names)
+model_metadata = onnx_mxnet.get_model_metadata(onnx_path)
+print(model_metadata)
 ```
 
+```
+{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))],
+ 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]}
+```
 
-```['gpu_0/data_0']```
+```python
+data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]
+print(data_names)
+```
 
+```
+[u'gpu_0/data_0']
+```
 
 And load them into a MXNet Gluon symbol block. 
 
diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py
index a52f1a8..fcb8ccc 100644
--- a/example/onnx/super_resolution.py
+++ b/example/onnx/super_resolution.py
@@ -55,10 +55,8 @@ def get_test_image():
 
 def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr):
     """Perform inference on image using mxnet"""
-    # To fetch the data names of the input to the model we list the inputs of 
the symbol graph
-    # and exclude the argument and auxiliary parameters from the list
-    data_names = [graph_input for graph_input in sym.list_inputs()
-                  if graph_input not in arg_params and graph_input not in 
aux_params]
+    metadata = onnx_mxnet.get_model_metadata('super_resolution.onnx')
+    data_names = [input_name[0] for input_name in 
metadata.get('input_tensor_data')]
     # create module
     mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None)
     mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_img.shape)])
diff --git a/python/mxnet/contrib/onnx/__init__.py 
b/python/mxnet/contrib/onnx/__init__.py
index 169ac67..fb8488c 100644
--- a/python/mxnet/contrib/onnx/__init__.py
+++ b/python/mxnet/contrib/onnx/__init__.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 """Module for ONNX model format support for Apache MXNet."""
 
-from ._import.import_model import import_model
+from ._import.import_model import import_model, get_model_metadata
diff --git a/python/mxnet/contrib/onnx/_import/import_model.py 
b/python/mxnet/contrib/onnx/_import/import_model.py
index 1bd4b41..4e4d786 100644
--- a/python/mxnet/contrib/onnx/_import/import_model.py
+++ b/python/mxnet/contrib/onnx/_import/import_model.py
@@ -52,3 +52,33 @@ def import_model(model_file):
     model_proto = onnx.load(model_file)
     sym, arg_params, aux_params = graph.from_onnx(model_proto.graph)
     return sym, arg_params, aux_params
+
+def get_model_metadata(model_file):
+    """
+    Returns the name and shape information of input and output tensors of the 
given ONNX model file.
+
+    Parameters
+    ----------
+    model_file : str
+        ONNX model file name
+
+    Returns
+    -------
+    model_metadata : dict
+        A dictionary object mapping various metadata to its corresponding 
value.
+        The dictionary will have the following template.
+        {
+            'input_tensor_data' : <list of tuples representing the shape of 
the input paramters>,
+            'output_tensor_data' : <list of tuples representing the shape of 
the output
+                                    of the model>
+        }
+    """
+    graph = GraphProto()
+    try:
+        import onnx
+    except ImportError:
+        raise ImportError("Onnx and protobuf need to be installed. "
+                          + "Instructions to install - 
https://github.com/onnx/onnx";)
+    model_proto = onnx.load(model_file)
+    metadata = graph.get_graph_metadata(model_proto.graph)
+    return metadata
diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py 
b/python/mxnet/contrib/onnx/_import/import_onnx.py
index 5192c6f..db23357 100644
--- a/python/mxnet/contrib/onnx/_import/import_onnx.py
+++ b/python/mxnet/contrib/onnx/_import/import_onnx.py
@@ -132,6 +132,29 @@ class GraphProto(object): # pylint: 
disable=too-few-public-methods
             out = out[0]
         return out, argDict, auxDict
 
+    def get_graph_metadata(self, graph):
+        """
+        Get the model metadata from a given onnx graph.
+        """
+        _params = set()
+        for tensor_vals in graph.initializer:
+            _params.add(tensor_vals.name)
+
+        input_data = []
+        for graph_input in graph.input:
+            if graph_input.name not in _params:
+                shape = [val.dim_value for val in 
graph_input.type.tensor_type.shape.dim]
+                input_data.append((graph_input.name, tuple(shape)))
+
+        output_data = []
+        for graph_out in graph.output:
+            shape = [val.dim_value for val in 
graph_out.type.tensor_type.shape.dim]
+            output_data.append((graph_out.name, tuple(shape)))
+        metadata = {'input_tensor_data' : input_data,
+                    'output_tensor_data' : output_data
+                   }
+        return metadata
+
     def _parse_array(self, tensor_proto):
         """Grab data in TensorProto and convert to numpy array."""
         try:
diff --git a/tests/python-pytest/onnx/onnx_test.py 
b/tests/python-pytest/onnx/onnx_test.py
index e75ef69..b3718c9 100644
--- a/tests/python-pytest/onnx/onnx_test.py
+++ b/tests/python-pytest/onnx/onnx_test.py
@@ -186,12 +186,17 @@ def test_bvlc_googlenet():
     model_path, inputs, outputs = get_test_files('bvlc_googlenet')
     logging.info("Translating Googlenet model from ONNX to Mxnet")
     sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+    metadata = onnx_mxnet.get_model_metadata(model_path)
+    assert len(metadata) == 2
+    assert metadata.get('input_tensor_data')
+    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
+    assert metadata.get('output_tensor_data')
+    assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))]
+    data_names = [input_name[0] for input_name in 
metadata.get('input_tensor_data')]
 
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in 
aux_params]
         mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
         mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
@@ -210,12 +215,17 @@ def test_bvlc_reference_caffenet():
     model_path, inputs, outputs = get_test_files('bvlc_reference_caffenet')
     logging.info("Translating Caffenet model from ONNX to Mxnet")
     sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+    metadata = onnx_mxnet.get_model_metadata(model_path)
+    assert len(metadata) == 2
+    assert metadata.get('input_tensor_data')
+    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
+    assert metadata.get('output_tensor_data')
+    assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))]
+    data_names = [input_name[0] for input_name in 
metadata.get('input_tensor_data')]
 
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in 
aux_params]
         mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
         mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,
@@ -234,12 +244,17 @@ def test_bvlc_rcnn_ilsvrc13():
     model_path, inputs, outputs = 
get_test_files('bvlc_reference_rcnn_ilsvrc13')
     logging.info("Translating rcnn_ilsvrc13 model from ONNX to Mxnet")
     sym, arg_params, aux_params = onnx_mxnet.import_model(model_path)
+    metadata = onnx_mxnet.get_model_metadata(model_path)
+    assert len(metadata) == 2
+    assert metadata.get('input_tensor_data')
+    assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))]
+    assert metadata.get('output_tensor_data')
+    assert metadata.get('output_tensor_data') == [(u'fc-rcnn_1', (1, 200))]
+    data_names = [input_name[0] for input_name in 
metadata.get('input_tensor_data')]
 
     # run test for each test file
     for input_data, output_data in zip(inputs, outputs):
         # create module
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in 
aux_params]
         mod = mx.mod.Module(symbol=sym, data_names=data_names, 
context=mx.cpu(), label_names=None)
         mod.bind(for_training=False, data_shapes=[(data_names[0], 
input_data.shape)], label_shapes=None)
         mod.set_params(arg_params=arg_params, aux_params=aux_params,

-- 
To stop receiving notification emails like this one, please contact
anirudh2...@apache.org.

Reply via email to