This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 2cafa87b10 [Bugfix][Relay] Fix threshold calculation logic in PyTorch frontend (#14820) 2cafa87b10 is described below commit 2cafa87b10c6124f1a08af7ead712f29b9039762 Author: Qingchao Shen <qingchaos...@outlook.com> AuthorDate: Thu May 11 18:59:04 2023 +0800 [Bugfix][Relay] Fix threshold calculation logic in PyTorch frontend (#14820) * fix threshold * add test case * Update pytorch.py * Update pytorch.py --- python/tvm/relay/frontend/pytorch.py | 6 +++++- tests/python/frontend/pytorch/test_forward.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index afd46b2001..5e2e6a5f5e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1333,7 +1333,11 @@ class PyTorchOpConverter: def threshold(self, inputs, input_types): data = inputs[0] - return _op.nn.relu(data) + threshold_f = float(inputs[1]) + threshold_ = _op.full_like(inputs[0], fill_value=_expr.const(threshold_f)) + value_f = float(inputs[2]) + value = _op.full_like(inputs[0], fill_value=_expr.const(value_f)) + return _op.where(_op.greater(data, threshold_), data, value) def contiguous(self, inputs, input_types): return inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9e5e9e22bc..fcaf7b7847 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1348,6 +1348,8 @@ def test_forward_threshold(): input_shape = [1, 3] input_data = torch.rand(input_shape).float() verify_model(torch.nn.Threshold(0, 0).float().eval(), input_data=input_data) + input_data = torch.tensor([[-1.0, 2.0]], dtype=torch.float32) + verify_model(torch.nn.Threshold(1, 1).float().eval(), input_data=input_data) @tvm.testing.uses_gpu