This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0d017c13a4 [Relax] Replaced call_pure_packed with tensor_to_shape
operator (#18616)
0d017c13a4 is described below
commit 0d017c13a48d3d2ed8ae63b0329f5064204316b3
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Dec 30 16:02:29 2025 +0800
[Relax] Replaced call_pure_packed with tensor_to_shape operator (#18616)
## Why
Simplifying the code and addressing the purity issue mentioned in the
TODO comment.
## How
**Before**
```
output_shape = bb.emit(
call_pure_packed(
"vm.builtin.tensor_to_shape", output_shape,
sinfo_args=ShapeStructInfo(ndim=ndim)
)
)
```
**After**
```
output_shape = bb.emit(tensor_to_shape(output_shape))
```
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
python/tvm/relax/transform/legalize_ops/index.py | 13 ++-----------
.../test_transform_legalize_ops_index_linear_algebra.py | 8 ++------
2 files changed, 4 insertions(+), 17 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/index.py
b/python/tvm/relax/transform/legalize_ops/index.py
index d99c1f4db6..75c17f7fa9 100644
--- a/python/tvm/relax/transform/legalize_ops/index.py
+++ b/python/tvm/relax/transform/legalize_ops/index.py
@@ -17,7 +17,7 @@
# pylint: disable=invalid-name
"""Default legalization function for index operators."""
from tvm import topi, tir, te
-from ...op import call_pure_packed
+from ...op import tensor_to_shape
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from ...struct_info import ShapeStructInfo, PrimStructInfo
@@ -109,17 +109,8 @@ def _dynamic_strided_slice(bb: BlockBuilder, call: Call)
-> Expr:
)
# 2. Convert tensor to shape and match cast with new symbolic vars
- # Get shape length
ndim = int(output_shape.struct_info.shape[0])
- output_shape = bb.emit(
- # TODO(@relax-team): Ideally, we should use the tensor_to_shape op
here to
- # address the issue with purity, but that introduces a staging issue:
- # we need to apply DecomposeOpsForInference in that case
- # and it's unclear when in the build it should happen
- call_pure_packed(
- "vm.builtin.tensor_to_shape", output_shape,
sinfo_args=ShapeStructInfo(ndim=ndim)
- )
- )
+ output_shape = bb.emit(tensor_to_shape(output_shape))
output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)]
bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars))
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index efa7f4dfff..a6e53dab4d 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -669,9 +669,7 @@ def test_dynamic_strided_slice():
(x, begin, end, strides),
out_sinfo=R.Tensor((4,), dtype="int64"),
)
- gv1: R.Shape(ndim=4) = R.call_pure_packed(
- "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=4),)
- )
+ gv1: R.Shape(ndim=4) = R.tensor_to_shape(gv)
gv2: R.Shape([s, s_1, s_2, s_3]) = R.match_cast(
gv1, R.Shape([s, s_1, s_2, s_3])
)
@@ -868,9 +866,7 @@ def test_dynamic_strided_slice_symbolic():
(x, begin, end, strides),
out_sinfo=R.Tensor((2,), dtype="int64"),
)
- gv1: R.Shape(ndim=2) = R.call_pure_packed(
- "vm.builtin.tensor_to_shape", gv, sinfo_args=(R.Shape(ndim=2),)
- )
+ gv1: R.Shape(ndim=2) = R.tensor_to_shape(gv)
gv2: R.Shape([s, s_1]) = R.match_cast(gv1, R.Shape([s, s_1]))
gv_1 = R.call_tir(
Expected.dynamic_strided_slice,