This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2cafa87b10 [Bugfix][Relay] Fix threshold calculation logic in PyTorch 
frontend (#14820)
2cafa87b10 is described below

commit 2cafa87b10c6124f1a08af7ead712f29b9039762
Author: Qingchao Shen <qingchaos...@outlook.com>
AuthorDate: Thu May 11 18:59:04 2023 +0800

    [Bugfix][Relay] Fix threshold calculation logic in PyTorch frontend (#14820)
    
    * fix threshold
    
    * add test case
    
    * Update pytorch.py
    
    * Update pytorch.py
---
 python/tvm/relay/frontend/pytorch.py          | 6 +++++-
 tests/python/frontend/pytorch/test_forward.py | 2 ++
 2 files changed, 7 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index afd46b2001..5e2e6a5f5e 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1333,7 +1333,11 @@ class PyTorchOpConverter:
 
     def threshold(self, inputs, input_types):
         data = inputs[0]
-        return _op.nn.relu(data)
+        threshold_f = float(inputs[1])
+        threshold_ = _op.full_like(inputs[0], 
fill_value=_expr.const(threshold_f))
+        value_f = float(inputs[2])
+        value = _op.full_like(inputs[0], fill_value=_expr.const(value_f))
+        return _op.where(_op.greater(data, threshold_), data, value)
 
     def contiguous(self, inputs, input_types):
         return inputs[0]
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 9e5e9e22bc..fcaf7b7847 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1348,6 +1348,8 @@ def test_forward_threshold():
     input_shape = [1, 3]
     input_data = torch.rand(input_shape).float()
     verify_model(torch.nn.Threshold(0, 0).float().eval(), 
input_data=input_data)
+    input_data = torch.tensor([[-1.0, 2.0]], dtype=torch.float32)
+    verify_model(torch.nn.Threshold(1, 1).float().eval(), 
input_data=input_data)
 
 
 @tvm.testing.uses_gpu

Reply via email to