This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c00c66259a [Relax][ONNX] Support AllClassNMS Operator for ONNX
Frontend (#18321)
c00c66259a is described below
commit c00c66259a8dd4cf197601c978c566ce2db9bc17
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Oct 1 16:09:59 2025 -0400
[Relax][ONNX] Support AllClassNMS Operator for ONNX Frontend (#18321)
Follow #18175 , this PR supports AllClassNMS Operator for ONNX Frontend
---
include/tvm/relax/attrs/vision.h | 54 +++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 179 +++++++-
python/tvm/relax/op/__init__.py | 1 +
python/tvm/relax/op/op_attrs.py | 5 +
.../tvm/{topi/cpp => relax/op}/vision/__init__.py | 9 +-
.../__init__.py => relax/op/vision/_ffi_api.py} | 7 +-
python/tvm/relax/op/vision/nms.py | 75 ++++
.../tvm/relax/transform/legalize_ops/__init__.py | 1 +
python/tvm/relax/transform/legalize_ops/vision.py | 120 +++++
python/tvm/script/ir_builder/relax/ir.py | 2 +
python/tvm/topi/__init__.py | 1 +
python/tvm/topi/cpp/vision/__init__.py | 1 +
python/tvm/topi/{cpp => }/vision/__init__.py | 9 +-
python/tvm/topi/vision/nms.py | 500 +++++++++++++++++++++
python/tvm/topi/vision/nms_util.py | 473 +++++++++++++++++++
src/relax/ir/emit_te.h | 4 +
src/relax/op/vision/nms.cc | 114 +++++
src/relax/op/vision/nms.h | 44 ++
src/te/operation/create_primfunc.cc | 5 +-
tests/python/relax/test_frontend_onnx.py | 426 ++++++++++++++++++
tests/python/relax/test_op_vision.py | 90 ++++
.../relax/test_tvmscript_parser_op_vision.py | 80 ++++
22 files changed, 2179 insertions(+), 21 deletions(-)
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
new file mode 100644
index 0000000000..2fd98533b5
--- /dev/null
+++ b/include/tvm/relax/attrs/vision.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/relax/attrs/vision.h
+ * \brief Auxiliary attributes for vision operators.
+ */
+#ifndef TVM_RELAX_ATTRS_VISION_H_
+#define TVM_RELAX_ATTRS_VISION_H_
+
+#include <tvm/ffi/string.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/type.h>
+#include <tvm/relax/expr.h>
+#include <tvm/runtime/object.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes used in AllClassNonMaximumSuppression operator */
+struct AllClassNonMaximumSuppressionAttrs
+ : public AttrsNodeReflAdapter<AllClassNonMaximumSuppressionAttrs> {
+ ffi::String output_format;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<AllClassNonMaximumSuppressionAttrs>().def_ro(
+ "output_format", &AllClassNonMaximumSuppressionAttrs::output_format,
+ "Output format, onnx or tensorflow. Returns outputs in a way that can
be easily "
+ "consumed by each frontend.");
+ }
+
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.AllClassNonMaximumSuppressionAttrs",
+ AllClassNonMaximumSuppressionAttrs,
BaseAttrsNode);
+}; // struct AllClassNonMaximumSuppressionAttrs
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_ATTRS_VISION_H_
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 7a4a65df6e..7432967c29 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3455,6 +3455,182 @@ class SequenceAt(OnnxOpConverter):
return input_sequence[position]
+class NonMaxSuppression(OnnxOpConverter):
+ """Converts an onnx NonMaxSuppression node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v10(cls, bb, inputs, attr, params):
+ """
+ NonMaxSuppression performs non-maximum suppression (NMS) on all
classes.
+
+ Inputs:
+ - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2]
+ - scores: (N, C) tensor of scores for each box and class
+ - max_output_boxes_per_class: maximum number of boxes to keep per class
+ - iou_threshold: IoU threshold for NMS
+ - score_threshold: score threshold for filtering
+
+ Outputs:
+ - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx]
+ """
+ boxes = inputs[0]
+ scores = inputs[1]
+ max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None
+ iou_threshold = inputs[3] if len(inputs) > 3 else None
+ score_threshold = inputs[4] if len(inputs) > 4 else None
+
+ center_point_box = attr.get("center_point_box", 0)
+
+ if max_output_boxes_per_class is not None and isinstance(
+ max_output_boxes_per_class, relax.Constant
+ ):
+ max_output_boxes_per_class =
int(max_output_boxes_per_class.data.numpy())
+ elif max_output_boxes_per_class is not None and isinstance(
+ max_output_boxes_per_class, relax.Var
+ ):
+ var_name = max_output_boxes_per_class.name_hint
+ if var_name in params[1]:
+ _, param_value = params[1][var_name]
+ max_output_boxes_per_class = int(param_value.numpy().item())
+ else:
+ max_output_boxes_per_class = 100 # Default value
+ else:
+ max_output_boxes_per_class = 100 # Default value
+
+ if iou_threshold is not None and isinstance(iou_threshold,
relax.Constant):
+ iou_threshold = float(iou_threshold.data.numpy())
+ else:
+ iou_threshold = 0.5 # Default value
+
+ if score_threshold is not None and isinstance(score_threshold,
relax.Constant):
+ score_threshold = float(score_threshold.data.numpy())
+ elif score_threshold is not None and isinstance(score_threshold,
relax.Var):
+ var_name = score_threshold.name_hint
+ if var_name in params[1]:
+ _, param_value = params[1][var_name]
+ score_threshold = float(param_value.numpy().item())
+ else:
+ score_threshold = 0.0 # Default value
+ else:
+ score_threshold = 0.0 # Default value
+
+ if center_point_box != 0:
+ split_result = relax.op.split(boxes, 4, axis=2)
+ xc = split_result[0]
+ yc = split_result[1]
+ w = split_result[2]
+ h = split_result[3]
+ half_w = w / relax.const(2.0, boxes.struct_info.dtype)
+ half_h = h / relax.const(2.0, boxes.struct_info.dtype)
+ x1 = xc - half_w
+ x2 = xc + half_w
+ y1 = yc - half_h
+ y2 = yc + half_h
+ boxes = relax.op.concat([y1, x1, y2, x2], axis=2)
+
+ nms_out = bb.normalize(
+ relax.op.vision.all_class_non_max_suppression(
+ boxes,
+ scores,
+ relax.const(max_output_boxes_per_class, dtype="int64"),
+ relax.const(iou_threshold, dtype="float32"),
+ relax.const(score_threshold, dtype="float32"),
+ output_format="onnx",
+ )
+ )
+
+ selected_indices = bb.emit(relax.TupleGetItem(nms_out, 0))
+
+ return selected_indices
+
+
+class AllClassNMS(OnnxOpConverter):
+ """Converts an onnx AllClassNMS node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ """
+ AllClassNMS performs non-maximum suppression (NMS) on all classes.
+
+ Inputs:
+ - boxes: (N, 4) tensor of bounding boxes in format [x1, y1, x2, y2]
+ - scores: (N, C) tensor of scores for each box and class
+ - max_output_boxes_per_class: maximum number of boxes to keep per class
+ - iou_threshold: IoU threshold for NMS
+ - score_threshold: score threshold for filtering
+
+ Outputs:
+ - selected_indices: (M, 3) tensor with [batch_idx, class_idx, box_idx]
+ """
+ boxes = inputs[0]
+ scores = inputs[1]
+ max_output_boxes_per_class = inputs[2] if len(inputs) > 2 else None
+ iou_threshold = inputs[3] if len(inputs) > 3 else None
+ score_threshold = inputs[4] if len(inputs) > 4 else None
+
+ center_point_box = attr.get("center_point_box", 0)
+
+ if max_output_boxes_per_class is not None and isinstance(
+ max_output_boxes_per_class, relax.Constant
+ ):
+ max_output_boxes_per_class =
int(max_output_boxes_per_class.data.numpy())
+ elif max_output_boxes_per_class is not None and isinstance(
+ max_output_boxes_per_class, relax.Var
+ ):
+ var_name = max_output_boxes_per_class.name_hint
+ if var_name in params[1]:
+ _, param_value = params[1][var_name]
+ max_output_boxes_per_class = int(param_value.numpy().item())
+ else:
+ max_output_boxes_per_class = 100 # Default value
+ else:
+ max_output_boxes_per_class = 100 # Default value
+
+ if iou_threshold is not None and isinstance(iou_threshold,
relax.Constant):
+ iou_threshold = float(iou_threshold.data.numpy())
+ else:
+ iou_threshold = 0.5 # Default value
+
+ if score_threshold is not None and isinstance(score_threshold,
relax.Constant):
+ score_threshold = float(score_threshold.data.numpy())
+ elif score_threshold is not None and isinstance(score_threshold,
relax.Var):
+ var_name = score_threshold.name_hint
+ if var_name in params[1]:
+ _, param_value = params[1][var_name]
+ score_threshold = float(param_value.numpy().item())
+ else:
+ score_threshold = 0.0 # Default value
+ else:
+ score_threshold = 0.0 # Default value
+
+ if center_point_box != 0:
+ split_result = relax.op.split(boxes, 4, axis=2)
+ xc = split_result[0]
+ yc = split_result[1]
+ w = split_result[2]
+ h = split_result[3]
+ half_w = w / relax.const(2.0, boxes.struct_info.dtype)
+ half_h = h / relax.const(2.0, boxes.struct_info.dtype)
+ x1 = xc - half_w
+ x2 = xc + half_w
+ y1 = yc - half_h
+ y2 = yc + half_h
+ boxes = relax.op.concat([y1, x1, y2, x2], axis=2)
+
+ nms_out = bb.normalize(
+ relax.op.vision.all_class_non_max_suppression(
+ boxes,
+ scores,
+ relax.const(max_output_boxes_per_class, dtype="int64"),
+ relax.const(iou_threshold, dtype="float32"),
+ relax.const(score_threshold, dtype="float32"),
+ output_format="onnx",
+ )
+ )
+
+ return nms_out
+
+
def _get_convert_map():
return {
# defs/experimental
@@ -3605,7 +3781,8 @@ def _get_convert_map():
# "LRN": LRN,
# "MaxRoiPool": MaxRoiPool,
# "RoiAlign": RoiAlign,
- # "NonMaxSuppression": NonMaxSuppression,
+ "NonMaxSuppression": NonMaxSuppression,
+ "AllClassNMS": AllClassNMS,
# "GridSample": GridSample,
"Upsample": Upsample,
# others
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 6ea8305eca..19096decd9 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -155,6 +155,7 @@ from .unary import (
tanh,
trunc,
)
+from .vision import all_class_non_max_suppression
def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 4062aae0c7..229a789a45 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -239,6 +239,11 @@ class AttentionAttrs(Attrs):
"""Attributes used in attention operator"""
+@tvm_ffi.register_object("relax.attrs.AllClassNonMaximumSuppressionAttrs")
+class AllClassNonMaximumSuppressionAttrs(Attrs):
+ """Attributes for vision.all_class_non_max_suppression"""
+
+
@tvm_ffi.register_object("relax.attrs.Conv1DAttrs")
class Conv1DAttrs(Attrs):
"""Attributes for nn.conv1d"""
diff --git a/python/tvm/topi/cpp/vision/__init__.py
b/python/tvm/relax/op/vision/__init__.py
similarity index 84%
copy from python/tvm/topi/cpp/vision/__init__.py
copy to python/tvm/relax/op/vision/__init__.py
index 8acbb38610..be45458d36 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -14,10 +14,5 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""FFI for vision TOPI ops and schedules"""
-import tvm_ffi
-
-from . import yolo
-
-tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
+"""VISION operators."""
+from .nms import *
diff --git a/python/tvm/topi/cpp/vision/__init__.py
b/python/tvm/relax/op/vision/_ffi_api.py
similarity index 86%
copy from python/tvm/topi/cpp/vision/__init__.py
copy to python/tvm/relax/op/vision/_ffi_api.py
index 8acbb38610..8af761dc5a 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/relax/op/vision/_ffi_api.py
@@ -14,10 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""FFI for vision TOPI ops and schedules"""
+"""Constructor APIs"""
import tvm_ffi
-from . import yolo
-
-tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
+tvm_ffi.init_ffi_api("relax.op.vision", __name__)
diff --git a/python/tvm/relax/op/vision/nms.py
b/python/tvm/relax/op/vision/nms.py
new file mode 100644
index 0000000000..3714b00b01
--- /dev/null
+++ b/python/tvm/relax/op/vision/nms.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Non-maximum suppression operator"""
+# from tvm import relax # Unused import
+from . import _ffi_api
+
+
+def all_class_non_max_suppression(
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ output_format="onnx",
+):
+ """Non-maximum suppression operator for object detection, corresponding to
ONNX
+ NonMaxSuppression and TensorFlow combined_non_max_suppression.
+ NMS is performed for each class separately.
+
+ Parameters
+ ----------
+ boxes : relax.Expr
+ 3-D tensor with shape (batch_size, num_boxes, 4)
+ scores: relax.Expr
+ 3-D tensor with shape (batch_size, num_classes, num_boxes)
+ max_output_boxes_per_class : relax.Expr
+ The maxinum number of output selected boxes per class
+ iou_threshold : relax.Expr
+ IoU test threshold
+ score_threshold : relax.Expr
+ Score threshold to filter out low score boxes early
+ output_format : str, optional
+ "onnx" or "tensorflow", see below.
+
+ Returns
+ -------
+ out : relax.Expr
+ If `output_format` is "onnx", the output is two tensors. The first is
`indices` of size
+ `(batch_size * num_class* num_boxes , 3)` and the second is a scalar
tensor
+ `num_total_detection` of shape `(1,)` representing the total number of
selected
+ boxes. The three values in `indices` encode batch, class, and box
indices.
+ Rows of `indices` are ordered such that selected boxes from batch 0,
class 0 come
+ first, in descending of scores, followed by boxes from batch 0, class
1 etc. Out of
+ `batch_size * num_class* num_boxes` rows of indices, only the first
`num_total_detection`
+ rows are valid.
+
+ TODO: Implement true dynamic output shapes to match ONNX Runtime
behavior exactly.
+ This would eliminate the need for manual trimming and improve memory
efficiency.
+ If `output_format` is "tensorflow", the output is three tensors, the
first
+ is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the
second is `scores` of
+ size `(batch_size, num_class * num_boxes)`, and the third is
`num_total_detection` of size
+ `(batch_size,)` representing the total number of selected boxes per
batch. The two values
+ in `indices` encode class and box indices. Of num_class * num_boxes
boxes in `indices` at
+ batch b, only the first `num_total_detection[b]` entries are valid.
The second axis of
+ `indices` and `scores` are sorted within each class by box scores, but
not across classes.
+ So the box indices and scores for the class 0 come first in a sorted
order, followed by
+ the class 1 etc.
+ """
+ return _ffi_api.all_class_non_max_suppression(
+ boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, output_format
+ )
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py
b/python/tvm/relax/transform/legalize_ops/__init__.py
index b4aba0291f..5614d02296 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -31,3 +31,4 @@ from . import qdq
from . import search
from . import statistical
from . import unary
+from . import vision
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
new file mode 100644
index 0000000000..f910f62cec
--- /dev/null
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Default legalization function for vision network related operators."""
+from tvm import topi, te
+from tvm import relax
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
+from .common import register_legalize
+
+
+def _create_onnx_nms_te(boxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold):
+ """Create a proper NMS implementation that follows the correct algorithm"""
+ scores_shape = list(scores.shape)
+ if len(scores_shape) == 3:
+ batch, num_classes, _ = scores_shape
+ elif len(scores_shape) == 2:
+ num_classes, _ = scores_shape
+ batch = 1
+ else:
+ raise ValueError(f"Unexpected scores shape: {scores_shape}")
+
+ if hasattr(max_output_boxes_per_class, "data"):
+ max_boxes = int(max_output_boxes_per_class.data.numpy())
+ else:
+ max_boxes = 3 # Default value
+
+ expected_detections = batch * num_classes * max_boxes
+
+ selected_indices_full, _ = topi.vision.all_class_non_max_suppression(
+ boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, "onnx"
+ )
+
+ def slice_to_onnx_shape(data, expected_size):
+ def compute_element(i, j):
+ return tvm.tir.if_then_else(i < expected_size, data[i, j],
tvm.tir.Cast("int64", 0))
+
+ return te.compute((expected_size, 3), compute_element,
name="sliced_indices")
+
+ sliced_indices = slice_to_onnx_shape(selected_indices_full,
expected_detections)
+
+ actual_detections = te.compute(
+ (1,), lambda i: tvm.tir.Cast("int64", expected_detections),
name="actual_detections"
+ )
+
+ return [sliced_indices, actual_detections]
+
+
+@register_legalize("relax.vision.all_class_non_max_suppression")
+def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) ->
Expr:
+ """Legalize all_class_non_max_suppression with fixed shape output.
+
+ Note: This implementation outputs fixed-size tensors with trailing garbage
data.
+ Only the first `num_total_detection` rows contain valid data. Users should
use
+ the `valid_count` tensor to determine how many rows are actually valid.
+
+ For complete ONNX compatibility, users can post-process the output:
+ ```python
+ selected_indices, valid_count = nms_output
+ actual_count = int(valid_count.numpy()[0])
+ valid_indices = selected_indices.numpy()[:actual_count, :]
+ ```
+ """
+ boxes = call.args[0]
+ scores = call.args[1]
+ max_output_boxes_per_class = call.args[2]
+ iou_threshold = call.args[3]
+ score_threshold = call.args[4]
+ output_format = call.attrs.output_format
+
+ scores_shape = scores.struct_info.shape
+ if len(scores_shape) == 3:
+ _, _, num_boxes = scores_shape
+ elif len(scores_shape) == 2:
+ _, num_boxes = scores_shape
+ else:
+ raise ValueError(f"Unexpected scores shape: {scores_shape}")
+
+ if isinstance(max_output_boxes_per_class, relax.Constant):
+ max_boxes_val = int(max_output_boxes_per_class.data.numpy())
+ else:
+ max_boxes_val = int(num_boxes)
+
+ # Get NMS result with fixed shape from TOPI
+ nms_result = block_builder.call_te(
+ topi.vision.all_class_non_max_suppression,
+ boxes,
+ scores,
+ max_boxes_val,
+ iou_threshold,
+ score_threshold,
+ output_format,
+ )
+
+ # TODO: Implement dynamic output trimming for better memory efficiency
+ # Current approach returns fixed-size output with trailing garbage data
+ # Future improvements could include:
+ # 1. Dynamic strided_slice based on num_total_detections
+ # 2. Custom Relax operator with true dynamic shapes
+ # 3. VM builtin functions for runtime shape adjustment
+ # 4. Symbolic shape inference in Relax IR
+ #
+ # For now, users should trim manually:
+ # actual_count = int(num_total_detections.numpy()[0])
+ # valid_indices = selected_indices.numpy()[:actual_count, :]
+
+ return nms_result
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 3fa735197a..f221a13089 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -188,6 +188,7 @@ from tvm.relax.op import (
wrap_param,
zeros,
zeros_like,
+ vision,
)
from tvm.relax.op.builtin import stop_lift_params
from tvm.relax.struct_info import StructInfo
@@ -950,4 +951,5 @@ __all__ = [
"nn",
"ccl",
"erf",
+ "vision",
]
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index 9503aea0cd..c73e8bf54c 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -50,6 +50,7 @@ from .signal import *
from . import nn
from . import utils
from . import image
+from . import vision
from . import gpu
# error reporting
diff --git a/python/tvm/topi/cpp/vision/__init__.py
b/python/tvm/topi/cpp/vision/__init__.py
index 8acbb38610..467ce70fbd 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/topi/cpp/vision/__init__.py
@@ -19,5 +19,6 @@
import tvm_ffi
from . import yolo
+from ...vision import nms
tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
diff --git a/python/tvm/topi/cpp/vision/__init__.py
b/python/tvm/topi/vision/__init__.py
similarity index 84%
copy from python/tvm/topi/cpp/vision/__init__.py
copy to python/tvm/topi/vision/__init__.py
index 8acbb38610..f12758bb9c 100644
--- a/python/tvm/topi/cpp/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -14,10 +14,5 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""FFI for vision TOPI ops and schedules"""
-import tvm_ffi
-
-from . import yolo
-
-tvm_ffi.init_ffi_api("topi.vision", "tvm.topi.cpp.vision")
+"""Vision operators."""
+from .nms import *
diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py
new file mode 100644
index 0000000000..f4aae45ef9
--- /dev/null
+++ b/python/tvm/topi/vision/nms.py
@@ -0,0 +1,500 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-error, invalid-name, no-member, too-many-locals,
too-many-arguments, undefined-variable, too-many-nested-blocks,
too-many-branches, too-many-statements, too-many-function-args
+"""Non-maximum suppression operator"""
+import tvm
+from tvm import te
+
+from tvm.tir import if_then_else
+
+from ..sort import argsort
+from ..math import cast
+from ..transform import reshape, gather
+from .. import reduction
+from ..scan import cumsum
+from .nms_util import (
+ binary_search,
+ collect_selected_indices,
+ collect_selected_indices_and_scores,
+ run_all_class_nms,
+)
+
+
+def get_valid_counts(
+ data, score_threshold=0, id_index=0, score_index=1
+): # pylint: disable=unused-argument
+ """Get valid count of bounding boxes given a score threshold.
+ Also moves valid boxes to the top of input data.
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ Input data. 3-D tensor with shape [batch_size, num_anchors, 6]
+ or [batch_size, num_anchors, 5].
+ score_threshold : optional, float
+ Lower limit of score for valid bounding boxes.
+ id_index : optional, int
+ index of the class categories, -1 to disable.
+ score_index: optional, int
+ Index of the scores/confidence of boxes.
+ Returns
+ -------
+ valid_count : tvm.te.Tensor
+ 1-D tensor for valid number of boxes.
+ out_tensor : tvm.te.Tensor
+ Rearranged data tensor.
+ out_indices: tvm.te.Tensor or numpy NDArray
+ Related index in input data.
+ """
+ if isinstance(score_threshold, (float, int)):
+ score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype)
+ # id_index_const = tvm.tir.const(id_index, "int32") # Unused
+ # score_index_const = tvm.tir.const(score_index, "int32") # Unused
+ return (
+ te.compute((data.shape[0],), lambda i: data.shape[1],
name="valid_count"),
+ data,
+ te.compute((data.shape[0], data.shape[1]), lambda i, j: j,
name="out_indices"),
+ )
+
+
+def _nms_loop(
+ ib,
+ batch_size,
+ top_k,
+ iou_threshold,
+ max_output_size,
+ valid_count,
+ on_new_valid_box_func,
+ on_new_invalidated_box_func,
+ needs_bbox_check_func,
+ calc_overlap_func,
+ out_scores,
+ num_valid_boxes,
+ score_threshold=None,
+):
+ def nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local):
+ on_new_valid_box_func(ib, 0, num_valid_boxes_local[0], i, j)
+ num_valid_boxes_local[0] += 1
+
+ num_boxes_to_check = nkeep - (j + 1)
+
+ with ib.for_range(0, num_boxes_to_check, name="_k", kind="parallel")
as _k:
+ k = j + 1 + _k
+
+ with ib.if_scope(
+ tvm.tir.all(
+ k < nkeep,
+ out_scores[i, k] > 0, # is the box k still valid?
+ needs_bbox_check_func(i, j, k),
+ )
+ ):
+ iou = calc_overlap_func(i, j, k)
+
+ with ib.if_scope(iou >= iou_threshold):
+ out_scores[i, k] = -1.0
+ on_new_invalidated_box_func(i, k)
+
+ with ib.for_range(0, batch_size, name="i") as i:
+ nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]),
top_k, valid_count[i])
+ # Use max_output_size directly without if_then_else
+ # max_output_size = if_then_else(max_output_size > te.const(0),
max_output_size, nkeep)
+
+ with ib.if_scope(tvm.tir.all(iou_threshold > te.const(0),
valid_count[i] > te.const(0))):
+ num_valid_boxes_local = ib.allocate(
+ "int32", (1,), name="num_valid_boxes_local", scope="local"
+ )
+ num_valid_boxes_local[0] = 0
+
+ # Use for_range to iterate through all boxes, but limit selection
count
+ with ib.for_range(0, nkeep, name="j") as j:
+ with ib.if_scope(
+ tvm.tir.all(
+ out_scores[i, j] > -1.0, # box is still valid
+ num_valid_boxes_local[0] < max_output_size, # haven't
reached max limit
+ )
+ ):
+ if score_threshold is not None:
+ with ib.if_scope(out_scores[i, j] >
score_threshold[()]):
+ nms_inner_loop(ib, i, j, nkeep,
num_valid_boxes_local)
+ else:
+ nms_inner_loop(ib, i, j, nkeep, num_valid_boxes_local)
+
+ num_valid_boxes[i] = num_valid_boxes_local[0]
+
+ with ib.else_scope():
+ num_valid_boxes[i] = 0
+
+ return ib.get()
+
+
+def _get_valid_box_count(scores, score_threshold):
+ batch_classes, num_boxes = scores.shape
+
+ def searchsorted_ir(scores, score_thresh, valid_count):
+ ib = tvm.tir.ir_builder.create()
+ scores = ib.buffer_ptr(scores)
+ valid_count = ib.buffer_ptr(valid_count)
+
+ with ib.for_range(0, batch_classes, name="i", kind="parallel") as i:
+ if hasattr(score_threshold, "shape"):
+ if len(score_threshold.shape) == 0:
+ score_thresh_scalar = score_thresh[()]
+ elif len(score_threshold.shape) == 1 and
score_threshold.shape[0] > 0:
+ score_thresh_scalar = score_thresh[0]
+ else:
+ score_thresh_scalar = tvm.tir.FloatImm("float32", 0.0)
+ else:
+ score_thresh_scalar = score_threshold
+ binary_search(ib, i, num_boxes, scores, score_thresh_scalar,
valid_count)
+
+ return ib.get()
+
+ scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf",
data_alignment=8)
+ searchsorted_buf = tvm.tir.decl_buffer(
+ (batch_classes,), "int32", "searchsorted", data_alignment=8
+ )
+
+ if hasattr(score_threshold, "shape"):
+ score_thresh_buf = tvm.tir.decl_buffer(
+ score_threshold.shape, score_threshold.dtype, "score_thresh_buf",
data_alignment=8
+ )
+ return te.extern(
+ [(batch_classes,)],
+ [scores, score_threshold],
+ lambda ins, outs: searchsorted_ir(ins[0], ins[1], outs[0]),
+ dtype=["int32"],
+ in_buffers=[scores_buf, score_thresh_buf],
+ out_buffers=[searchsorted_buf],
+ name="searchsorted",
+ tag="searchsorted",
+ )
+ else:
+
+ def searchsorted_ir_scalar(scores, valid_count):
+ ib = tvm.tir.ir_builder.create()
+ scores = ib.buffer_ptr(scores)
+ valid_count = ib.buffer_ptr(valid_count)
+
+ with ib.for_range(0, batch_classes, name="i", kind="parallel") as
i:
+ if isinstance(score_threshold, te.Tensor):
+ if len(score_threshold.shape) == 0:
+ score_thresh_tir = score_threshold()
+ elif len(score_threshold.shape) == 1 and
score_threshold.shape[0] == 1:
+ score_thresh_tir = score_threshold[0]
+ else:
+ score_thresh_tir = tvm.tir.FloatImm("float32", 0.0)
+ else:
+ score_thresh_tir = tvm.tir.FloatImm("float32",
float(score_threshold))
+ binary_search(ib, i, num_boxes, scores, score_thresh_tir,
valid_count)
+
+ return ib.get()
+
+ return te.extern(
+ [(batch_classes,)],
+ [scores],
+ lambda ins, outs: searchsorted_ir_scalar(ins[0], outs[0]),
+ dtype=["int32"],
+ in_buffers=[scores_buf],
+ out_buffers=[searchsorted_buf],
+ name="searchsorted",
+ tag="searchsorted",
+ )
+
+
+def _collect_selected_indices_ir(
+ num_class, selected_indices, num_detections, row_offsets, out,
max_output_boxes_per_class=None
+):
+ batch_classes, _ = selected_indices.shape
+
+ ib = tvm.tir.ir_builder.create()
+
+ selected_indices = ib.buffer_ptr(selected_indices)
+ num_detections = ib.buffer_ptr(num_detections)
+ row_offsets = ib.buffer_ptr(row_offsets)
+ out = ib.buffer_ptr(out)
+
+ # Initialize output buffer to zero
+ # Calculate the actual output shape based on max_output_boxes_per_class
+ if isinstance(max_output_boxes_per_class, int):
+ max_output_rows = batch_classes * max_output_boxes_per_class
+ else:
+ # Fallback to a reasonable default if max_output_boxes_per_class is
not an integer
+ max_output_rows = batch_classes * 10
+ with ib.for_range(0, max_output_rows, name="init_i") as init_i:
+ with ib.for_range(0, 3, name="init_j") as init_j: # 3 columns
+ out[init_i, init_j] = cast(0, "int64")
+
+ with ib.for_range(0, batch_classes, name="i", kind="parallel") as i:
+ i = cast(i, "int64")
+ batch_id = i // num_class
+ class_id = i % num_class
+
+ if isinstance(max_output_boxes_per_class, int):
+ limit = tvm.tir.min(
+ num_detections[i], tvm.tir.IntImm("int32",
max_output_boxes_per_class)
+ )
+ elif isinstance(max_output_boxes_per_class, te.Tensor):
+ if len(max_output_boxes_per_class.shape) == 0:
+ max_boxes_val = max_output_boxes_per_class[()]
+ else:
+ max_boxes_val = max_output_boxes_per_class[0]
+ limit = tvm.tir.min(num_detections[i], max_boxes_val)
+ else:
+ limit = num_detections[i]
+
+ with ib.for_range(0, limit, name="j") as j:
+ out[row_offsets[i] + j, 0] = batch_id
+ out[row_offsets[i] + j, 1] = class_id
+ out[row_offsets[i] + j, 2] = cast(selected_indices[i, j], "int64")
+
+ return ib.get()
+
+
+def _collect_selected_indices_and_scores_ir(
+ selected_indices,
+ selected_scores,
+ num_detections,
+ row_offsets,
+ num_total_detections,
+ collected_indices,
+ collected_scores,
+):
+ batch_size, num_class = row_offsets.shape
+ num_boxes = selected_indices.shape[1]
+
+ ib = tvm.tir.ir_builder.create()
+
+ selected_indices = ib.buffer_ptr(selected_indices)
+ selected_scores = ib.buffer_ptr(selected_scores)
+ num_detections = ib.buffer_ptr(num_detections)
+ row_offsets = ib.buffer_ptr(row_offsets)
+ num_total_detections = ib.buffer_ptr(num_total_detections)
+ collected_indices = ib.buffer_ptr(collected_indices)
+ collected_scores = ib.buffer_ptr(collected_scores)
+ zero = cast(0, "int64")
+
+ with ib.for_range(0, batch_size * num_class, name="i", kind="parallel") as
i:
+ i = cast(i, "int64")
+ batch_id = i // num_class
+ class_id = i % num_class
+
+ with ib.for_range(0, num_boxes, name="j") as j:
+ with ib.if_scope(j < num_detections[batch_id, class_id]):
+ offset = row_offsets[batch_id, class_id] + j
+ collected_indices[batch_id, offset, 0] = class_id
+ collected_indices[batch_id, offset, 1] =
cast(selected_indices[i, j], "int64")
+ collected_scores[batch_id, offset] = selected_scores[i, j]
+ with ib.else_scope():
+ offset = (
+ num_total_detections[batch_id]
+ + class_id * num_boxes
+ - row_offsets[batch_id, class_id]
+ + j
+ - num_detections[batch_id, class_id]
+ )
+ collected_indices[batch_id, offset, 0] = zero
+ collected_indices[batch_id, offset, 1] = zero
+ collected_scores[batch_id, offset] = 0.0
+
+ return ib.get()
+
+
+def all_class_non_max_suppression(
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ output_format="onnx",
+ output_shape=None,
+):
+ """Non-maximum suppression operator for object detection, corresponding to
ONNX
+ NonMaxSuppression and TensorFlow combined_non_max_suppression.
+ NMS is performed for each class separately.
+ Parameters
+ ----------
+ boxes : tvm.te.Tensor
+ 3-D tensor with shape (batch_size, num_boxes, 4)
+ scores: tvm.te.Tensor
+ 3-D tensor with shape (batch_size, num_classes, num_boxes)
+ max_output_boxes_per_class : int or tvm.te.Tensor, optional
+ The maxinum number of output selected boxes per class
+ iou_threshold : float or tvm.te.Tensor, optionaIl
+ IoU test threshold
+ score_threshold : float or tvm.te.Tensor, optional
+ Score threshold to filter out low score boxes early
+ output_format : str, optional
+ "onnx" or "tensorflow", see below.
+ Returns
+ -------
+ out : list of tvm.te.Tensor
+ If `output_format` is "onnx", the output is two tensors. The first is
`indices` of size
+ `(batch_size * num_class* num_boxes , 3)` and the second is a scalar
tensor
+ `num_total_detection` of shape `(1,)` representing the total number of
selected
+ boxes. The three values in `indices` encode batch, class, and box
indices.
+ Rows of `indices` are ordered such that selected boxes from batch 0,
class 0 come
+ first, in descending of scores, followed by boxes from batch 0, class
1 etc. Out of
+ `batch_size * num_class* num_boxes` rows of indices, only the first
`num_total_detection`
+ rows are valid.
+
+ .. note::
+ **Important**: The output tensor has a fixed size based on
`max_output_boxes_per_class`,
+ but only the first `num_total_detection` rows contain valid data.
The remaining rows
+ may contain garbage values. When comparing with ONNX Runtime or
other implementations
+ that output dynamic shapes, you should only compare the first
+ `num_total_detection` rows.
+ Example:
+ ```python
+ selected_indices, valid_count = nms_output
+ actual_count = int(valid_count.numpy()[0])
+ valid_indices = selected_indices.numpy()[:actual_count, :]
+ ```
+ If `output_format` is "tensorflow", the output is three tensors, the
first
+ is `indices` of size `(batch_size, num_class * num_boxes , 2)`, the
second is `scores` of
+ size `(batch_size, num_class * num_boxes)`, and the third is
`num_total_detection` of size
+ `(batch_size,)` representing the total number of selected boxes per
batch. The two values
+ in `indices` encode class and box indices. Of num_class * num_boxes
boxes in `indices` at
+ batch b, only the first `num_total_detection[b]` entries are valid.
The second axis of
+ `indices` and `scores` are sorted within each class by box scores, but
not across classes.
+ So the box indices and scores for the class 0 come first in a sorted
order, followed by
+ the class 1 etc.
+ """
+ batch, num_class, num_boxes = scores.shape
+ scores = reshape(scores, (batch * num_class, num_boxes))
+
+ sorted_indices = argsort(scores, axis=1, is_ascend=False, dtype="int32")
+ sorted_scores = gather(scores, 1, sorted_indices)
+
+ if not isinstance(score_threshold, te.Tensor):
+ score_threshold_tensor = te.compute((), lambda: score_threshold,
name="score_threshold")
+ else:
+ score_threshold_tensor = score_threshold
+
+ valid_count = _get_valid_box_count(sorted_scores, score_threshold_tensor)
+
+ selected_indices, selected_scores, num_detections = run_all_class_nms(
+ boxes,
+ sorted_scores,
+ sorted_indices,
+ valid_count,
+ max_output_boxes_per_class,
+ iou_threshold,
+ _nms_loop,
+ return_scores=(output_format == "tensorflow"),
+ score_threshold=score_threshold_tensor, # Passed score_threshold as
tensor
+ )
+
+ if output_format == "onnx":
+ row_offsets = cumsum(num_detections, exclusive=True, dtype="int64")
+
+ def _sum_clamped_total():
+ if isinstance(max_output_boxes_per_class, int):
+ k_expr = tvm.tir.IntImm("int32",
int(max_output_boxes_per_class))
+ clamped = te.compute(
+ num_detections.shape,
+ lambda i: tvm.tir.min(num_detections[i], k_expr),
+ name="clamped_num",
+ )
+ return reduction.sum(cast(clamped, "int64"), axis=0)
+ if isinstance(max_output_boxes_per_class, tvm.tir.IntImm):
+ k_expr = tvm.tir.Cast("int32", max_output_boxes_per_class)
+ clamped = te.compute(
+ num_detections.shape,
+ lambda i: tvm.tir.min(num_detections[i], k_expr),
+ name="clamped_num",
+ )
+ return reduction.sum(cast(clamped, "int64"), axis=0)
+ if isinstance(max_output_boxes_per_class, te.Tensor):
+ if len(max_output_boxes_per_class.shape) == 0:
+ kb = te.compute(
+ num_detections.shape,
+ lambda i: cast(max_output_boxes_per_class, "int32"),
+ name="k_broadcast",
+ )
+ elif (
+ len(max_output_boxes_per_class.shape) == 1
+ and max_output_boxes_per_class.shape[0] == 1
+ ):
+ kb = te.compute(
+ num_detections.shape,
+ lambda i: cast(max_output_boxes_per_class[0], "int32"),
+ name="k_broadcast",
+ )
+ else:
+ return reduction.sum(cast(num_detections, "int64"), axis=0)
+
+ clamped = te.compute(
+ num_detections.shape,
+ lambda i: tvm.tir.min(num_detections[i], kb[i]),
+ name="clamped_num",
+ )
+ return reduction.sum(cast(clamped, "int64"), axis=0)
+ return reduction.sum(cast(num_detections, "int64"), axis=0)
+
+ num_total_scalar = _sum_clamped_total()
+ num_total_detections = reshape(num_total_scalar, (1,))
+
+ if output_shape is not None:
+ selected_indices = collect_selected_indices(
+ num_class,
+ selected_indices,
+ num_detections,
+ row_offsets,
+ _collect_selected_indices_ir,
+ max_output_boxes_per_class=max_output_boxes_per_class,
+ output_shape=output_shape,
+ )
+ else:
+ # Use num_total_detections to enable dynamic trimming
+ # Pass image size for intelligent default estimation
+ input_image_size = None
+ if hasattr(scores, "shape") and len(scores.shape) >= 3:
+ # Extract image size from scores shape: (batch, num_classes,
num_boxes)
+ # We can estimate image size from num_boxes (more boxes =
larger image)
+ input_image_size = (scores.shape[2],) # Use num_boxes as
proxy for image size
+
+ # TODO: Improve image size estimation by:
+ # 1. Accepting actual image dimensions as parameters
+ # 2. Using model metadata to infer typical image sizes
+ # 3. Learning from historical detection patterns
+ # 4. Providing user-configurable estimation strategies
+
+ selected_indices = collect_selected_indices(
+ num_class,
+ selected_indices,
+ num_detections,
+ row_offsets,
+ _collect_selected_indices_ir,
+ max_output_boxes_per_class=max_output_boxes_per_class,
+ num_total_detections=num_total_detections,
+ input_image_size=input_image_size,
+ )
+ return [selected_indices, num_total_detections]
+
+ num_detections_per_batch = reshape(num_detections, (batch, num_class))
+ row_offsets = cumsum(num_detections_per_batch, exclusive=True,
dtype="int64", axis=1)
+ num_total_detections = reduction.sum(cast(num_detections_per_batch,
"int64"), axis=1)
+
+ selected_indices, selected_scores = collect_selected_indices_and_scores(
+ selected_indices,
+ selected_scores,
+ num_detections_per_batch,
+ row_offsets,
+ num_total_detections,
+ _collect_selected_indices_and_scores_ir,
+ )
+
+ return [selected_indices, selected_scores, num_total_detections]
diff --git a/python/tvm/topi/vision/nms_util.py
b/python/tvm/topi/vision/nms_util.py
new file mode 100644
index 0000000000..1633c923e1
--- /dev/null
+++ b/python/tvm/topi/vision/nms_util.py
@@ -0,0 +1,473 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Common utilities used in Non-maximum suppression operators"""
+import tvm
+from tvm import te
+
+
+def _get_boundaries(output, box_idx):
+ l = tvm.te.min(
+ output[box_idx],
+ output[box_idx + 2],
+ )
+ t = tvm.te.min(
+ output[box_idx + 1],
+ output[box_idx + 3],
+ )
+ r = tvm.te.max(
+ output[box_idx],
+ output[box_idx + 2],
+ )
+ b = tvm.te.max(
+ output[box_idx + 1],
+ output[box_idx + 3],
+ )
+ return l, t, r, b
+
+
+def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
+ """Calculate overlap of two boxes."""
+ a_l, a_t, a_r, a_b = _get_boundaries(out_tensor, box_a_idx)
+ b_l, b_t, b_r, b_b = _get_boundaries(out_tensor, box_b_idx)
+
+ # Overlapping width and height
+ w = tvm.te.max(0.0, tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l))
+ h = tvm.te.max(0.0, tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t))
+
+ # Overlapping area
+ area = h * w
+
+ # total area of the figure formed by box a and box b
+ # except for overlapping area
+ u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
+ return tvm.tir.Select(u <= 0.0, 0.0, area / u)
+
+
+def binary_search(ib, y, num_boxes, scores, score_threshold, out):
+ """Binary search for score_threshold on scores sorted in descending
order"""
+ lo = ib.allocate("int32", (1,), name="lo", scope="local")
+ hi = ib.allocate("int32", (1,), name="hi", scope="local")
+
+ lo[0] = 0
+ hi[0] = num_boxes.astype("int32")
+
+ with ib.while_loop(lo[0] < hi[0]):
+ mid = (hi[0] + lo[0]) >> 1
+ with ib.if_scope(scores[y, mid] > score_threshold):
+ lo[0] = mid + 1
+ with ib.else_scope():
+ hi[0] = mid
+
+ out[y] = lo[0]
+
+
+def _estimate_max_detections(batch_class, input_image_size=None):
+ """Estimate maximum detections based on input image size and number of
classes.
+
+ This provides a more intelligent default for production environments.
+ """
+ if input_image_size is not None:
+ # Estimate based on image size: larger images typically have more
objects
+ if len(input_image_size) >= 2:
+ height, width = input_image_size[-2], input_image_size[-1]
+ total_pixels = height * width
+
+ # Base estimation per class based on image size
+ if total_pixels < 300000: # Small images (< 300k pixels)
+ base_detections_per_class = min(50, max(10, total_pixels //
2000))
+ elif total_pixels < 1000000: # Medium images (< 1M pixels)
+ base_detections_per_class = min(100, max(25, total_pixels //
3000))
+ else: # Large images (>= 1M pixels)
+ base_detections_per_class = min(200, max(50, total_pixels //
4000))
+
+ # Scale down for many classes (more realistic for multi-class
scenarios)
+ if batch_class > 20:
+ # For many classes, reduce per-class detections to avoid
explosion
+ detections_per_class = min(base_detections_per_class, 50)
+ else:
+ detections_per_class = base_detections_per_class
+ else:
+ detections_per_class = 50 # fallback
+ else:
+ # Fallback to class-based estimation
+ if batch_class == 1:
+ detections_per_class = 100 # Single class detection
+ elif batch_class <= 10:
+ detections_per_class = 50 # Small multi-class
+ else:
+ detections_per_class = 25 # Large multi-class (COCO-like)
+
+ return batch_class * detections_per_class
+
+
+def collect_selected_indices(
+ num_class,
+ selected_indices,
+ num_detections,
+ row_offsets,
+ ir,
+ max_output_boxes_per_class=None,
+ output_shape=None,
+ num_total_detections=None,
+ input_image_size=None,
+):
+ """Collect selected indices from the core NMS loop into one linear output
+ Parameters
+ ----------
+ num_class : int
+ selected_indices: tvm.te.Tensor
+ 2-D tensor with shape (batch_size * num_classes, num_boxes),
representing the indices
+ of selected boxes by the core NMS loop.
+ num_detections tvm.te.Tensor
+ 1-D tensor with shape (batch_size * num_classes,), representing
+ the number of boxes selected by the core NMS loop, per batch and class
+ row_offsets tvm.te.Tensor
+ 1-D tensor with shape (batch_size * num_classes,), this should be the
exclusive scan
+ of num_detections
+ ir : function
+ A function to generate IR for CPU or GPU, see its usage in
vision/nms.py and cuda/nms.py
+ Returns
+ -------
+ out : tvm.te.Tensor
+ The output is indices of size (batch_size * num_class* num_boxes , 3).
+ Rows of indices are ordered such that selected boxes from batch 0,
class 0 come
+ first, in descending of scores, followed by boxes from batch 0, class
1 etc.
+ """
+ batch_class, num_boxes = selected_indices.shape
+
+ if output_shape is not None:
+ return te.extern(
+ [output_shape],
+ [selected_indices, num_detections, row_offsets],
+ lambda ins, outs: ir(
+ num_class, ins[0], ins[1], ins[2], outs[0],
max_output_boxes_per_class
+ ),
+ dtype=["int64"],
+ name="collect_indices",
+ tag="collect_indices",
+ )
+
+ # TODO: Implement dynamic trimming based on num_total_detections
+ if num_total_detections is not None:
+ if isinstance(max_output_boxes_per_class, int):
+ out_rows = batch_class * max_output_boxes_per_class
+ else:
+ # Smart fallback based on input image size and typical production
scenarios
+ out_rows = _estimate_max_detections(batch_class, input_image_size)
+
+ return te.extern(
+ [(out_rows, 3)],
+ [selected_indices, num_detections, row_offsets],
+ lambda ins, outs: ir(
+ num_class, ins[0], ins[1], ins[2], outs[0],
max_output_boxes_per_class
+ ),
+ dtype=["int64"],
+ name="collect_indices",
+ tag="collect_indices",
+ )
+
+ if isinstance(max_output_boxes_per_class, int):
+ out_rows = batch_class * max_output_boxes_per_class
+ return te.extern(
+ [(out_rows, 3)],
+ [selected_indices, num_detections, row_offsets],
+ lambda ins, outs: ir(
+ num_class, ins[0], ins[1], ins[2], outs[0],
max_output_boxes_per_class
+ ),
+ dtype=["int64"],
+ name="collect_indices",
+ tag="collect_indices",
+ )
+
+ if isinstance(max_output_boxes_per_class, te.Tensor):
+ try:
+ if len(max_output_boxes_per_class.shape) == 0:
+ max_boxes_val = int(max_output_boxes_per_class.data.numpy())
+ elif (
+ len(max_output_boxes_per_class.shape) == 1
+ and max_output_boxes_per_class.shape[0] == 1
+ ):
+ max_boxes_val = int(max_output_boxes_per_class.data.numpy()[0])
+ else:
+ max_boxes_val = num_boxes
+ except (ValueError, IndexError, AttributeError):
+ max_boxes_val = num_boxes
+
+ out_rows = batch_class * max_boxes_val
+ return te.extern(
+ [(out_rows, 3)],
+ [selected_indices, num_detections, row_offsets],
+ lambda ins, outs: ir(
+ num_class, ins[0], ins[1], ins[2], outs[0],
max_output_boxes_per_class
+ ),
+ dtype=["int64"],
+ name="collect_indices",
+ tag="collect_indices",
+ )
+
+ return te.extern(
+ [(batch_class * num_boxes, 3)],
+ [selected_indices, num_detections, row_offsets],
+ lambda ins, outs: ir(
+ num_class, ins[0], ins[1], ins[2], outs[0],
max_output_boxes_per_class
+ ),
+ dtype=["int64"],
+ name="collect_indices",
+ tag="collect_indices",
+ )
+
+
+def collect_selected_indices_and_scores(
+ selected_indices, selected_scores, num_detections, row_offsets,
num_total_detections, ir
+):
+ """Collect selected indices and scores from the core NMS loop into one
linear output
+ Parameters
+ ----------
+ num_class : int
+ selected_indices: tvm.te.Tensor
+ 2-D tensor with shape (batch_size * num_classes, num_boxes),
representing the indices
+ of selected boxes by the core NMS loop.
+ selected_indices: tvm.te.Tensor
+ 2-D tensor with shape (batch_size * num_classes, num_boxes),
representing the scores
+ of selected boxes by the core NMS loop.
+ num_detections tvm.te.Tensor
+ 2-D tensor with shape (batch_size, num_classes), representing
+ the number of boxes selected by the core NMS loop, per batch and class
+ row_offsets tvm.te.Tensor
+ 2-D tensor with shape (batch_size, num_classes), this should be the
exclusive scan
+ of num_detections along axis 1
+ ir : function
+ A function to generate IR for CPU or GPU, see its usage in
vision/nms.py and cuda/nms.py
+ Returns
+ -------
+ out : [tvm.te.Tensor, tvm.te.Tensor]
+ The output is two tensors. The first is indices of size
+ (batch_size, num_class* num_boxes, 2), and the second is scores of size
+ (batch_size, num_class* num_boxes).
+ """
+ batch_size, num_class = row_offsets.shape
+ num_boxes = selected_indices.shape[1]
+ return te.extern(
+ [(batch_size, num_class * num_boxes, 2), (batch_size, num_class *
num_boxes)],
+ [selected_indices, selected_scores, num_detections, row_offsets,
num_total_detections],
+ lambda ins, outs: ir(ins[0], ins[1], ins[2], ins[3], ins[4], outs[0],
outs[1]),
+ dtype=["int64", "float32"],
+ name="collect_indices_and_scores",
+ tag="collect_indices_and_scores",
+ )
+
+
+def _all_class_nms_ir(
+ boxes,
+ sorted_scores,
+ sorted_indices,
+ valid_count,
+ batch_class,
+ num_class,
+ num_anchors,
+ iou_threshold,
+ max_output_size_per_class,
+ box_indices,
+ selected_scores,
+ num_valid_boxes,
+ nms_loop,
+ score_threshold=None,
+):
+ ib = tvm.tir.ir_builder.create()
+ boxes = ib.buffer_ptr(boxes)
+ sorted_scores = ib.buffer_ptr(sorted_scores)
+ sorted_indices = ib.buffer_ptr(sorted_indices)
+ valid_count = ib.buffer_ptr(valid_count)
+ box_indices = ib.buffer_ptr(box_indices)
+ num_valid_boxes = ib.buffer_ptr(num_valid_boxes)
+
+ if selected_scores is not None:
+ selected_scores = ib.buffer_ptr(selected_scores)
+
+ if isinstance(iou_threshold, float):
+ iou_threshold = tvm.tir.FloatImm("float32", iou_threshold)
+ elif isinstance(iou_threshold, te.Tensor):
+ if len(iou_threshold.shape) == 0:
+ iou_threshold = iou_threshold()
+ elif len(iou_threshold.shape) == 1 and iou_threshold.shape[0] == 1:
+ iou_threshold = iou_threshold[0]
+ else:
+ iou_threshold = tvm.tir.FloatImm("float32", 0.5)
+
+ if isinstance(max_output_size_per_class, int):
+ max_output_size_per_class = tvm.tir.const(max_output_size_per_class)
+ elif isinstance(max_output_size_per_class, te.Tensor):
+ if len(max_output_size_per_class.shape) == 0:
+ max_output_size_per_class = max_output_size_per_class()
+ elif len(max_output_size_per_class.shape) == 1 and
max_output_size_per_class.shape[0] == 1:
+ # Use tensor indexing to get the first element
+ max_output_size_per_class = max_output_size_per_class[0]
+ else:
+ max_output_size_per_class = tvm.tir.const(1000)
+
+ def calc_overlap(i, j, k):
+ offset_j = sorted_indices[i, j] * 4
+ offset_k = sorted_indices[i, k] * 4
+ batch_id = i // num_class
+ base_bbox_idx = batch_id * num_anchors * 4
+ return calculate_overlap(
+ boxes,
+ base_bbox_idx + offset_j,
+ base_bbox_idx + offset_k,
+ )
+
+ def on_new_valid_box(ib, tid, num_current_valid_box, i, j):
+ with ib.if_scope(tid + 0 == 0):
+ box_indices[i, num_current_valid_box] = sorted_indices[i, j]
+
+ if selected_scores is not None:
+ selected_scores[i, num_current_valid_box] = sorted_scores[i, j]
+
+ def on_new_invalidated_box(*_):
+ pass
+
+ def needs_bbox_check(*_):
+ return tvm.tir.const(True)
+
+ return nms_loop(
+ ib,
+ batch_class,
+ tvm.tir.IntImm("int32", -1), # top_k
+ iou_threshold,
+ max_output_size_per_class,
+ valid_count,
+ on_new_valid_box,
+ on_new_invalidated_box,
+ needs_bbox_check,
+ calc_overlap,
+ sorted_scores,
+ num_valid_boxes,
+ score_threshold,
+ )
+
+
+def run_all_class_nms(
+ boxes,
+ sorted_scores,
+ sorted_indices,
+ valid_count,
+ max_output_size_per_class,
+ iou_threshold,
+ nms_loop,
+ return_scores=False,
+ score_threshold=None,
+):
+ """The core all class NMS routine
+ Parameters
+ ----------
+ boxes : tvm.te.Tensor
+ 3-D tensor with shape (batch_size, num_boxes, 4)
+ sorted_scores: tvm.te.Tensor
+ 2-D tensor with shape (batch_size * num_classes, num_boxes)
+ One of the outputs from argsort
+ sorted_indices: tvm.te.Tensor
+ 2-D tensor with shape (batch_size * num_classes, num_boxes)
+ The other output from argsort
+ valid_count: tvm.te.Tensor
+ 1-D tensor with shape (batch_size * num_classes,), representing
+ the number of boxes whose score is above score_threshold, per batch
and class
+ max_output_boxes_per_class : int or tvm.te.Tensor, optional
+ The maxinum number of output selected boxes per class
+ iou_threshold : float or tvm.te.Tensor, optionaIl
+ IoU test threshold
+ nms_loop : function
+ A core NMS loop, see its usage in vision/nms.py and cuda/nms.py
+ return_scores : bool, optional
+ Whether or not to return selected scores, needed by the tensorflow
output format.
+ Returns
+ -------
+ out : a list of tvm.te.Tensor
+ The output is three tensors, the first and second are indices and
scores of size
+ (batch_size * num_class, num_boxes), and the third is a tensor
+ num_selected_boxes of shape (batch_size * num_class,) representing the
total number of
+ selected boxes per batch and class. If return_scores is False, the
second output is
+ None.
+ """
+ batch, num_boxes, _ = boxes.shape
+ batch_class = sorted_scores.shape[0]
+ num_class = batch_class // batch
+
+ if return_scores is False:
+ all_class_num0_buf = tvm.tir.decl_buffer(
+ (batch_class, num_boxes), "int32", "all_class_nms0",
data_alignment=8
+ )
+ all_class_num1_buf = tvm.tir.decl_buffer(
+ (batch_class,), "int32", "all_class_nms1", data_alignment=8
+ )
+ extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count]
+ if score_threshold is not None:
+ extern_inputs.append(score_threshold)
+
+ selected_indices, num_detections = te.extern(
+ [(batch_class, num_boxes), (batch_class,)],
+ extern_inputs,
+ lambda ins, outs: _all_class_nms_ir(
+ ins[0], # boxes
+ ins[1], # sorted_scores
+ ins[2], # sorted_indices
+ ins[3], # valid_count
+ batch_class,
+ num_class,
+ num_boxes,
+ iou_threshold,
+ max_output_size_per_class,
+ outs[0], # box_indices
+ None, # scores
+ outs[1], # num_selected_boxes
+ nms_loop,
+ ins[4] if score_threshold is not None else None, #
score_threshold
+ ),
+ out_buffers=[all_class_num0_buf, all_class_num1_buf],
+ dtype=["int32", "int32"],
+ name="all_class_nms",
+ tag="all_class_nms",
+ )
+ return selected_indices, None, num_detections
+
+ extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count]
+ if score_threshold is not None:
+ extern_inputs.append(score_threshold)
+
+ return te.extern(
+ [(batch_class, num_boxes), (batch_class, num_boxes), (batch_class,)],
+ extern_inputs,
+ lambda ins, outs: _all_class_nms_ir(
+ ins[0], # boxes
+ ins[1], # sorted_scores
+ ins[2], # sorted_indices
+ ins[3], # valid_count
+ batch_class,
+ num_class,
+ num_boxes,
+ iou_threshold,
+ max_output_size_per_class,
+ outs[0], # box_indices
+ outs[1], # selected scores
+ outs[2], # num_selected_boxes
+ nms_loop,
+ ins[4] if score_threshold is not None else None, # score_threshold
+ ),
+ dtype=["int32", "float32", "int32"],
+ name="all_class_nms",
+ tag="all_class_nms",
+ )
diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h
index bb4098ae82..f09dcb7f82 100644
--- a/src/relax/ir/emit_te.h
+++ b/src/relax/ir/emit_te.h
@@ -51,6 +51,10 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode {
.def_ro("shape", &RXPlaceholderOpNode::shape)
.def_ro("dtype", &RXPlaceholderOpNode::dtype);
}
+
+ // FFI system configuration for structural equality and hashing
+ static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind =
kTVMFFISEqHashKindTreeNode;
+
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.TEPlaceholderOp",
RXPlaceholderOpNode,
te::PlaceholderOpNode);
};
diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc
new file mode 100644
index 0000000000..2a1ad8f40a
--- /dev/null
+++ b/src/relax/op/vision/nms.cc
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include "nms.h"
+
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/ir/expr.h>
+#include <tvm/ir/op.h>
+#include <tvm/relax/attrs/vision.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/runtime/object.h>
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() {
AllClassNonMaximumSuppressionAttrs::RegisterReflection(); }
+
+/* relax.vision.all_class_non_max_suppression */
+
+Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr
max_output_boxes_per_class,
+ Expr iou_threshold, Expr score_threshold,
+ ffi::String output_format) {
+ auto attrs = tvm::ffi::make_object<AllClassNonMaximumSuppressionAttrs>();
+ attrs->output_format = output_format;
+
+ static const Op& op = Op::Get("relax.vision.all_class_non_max_suppression");
+ return Call(op,
+ {std::move(boxes), std::move(scores),
std::move(max_output_boxes_per_class),
+ std::move(iou_threshold), std::move(score_threshold)},
+ Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.vision.all_class_non_max_suppression",
+ all_class_non_max_suppression);
+}
+
+StructInfo InferStructInfoAllClassNMS(const Call& call, const BlockBuilder&
ctx) {
+ tvm::ffi::Array<TensorStructInfo> input_sinfo =
GetInputTensorStructInfo(call, ctx);
+ const auto boxes_sinfo = input_sinfo[0];
+ const auto scores_sinfo = input_sinfo[1];
+ ICHECK(!boxes_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ ICHECK(!scores_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ ICHECK_EQ(boxes_sinfo->ndim, 3) << "AllClassNMS input boxes should be 3-D.";
+ ICHECK_EQ(scores_sinfo->ndim, 3) << "AllClassNMS input scores count should
be 3-D.";
+
+ const auto batch = boxes_sinfo->shape.as<ShapeExprNode>()->values[0];
+ const auto num_classes = scores_sinfo->shape.as<ShapeExprNode>()->values[1];
+ const auto num_boxes = boxes_sinfo->shape.as<ShapeExprNode>()->values[1];
+
+ auto vdev = input_sinfo[0]->vdevice;
+ const auto* attrs = call->attrs.as<AllClassNonMaximumSuppressionAttrs>();
+ if (attrs->output_format == "onnx") {
+ auto vdev = input_sinfo[0]->vdevice;
+ auto num_total_boxes = batch * num_classes * num_boxes;
+ tvm::ffi::Array<PrimExpr> oshape_values = {num_total_boxes, 3};
+ ShapeExpr oshape(oshape_values);
+ tvm::ffi::Array<PrimExpr> counts_values = {1};
+ ShapeExpr counts_shape(counts_values);
+ tvm::ffi::Array<StructInfo> fields = {TensorStructInfo(oshape,
DataType::Int(64), vdev),
+ TensorStructInfo(counts_shape,
DataType::Int(64), vdev)};
+ return TupleStructInfo(fields);
+ }
+
+ auto num_total_boxes_per_batch = num_classes * num_boxes;
+ tvm::ffi::Array<PrimExpr> indices_values = {batch,
num_total_boxes_per_batch, 2};
+ ShapeExpr indices_shape(indices_values);
+ tvm::ffi::Array<PrimExpr> scores_values = {batch, num_total_boxes_per_batch};
+ ShapeExpr scores_shape(scores_values);
+ tvm::ffi::Array<PrimExpr> counts_values = {batch};
+ ShapeExpr counts_shape(counts_values);
+ tvm::ffi::Array<StructInfo> fields = {TensorStructInfo(indices_shape,
DataType::Int(64), vdev),
+ TensorStructInfo(scores_shape,
DataType::Float(32), vdev),
+ TensorStructInfo(counts_shape,
DataType::Int(64), vdev)};
+ return TupleStructInfo(fields);
+}
+
+TVM_REGISTER_OP("relax.vision.all_class_non_max_suppression")
+ .set_attrs_type<AllClassNonMaximumSuppressionAttrs>()
+ .set_num_inputs(5)
+ .add_argument("boxes", "Tensor", "The input boxes in the format [batch,
num_boxes, 4].")
+ .add_argument("scores", "Tensor",
+ "Scores for each box and class in the format [batch,
num_classes, num_boxes].")
+ .add_argument("max_output_boxes_per_class", "Tensor",
+ "The maximum number of output boxes per class.")
+ .add_argument("iou_threshold", "Tensor", "The IoU threshold for box the
overlap test.")
+ .add_argument("score_threshold", "Tensor",
+ "The score threshold to filter out low score boxes early.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAllClassNMS)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h
new file mode 100644
index 0000000000..c86bf98c94
--- /dev/null
+++ b/src/relax/op/vision/nms.h
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file nms.h
+ * \brief The functions to make Relax Non-maximum suppression operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_NMS_H_
+#define TVM_RELAX_OP_VISION_NMS_H_
+
+#include <tvm/ffi/string.h>
+#include <tvm/relax/attrs/vision.h>
+#include <tvm/runtime/object.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Compute All Class NonMaximumSuppression. */
+Expr all_class_non_max_suppression(Expr boxes, Expr scores, Expr
max_output_boxes_per_class,
+ Expr iou_threshold, Expr score_threshold,
+ ffi::String output_format);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_OP_VISION_NMS_H_
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index 24c16ab268..fa84ab3863 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -650,7 +650,10 @@ Stmt GenerateStmtFromExternOp(const te::ExternOp&
extern_op, CreateFuncInfo* inf
// reads/writes filled in.
BufferSubstituter substituter(var_map, input_buffer_map);
- Stmt body = substituter(extern_op->body);
+ Stmt substituted_body = substituter(extern_op->body);
+
+ ProducerToBufferTransformer transformer(info->tensor2buffers);
+ Stmt body = transformer(substituted_body);
// Step 4. Generate opaque block as body.
return BlockRealize(/*iter_values=*/{},
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index d2f5a65593..e4960e5b1a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -3230,6 +3230,7 @@ def test_shape_dim_string_expression_graph_div_1():
gv: R.Tensor((A, B, A // B), dtype="float32") = x
R.output(gv)
return gv
+
# fmt: on
tvm.ir.assert_structural_equal(tvm_model, Expected)
@@ -3269,5 +3270,430 @@ def test_shape_dim_string_expression_graph_div_2():
tvm.ir.assert_structural_equal(tvm_model, Expected)
+def test_nms():
+ """Test NonMaxSuppression operator conversion using our AllClassNMS
implementation."""
+ nms_node = helper.make_node(
+ "NonMaxSuppression",
+ ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold",
"score_threshold"],
+ ["selected_indices"],
+ center_point_box=0,
+ )
+
+ boxes_shape = [1, 5, 4] # batch_size, num_boxes, 4
+ scores_shape = [1, 2, 5] # batch_size, num_classes, num_boxes
+
+ graph = helper.make_graph(
+ [nms_node],
+ "nms_test",
+ inputs=[
+ helper.make_tensor_value_info("boxes", TensorProto.FLOAT,
boxes_shape),
+ helper.make_tensor_value_info("scores", TensorProto.FLOAT,
scores_shape),
+ ],
+ initializer=[
+ helper.make_tensor("max_output_boxes_per_class",
TensorProto.INT64, [1], [3]),
+ helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.5]),
+ helper.make_tensor("score_threshold", TensorProto.FLOAT, [1],
[0.1]),
+ ],
+ outputs=[helper.make_tensor_value_info("selected_indices",
TensorProto.INT64, [0, 3])],
+ )
+
+ model = helper.make_model(graph, producer_name="nms_test")
+ model.opset_import[0].version = 11
+
+ # Use deterministic random inputs for consistent testing
+ bg = np.random.MT19937(0)
+ rg = np.random.Generator(bg)
+ boxes = rg.standard_normal(size=boxes_shape).astype(np.float32)
+ scores = rg.standard_normal(size=scores_shape).astype(np.float32)
+ inputs = {"boxes": boxes, "scores": scores}
+
+ # Run ONNX Runtime
+ ort_session = onnxruntime.InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ ort_output = ort_session.run([], inputs)
+
+ # Run TVM
+ tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+ tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+ tvm_model = relax.transform.LegalizeOps()(tvm_model)
+ tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+ with tvm.transform.PassContext(opt_level=3):
+ ex = tvm.compile(tvm_model, target="llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ input_list = [
+ inputs[key.name_hint] for key in tvm_model["main"].params if
key.name_hint in inputs
+ ]
+ if params:
+ input_list += params["main"]
+
+ vm.set_input("main", *input_list)
+ vm.invoke_stateful("main")
+ tvm_output = vm.get_outputs("main")
+
+ if isinstance(tvm_output, (list, tuple)):
+ tvm_selected = tvm_output[0].numpy()
+ else:
+ tvm_selected = tvm_output.numpy()
+ ort_selected = ort_output[0]
+
+ min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+ if min_rows > 0:
+ tvm.testing.assert_allclose(
+ tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_nms_algorithm_correctness():
+ """Test NMS algorithm correctness with fixed data to verify suppression
logic."""
+ nms_node = helper.make_node(
+ "NonMaxSuppression",
+ ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold",
"score_threshold"],
+ ["selected_indices"],
+ center_point_box=0,
+ )
+
+ # Create fixed test data with known expected results
+ # Boxes: [x1, y1, x2, y2] format
+ boxes_data = np.array(
+ [
+ [
+ [0.0, 0.0, 1.0, 1.0], # Box 0: [0,0,1,1] - should be selected
+ [
+ 0.5,
+ 0.5,
+ 1.5,
+ 1.5,
+ ], # Box 1: [0.5,0.5,1.5,1.5] - overlaps with box 0, should
be suppressed
+ [2.0, 2.0, 3.0, 3.0],
+ ]
+ ], # Box 2: [2,2,3,3] - no overlap, should be selected
+ dtype=np.float32,
+ )
+
+ # Scores: higher score = better
+ scores_data = np.array(
+ [
+ [[0.9, 0.8, 0.7], [0.6, 0.5, 0.4]] # Class 0: [0.9, 0.8, 0.7] -
box 0 has highest score
+ ], # Class 1: [0.6, 0.5, 0.4] - box 0 has highest score
+ dtype=np.float32,
+ )
+
+ boxes_shape = [1, 3, 4] # batch_size, num_boxes, 4
+ scores_shape = [1, 2, 3] # batch_size, num_classes, num_boxes
+
+ graph = helper.make_graph(
+ [nms_node],
+ "nms_test_correctness",
+ inputs=[
+ helper.make_tensor_value_info("boxes", TensorProto.FLOAT,
boxes_shape),
+ helper.make_tensor_value_info("scores", TensorProto.FLOAT,
scores_shape),
+ ],
+ initializer=[
+ helper.make_tensor(
+ "max_output_boxes_per_class", TensorProto.INT64, [1], [2]
+ ), # Only 2 boxes per class
+ helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1],
[0.5]), # IoU threshold 0.5
+ helper.make_tensor(
+ "score_threshold", TensorProto.FLOAT, [1], [0.1]
+ ), # Score threshold 0.1
+ ],
+ outputs=[helper.make_tensor_value_info("selected_indices",
TensorProto.INT64, [4, 3])],
+ )
+
+ model = helper.make_model(graph, producer_name="nms_test_correctness")
+
+ # Use fixed inputs instead of random
+ inputs = {
+ "boxes": boxes_data,
+ "scores": scores_data,
+ }
+
+ check_correctness(model, inputs=inputs, opset=11)
+
+
+def test_nms_iou_suppression():
+ """Test that NMS correctly suppresses overlapping boxes based on IoU
threshold."""
+ nms_node = helper.make_node(
+ "NonMaxSuppression",
+ ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold",
"score_threshold"],
+ ["selected_indices"],
+ center_point_box=0,
+ )
+
+ # Create overlapping boxes where box 0 has higher score and should be kept
+ boxes_data = np.array(
+ [
+ [
+ [0.0, 0.0, 1.0, 1.0], # Box 0: [0,0,1,1] - highest score
+ [
+ 0.1,
+ 0.1,
+ 1.1,
+ 1.1,
+ ], # Box 1: [0.1,0.1,1.1,1.1] - high IoU with box 0, should
be suppressed
+ [2.0, 2.0, 3.0, 3.0],
+ ]
+ ], # Box 2: [2,2,3,3] - no overlap, should be kept
+ dtype=np.float32,
+ )
+
+ # Box 0 has highest score, Box 1 should be suppressed due to IoU with box 0
+ scores_data = np.array([[[0.9, 0.8, 0.7]]], dtype=np.float32)
+
+ boxes_shape = [1, 3, 4]
+ scores_shape = [1, 1, 3]
+
+ graph = helper.make_graph(
+ [nms_node],
+ "nms_test_iou_suppression",
+ inputs=[
+ helper.make_tensor_value_info("boxes", TensorProto.FLOAT,
boxes_shape),
+ helper.make_tensor_value_info("scores", TensorProto.FLOAT,
scores_shape),
+ ],
+ initializer=[
+ helper.make_tensor("max_output_boxes_per_class",
TensorProto.INT64, [1], [2]),
+ helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1],
[0.5]), # IoU threshold 0.5
+ helper.make_tensor("score_threshold", TensorProto.FLOAT, [1],
[0.1]),
+ ],
+ outputs=[helper.make_tensor_value_info("selected_indices",
TensorProto.INT64, [2, 3])],
+ )
+
+ model = helper.make_model(graph, producer_name="nms_test_iou_suppression")
+ model.opset_import[0].version = 11
+
+ inputs = {
+ "boxes": boxes_data,
+ "scores": scores_data,
+ }
+
+ # Run ONNX Runtime
+ ort_session = onnxruntime.InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ ort_output = ort_session.run([], inputs)
+
+ # Run TVM
+ tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+ tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+ tvm_model = relax.transform.LegalizeOps()(tvm_model)
+ tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+ with tvm.transform.PassContext(opt_level=3):
+ ex = tvm.compile(tvm_model, target="llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ input_list = [
+ inputs[key.name_hint] for key in tvm_model["main"].params if
key.name_hint in inputs
+ ]
+ if params:
+ input_list += params["main"]
+
+ vm.set_input("main", *input_list)
+ vm.invoke_stateful("main")
+ tvm_output = vm.get_outputs("main")
+
+ # Custom NMS output comparison
+ if isinstance(tvm_output, (list, tuple)):
+ tvm_selected = tvm_output[0].numpy()
+ else:
+ tvm_selected = tvm_output.numpy()
+ ort_selected = ort_output[0]
+
+ # For NMS, compare only the valid rows
+ min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+ if min_rows > 0:
+ tvm.testing.assert_allclose(
+ tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_nms_max_boxes_limit():
+ """Test that NMS correctly limits the number of boxes per class."""
+ nms_node = helper.make_node(
+ "NonMaxSuppression",
+ ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold",
"score_threshold"],
+ ["selected_indices"],
+ center_point_box=0,
+ )
+
+ # Create data with 4 boxes, but limit to 2 per class
+ boxes_data = np.array(
+ [
+ [
+ [0.0, 0.0, 1.0, 1.0], # Box 0
+ [2.0, 0.0, 3.0, 1.0], # Box 1
+ [0.0, 2.0, 1.0, 3.0], # Box 2
+ [2.0, 2.0, 3.0, 3.0],
+ ]
+ ], # Box 3
+ dtype=np.float32,
+ )
+
+ # All boxes have different scores
+ scores_data = np.array([[[0.9, 0.8, 0.7, 0.6]]], dtype=np.float32)
+
+ boxes_shape = [1, 4, 4]
+ scores_shape = [1, 1, 4]
+
+ graph = helper.make_graph(
+ [nms_node],
+ "nms_test_max_boxes_limit",
+ inputs=[
+ helper.make_tensor_value_info("boxes", TensorProto.FLOAT,
boxes_shape),
+ helper.make_tensor_value_info("scores", TensorProto.FLOAT,
scores_shape),
+ ],
+ initializer=[
+ helper.make_tensor(
+ "max_output_boxes_per_class", TensorProto.INT64, [1], [2]
+ ), # Limit to 2 boxes
+ helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1],
[0.1]), # Low IoU threshold
+ helper.make_tensor("score_threshold", TensorProto.FLOAT, [1],
[0.1]),
+ ],
+ outputs=[helper.make_tensor_value_info("selected_indices",
TensorProto.INT64, [2, 3])],
+ )
+
+ model = helper.make_model(graph, producer_name="nms_test_max_boxes_limit")
+ model.opset_import[0].version = 11
+
+ inputs = {
+ "boxes": boxes_data,
+ "scores": scores_data,
+ }
+
+ # Run ONNX Runtime
+ ort_session = onnxruntime.InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ ort_output = ort_session.run([], inputs)
+
+ # Run TVM
+ tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+ tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+ tvm_model = relax.transform.LegalizeOps()(tvm_model)
+ tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+ with tvm.transform.PassContext(opt_level=3):
+ ex = tvm.compile(tvm_model, target="llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ input_list = [
+ inputs[key.name_hint] for key in tvm_model["main"].params if
key.name_hint in inputs
+ ]
+ if params:
+ input_list += params["main"]
+
+ vm.set_input("main", *input_list)
+ vm.invoke_stateful("main")
+ tvm_output = vm.get_outputs("main")
+
+ # Custom NMS output comparison
+ if isinstance(tvm_output, (list, tuple)):
+ tvm_selected = tvm_output[0].numpy()
+ else:
+ tvm_selected = tvm_output.numpy()
+ ort_selected = ort_output[0]
+
+ # For NMS, compare only the valid rows
+ min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+ if min_rows > 0:
+ tvm.testing.assert_allclose(
+ tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5,
atol=1e-5
+ )
+
+
+def test_nms_score_threshold():
+ """Test that NMS correctly filters boxes based on score threshold.
+
+ Note: This test uses a low score threshold (0.05) to ensure both TVM and
ONNX Runtime
+ output the same fixed shape [3,3], allowing use of the standard
check_correctness function.
+ """
+ nms_node = helper.make_node(
+ "NonMaxSuppression",
+ ["boxes", "scores", "max_output_boxes_per_class", "iou_threshold",
"score_threshold"],
+ ["selected_indices"],
+ center_point_box=0,
+ )
+
+ # Create data with varying scores - ensure we get exactly 3 boxes after NMS
+ boxes_data = np.array(
+ [
+ [[0.0, 0.0, 1.0, 1.0], [2.0, 0.0, 3.0, 1.0], [0.0, 2.0, 1.0, 3.0]]
# Box 0 # Box 1
+ ], # Box 2
+ dtype=np.float32,
+ )
+
+ # Scores: 0.9, 0.3, 0.1 - adjust score threshold to get exactly 3 boxes
+ scores_data = np.array([[[0.9, 0.3, 0.1]]], dtype=np.float32)
+
+ boxes_shape = [1, 3, 4]
+ scores_shape = [1, 1, 3]
+
+ graph = helper.make_graph(
+ [nms_node],
+ "nms_test_score_threshold",
+ inputs=[
+ helper.make_tensor_value_info("boxes", TensorProto.FLOAT,
boxes_shape),
+ helper.make_tensor_value_info("scores", TensorProto.FLOAT,
scores_shape),
+ ],
+ initializer=[
+ helper.make_tensor("max_output_boxes_per_class",
TensorProto.INT64, [1], [3]),
+ helper.make_tensor("iou_threshold", TensorProto.FLOAT, [1], [0.1]),
+ helper.make_tensor("score_threshold", TensorProto.FLOAT, [1],
[0.05]),
+ ],
+ outputs=[helper.make_tensor_value_info("selected_indices",
TensorProto.INT64, [3, 3])],
+ )
+
+ model = helper.make_model(graph, producer_name="nms_test_score_threshold")
+ model.opset_import[0].version = 11
+
+ inputs = {
+ "boxes": boxes_data,
+ "scores": scores_data,
+ }
+
+ # Run ONNX Runtime
+ ort_session = onnxruntime.InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ )
+ ort_output = ort_session.run([], inputs)
+
+ # Run TVM
+ tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+ tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+ tvm_model = relax.transform.LegalizeOps()(tvm_model)
+ tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+ with tvm.transform.PassContext(opt_level=3):
+ ex = tvm.compile(tvm_model, target="llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ input_list = [
+ inputs[key.name_hint] for key in tvm_model["main"].params if
key.name_hint in inputs
+ ]
+ if params:
+ input_list += params["main"]
+
+ vm.set_input("main", *input_list)
+ vm.invoke_stateful("main")
+ tvm_output = vm.get_outputs("main")
+
+ # Custom NMS output comparison
+ if isinstance(tvm_output, (list, tuple)):
+ tvm_selected = tvm_output[0].numpy()
+ else:
+ tvm_selected = tvm_output.numpy()
+ ort_selected = ort_output[0]
+
+ # For NMS, compare only the valid rows
+ min_rows = min(tvm_selected.shape[0], ort_selected.shape[0])
+ if min_rows > 0:
+ tvm.testing.assert_allclose(
+ tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5,
atol=1e-5
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_op_vision.py
b/tests/python/relax/test_op_vision.py
new file mode 100644
index 0000000000..97145a53ff
--- /dev/null
+++ b/tests/python/relax/test_op_vision.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op, VDevice
+from tvm.script import relax as R
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
+ ret = bb.normalize(call)
+ tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_all_class_non_max_suppression_infer_struct_info():
+ bb = relax.BlockBuilder()
+ batch_size, num_classes, num_boxes = 10, 8, 5
+ boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32"))
+ scores = relax.Var("scores", R.Tensor((batch_size, num_classes,
num_boxes), "float32"))
+ max_output_boxes_per_class = relax.const(10, "int64")
+ iou_threshold = relax.const(0.5, "float32")
+ score_threshold = relax.const(0.1, "float32")
+
+ _check_inference(
+ bb,
+ relax.op.vision.all_class_non_max_suppression(
+ boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, "onnx"
+ ),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((batch_size * num_classes * num_boxes,
3), "int64"),
+ relax.TensorStructInfo((1,), "int64"),
+ ]
+ ),
+ )
+
+
+def test_all_class_non_max_suppression_wrong_input_number():
+ bb = relax.BlockBuilder()
+ boxes = relax.Var("boxes", R.Tensor((1, 5, 4), "float32"))
+ scores = relax.Var("scores", R.Tensor((1, 3, 5), "float32"))
+
+ with pytest.raises(TVMError):
+ relax.op.vision.all_class_non_max_suppression(boxes, scores)
+
+
+def test_all_class_non_max_suppression_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ batch_size = tir.Var("batch_size", "int64")
+ num_classes = tir.Var("num_classes", "int64")
+ num_boxes = tir.Var("num_boxes", "int64")
+ boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32"))
+ scores = relax.Var("scores", R.Tensor((batch_size, num_classes,
num_boxes), "float32"))
+ max_output_boxes_per_class = relax.const(10, "int64")
+ iou_threshold = relax.const(0.5, "float32")
+ score_threshold = relax.const(0.1, "float32")
+
+ _check_inference(
+ bb,
+ relax.op.vision.all_class_non_max_suppression(
+ boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, "onnx"
+ ),
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((batch_size * num_classes * num_boxes,
3), "int64"),
+ relax.TensorStructInfo((1,), "int64"),
+ ]
+ ),
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py
b/tests/python/relax/test_tvmscript_parser_op_vision.py
new file mode 100644
index 0000000000..66e0adac3d
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+ parsed: Union[relax.Function, IRModule],
+ expect: Optional[Union[relax.Function, IRModule]],
+):
+ test = parsed.script(show_meta=True)
+ roundtrip_mod = tvm.script.from_source(test)
+ tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+ if expect:
+ tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_all_class_non_max_suppression():
+ @R.function
+ def foo(
+ boxes: R.Tensor((10, 5, 4), "float32"),
+ scores: R.Tensor((10, 8, 5), "float32"),
+ max_output_boxes_per_class: R.Tensor((), "int64"),
+ iou_threshold: R.Tensor((), "float32"),
+ score_threshold: R.Tensor((), "float32"),
+ ) -> R.Tuple(R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64")):
+ gv: R.Tuple(
+ R.Tensor((400, 3), "int64"), R.Tensor((1,), "int64")
+ ) = R.vision.all_class_non_max_suppression(
+ boxes,
+ scores,
+ max_output_boxes_per_class,
+ iou_threshold,
+ score_threshold,
+ "onnx",
+ )
+ return gv
+
+ boxes = relax.Var("boxes", R.Tensor((10, 5, 4), "float32"))
+ scores = relax.Var("scores", R.Tensor((10, 8, 5), "float32"))
+ max_output_boxes_per_class = relax.Var("max_output_boxes_per_class",
R.Tensor((), "int64"))
+ iou_threshold = relax.Var("iou_threshold", R.Tensor((), "float32"))
+ score_threshold = relax.Var("score_threshold", R.Tensor((), "float32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function(
+ "foo", [boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold]
+ ):
+ gv = bb.emit(
+ relax.op.vision.all_class_non_max_suppression(
+ boxes, scores, max_output_boxes_per_class, iou_threshold,
score_threshold, "onnx"
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+ tvm.testing.main()