This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6bd55f0c90 [Relax][PyTorch] full.default, full_like.default, 
ones.default  (#17832)
6bd55f0c90 is described below

commit 6bd55f0c90c74d667afe9b2aba887a33a90ae84d
Author: Hugo Latendresse <85399628+hugolatendre...@users.noreply.github.com>
AuthorDate: Mon Apr 14 05:52:30 2025 -0400

    [Relax][PyTorch] full.default, full_like.default, ones.default  (#17832)
    
    * unit test
    
    * full.default
    
    * linting
    
    * ones ok
    
    * tests for ones, full, and full like work
---
 .../frontend/torch/base_fx_graph_translator.py     | 38 +++++++++++++++++
 .../frontend/torch/exported_program_translator.py  |  3 ++
 python/tvm/relax/frontend/torch/fx_translator.py   | 33 ---------------
 tests/python/relax/test_from_exported_to_cuda.py   | 48 ++++++++++++++++++++++
 4 files changed, 89 insertions(+), 33 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index c9c6afd71a..3018b0db77 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1271,6 +1271,28 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         value = args[1] if isinstance(args[1], relax.Expr) else 
relax.const(args[1], dtype)
         return self.block_builder.emit(relax.op.full(x.struct_info.shape, 
value, dtype))
 
+    def _full(self, node: fx.Node) -> relax.Var:
+        import torch
+
+        args = self.retrieve_args(node)
+        size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+        )
+        value = args[1] if isinstance(args[1], relax.expr.Constant) else 
relax.const(args[1], dtype)
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                value,
+                dtype,
+            )
+        )
+
+    def _full_like(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        fill_value = relax.const(node.args[1])
+        return self.block_builder.emit(relax.op.full_like(x, fill_value))
+
     def _index_select(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dim = node.args[1]
@@ -1292,6 +1314,22 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             )
         )
 
+    def _ones(self, node: fx.Node) -> relax.Var:
+        import torch
+
+        args = self.retrieve_args(node)
+        size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+        )
+        return self.block_builder.emit(
+            relax.op.full(
+                size,
+                relax.const(1, dtype),
+                dtype,
+            )
+        )
+
     ########## DataType ##########
 
     def _to(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index c398cc4558..7b9587b675 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -442,10 +442,13 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "empty.memory_format": self._empty,
             "empty_like.default": self._empty_like,
             "fill.Scalar": self._fill,
+            "full.default": self._full,
+            "full_like.default": self._full_like,
             "index_select.default": self._index_select,
             "lift_fresh_copy.default": self._to_copy,
             "new_ones.default": self._new_ones,
             "one_hot.default": self._one_hot,
+            "ones.default": self._ones,
             # datatype
             "to.dtype": self._to,
             "to.dtype_layout": self._to,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index f6dd235d5a..d24d67105e 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -468,23 +468,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         self.env[node.args[0]] = filled
         return filled
 
-    def _full(self, node: fx.Node) -> relax.Var:
-        import torch
-
-        args = self.retrieve_args(node)
-        size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
-        dtype = self._convert_data_type(
-            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
-        )
-        value = args[1] if isinstance(args[1], relax.expr.Constant) else 
relax.const(args[1], dtype)
-        return self.block_builder.emit(
-            relax.op.full(
-                size,
-                value,
-                dtype,
-            )
-        )
-
     def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         mask = self.env[node.args[1]]
@@ -527,22 +510,6 @@ class TorchFXImporter(BaseFXGraphImporter):
             mask = self.block_builder.emit(relax.op.broadcast_to(mask, 
x.struct_info.shape))
         return self.block_builder.emit(relax.op.where(mask, gathered_source, 
x))
 
-    def _ones(self, node: fx.Node) -> relax.Var:
-        import torch
-
-        args = self.retrieve_args(node)
-        size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) 
else (args[0],))
-        dtype = self._convert_data_type(
-            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
-        )
-        return self.block_builder.emit(
-            relax.op.full(
-                size,
-                relax.const(1, dtype),
-                dtype,
-            )
-        )
-
     def _one_hot(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         num_classes = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("num_classes")
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index 8405f48576..e92855885e 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -63,6 +63,54 @@ def 
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
         np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, 
atol=1e-5)
 
 
+@tvm.testing.parametrize_targets("cuda")
+def test_full(target, dev):
+    class FullModel(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return torch.full((2, 3), 3.141592)
+
+    torch_module = FullModel().eval()
+
+    raw_data = np.random.rand(3, 3).astype("float32")
+
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
+@tvm.testing.parametrize_targets("cuda")
+def test_full_like(target, dev):
+    class FullLike(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.fill_value = 7.0
+
+        def forward(self, x):
+            return torch.full_like(x, self.fill_value)
+
+    torch_module = FullLike().eval()
+    raw_data = np.random.rand(2, 3).astype("float32")
+
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
+@tvm.testing.parametrize_targets("cuda")
+def test_ones(target, dev):
+    class FullModel(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return torch.ones((2, 3))
+
+    torch_module = FullModel().eval()
+
+    raw_data = np.random.rand(1, 1).astype("float32")
+
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_tensor_clamp(target, dev):
     class ClampBothTensor(torch.nn.Module):

Reply via email to