This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5842bdbd30 [Relax][PyTorch] Add support for broadcast_to, narrow ops
(#17820)
5842bdbd30 is described below
commit 5842bdbd30070d79d823bf906b590cbc1f6d6f0d
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Apr 11 03:39:04 2025 +0800
[Relax][PyTorch] Add support for broadcast_to, narrow ops (#17820)
* Update fx_translator.py
* Update base_fx_graph_translator.py
* Update test_frontend_from_fx.py
* Update test_frontend_from_fx.py
---
.../frontend/torch/base_fx_graph_translator.py | 6 +++
python/tvm/relax/frontend/torch/fx_translator.py | 9 +++++
tests/python/relax/test_frontend_from_fx.py | 43 ++++++++++++++++++++++
3 files changed, 58 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d1a42d645c..c9c6afd71a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -972,6 +972,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
descending = node.args[2] if len(node.args) > 2 else
node.kwargs.get("descending", False)
return self.block_builder.emit(relax.op.argsort(x, dim, descending))
+ def _broadcast_to(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ shape = args[1] if len(args) > 1 else args[0]
+ return self.block_builder.emit(relax.op.broadcast_to(x, shape))
+
def _cat(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index f3732b3472..a5b50a7d1d 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -429,6 +429,13 @@ class TorchFXImporter(BaseFXGraphImporter):
end_dim = module.end_dim
return self._flatten_impl(x, start_dim, end_dim)
+ def _narrow(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ dim = node.args[1]
+ start = node.args[2]
+ length = node.args[3]
+ return self.block_builder.emit(relax.op.strided_slice(x, [dim],
[start], [length]))
+
def _numel(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
shape = self.shape_of(x)
@@ -764,6 +771,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"where": self._where,
# tensor manipulation
"argsort": self._argsort,
+ "broadcast_to": self._broadcast_to,
"cat": self._cat,
"chunk": self._chunk,
"concat": self._cat,
@@ -775,6 +783,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"flatten": self._flatten,
"flip": self._flip,
"gather": self._gather,
+ "narrow": self._narrow,
"numel": self._numel,
"permute": self._permute,
"repeat": self._repeat,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index fd9bfdf633..ee5a5c78c7 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4470,5 +4470,48 @@ def test_topk():
verify_model(Topk(), [([5, 3], "float32")], {}, Expected)
+def test_broadcast_to():
+ class BroadcastTo(Module):
+ def forward(self, x):
+ return torch.broadcast_to(x, (5, 3))
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 1), dtype="float32"),
+ ) -> R.Tensor((5, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(inp_0,
(5, 3))
+ gv: R.Tensor((5, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(BroadcastTo(), [([5, 1], "float32")], {}, Expected)
+
+
+def test_narrow():
+ class Narrow(Module):
+ def forward(self, x):
+ return torch.narrow(x, 1, 0, 2)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((5, 2), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
+ inp_0, axes=[1], begin=[0], end=[2]
+ )
+ gv: R.Tensor((5, 2), dtype="float32") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(Narrow(), [([5, 3], "float32")], {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()