This is an automated email from the ASF dual-hosted git repository.

thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cb627bc  Onnx multi output (#13390)
cb627bc is described below

commit cb627bcaccc127a00ab035a2a3006e5cbb6d501d
Author: Sina Afrooze <sina....@gmail.com>
AuthorDate: Mon Nov 26 00:00:08 2018 -0800

    Onnx multi output (#13390)
    
    * Fix ONNX export to support multi-output graphs
    
    * Add ONNX unit-test
    
    * Added multi-output shape inference.
    
    - Removed unnecessary forward_pass() call
    - Modified infer_output_shape to return multiple shapes for multiple 
outputs as well as output names.
    
    * Fixed pylint
---
 python/mxnet/contrib/onnx/mx2onnx/export_onnx.py   | 128 +++++++--------------
 .../python-pytest/onnx/export/mxnet_export_test.py |  76 ++++++++++++
 2 files changed, 119 insertions(+), 85 deletions(-)

diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py 
b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
index b02d970..14c674f 100644
--- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
+++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
@@ -53,12 +53,8 @@ from __future__ import print_function
 from __future__ import unicode_literals
 import logging
 import json
-import numpy as np
 
-from .... import context
 from .... import ndarray as nd
-from .... import io
-from .... import module as mod
 
 
 class MXNetGraph(object):
@@ -96,60 +92,6 @@ class MXNetGraph(object):
         return convert_func(node, **kwargs)
 
     @staticmethod
-    def forward_pass(inputs, sym, arg_params, aux_params, output_label):
-        """Do a forward pass based on the sym and params to get the shape
-        of the output using dummy data
-
-        Parameters
-        ----------
-        inputs   : json string
-
-        sym : :class:`~mxnet.symbol.Symbol`
-            MXNet symbol object
-        arg_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
-            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
-        aux_params : dict of ``str`` to :class:`~mxnet.ndarray.NDArray`
-            Dict of converted parameters stored in ``mxnet.ndarray.NDArray`` 
format
-
-        Returns
-        -------
-        shape : Shape
-            Output shape
-        """
-        # if label is not provided, MXNet adds label "softmax_label" by default
-        # while running load_checkpoint which is not actually a graph input. 
So ignoring it here
-        data_names = [graph_input for graph_input in sym.list_inputs()
-                      if graph_input not in arg_params and graph_input not in 
aux_params
-                      and graph_input != output_label]
-
-        data_shapes = []
-        # Adding extra dimension of batch_size 1 if the batch_size is 
different for multiple inputs.
-        for idx, input_name in enumerate(data_names):
-            data_shapes.append((input_name, inputs[idx].shape))
-
-        # create module, passing cpu context
-        ctx = context.cpu()
-        test_mod = mod.Module(symbol=sym, data_names=data_names, context=ctx, 
label_names=None)
-        test_mod.bind(for_training=False, data_shapes=data_shapes, 
label_shapes=None)
-
-        # initializing parameters for calculating result of each individual 
node
-        if arg_params is None and aux_params is None:
-            test_mod.init_params()
-        else:
-            test_mod.set_params(arg_params=arg_params, aux_params=aux_params, 
allow_missing=True)
-
-        data_forward = []
-        for idx, input_name in enumerate(data_names):
-            val = inputs[idx]
-            data_forward.append(nd.array(val))
-
-        test_mod.forward(io.DataBatch(data_forward))
-        result = test_mod.get_outputs()[0].asnumpy()
-
-        return result.shape
-
-
-    @staticmethod
     def split_params(sym, params):
         """Helper function to split params dictionary into args and aux params
 
@@ -177,15 +119,40 @@ class MXNetGraph(object):
                 aux_params.update({aux: nd.array(params[aux])})
         return arg_params, aux_params
 
-
     @staticmethod
-    def infer_output_shape(sym, params, in_shape, output_label):
-        """Infer output shape by doing a forward pass using dummy inputs """
-        # create dummy input
-        inputs = [np.random.randn(*input_shape) for input_shape in in_shape]
-        arg, aux = MXNetGraph.split_params(sym, params)
-        return MXNetGraph.forward_pass(inputs, sym, arg, aux, output_label)
+    def get_outputs(sym, params, in_shape, in_label):
+        """ Infer output shapes and return dictionary of output name to shape
+
+        :param :class:`~mxnet.symbol.Symbol` sym: symbol to perform infer 
shape on
+        :param dic of (str, nd.NDArray) params:
+        :param list of tuple(int, ...) in_shape: list of all input shapes
+        :param  in_label: name of label typically used in loss that may be 
left in graph. This name is
+            removed from list of inputs required by symbol
+        :return: dictionary of output name to shape
+        :rtype: dict of (str, tuple(int, ...))
+        """
+        # remove any input listed in params from sym.list_inputs() and bind 
them to the input shapes provided
+        # by user. Also remove in_label, which is the name of the label symbol 
that may have been used
+        # as the label for loss during training.
+        inputs = {n: s for n, s in zip([n for n in sym.list_inputs() if n not 
in params and n != in_label], in_shape)}
+        # Add params and their shape to list of inputs
+        inputs.update({n: v.shape for n, v in params.items()})
+        # Provide input data as well as input params to infer_shape()
+        _, out_shapes, _ = sym.infer_shape(**inputs)
+
+        out_names = list()
+        for name in sym.list_outputs():
+            if name.endswith('_output'):
+                out_names.append(name[:-len('_output')])
+            else:
+                logging.warning("output '%s' does not end with '_output'", 
name)
+                out_names.append(name)
 
+        assert len(out_shapes) == len(out_names)
+        # bind output shapes with output names
+        graph_outputs = {n: s for n, s in zip(out_names, out_shapes)}
+
+        return graph_outputs
 
     @staticmethod
     def convert_weights_to_numpy(weights_dict):
@@ -228,9 +195,6 @@ class MXNetGraph(object):
         # Deriving the output_label name.
         output_label = sym.get_internals()[len(sym.get_internals()) - 1].name 
+ "_label"
 
-        # Determine output shape
-        output_shape = MXNetGraph.infer_output_shape(sym, params, in_shape, 
output_label)
-
         weights = MXNetGraph.convert_weights_to_numpy(params)
 
         mx_graph = json.loads(sym.tojson())["nodes"]
@@ -242,6 +206,9 @@ class MXNetGraph(object):
         onnx_processed_outputs = []
         index_lookup = []
 
+        # Determine output shape
+        graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, 
output_label)
+
         graph_input_idx = 0
         for idx, node in enumerate(mx_graph):
             op = node["op"]
@@ -294,24 +261,15 @@ class MXNetGraph(object):
                     # If converted node is NodeProto, add it in processed 
nodes list
                     elif isinstance(converted_node, NodeProto):
                         onnx_processed_nodes.append(converted_node)
-                        if idx == (len(mx_graph) - 1):
-                            # If converted node doesnt have name, use it from 
output field
-                            if not converted_node.name:
-                                onnx_processed_outputs.append(
-                                    make_tensor_value_info(
-                                        name=converted_node.output[0],
-                                        elem_type=in_type,
-                                        shape=output_shape
-                                    )
-                                )
-                            else:
-                                onnx_processed_outputs.append(
-                                    make_tensor_value_info(
-                                        name=converted_node.name,
-                                        elem_type=in_type,
-                                        shape=output_shape
-                                    )
+                        node_name = converted_node.name if converted_node.name 
else converted_node.output[0]
+                        if node_name in graph_outputs:
+                            onnx_processed_outputs.append(
+                                make_tensor_value_info(
+                                    name=node_name,
+                                    elem_type=in_type,
+                                    shape=graph_outputs[node_name]
                                 )
+                            )
                             if verbose:
                                 logging.info("Output node is: %s", 
converted_node.name)
                     elif isinstance(converted_node, TensorProto):
diff --git a/tests/python-pytest/onnx/export/mxnet_export_test.py 
b/tests/python-pytest/onnx/export/mxnet_export_test.py
index 9f91369..bbff783 100644
--- a/tests/python-pytest/onnx/export/mxnet_export_test.py
+++ b/tests/python-pytest/onnx/export/mxnet_export_test.py
@@ -28,11 +28,14 @@ import os
 import unittest
 import logging
 import tarfile
+import tempfile
 from collections import namedtuple
 import numpy as np
 import numpy.testing as npt
 from onnx import numpy_helper, helper
 from onnx import TensorProto
+from mxnet import nd, sym
+from mxnet.gluon import nn
 from mxnet.test_utils import download
 from mxnet.contrib import onnx as onnx_mxnet
 import mxnet as mx
@@ -238,6 +241,79 @@ def test_square():
 
     npt.assert_almost_equal(result, numpy_op)
 
+
+def _assert_sym_equal(lhs, rhs):
+    assert lhs.list_inputs() == rhs.list_inputs()  # input names must be 
identical
+    assert len(lhs.list_outputs()) == len(rhs.list_outputs())  # number of 
outputs must be identical
+
+
+def _force_list(output):
+    if isinstance(output, nd.NDArray):
+        return [output]
+    return list(output)
+
+
+def _optional_group(symbols, group=False):
+    if group:
+        return sym.Group(symbols)
+    else:
+        return symbols
+
+
+def _check_onnx_export(net, group_outputs=False):
+    net.initialize()
+    data = nd.random.uniform(0, 1, (1, 1024))
+    output = _force_list(net(data))  # initialize weights
+    net_sym = _optional_group(net(sym.Variable('data')), group_outputs)
+    net_params = {name:param._reduce() for name, param in 
net.collect_params().items()}
+    with tempfile.TemporaryDirectory() as tmpdirname:
+        onnx_file_path = os.path.join(tmpdirname, 'net.onnx')
+        export_path = onnx_mxnet.export_model(
+            sym=net_sym,
+            params=net_params,
+            input_shape=[data.shape],
+            onnx_file_path=onnx_file_path)
+        assert export_path == onnx_file_path
+        # Try importing the model to symbol
+        _assert_sym_equal(net_sym, onnx_mxnet.import_model(export_path)[0])
+
+        # Try importing the model to gluon
+        imported_net = onnx_mxnet.import_to_gluon(export_path, ctx=None)
+        _assert_sym_equal(net_sym, 
_optional_group(imported_net(sym.Variable('data')), group_outputs))
+
+        # Confirm network outputs are the same
+        imported_net_output = _force_list(imported_net(data))
+        for out, imp_out in zip(output, imported_net_output):
+            mx.test_utils.assert_almost_equal(out.asnumpy(), imp_out.asnumpy())
+
+
+@with_seed()
+def test_onnx_export_single_output():
+    net = nn.HybridSequential(prefix='single_output_net')
+    with net.name_scope():
+        net.add(nn.Dense(100, activation='relu'), nn.Dense(10))
+    _check_onnx_export(net)
+
+
+@with_seed()
+def test_onnx_export_multi_output():
+    class MultiOutputBlock(nn.HybridBlock):
+        def __init__(self):
+            super(MultiOutputBlock, self).__init__()
+            with self.name_scope():
+                self.net = nn.HybridSequential()
+                for i in range(10):
+                    self.net.add(nn.Dense(100 + i * 10, activation='relu'))
+
+        def hybrid_forward(self, F, x):
+            out = tuple(block(x) for block in self.net._children.values())
+            return out
+
+    net = MultiOutputBlock()
+    assert len(sym.Group(net(sym.Variable('data'))).list_outputs()) == 10
+    _check_onnx_export(net, group_outputs=True)
+
+
 if __name__ == '__main__':
     test_models("bvlc_googlenet", (1, 3, 224, 224), (1, 1000))
     test_models("bvlc_reference_caffenet", (1, 3, 224, 224), (1, 1000))

Reply via email to