This is an automated email from the ASF dual-hosted git repository. anirudh2290 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 7641759 [MXNET-309] [ONNX-MXNet] Model Metadata API (#10512) 7641759 is described below commit 76417594e56a85ec0cc9412b9dd2c7e2ab581d8b Author: Anirudh <anirudhk...@gmail.com> AuthorDate: Wed May 16 20:36:54 2018 -0700 [MXNET-309] [ONNX-MXNet] Model Metadata API (#10512) * metadata api * pylint changes * move logic to import_onnx * test fix * doc API * rerun CI. * fix comments * docs fix --- docs/api/python/contrib/onnx.md | 6 +++-- docs/tutorials/onnx/inference_on_onnx_model.md | 19 ++++++++++---- example/onnx/super_resolution.py | 6 ++--- python/mxnet/contrib/onnx/__init__.py | 3 +-- python/mxnet/contrib/onnx/_import/import_model.py | 30 +++++++++++++++++++++++ python/mxnet/contrib/onnx/_import/import_onnx.py | 23 +++++++++++++++++ tests/python-pytest/onnx/onnx_test.py | 27 +++++++++++++++----- 7 files changed, 95 insertions(+), 19 deletions(-) diff --git a/docs/api/python/contrib/onnx.md b/docs/api/python/contrib/onnx.md index 44aabaf..6fb546f 100644 --- a/docs/api/python/contrib/onnx.md +++ b/docs/api/python/contrib/onnx.md @@ -13,7 +13,7 @@ With ONNX format support for MXNet, developers can build and train models with a ``` ### Installation Instructions -- To use this module developers need to **install ONNX**, which requires protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). Once installed, you can go through the tutorials on how to use this module. +- To use this module developers need to **install ONNX**, which requires the protobuf compiler to be installed separately. Please follow the [instructions to install ONNX and its dependencies](https://github.com/onnx/onnx#installation). **MXNet currently supports ONNX v1.1.1**. Once installed, you can go through the tutorials on how to use this module. This document describes all the ONNX-MXNet APIs. @@ -23,6 +23,7 @@ This document describes all the ONNX-MXNet APIs. :nosignatures: mxnet.contrib.onnx.import_model + mxnet.contrib.onnx.get_model_metadata ``` ## ONNX Tutorials @@ -43,7 +44,8 @@ This document describes all the ONNX-MXNet APIs. ```eval_rst .. automodule:: mxnet.contrib.onnx - :members: import_model + :members: import_model + :members: get_model_metadata ``` diff --git a/docs/tutorials/onnx/inference_on_onnx_model.md b/docs/tutorials/onnx/inference_on_onnx_model.md index f342dad..3d4072a 100644 --- a/docs/tutorials/onnx/inference_on_onnx_model.md +++ b/docs/tutorials/onnx/inference_on_onnx_model.md @@ -104,17 +104,26 @@ We pick a context, GPU if available, otherwise CPU ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu() ``` -We obtain the data names of the inputs to the model, by listing all the inputs to the symbol graph and excluding the argument and auxiliary parameters from that list: +We obtain the data names of the inputs to the model by using the model metadata API: ```python -data_names = [graph_input for graph_input in sym.list_inputs() - if graph_input not in arg_params and graph_input not in aux_params] -print(data_names) +model_metadata = onnx_mxnet.get_model_metadata(onnx_path) +print(model_metadata) ``` +``` +{'output_tensor_data': [(u'gpu_0/softmax_1', (1L, 1000L))], + 'input_tensor_data': [(u'gpu_0/data_0', (1L, 3L, 224L, 224L))]} +``` -```['gpu_0/data_0']``` +```python +data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')] +print(data_names) +``` +``` +[u'gpu_0/data_0'] +``` And load them into a MXNet Gluon symbol block. diff --git a/example/onnx/super_resolution.py b/example/onnx/super_resolution.py index a52f1a8..fcb8ccc 100644 --- a/example/onnx/super_resolution.py +++ b/example/onnx/super_resolution.py @@ -55,10 +55,8 @@ def get_test_image(): def perform_inference(sym, arg_params, aux_params, input_img, img_cb, img_cr): """Perform inference on image using mxnet""" - # To fetch the data names of the input to the model we list the inputs of the symbol graph - # and exclude the argument and auxiliary parameters from the list - data_names = [graph_input for graph_input in sym.list_inputs() - if graph_input not in arg_params and graph_input not in aux_params] + metadata = onnx_mxnet.get_model_metadata('super_resolution.onnx') + data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] # create module mod = mx.mod.Module(symbol=sym, data_names=data_names, label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_img.shape)]) diff --git a/python/mxnet/contrib/onnx/__init__.py b/python/mxnet/contrib/onnx/__init__.py index 169ac67..fb8488c 100644 --- a/python/mxnet/contrib/onnx/__init__.py +++ b/python/mxnet/contrib/onnx/__init__.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - """Module for ONNX model format support for Apache MXNet.""" -from ._import.import_model import import_model +from ._import.import_model import import_model, get_model_metadata diff --git a/python/mxnet/contrib/onnx/_import/import_model.py b/python/mxnet/contrib/onnx/_import/import_model.py index 1bd4b41..4e4d786 100644 --- a/python/mxnet/contrib/onnx/_import/import_model.py +++ b/python/mxnet/contrib/onnx/_import/import_model.py @@ -52,3 +52,33 @@ def import_model(model_file): model_proto = onnx.load(model_file) sym, arg_params, aux_params = graph.from_onnx(model_proto.graph) return sym, arg_params, aux_params + +def get_model_metadata(model_file): + """ + Returns the name and shape information of input and output tensors of the given ONNX model file. + + Parameters + ---------- + model_file : str + ONNX model file name + + Returns + ------- + model_metadata : dict + A dictionary object mapping various metadata to its corresponding value. + The dictionary will have the following template. + { + 'input_tensor_data' : <list of tuples representing the shape of the input paramters>, + 'output_tensor_data' : <list of tuples representing the shape of the output + of the model> + } + """ + graph = GraphProto() + try: + import onnx + except ImportError: + raise ImportError("Onnx and protobuf need to be installed. " + + "Instructions to install - https://github.com/onnx/onnx") + model_proto = onnx.load(model_file) + metadata = graph.get_graph_metadata(model_proto.graph) + return metadata diff --git a/python/mxnet/contrib/onnx/_import/import_onnx.py b/python/mxnet/contrib/onnx/_import/import_onnx.py index 5192c6f..db23357 100644 --- a/python/mxnet/contrib/onnx/_import/import_onnx.py +++ b/python/mxnet/contrib/onnx/_import/import_onnx.py @@ -132,6 +132,29 @@ class GraphProto(object): # pylint: disable=too-few-public-methods out = out[0] return out, argDict, auxDict + def get_graph_metadata(self, graph): + """ + Get the model metadata from a given onnx graph. + """ + _params = set() + for tensor_vals in graph.initializer: + _params.add(tensor_vals.name) + + input_data = [] + for graph_input in graph.input: + if graph_input.name not in _params: + shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim] + input_data.append((graph_input.name, tuple(shape))) + + output_data = [] + for graph_out in graph.output: + shape = [val.dim_value for val in graph_out.type.tensor_type.shape.dim] + output_data.append((graph_out.name, tuple(shape))) + metadata = {'input_tensor_data' : input_data, + 'output_tensor_data' : output_data + } + return metadata + def _parse_array(self, tensor_proto): """Grab data in TensorProto and convert to numpy array.""" try: diff --git a/tests/python-pytest/onnx/onnx_test.py b/tests/python-pytest/onnx/onnx_test.py index e75ef69..b3718c9 100644 --- a/tests/python-pytest/onnx/onnx_test.py +++ b/tests/python-pytest/onnx/onnx_test.py @@ -186,12 +186,17 @@ def test_bvlc_googlenet(): model_path, inputs, outputs = get_test_files('bvlc_googlenet') logging.info("Translating Googlenet model from ONNX to Mxnet") sym, arg_params, aux_params = onnx_mxnet.import_model(model_path) + metadata = onnx_mxnet.get_model_metadata(model_path) + assert len(metadata) == 2 + assert metadata.get('input_tensor_data') + assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))] + assert metadata.get('output_tensor_data') + assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))] + data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module - data_names = [graph_input for graph_input in sym.list_inputs() - if graph_input not in arg_params and graph_input not in aux_params] mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, @@ -210,12 +215,17 @@ def test_bvlc_reference_caffenet(): model_path, inputs, outputs = get_test_files('bvlc_reference_caffenet') logging.info("Translating Caffenet model from ONNX to Mxnet") sym, arg_params, aux_params = onnx_mxnet.import_model(model_path) + metadata = onnx_mxnet.get_model_metadata(model_path) + assert len(metadata) == 2 + assert metadata.get('input_tensor_data') + assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))] + assert metadata.get('output_tensor_data') + assert metadata.get('output_tensor_data') == [(u'prob_1', (1, 1000))] + data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module - data_names = [graph_input for graph_input in sym.list_inputs() - if graph_input not in arg_params and graph_input not in aux_params] mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, @@ -234,12 +244,17 @@ def test_bvlc_rcnn_ilsvrc13(): model_path, inputs, outputs = get_test_files('bvlc_reference_rcnn_ilsvrc13') logging.info("Translating rcnn_ilsvrc13 model from ONNX to Mxnet") sym, arg_params, aux_params = onnx_mxnet.import_model(model_path) + metadata = onnx_mxnet.get_model_metadata(model_path) + assert len(metadata) == 2 + assert metadata.get('input_tensor_data') + assert metadata.get('input_tensor_data') == [(u'data_0', (1, 3, 224, 224))] + assert metadata.get('output_tensor_data') + assert metadata.get('output_tensor_data') == [(u'fc-rcnn_1', (1, 200))] + data_names = [input_name[0] for input_name in metadata.get('input_tensor_data')] # run test for each test file for input_data, output_data in zip(inputs, outputs): # create module - data_names = [graph_input for graph_input in sym.list_inputs() - if graph_input not in arg_params and graph_input not in aux_params] mod = mx.mod.Module(symbol=sym, data_names=data_names, context=mx.cpu(), label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0], input_data.shape)], label_shapes=None) mod.set_params(arg_params=arg_params, aux_params=aux_params, -- To stop receiving notification emails like this one, please contact anirudh2...@apache.org.