This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 361d21bbe9 [Relax][ONNX] add support for unique optional outputs
(#18652)
361d21bbe9 is described below
commit 361d21bbe9e66bf7fbd8cd630ae49ff3278e176a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Jan 20 10:38:18 2026 +0800
[Relax][ONNX] add support for unique optional outputs (#18652)
## Why
The ONNX Unique operator supports four optional outputs (unique values,
indices, inverse_indices, and counts), but the TVM ONNX frontend only
returned the unique values output.
## How
- Updated `Unique._impl_v11` to check the number of expected outputs via
`attr["tvm_custom"]["num_outputs"]`
- Pass `return_index`, `return_inverse`, and `return_counts` parameters
to `relax.op.unique`
- Return a `relax.Tuple` containing all requested outputs
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 65 +++++++++++++++----
python/tvm/relax/op/set.py | 85 ++++++++++++++++++++++---
src/relax/op/tensor/set.cc | 43 ++++++++++---
tests/python/relax/test_frontend_onnx.py | 23 +++++--
4 files changed, 180 insertions(+), 36 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 4dbb0ca36f..e14e2ed956 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3239,24 +3239,63 @@ class Unique(OnnxOpConverter):
def _impl_v11(cls, bb, inputs, attr, params):
data = inputs[0]
axis = attr.get("axis", None)
- sorted = bool(attr.get("sorted", 1))
- # TODO(tvm-team): Add support for return_index, return_inverse,
return_counts
- unique = relax.op.unique(data, sorted=sorted, axis=axis)
+ sorted_flag = bool(attr.get("sorted", 1))
+ num_outputs = attr["tvm_custom"]["num_outputs"]
+
+ return_index = num_outputs > 1
+ return_inverse = num_outputs > 2
+ return_counts = num_outputs > 3
+
+ unique = relax.op.unique(
+ data,
+ sorted=sorted_flag,
+ return_index=return_index,
+ return_inverse=return_inverse,
+ return_counts=return_counts,
+ axis=axis,
+ )
+
unique_numbers = tir.Var("unique_numbers", "int64")
input_shape = data.struct_info.shape
dtype = data.struct_info.dtype
if axis is None:
- # flatten the input tensor
- return bb.match_cast(unique,
relax.TensorStructInfo((unique_numbers,), dtype))
-
- axis = axis if axis >= 0 else len(input_shape) + axis
- if axis < 0 or axis >= len(input_shape):
- raise ValueError(f"Axis {axis} is out of bounds")
- output_shape = [
- input_shape[i] if i != axis else unique_numbers for i in
range(len(input_shape))
- ]
- return bb.match_cast(unique, relax.TensorStructInfo(output_shape,
dtype))
+ output_shape = (unique_numbers,)
+ else:
+ axis = axis if axis >= 0 else len(input_shape) + axis
+ if axis < 0 or axis >= len(input_shape):
+ raise ValueError(f"Axis {axis} is out of bounds")
+ output_shape = [
+ input_shape[i] if i != axis else unique_numbers for i in
range(len(input_shape))
+ ]
+
+ if num_outputs == 1:
+ return bb.match_cast(unique, relax.TensorStructInfo(output_shape,
dtype))
+
+ outputs = [bb.match_cast(unique[0],
relax.TensorStructInfo(output_shape, dtype))]
+ tuple_idx = 1 # Track which index in the tuple we're at
+
+ if return_index:
+ index_shape = (unique_numbers,)
+ index_sinfo = relax.TensorStructInfo(index_shape, "int64")
+ outputs.append(bb.match_cast(unique[tuple_idx], index_sinfo))
+ tuple_idx += 1
+
+ if return_inverse:
+ # ONNX spec: inverse_indices is always 1D
+ # When axis is None: shape is [X.size]
+ # When axis is specified: shape is [X.shape[axis]]
+ inverse_shape = (tir.Var("inverse_numbers", "int64"),)
+ inverse_sinfo = relax.TensorStructInfo(inverse_shape, "int64")
+ outputs.append(bb.match_cast(unique[tuple_idx], inverse_sinfo))
+ tuple_idx += 1
+
+ if return_counts:
+ count_shape = (unique_numbers,)
+ count_sinfo = relax.TensorStructInfo(count_shape, "int64")
+ outputs.append(bb.match_cast(unique[tuple_idx], count_sinfo))
+
+ return relax.Tuple(outputs)
class NonZero(OnnxOpConverter):
diff --git a/python/tvm/relax/op/set.py b/python/tvm/relax/op/set.py
index 87fd067e5d..a7d837d673 100644
--- a/python/tvm/relax/op/set.py
+++ b/python/tvm/relax/op/set.py
@@ -99,17 +99,84 @@ def numpy_unique(
"""
import builtins
- # TODO(prakalp): add support for returning a tuple when return_inverse or
return_counts is True
- if bool(return_index) or bool(return_inverse) or bool(return_counts):
- raise NotImplementedError("missing support return_inverse or
return_counts set to true")
x_numpy = x.numpy()
- # TODO(prakalp): use torch.unique instead of numpy when torch is installed
in ci.
- output_sorted_numpy, indices = np.unique(x_numpy, return_index=True,
axis=axis)
- if sorted:
- return tvm.runtime.tensor(output_sorted_numpy)
- output_numpy = np.take(x_numpy, builtins.sorted(indices), axis=axis)
- return tvm.runtime.tensor(output_numpy)
+ # Call numpy.unique with all the requested return flags
+ result = np.unique(
+ x_numpy,
+ return_index=bool(return_index),
+ return_inverse=bool(return_inverse),
+ return_counts=bool(return_counts),
+ axis=axis,
+ )
+
+ # If no optional outputs requested, result is just the unique values
+ if not bool(return_index) and not bool(return_inverse) and not
bool(return_counts):
+ unique_values = result
+ if not sorted:
+ indices = np.unique(x_numpy, return_index=True, axis=axis)[1]
+ unique_values = np.take(x_numpy, builtins.sorted(indices),
axis=axis)
+ return tvm.runtime.tensor(unique_values)
+
+ # Otherwise, numpy returns a tuple
+ unique_values = result[0]
+ output_list = []
+ result_idx = 1
+
+ # Handle sorting for unique values
+ if not sorted and bool(return_index):
+ # Get the indices from numpy result
+ indices = result[result_idx]
+ result_idx += 1
+ # Sort indices to get original order
+ sort_order = np.argsort(indices)
+ unique_values = np.take(unique_values, sort_order, axis=axis)
+ indices = np.sort(indices)
+ output_list.append(tvm.runtime.tensor(unique_values))
+ output_list.append(tvm.runtime.tensor(indices))
+ elif not sorted:
+ # Need to get indices to reorder
+ _, indices = np.unique(x_numpy, return_index=True, axis=axis)
+ sort_order = np.argsort(indices)
+ unique_values = np.take(unique_values, sort_order, axis=axis)
+ output_list.append(tvm.runtime.tensor(unique_values))
+ if bool(return_index):
+ indices_from_result = result[result_idx]
+ result_idx += 1
+
output_list.append(tvm.runtime.tensor(np.sort(indices_from_result)))
+ else:
+ # Sorted case
+ output_list.append(tvm.runtime.tensor(unique_values))
+ if bool(return_index):
+ output_list.append(tvm.runtime.tensor(result[result_idx]))
+ result_idx += 1
+
+ if bool(return_inverse):
+ inverse_indices = result[result_idx]
+ if not sorted:
+ # Need to remap inverse indices to match reordered unique values
+ _, orig_indices = np.unique(x_numpy, return_index=True, axis=axis)
+ sort_order = np.argsort(orig_indices)
+ inverse_mapping = np.empty_like(sort_order)
+ inverse_mapping[sort_order] = np.arange(len(sort_order))
+ inverse_indices = inverse_mapping[inverse_indices]
+ # ONNX spec: inverse_indices is always 1D
+ # When axis is None, it has length X.size (flattened)
+ # When axis is specified, it has length X.shape[axis]
+ # numpy.unique already returns 1D inverse_indices, so no reshaping
needed
+ output_list.append(tvm.runtime.tensor(inverse_indices))
+ result_idx += 1
+
+ if bool(return_counts):
+ counts = result[result_idx]
+ if not sorted:
+ # Reorder counts to match reordered unique values
+ _, orig_indices = np.unique(x_numpy, return_index=True, axis=axis)
+ sort_order = np.argsort(orig_indices)
+ counts = counts[sort_order]
+ output_list.append(tvm.runtime.tensor(counts))
+
+ return tuple(output_list)
def nonzero(x: Expr) -> Expr:
diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc
index d80c73b131..c3ee496794 100644
--- a/src/relax/op/tensor/set.cc
+++ b/src/relax/op/tensor/set.cc
@@ -101,16 +101,41 @@ StructInfo InferStructInfoUnique(const Call& call, const
BlockBuilder& ctx) {
output_sinfo.push_back(TensorStructInfo(data_sinfo->dtype, /*ndim=*/1,
data_sinfo->vdevice));
}
- // index, reverse and counts
- TensorStructInfo int_return{nullptr};
- if (data_sinfo->ndim == 0) {
- int_return = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64),
/*value=*/1)}),
- DataType::Int(64), data_sinfo->vdevice);
- } else {
- int_return = TensorStructInfo(DataType::Int(64), /*ndim=*/1,
data_sinfo->vdevice);
+ // index, inverse_indices, and counts
+ // index: always 1D
+ if (f_convert_to_int64(return_index->value)) {
+ TensorStructInfo index_sinfo{nullptr};
+ if (data_sinfo->ndim == 0) {
+ index_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64),
/*value=*/1)}),
+ DataType::Int(64), data_sinfo->vdevice);
+ } else {
+ index_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1,
data_sinfo->vdevice);
+ }
+ output_sinfo.push_back(index_sinfo);
+ }
+
+ // inverse_indices: always 1D per ONNX spec
+ if (f_convert_to_int64(return_inverse->value)) {
+ TensorStructInfo inverse_sinfo{nullptr};
+ if (data_sinfo->ndim == 0) {
+ inverse_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64),
/*value=*/1)}),
+ DataType::Int(64), data_sinfo->vdevice);
+ } else {
+ inverse_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1,
data_sinfo->vdevice);
+ }
+ output_sinfo.push_back(inverse_sinfo);
}
- for (int i = 0; i < n_int_return; ++i) {
- output_sinfo.push_back(int_return);
+
+ // counts: always 1D
+ if (f_convert_to_int64(return_counts->value)) {
+ TensorStructInfo counts_sinfo{nullptr};
+ if (data_sinfo->ndim == 0) {
+ counts_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64),
/*value=*/1)}),
+ DataType::Int(64), data_sinfo->vdevice);
+ } else {
+ counts_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1,
data_sinfo->vdevice);
+ }
+ output_sinfo.push_back(counts_sinfo);
}
if (output_sinfo.size() == 1) {
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 6f5c7da5ef..df94c13478 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2922,19 +2922,32 @@ def test_onehot():
@pytest.mark.parametrize("axis", [None, 0, 1, -1])
@pytest.mark.parametrize("sorted", [0, 1])
-def test_unique(axis: Optional[int], sorted: int):
- input_shape = [32, 32]
[email protected]("num_outputs", [1, 2, 3, 4])
+def test_unique(axis: Optional[int], sorted: int, num_outputs: int):
+ input_shape = [8, 8]
if axis is None:
output_shape = [-1]
else:
- output_shape = [32, 32]
+ output_shape = [8, 8]
output_shape[axis] = -1
- unique_node = helper.make_node("Unique", ["x"], ["y"], axis=axis,
sorted=sorted)
+
+ output_names = ["y", "indices", "inverse_indices", "counts"][:num_outputs]
+ unique_node = helper.make_node("Unique", ["x"], output_names, axis=axis,
sorted=sorted)
+
+ outputs = [helper.make_tensor_value_info("y", TensorProto.FLOAT,
output_shape)]
+ if num_outputs > 1:
+ outputs.append(helper.make_tensor_value_info("indices",
TensorProto.INT64, [-1]))
+ if num_outputs > 2:
+ # ONNX spec: inverse_indices is always 1D
+ outputs.append(helper.make_tensor_value_info("inverse_indices",
TensorProto.INT64, [-1]))
+ if num_outputs > 3:
+ outputs.append(helper.make_tensor_value_info("counts",
TensorProto.INT64, [-1]))
+
graph = helper.make_graph(
[unique_node],
"unique_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
input_shape)],
- outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT,
output_shape)],
+ outputs=outputs,
)
model = helper.make_model(graph, producer_name="unique_test")
check_correctness(model)