This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e74f7eee32 [Relax] Add support to ingest Tensor.expand_as() (#17724)
e74f7eee32 is described below

commit e74f7eee32059d816c491983a8a4651d936cfee8
Author: Hugo Latendresse <[email protected]>
AuthorDate: Tue Mar 11 10:37:44 2025 -0400

    [Relax] Add support to ingest Tensor.expand_as() (#17724)
    
    Support to ingest Tensor.expand_as(), with unit test for correctness
---
 .../frontend/torch/base_fx_graph_translator.py     |  8 ++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 python/tvm/relax/frontend/torch/fx_translator.py   |  1 +
 tests/python/relax/test_from_exported_to_cuda.py   | 47 ++++++++++++++++++++++
 4 files changed, 57 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d5cad2381b..a0f00e1f4b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -883,6 +883,14 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
                 broadcast_shape.append(i)
         return self.block_builder.emit(relax.op.broadcast_to(args[0], 
broadcast_shape))
 
+    def _expand_as(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        # args[0] is the 'self' tensor
+        # args[1] is the 'other' tensor
+        data = args[0]
+        other_shape = self.shape_of(args[1])  # the shape of 'other'
+        return self.block_builder.emit(relax.op.broadcast_to(data, 
other_shape))
+
     def _flip(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", 
None)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4ff31ea1d7..2103365c6c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -298,6 +298,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "copy_.default": self._copy_,
             "cumsum.default": self._cumsum,
             "expand.default": self._expand,
+            "expand_as.default": self._expand_as,
             "permute.default": self._permute,
             "repeat.default": self._repeat,
             "select.int": self._select,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 29d959818f..abda5088db 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -749,6 +749,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "contiguous": lambda node: self.env[node.args[0]],
             "cumsum": self._cumsum,
             "expand": self._expand,
+            "expand_as.default": self._expand_as,
             "flatten": self._flatten,
             "flip": self._flip,
             "gather": self._gather,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py 
b/tests/python/relax/test_from_exported_to_cuda.py
index bd4bdcf617..e8b5da0dc2 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -56,6 +56,53 @@ def 
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
     np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, 
atol=1e-5)
 
 
[email protected]_targets("cuda")
+def test_tensor_expand_as(target, dev):
+    class ExpandAs0(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.template = torch.ones((1, 1, 1, 1))
+
+        def forward(self, x):
+            return self.template.expand_as(x)
+
+    class ExpandAs1(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.template = torch.ones((2, 1, 4, 1))
+
+        def forward(self, x):
+            return self.template.expand_as(x)
+
+    class ExpandAs2(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.template = torch.ones((2, 1, 1, 10))
+
+        def forward(self, x):
+            return self.template.expand_as(x)
+
+    class ExpandAs3(torch.nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.template = torch.ones((2, 3, 1, 1))
+
+        def forward(self, x):
+            return self.template.expand_as(x)
+
+    raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32)
+
+    torch_module0 = ExpandAs0().eval()
+    torch_module1 = ExpandAs1().eval()
+    torch_module2 = ExpandAs2().eval()
+    torch_module3 = ExpandAs3().eval()
+
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, 
target, dev)
+    assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, 
target, dev)
+
+
 @tvm.testing.parametrize_targets("cuda")
 def test_copy_(target, dev):
     class CopyTester(nn.Module):

Reply via email to