This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 63b54c4547 [Relax][PyTorch] Add support for numel, empty_like and
one_hot ops (#17726)
63b54c4547 is described below
commit 63b54c4547df84c07ec998437b616616a6f16323
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Mar 10 21:44:49 2025 +0800
[Relax][PyTorch] Add support for numel, empty_like and one_hot ops (#17726)
This pr supports Pytorch `numel`, `empty_like` and `one_hot` for Relax.
---
.../frontend/torch/base_fx_graph_translator.py | 4 ++
python/tvm/relax/frontend/torch/fx_translator.py | 20 +++++++
tests/python/relax/test_frontend_from_fx.py | 62 ++++++++++++++++++++++
3 files changed, 86 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 003ceebec6..a9f54d91e3 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1018,6 +1018,10 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dtype = self._convert_data_type(str(node.kwargs["dtype"]), self.env)
return self.block_builder.emit(relax.op.zeros(node.args[0], dtype))
+ def _empty_like(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ return self.block_builder.emit(relax.op.zeros_like(x))
+
def _fill(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index ef98d3c025..29d959818f 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -409,6 +409,11 @@ class TorchFXImporter(BaseFXGraphImporter):
end_dim = module.end_dim
return self._flatten_impl(x, start_dim, end_dim)
+ def _numel(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ shape = self.shape_of(x)
+ return relax.const(reduce(lambda x, y: x * y, [s.value for s in
shape]), "int32")
+
def _size(self, node: fx.Node) -> relax.Expr:
x = self.env[node.args[0]]
shape = self.shape_of(x)
@@ -511,6 +516,18 @@ class TorchFXImporter(BaseFXGraphImporter):
)
)
+ def _one_hot(self, node: fx.Node) -> relax.Var:
+ x = self.env[node.args[0]]
+ num_classes = node.args[1] if len(node.args) > 1 else
node.kwargs.get("num_classes")
+ if num_classes is None:
+ raise ValueError("num_classes not found in node.args or
node.kwargs")
+ on_value = node.args[2] if len(node.args) > 2 else
node.kwargs.get("on_value", 1)
+ off_value = node.args[3] if len(node.args) > 3 else
node.kwargs.get("off_value", 0)
+ axis = node.args[4] if len(node.args) > 4 else node.kwargs.get("axis",
-1)
+ on_value = relax.PrimValue(on_value)
+ off_value = relax.PrimValue(off_value)
+ return self.block_builder.emit(relax.op.one_hot(x, on_value,
off_value, num_classes, axis))
+
def _tensor(self, node: fx.Node) -> relax.Var:
dtype = node.kwargs.get("dtype", None)
if isinstance(node.args[0], float):
@@ -735,6 +752,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"flatten": self._flatten,
"flip": self._flip,
"gather": self._gather,
+ "numel": self._numel,
"permute": self._permute,
"repeat": self._repeat,
"reshape": self._reshape,
@@ -753,6 +771,7 @@ class TorchFXImporter(BaseFXGraphImporter):
# tensor creation
"arange": self._arange,
"empty": self._empty,
+ "empty_like": self._empty_like,
"fill_": self._inplace_fill,
"full": self._full,
"index_select": self._index_select,
@@ -761,6 +780,7 @@ class TorchFXImporter(BaseFXGraphImporter):
"masked_scatter": self._masked_scatter,
"new_ones": self._new_ones,
"ones": self._ones,
+ "one_hot": self._one_hot,
"tensor": self._tensor,
# datatype
"astype": self._type,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 0b4b34e0c9..020fc8f5b3 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4037,5 +4037,67 @@ def test_take():
verify_model(Take(), [([5], "float32"), ([3], "int32")], {}, Expected)
+def test_one_hot():
+ class OneHot(Module):
+ def forward(self, indices):
+ return torch.nn.functional.one_hot(indices, num_classes=10)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="int32"),
+ ) -> R.Tensor((5, 10), dtype="int64"):
+ with R.dataflow():
+ lv: R.Tensor((5, 10), dtype="int64") = R.one_hot(
+ inp_0, R.prim_value(1), R.prim_value(0), depth=10, axis=-1
+ )
+ gv: R.Tensor((5, 10), dtype="int64") = lv
+ R.output(gv)
+
+ return gv
+
+ verify_model(OneHot(), [([5], "int32")], {}, Expected)
+
+
+def test_empty_like():
+ class EmptyLike(Module):
+ def forward(self, data):
+ return torch.empty_like(data)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="float32"),
+ ) -> R.Tensor((5,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5,), dtype="float32") = R.zeros_like(inp_0)
+ gv: R.Tensor((5,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(EmptyLike(), [([5], "float32")], {}, Expected)
+
+
+def test_numel():
+ class Numel(Module):
+ def forward(self, data):
+ return torch.numel(data)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5, 3), dtype="float32"),
+ ) -> R.Tensor((), dtype="int32"):
+ with R.dataflow():
+ gv: R.Tensor((), dtype="int32") = R.const(15, "int32")
+ R.output(gv)
+ return gv
+
+ verify_model(Numel(), [([5, 3], "float32")], {}, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()