This is an automated email from the ASF dual-hosted git repository. mshr pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 6bd55f0c90 [Relax][PyTorch] full.default, full_like.default, ones.default (#17832) 6bd55f0c90 is described below commit 6bd55f0c90c74d667afe9b2aba887a33a90ae84d Author: Hugo Latendresse <85399628+hugolatendre...@users.noreply.github.com> AuthorDate: Mon Apr 14 05:52:30 2025 -0400 [Relax][PyTorch] full.default, full_like.default, ones.default (#17832) * unit test * full.default * linting * ones ok * tests for ones, full, and full like work --- .../frontend/torch/base_fx_graph_translator.py | 38 +++++++++++++++++ .../frontend/torch/exported_program_translator.py | 3 ++ python/tvm/relax/frontend/torch/fx_translator.py | 33 --------------- tests/python/relax/test_from_exported_to_cuda.py | 48 ++++++++++++++++++++++ 4 files changed, 89 insertions(+), 33 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c9c6afd71a..3018b0db77 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1271,6 +1271,28 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta): value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) return self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + + def _full_like(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + fill_value = relax.const(node.args[1]) + return self.block_builder.emit(relax.op.full_like(x, fill_value)) + def _index_select(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] @@ -1292,6 +1314,22 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta): ) ) + def _ones(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) + ########## DataType ########## def _to(self, node: fx.Node) -> relax.Var: diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c398cc4558..7b9587b675 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -442,10 +442,13 @@ class ExportedProgramImporter(BaseFXGraphImporter): "empty.memory_format": self._empty, "empty_like.default": self._empty_like, "fill.Scalar": self._fill, + "full.default": self._full, + "full_like.default": self._full_like, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "new_ones.default": self._new_ones, "one_hot.default": self._one_hot, + "ones.default": self._ones, # datatype "to.dtype": self._to, "to.dtype_layout": self._to, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index f6dd235d5a..d24d67105e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -468,23 +468,6 @@ class TorchFXImporter(BaseFXGraphImporter): self.env[node.args[0]] = filled return filled - def _full(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) - return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) - ) - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] mask = self.env[node.args[1]] @@ -527,22 +510,6 @@ class TorchFXImporter(BaseFXGraphImporter): mask = self.block_builder.emit(relax.op.broadcast_to(mask, x.struct_info.shape)) return self.block_builder.emit(relax.op.where(mask, gathered_source, x)) - def _ones(self, node: fx.Node) -> relax.Var: - import torch - - args = self.retrieve_args(node) - size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) - dtype = self._convert_data_type( - node.kwargs.get("dtype", torch.get_default_dtype()), self.env - ) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) - ) - def _one_hot(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] num_classes = node.args[1] if len(node.args) > 1 else node.kwargs.get("num_classes") diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index 8405f48576..e92855885e 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -63,6 +63,54 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_full(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.full((2, 3), 3.141592) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(3, 3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_full_like(target, dev): + class FullLike(nn.Module): + def __init__(self): + super().__init__() + self.fill_value = 7.0 + + def forward(self, x): + return torch.full_like(x, self.fill_value) + + torch_module = FullLike().eval() + raw_data = np.random.rand(2, 3).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + +@tvm.testing.parametrize_targets("cuda") +def test_ones(target, dev): + class FullModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.ones((2, 3)) + + torch_module = FullModel().eval() + + raw_data = np.random.rand(1, 1).astype("float32") + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_tensor_clamp(target, dev): class ClampBothTensor(torch.nn.Module):