This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d9d129c93 [Relax][ONNX] Fix Cast operator float->int NaN/Inf handling
(#19626)
4d9d129c93 is described below
commit 4d9d129c93a0ac93e1c2643b3f35a67b05c0b451
Author: Neo Chien <[email protected]>
AuthorDate: Fri Jun 5 07:51:55 2026 +0800
[Relax][ONNX] Fix Cast operator float->int NaN/Inf handling (#19626)
Hi Committers,
This PR is trying to fix issues #19542. Any suggestions would be
appreciated if you are available.
### Root cause:
FP to INT lowering can be implementation-defined or UB for NaN/Inf and
extreme floats, producing backend-dependent results versus ONNX Runtime.
### Solution:
Apply a minimal, deterministic frontend sanitization for float to
integer Casts: map NaN and ±Inf to 0.0 before astype. This prevents
NaN/Inf from reaching backend fptosi/fptoui lowers and yields stable
behavior across targets.
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 57 +++++++++++++++++++++++++
tests/python/relax/test_frontend_onnx.py | 31 ++++++++++++++
2 files changed, 88 insertions(+)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index b82fceff1d..3a2a0fdaf2 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1105,6 +1105,63 @@ class Cast(OnnxOpConverter):
return relax.const(output, to_type)
if isinstance(inputs[0], relax.PrimValue):
return relax.PrimValue(inputs[0].value.astype(to_type))
+
+ try:
+ np_dst = _np.dtype(str(to_type))
+ except Exception:
+ return relax.op.astype(inputs[0], to_type)
+
+ if np_dst.kind in ("i", "u"):
+ src = inputs[0]
+ src_dtype = getattr(getattr(src, "struct_info", None), "dtype",
None) or getattr(
+ src, "dtype", None
+ )
+ if src_dtype is not None and
_relax_dtype_is_floating_point(src_dtype):
+ x_sanitized = bb.emit(
+ relax.op.where(
+ relax.op.logical_not(relax.op.isfinite(src)),
+ relax.const(0.0, src_dtype),
+ src,
+ )
+ )
+ dst_str = str(to_type)
+ if dst_str.startswith("uint"):
+ signed = False
+ bits = int(dst_str[4:])
+ elif dst_str.startswith("int"):
+ signed = True
+ bits = int(dst_str[3:])
+ else:
+ return relax.op.astype(x_sanitized, to_type)
+
+ if bits == 64:
+ return relax.op.astype(x_sanitized, to_type)
+
+ temp_dtype = "int64" if bits >= 32 else "int32"
+ t = relax.op.astype(x_sanitized, temp_dtype)
+ if bits == 32:
+ two_pow = relax.const(1 << bits, temp_dtype)
+ uw = relax.op.floor_mod(t, two_pow)
+ else:
+ mask_val = (1 << bits) - 1
+ mask = relax.const(mask_val, temp_dtype)
+ uw = relax.op.bitwise_and(t, mask)
+ if signed:
+ half = 1 << (bits - 1)
+ half_c = relax.const(half, temp_dtype)
+ if bits == 32:
+ two_pow = relax.const(1 << bits, temp_dtype)
+ else:
+ two_pow = relax.op.add(mask, relax.const(1,
temp_dtype))
+ wrapped = relax.op.where(
+ relax.op.greater_equal(uw, half_c),
+ relax.op.subtract(uw, two_pow),
+ uw,
+ )
+ else:
+ wrapped = uw
+ return relax.op.astype(wrapped, to_type)
+
return relax.op.astype(inputs[0], to_type)
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 7ee10993a4..9a644c4a3a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -863,6 +863,37 @@ def test_cast(from_type, to_type):
check_correctness(model, opset=13)
[email protected]("to_type", [TensorProto.INT64, TensorProto.UINT64])
+def test_cast_float_to_64bit_int_dynamic(to_type):
+ cast_node = helper.make_node("Cast", ["a"], ["b"], to=to_type)
+ graph = helper.make_graph(
+ [cast_node],
+ "cast_float_to_64bit_int_dynamic_test",
+ inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [1, 8])],
+ outputs=[helper.make_tensor_value_info("b", to_type, [1, 8])],
+ )
+ model = helper.make_model(graph,
producer_name="cast_float_to_64bit_int_dynamic_test")
+ inputs = {"a": np.array([[0.0, 1.2, 2.8, 7.9, 15.1, 31.7, 63.4, 127.9]],
dtype=np.float32)}
+ check_correctness(model, inputs=inputs, opset=13, check_dtypes=True)
+
+
+def test_cast_nan_inf_to_int8():
+ vals = np.array([300.0, np.nan, np.inf, -np.inf, 50.0, -50.0],
dtype=np.float32)
+ node = helper.make_node("Cast", inputs=["a"], outputs=["b"],
to=TensorProto.INT8)
+ graph = helper.make_graph(
+ [node],
+ "cast_nan_inf_test",
+ inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT,
list(vals.shape))],
+ outputs=[helper.make_tensor_value_info("b", TensorProto.INT8,
list(vals.shape))],
+ )
+ model = helper.make_model(graph, producer_name="cast_nan_inf_test")
+ tvm_output = run_in_tvm(model, inputs={"a": vals}, opset=13)
+ out_np = tvm_output.numpy()
+ expected = np.array([44, 0, 0, 0, 50, -50], dtype=np.int8)
+ assert out_np.dtype == np.int8
+ np.testing.assert_array_equal(out_np, expected)
+
+
def test_gather():
def _verify_gather(data_shape, indices, out_shape, axis=0):
gather_node = helper.make_node("Gather", ["data", "indices"], ["y"],
axis=axis)