This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1720d305ae [Relax][ONNX] Fix TopK scalar K extraction in from_onnx
(#19573)
1720d305ae is described below
commit 1720d305ae785cc33fa453a3d0f7eb91a0073534
Author: Javier De Jesus <[email protected]>
AuthorDate: Tue May 19 18:02:59 2026 +0200
[Relax][ONNX] Fix TopK scalar K extraction in from_onnx (#19573)
### Root Cause
`TopK._impl_v11` extracted `k` with `int(k.data.numpy())`. ONNX emits
`K` as a single-element 1-D tensor constant, so `numpy()` returns a 1-D
array and `int()` raises `TypeError: only 0-dimensional arrays can be
converted to Python scalars`, failing conversion of any model with a
`TopK` node.
### Solution
Resolve `k` with `get_constant(inputs[1], params)` and extract the
scalar with `.item()`, matching the `Trilu` and `Reshape` converters in
the same file. `get_constant` also handles `k` arriving as a parameter
when `keep_params_in_input=True`.
### Test Plan
`test_topk` in `tests/python/relax/test_frontend_onnx.py` already builds
`K` as a single-element 1-D INT64 constant, so it exercises this path.
`.item()` returns the scalar for both single-element 1-D and 0-d
constants.
### Issue
Fixes #19571
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5f41644149..6624110241 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4242,11 +4242,10 @@ class TopK(OnnxOpConverter):
@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
data = inputs[0]
- k = inputs[1]
+ k = get_constant(inputs[1], params)
if not isinstance(k, relax.Constant):
raise ValueError("TopK k must be a constant")
- # ONNX represents k as a tensor of shape [1]; flatten before scalar
cast.
- k = int(k.data.numpy().reshape(-1)[0])
+ k = int(k.data.numpy().item())
axis = attr.get("axis", -1)
largest = attr.get("largest", 1)
sorted = attr.get("sorted", 1)