This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 2d57470 [ONNX]fix datatype on Reciprocal op (#7519) 2d57470 is described below commit 2d5747054ca05a0863236b317e2fed281b455a00 Author: Matthew Brookhart <mbrookh...@octoml.ai> AuthorDate: Fri Feb 26 14:05:22 2021 -0700 [ONNX]fix datatype on Reciprocal op (#7519) * fix datatype on Reciprocal op * clean up test case --- python/tvm/relay/frontend/onnx.py | 3 ++- tests/python/frontend/onnx/test_forward.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 58c2dbc..860753d 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -839,7 +839,8 @@ class Reciprocal(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - return _expr.const(1.0) / inputs[0] + dtype = infer_type(inputs[0]).checked_type.dtype + return _expr.const(1.0, dtype=dtype) / inputs[0] class Flatten(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8dbd049..1e13416 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1830,23 +1830,26 @@ def test_unary_ops(): dtype = "float32" out_shape = in_shape - def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5): + def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"): + x = x.astype(dtype) + ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] z = helper.make_node(op, ["in1"], ["out"]) graph = helper.make_graph( [z], "_test", inputs=[ - helper.make_tensor_value_info("in1", TensorProto.FLOAT, list(in_shape)), + helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)), ], - outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))], + outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))], ) model = helper.make_model(graph, producer_name="_test") verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol) - x = np.random.uniform(size=in_shape).astype(dtype) + x = np.random.uniform(size=in_shape) verify_unary_ops("Neg", x) verify_unary_ops("Abs", x) verify_unary_ops("Reciprocal", x) + verify_unary_ops("Reciprocal", x, dtype="float16") verify_unary_ops("Sqrt", x) verify_unary_ops("Relu", x) verify_unary_ops("Exp", x)