This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 63b54c4547 [Relax][PyTorch] Add support for numel, empty_like and 
one_hot ops (#17726)
63b54c4547 is described below

commit 63b54c4547df84c07ec998437b616616a6f16323
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Mar 10 21:44:49 2025 +0800

    [Relax][PyTorch] Add support for numel, empty_like and one_hot ops (#17726)
    
    This pr supports Pytorch `numel`, `empty_like` and `one_hot` for Relax.
---
 .../frontend/torch/base_fx_graph_translator.py     |  4 ++
 python/tvm/relax/frontend/torch/fx_translator.py   | 20 +++++++
 tests/python/relax/test_frontend_from_fx.py        | 62 ++++++++++++++++++++++
 3 files changed, 86 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 003ceebec6..a9f54d91e3 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1018,6 +1018,10 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
         return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
 
+    def _empty_like(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        return self.block_builder.emit(relax.op.zeros_like(x))
+
     def _fill(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         x = args[0]
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index ef98d3c025..29d959818f 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -409,6 +409,11 @@ class TorchFXImporter(BaseFXGraphImporter):
         end_dim = module.end_dim
         return self._flatten_impl(x, start_dim, end_dim)
 
+    def _numel(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        shape = self.shape_of(x)
+        return relax.const(reduce(lambda x, y: x * y, [s.value for s in 
shape]), "int32")
+
     def _size(self, node: fx.Node) -> relax.Expr:
         x = self.env[node.args[0]]
         shape = self.shape_of(x)
@@ -511,6 +516,18 @@ class TorchFXImporter(BaseFXGraphImporter):
             )
         )
 
+    def _one_hot(self, node: fx.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        num_classes = node.args[1] if len(node.args) > 1 else 
node.kwargs.get("num_classes")
+        if num_classes is None:
+            raise ValueError("num_classes not found in node.args or 
node.kwargs")
+        on_value = node.args[2] if len(node.args) > 2 else 
node.kwargs.get("on_value", 1)
+        off_value = node.args[3] if len(node.args) > 3 else 
node.kwargs.get("off_value", 0)
+        axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis", 
-1)
+        on_value = relax.PrimValue(on_value)
+        off_value = relax.PrimValue(off_value)
+        return self.block_builder.emit(relax.op.one_hot(x, on_value, 
off_value, num_classes, axis))
+
     def _tensor(self, node: fx.Node) -> relax.Var:
         dtype = node.kwargs.get("dtype", None)
         if isinstance(node.args[0], float):
@@ -735,6 +752,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "flatten": self._flatten,
             "flip": self._flip,
             "gather": self._gather,
+            "numel": self._numel,
             "permute": self._permute,
             "repeat": self._repeat,
             "reshape": self._reshape,
@@ -753,6 +771,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             # tensor creation
             "arange": self._arange,
             "empty": self._empty,
+            "empty_like": self._empty_like,
             "fill_": self._inplace_fill,
             "full": self._full,
             "index_select": self._index_select,
@@ -761,6 +780,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             "masked_scatter": self._masked_scatter,
             "new_ones": self._new_ones,
             "ones": self._ones,
+            "one_hot": self._one_hot,
             "tensor": self._tensor,
             # datatype
             "astype": self._type,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 0b4b34e0c9..020fc8f5b3 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4037,5 +4037,67 @@ def test_take():
     verify_model(Take(), [([5], "float32"), ([3], "int32")], {}, Expected)
 
 
+def test_one_hot():
+    class OneHot(Module):
+        def forward(self, indices):
+            return torch.nn.functional.one_hot(indices, num_classes=10)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="int32"),
+        ) -> R.Tensor((5, 10), dtype="int64"):
+            with R.dataflow():
+                lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
+                    inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
+                )
+                gv: R.Tensor((5, 10), dtype="int64") = lv
+                R.output(gv)
+
+            return gv
+
+    verify_model(OneHot(), [([5], "int32")], {}, Expected)
+
+
+def test_empty_like():
+    class EmptyLike(Module):
+        def forward(self, data):
+            return torch.empty_like(data)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="float32"),
+        ) -> R.Tensor((5,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0)
+                gv: R.Tensor((5,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(EmptyLike(), [([5], "float32")], {}, Expected)
+
+
+def test_numel():
+    class Numel(Module):
+        def forward(self, data):
+            return torch.numel(data)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5, 3), dtype="float32"),
+        ) -> R.Tensor((), dtype="int32"):
+            with R.dataflow():
+                gv: R.Tensor((), dtype="int32") = R.const(15, "int32")
+                R.output(gv)
+            return gv
+
+    verify_model(Numel(), [([5, 3], "float32")], {}, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to