This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5842bdbd30 [Relax][PyTorch] Add support for broadcast_to, narrow ops 
(#17820)
5842bdbd30 is described below

commit 5842bdbd30070d79d823bf906b590cbc1f6d6f0d
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Apr 11 03:39:04 2025 +0800

    [Relax][PyTorch] Add support for broadcast_to, narrow ops (#17820)
    
    * Update fx_translator.py
    
    * Update base_fx_graph_translator.py
    
    * Update test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
---
 .../frontend/torch/base_fx_graph_translator.py     |  6 +++
 python/tvm/relax/frontend/torch/fx_translator.py   |  9 +++++
 tests/python/relax/test_frontend_from_fx.py        | 43 ++++++++++++++++++++++
 3 files changed, 58 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index d1a42d645c..c9c6afd71a 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -972,6 +972,12 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         descending = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("descending", False)
         return self.block_builder.emit(relax.op.argsort(x, dim, descending))
 
+    def _broadcast_to(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        shape = args[1] if len(args) > 1 else args[0]
+        return self.block_builder.emit(relax.op.broadcast_to(x, shape))
+
     def _cat(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index f3732b3472..a5b50a7d1d 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -429,6 +429,13 @@ class TorchFXImporter(BaseFXGraphImporter):
         end_dim = module.end_dim
         return self._flatten_impl(x, start_dim, end_dim)
 
+    def _narrow(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        dim = node.args[1]
+        start = node.args[2]
+        length = node.args[3]
+        return self.block_builder.emit(relax.op.strided_slice(x, [dim], 
[start], [length]))
+
     def _numel(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
@@ -764,6 +771,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "where": self._where,
             # tensor manipulation
             "argsort": self._argsort,
+            "broadcast_to": self._broadcast_to,
             "cat": self._cat,
             "chunk": self._chunk,
             "concat": self._cat,
@@ -775,6 +783,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "flatten": self._flatten,
             "flip": self._flip,
             "gather": self._gather,
+            "narrow": self._narrow,
             "numel": self._numel,
             "permute": self._permute,
             "repeat": self._repeat,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index fd9bfdf633..ee5a5c78c7 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4470,5 +4470,48 @@ def test_topk():
     verify_model(Topk(), [([5, 3], "float32")], {}, Expected)
 
 
+def test_broadcast_to():
+    class BroadcastTo(Module):
+        def forward(self, x):
+            return torch.broadcast_to(x, (5, 3))
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 1), dtype="float32"),
+        ) -> R.Tensor((5, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 3), dtype="float32") = R.broadcast_to(inp_0, 
(5, 3))
+                gv: R.Tensor((5, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(BroadcastTo(), [([5, 1], "float32")], {}, Expected)
+
+
+def test_narrow():
+    class Narrow(Module):
+        def forward(self, x):
+            return torch.narrow(x, 1, 0, 2)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((5, 2), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 2), dtype="float32") = R.strided_slice(
+                    inp_0, axes=[1], begin=[0], end=[2]
+                )
+                gv: R.Tensor((5, 2), dtype="float32") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(Narrow(), [([5, 3], "float32")], {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to