This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new c35c9fd3a5 [Relay][PyTorch] Add aten::lerp (#12167) c35c9fd3a5 is described below commit c35c9fd3a5249cfb01093b08b35979db846dfa33 Author: xndcn <xnd...@gmail.com> AuthorDate: Thu Jul 28 12:59:30 2022 +0800 [Relay][PyTorch] Add aten::lerp (#12167) --- python/tvm/relay/frontend/pytorch.py | 11 +++++++++++ tests/python/frontend/pytorch/test_forward.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b88e08b719..1bd3232871 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -343,6 +343,16 @@ class PyTorchOpConverter: diag_input = _op.zeros(input_shape, dtype=input_types[0]) return _op.matrix_set_diag(data, diag_input, k=(k1, k2)) + def lerp(self, inputs, input_types): + if len(inputs) != 3: + msg = "Wrong number of arguments (%d) to parse." % (len(inputs)) + raise AssertionError(msg) + + start = inputs[0] + end = inputs[1] + weight = inputs[2] + return start + weight * (end - start) + def arange(self, inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype @@ -3412,6 +3422,7 @@ class PyTorchOpConverter: "aten::stft": self.stft, "aten::mul": self.make_elemwise("multiply"), "aten::pow": self.make_elemwise("power"), + "aten::lerp": self.lerp, "aten::arange": self.arange, "aten::meshgrid": self.meshgrid, "aten::div": self.make_elemwise("divide"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6d7926396a..4332f3efe5 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4596,5 +4596,20 @@ def test_softmax_fuse(): tvm.testing.assert_allclose(out, output_torch, rtol=1e-5, atol=1e-5) +@tvm.testing.uses_gpu +def test_lerp(): + def test_fn(x, y, w): + return torch.lerp(x, y, w) + + input_shape = [16] + x = torch.rand(input_shape).float() + y = torch.rand(input_shape).float() + w = torch.rand(input_shape).float() + + # weight can be tensor or scalar + verify_model(test_fn, [x, y, w]) + verify_model(test_fn, [x, y, w[0]]) + + if __name__ == "__main__": pytest.main([__file__])