This is an automated email from the ASF dual-hosted git repository. skm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 991bf3b ONNX import: Hardmax (#13717) 991bf3b is described below commit 991bf3b64f186295345a14f3ce4e6a8f364e8bed Author: Vandana Kannan <vandan...@users.noreply.github.com> AuthorDate: Sat Dec 29 08:39:42 2018 -0800 ONNX import: Hardmax (#13717) * ONNX import: Hardmax * Fix lint errors * add github link for issue with reshape --- .../mxnet/contrib/onnx/onnx2mx/_import_helper.py | 5 +++-- .../mxnet/contrib/onnx/onnx2mx/_op_translations.py | 26 ++++++++++++++++++++++ tests/python-pytest/onnx/test_cases.py | 3 ++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py index 2ceabae..5b33f9f 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_import_helper.py @@ -23,7 +23,7 @@ from ._op_translations import add, subtract, multiply, divide, absolute, negativ from ._op_translations import tanh, arccos, arcsin, arctan, _cos, _sin, _tan from ._op_translations import softplus, shape, gather, lp_pooling, size from ._op_translations import ceil, floor, hardsigmoid, global_lppooling -from ._op_translations import concat +from ._op_translations import concat, hardmax from ._op_translations import leaky_relu, _elu, _prelu, _selu, softmax, fully_connected from ._op_translations import global_avgpooling, global_maxpooling, linalg_gemm from ._op_translations import sigmoid, pad, relu, matrix_multiplication, batch_norm @@ -144,5 +144,6 @@ _convert_map = { 'HardSigmoid' : hardsigmoid, 'LpPool' : lp_pooling, 'DepthToSpace' : depthtospace, - 'SpaceToDepth' : spacetodepth + 'SpaceToDepth' : spacetodepth, + 'Hardmax' : hardmax } diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 7028325..ce0e0e5 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -714,3 +714,29 @@ def spacetodepth(attrs, inputs, proto_obj): new_attrs = translation_utils._fix_attribute_names(attrs, {'blocksize':'block_size'}) return "space_to_depth", new_attrs, inputs + + +def hardmax(attrs, inputs, proto_obj): + """Returns batched one-hot vectors.""" + input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0] + input_shape = input_tensor_data[1] + + axis = int(attrs.get('axis', 1)) + axis = axis if axis >= 0 else len(input_shape) + axis + + if axis == len(input_shape) - 1: + amax = symbol.argmax(inputs[0], axis=-1) + one_hot = symbol.one_hot(amax, depth=input_shape[-1]) + return one_hot, attrs, inputs + + # since reshape doesn't take a tensor for shape, + # computing with np.prod. This needs to be changed to + # to use mx.sym.prod() when mx.sym.reshape() is fixed. + # (https://github.com/apache/incubator-mxnet/issues/10789) + new_shape = (int(np.prod(input_shape[:axis])), + int(np.prod(input_shape[axis:]))) + reshape_op = symbol.reshape(inputs[0], new_shape) + amax = symbol.argmax(reshape_op, axis=-1) + one_hot = symbol.one_hot(amax, depth=new_shape[-1]) + hardmax_op = symbol.reshape(one_hot, input_shape) + return hardmax_op, attrs, inputs diff --git a/tests/python-pytest/onnx/test_cases.py b/tests/python-pytest/onnx/test_cases.py index 92e80e0..6a189b6 100644 --- a/tests/python-pytest/onnx/test_cases.py +++ b/tests/python-pytest/onnx/test_cases.py @@ -90,7 +90,8 @@ IMPLEMENTED_OPERATORS_TEST = { 'test_averagepool_2d_strides', 'test_averagepool_3d', 'test_LpPool_', - 'test_split_equal' + 'test_split_equal', + 'test_hardmax' ], 'export': ['test_random_uniform', 'test_random_normal',