This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e74f7eee32 [Relax] Add support to ingest Tensor.expand_as() (#17724)
e74f7eee32 is described below
commit e74f7eee32059d816c491983a8a4651d936cfee8
Author: Hugo Latendresse <[email protected]>
AuthorDate: Tue Mar 11 10:37:44 2025 -0400
[Relax] Add support to ingest Tensor.expand_as() (#17724)
Support to ingest Tensor.expand_as(), with unit test for correctness
---
.../frontend/torch/base_fx_graph_translator.py | 8 ++++
.../frontend/torch/exported_program_translator.py | 1 +
python/tvm/relax/frontend/torch/fx_translator.py | 1 +
tests/python/relax/test_from_exported_to_cuda.py | 47 ++++++++++++++++++++++
4 files changed, 57 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d5cad2381b..a0f00e1f4b 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -883,6 +883,14 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
broadcast_shape.append(i)
return self.block_builder.emit(relax.op.broadcast_to(args[0],
broadcast_shape))
+ def _expand_as(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ # args[0] is the 'self' tensor
+ # args[1] is the 'other' tensor
+ data = args[0]
+ other_shape = self.shape_of(args[1]) # the shape of 'other'
+ return self.block_builder.emit(relax.op.broadcast_to(data,
other_shape))
+
def _flip(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims",
None)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4ff31ea1d7..2103365c6c 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -298,6 +298,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"copy_.default": self._copy_,
"cumsum.default": self._cumsum,
"expand.default": self._expand,
+ "expand_as.default": self._expand_as,
"permute.default": self._permute,
"repeat.default": self._repeat,
"select.int": self._select,
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index 29d959818f..abda5088db 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -749,6 +749,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"contiguous": lambda node: self.env[node.args[0]],
"cumsum": self._cumsum,
"expand": self._expand,
+ "expand_as.default": self._expand_as,
"flatten": self._flatten,
"flip": self._flip,
"gather": self._gather,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index bd4bdcf617..e8b5da0dc2 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -56,6 +56,53 @@ def
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5,
atol=1e-5)
[email protected]_targets("cuda")
+def test_tensor_expand_as(target, dev):
+ class ExpandAs0(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.template = torch.ones((1, 1, 1, 1))
+
+ def forward(self, x):
+ return self.template.expand_as(x)
+
+ class ExpandAs1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.template = torch.ones((2, 1, 4, 1))
+
+ def forward(self, x):
+ return self.template.expand_as(x)
+
+ class ExpandAs2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.template = torch.ones((2, 1, 1, 10))
+
+ def forward(self, x):
+ return self.template.expand_as(x)
+
+ class ExpandAs3(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.template = torch.ones((2, 3, 1, 1))
+
+ def forward(self, x):
+ return self.template.expand_as(x)
+
+ raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32)
+
+ torch_module0 = ExpandAs0().eval()
+ torch_module1 = ExpandAs1().eval()
+ torch_module2 = ExpandAs2().eval()
+ torch_module3 = ExpandAs3().eval()
+
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2,
target, dev)
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_copy_(target, dev):
class CopyTester(nn.Module):