This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2ca6ec8a5d [Relax][PyTorch] Sort.default (#17852)
2ca6ec8a5d is described below
commit 2ca6ec8a5d1cd22dbf428b3b5cd9f899d058ea3c
Author: Hugo Latendresse <[email protected]>
AuthorDate: Mon Apr 21 15:17:16 2025 -0400
[Relax][PyTorch] Sort.default (#17852)
Add support for sort.default in exported program translator.
There was an existing _sort() function in base_fx_graph_translator.py,
but it would return values only. Pytorch returns a tuple of values and
indices, so that was corrected
---
.../frontend/torch/base_fx_graph_translator.py | 7 +++++-
.../frontend/torch/exported_program_translator.py | 1 +
tests/python/relax/test_from_exported_to_cuda.py | 25 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 17 +++++++++++----
4 files changed, 45 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 13d13ff24c..20556167c1 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1278,10 +1278,15 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.block_builder.emit(relax.op.scatter_elements(x, index,
src, axis=dim))
def _sort(self, node: fx.Node) -> relax.Var:
+ # torch.sort() returns a tuple of values and indices
+ # we use argsort to get indices and gather_elements to get values
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim",
-1)
descending = node.args[2] if len(node.args) > 2 else
node.kwargs.get("descending", False)
- return self.block_builder.emit(relax.op.sort(x, dim, descending))
+
+ indices = self.block_builder.emit(relax.op.argsort(x, dim, descending))
+ values = self.block_builder.emit(relax.op.gather_elements(x, indices,
axis=dim))
+ return self.block_builder.emit(relax.Tuple([values, indices]))
def _split(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ed6740a25e..f38f353a9e 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -431,6 +431,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"roll.default": self._roll,
"select.int": self._select,
"slice.Tensor": self._slice,
+ "sort.default": self._sort,
"split.Tensor": self._split,
"split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index 76a4bb2039..6bb35b50b1 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -208,6 +208,31 @@ def test_ones(target, dev):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
[email protected]_targets("cuda")
+def test_sort(target, dev):
+ raw_data = np.array([[4, 1, 13], [-30, 1, 3], [4, 0,
10]]).astype("float32")
+
+ # Test values
+ class SortModelValues(nn.Module):
+ def forward(self, x):
+ A, _ = torch.sort(x, dim=0, descending=True)
+ B, _ = torch.sort(x, dim=1, descending=False)
+ return A + B
+
+ torch_module = SortModelValues().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+ # Test indices
+ class SortModelIndices(nn.Module):
+ def forward(self, x):
+ _, A = torch.sort(x, dim=0, descending=True)
+ _, B = torch.sort(x, dim=1, descending=False)
+ return A + B
+
+ torch_module = SortModelIndices().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_tensor_clamp(target, dev):
class ClampBothTensor(torch.nn.Module):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index e8db6af347..2d27fa1f59 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -4749,11 +4749,20 @@ def test_sort():
class Expected:
@R.function
def main(
- inp_0: R.Tensor((5, 3), dtype="float32"),
- ) -> R.Tensor((5, 3), dtype="float32"):
+ inp_0: R.Tensor((5, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3),
dtype="int32")):
with R.dataflow():
- lv: R.Tensor((5, 3), dtype="float32") = R.sort(inp_0, axis=1,
descending=True)
- gv: R.Tensor((5, 3), dtype="float32") = lv
+ lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
+ inp_0, axis=1, descending=True, dtype="int32"
+ )
+ lv1: R.Tensor((5, 3), dtype="float32") =
R.gather_elements(inp_0, lv, axis=1)
+ lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5,
3), dtype="int32")) = (
+ lv1,
+ lv,
+ )
+ gv: R.Tuple(
+ R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3),
dtype="int32")
+ ) = lv2
R.output(gv)
return gv