This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 78b5ed068c [Relax] Implement dynamic output trimming for NMS (#18676)
78b5ed068c is described below
commit 78b5ed068cda945913e9ded5788f4818fdf67c15
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Jan 22 19:43:14 2026 +0800
[Relax] Implement dynamic output trimming for NMS (#18676)
## Why
NMS operator returns fixed-size output with trailing garbage data,
wasting memory and requiring manual trimming for ONNX
compatibility.
## How
- Add dynamic_strided_slice to trim NMS output to valid detections only
- Build slice parameters using TE compute to avoid legalization issues
---
python/tvm/relax/op/vision/nms.py | 8 +-
python/tvm/relax/transform/legalize_ops/vision.py | 114 ++++++++++------------
tests/python/relax/test_op_vision.py | 98 ++++++++++++++++++-
3 files changed, 146 insertions(+), 74 deletions(-)
diff --git a/python/tvm/relax/op/vision/nms.py
b/python/tvm/relax/op/vision/nms.py
index 3714b00b01..4c50748bdb 100644
--- a/python/tvm/relax/op/vision/nms.py
+++ b/python/tvm/relax/op/vision/nms.py
@@ -54,12 +54,10 @@ def all_class_non_max_suppression(
`num_total_detection` of shape `(1,)` representing the total number of
selected
boxes. The three values in `indices` encode batch, class, and box
indices.
Rows of `indices` are ordered such that selected boxes from batch 0,
class 0 come
- first, in descending of scores, followed by boxes from batch 0, class
1 etc. Out of
- `batch_size * num_class* num_boxes` rows of indices, only the first
`num_total_detection`
- rows are valid.
+ first, in descending of scores, followed by boxes from batch 0, class
1 etc.
+ The output uses dynamic_strided_slice to trim to only valid detections,
+ so the first tensor has shape (num_total_detection, 3) containing only
valid rows.
- TODO: Implement true dynamic output shapes to match ONNX Runtime
behavior exactly.
- This would eliminate the need for manual trimming and improve memory
efficiency.
If `output_format` is "tensorflow", the output is three tensors, the
first
is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the
second is `scores` of
size `(batch_size, num_class * num_boxes)`, and the third is
`num_total_detection` of size
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
index f910f62cec..9511c13018 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -15,64 +15,27 @@
# specific language governing permissions and limitations
# under the License.
"""Default legalization function for vision network related operators."""
-from tvm import topi, te
-from tvm import relax
+from tvm import relax, te, tir, topi
+
from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, TupleGetItem
from .common import register_legalize
-def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold):
- """Create a proper NMS implementation that follows the correct algorithm"""
- scores_shape = list(scores.shape)
- if len(scores_shape) == 3:
- batch, num_classes, _ = scores_shape
- elif len(scores_shape) == 2:
- num_classes, _ = scores_shape
- batch = 1
- else:
- raise ValueError(f"Unexpected scores shape: {scores_shape}")
-
- if hasattr(max_output_boxes_per_class, "data"):
- max_boxes = int(max_output_boxes_per_class.data.numpy())
- else:
- max_boxes = 3 # Default value
-
- expected_detections = batch * num_classes * max_boxes
-
- selected_indices_full, _ = topi.vision.all_class_non_max_suppression(
- boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, "onnx"
- )
-
- def slice_to_onnx_shape(data, expected_size):
- def compute_element(i, j):
- return tvm.tir.if_then_else(i < expected_size, data[i, j],
tvm.tir.Cast("int64", 0))
-
- return te.compute((expected_size, 3), compute_element,
name="sliced_indices")
-
- sliced_indices = slice_to_onnx_shape(selected_indices_full,
expected_detections)
-
- actual_detections = te.compute(
- (1,), lambda i: tvm.tir.Cast("int64", expected_detections),
name="actual_detections"
- )
-
- return [sliced_indices, actual_detections]
-
-
@register_legalize("relax.vision.all_class_non_max_suppression")
def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) ->
Expr:
- """Legalize all_class_non_max_suppression with fixed shape output.
-
- Note: This implementation outputs fixed-size tensors with trailing garbage
data.
- Only the first `num_total_detection` rows contain valid data. Users should
use
- the `valid_count` tensor to determine how many rows are actually valid.
-
- For complete ONNX compatibility, users can post-process the output:
- ```python
- selected_indices, valid_count = nms_output
- actual_count = int(valid_count.numpy()[0])
- valid_indices = selected_indices.numpy()[:actual_count, :]
- ```
+ """Legalize all_class_non_max_suppression with dynamic output trimming.
+
+ This implementation uses dynamic_strided_slice to trim the NMS output to
only
+ contain valid detections, improving memory efficiency and ONNX
compatibility.
+
+ Returns
+ -------
+ result : Tuple[Tensor, Tensor]
+ A tuple of (trimmed_indices, num_total_detections) where:
+ - trimmed_indices: Tensor of shape (num_total_detections, 3)
containing only
+ valid detection indices (batch_id, class_id, box_id)
+ - num_total_detections: Tensor of shape (1,) with the count of valid
detections
"""
boxes = call.args[0]
scores = call.args[1]
@@ -105,16 +68,37 @@ def _all_class_non_max_suppression(block_builder:
BlockBuilder, call: Call) -> E
output_format,
)
- # TODO: Implement dynamic output trimming for better memory efficiency
- # Current approach returns fixed-size output with trailing garbage data
- # Future improvements could include:
- # 1. Dynamic strided_slice based on num_total_detections
- # 2. Custom Relax operator with true dynamic shapes
- # 3. VM builtin functions for runtime shape adjustment
- # 4. Symbolic shape inference in Relax IR
- #
- # For now, users should trim manually:
- # actual_count = int(num_total_detections.numpy()[0])
- # valid_indices = selected_indices.numpy()[:actual_count, :]
-
- return nms_result
+ # Dynamic output trimming using dynamic_strided_slice
+ # Extract selected_indices and num_total_detections from the NMS result
+ selected_indices = block_builder.emit(TupleGetItem(nms_result, 0))
+ num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1))
+
+ # Build slicing parameters using TE to avoid high-level Relax ops during
legalization
+ def build_begin():
+ return te.compute((2,), lambda i: tir.const(0, "int64"), name="begin")
+
+ def build_strides():
+ return te.compute((2,), lambda i: tir.const(1, "int64"),
name="strides")
+
+ def build_end(count_tensor):
+ # end = [count_tensor[0], 3]
+ def compute_end(i):
+ return tir.if_then_else(
+ i == 0,
+ tir.Cast("int64", count_tensor[0]),
+ tir.const(3, "int64"),
+ )
+
+ return te.compute((2,), compute_end, name="end")
+
+ begin = block_builder.call_te(build_begin)
+ strides = block_builder.call_te(build_strides)
+ end = block_builder.call_te(build_end, num_total_detections)
+
+ # Apply dynamic strided slice to trim to valid detections only
+ trimmed_indices = block_builder.emit(
+ relax.op.dynamic_strided_slice(selected_indices, begin, end, strides)
+ )
+
+ # Return trimmed indices along with num_total_detections for compatibility
+ return relax.Tuple([trimmed_indices, num_total_detections])
diff --git a/tests/python/relax/test_op_vision.py
b/tests/python/relax/test_op_vision.py
index 97145a53ff..660b5d2772 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
+import numpy as np
import pytest
+
import tvm
import tvm.testing
-from tvm import relax, tir
-from tvm import TVMError
-from tvm.ir import Op, VDevice
+from tvm import TVMError, relax, tir
+from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R
@@ -53,7 +54,6 @@ def test_all_class_non_max_suppression_infer_struct_info():
def test_all_class_non_max_suppression_wrong_input_number():
- bb = relax.BlockBuilder()
boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32"))
scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32"))
@@ -86,5 +86,95 @@ def
test_all_class_non_max_suppression_infer_struct_info_shape_var():
)
+def test_all_class_non_max_suppression_legalize_dynamic_trim():
+ @tvm.script.ir_module
+ class NMSModule:
+ @R.function
+ def main(
+ boxes: R.Tensor((1, 5, 4), "float32"),
+ scores: R.Tensor((1, 2, 5), "float32"),
+ ) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
+ max_output_boxes_per_class = R.const(3, "int64")
+ iou_threshold = R.const(0.5, "float32")
+ score_threshold = R.const(0.1, "float32")
+ return R.vision.all_class_non_max_suppression(
+ boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, "onnx"
+ )
+
+ mod = LegalizeOps()(NMSModule)
+
+ # Check legalized function has dynamic output (uses dynamic_strided_slice)
+ assert "dynamic_strided_slice" in str(mod)
+
+ ret_sinfo = mod["main"].ret_struct_info
+ tvm.ir.assert_structural_equal(
+ ret_sinfo,
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(ndim=2, dtype="int64"),
+ relax.TensorStructInfo((1,), "int64"),
+ ]
+ ),
+ )
+
+
+def test_all_class_non_max_suppression_legalize_e2e():
+ @tvm.script.ir_module
+ class NMSModule:
+ @R.function
+ def main(
+ boxes: R.Tensor((1, 5, 4), "float32"),
+ scores: R.Tensor((1, 2, 5), "float32"),
+ ) -> R.Tuple(R.Tensor(dtype="int64", ndim=2), R.Tensor((1,), "int64")):
+ max_output_boxes_per_class = R.const(3, "int64")
+ iou_threshold = R.const(0.5, "float32")
+ score_threshold = R.const(0.1, "float32")
+ return R.vision.all_class_non_max_suppression(
+ boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, "onnx"
+ )
+
+ boxes_data = np.array(
+ [
+ [
+ [0.0, 0.0, 1.0, 1.0],
+ [0.1, 0.1, 1.1, 1.1],
+ [2.0, 2.0, 3.0, 3.0],
+ [4.0, 4.0, 5.0, 5.0],
+ [6.0, 6.0, 7.0, 7.0],
+ ]
+ ],
+ dtype=np.float32,
+ )
+ scores_data = np.array(
+ [[[0.9, 0.8, 0.7, 0.6, 0.5], [0.85, 0.75, 0.65, 0.55, 0.45]]],
+ dtype=np.float32,
+ )
+
+ mod = LegalizeOps()(NMSModule)
+
+ # Check struct info
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo(ndim=2, dtype="int64"),
+ relax.TensorStructInfo((1,), "int64"),
+ ]
+ ),
+ )
+
+ # Check runtime execution
+ exe = tvm.compile(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, tvm.cpu())
+ result = vm["main"](
+ tvm.runtime.tensor(boxes_data, tvm.cpu()),
+ tvm.runtime.tensor(scores_data, tvm.cpu()),
+ )
+
+ selected_indices = result[0].numpy()
+ num_total_detections = int(result[1].numpy()[0])
+ tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections,
3))
+
+
if __name__ == "__main__":
tvm.testing.main()