This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b1918c74fd [Fix][Relax]: ONNX Clip NaN bounds and preserve input NaN
(ORT parity) (#19535)
b1918c74fd is described below
commit b1918c74fd2516ed6e9c1438b6e942e4af4e2452
Author: ConvolutedDog <[email protected]>
AuthorDate: Tue May 12 19:06:06 2026 +0800
[Fix][Relax]: ONNX Clip NaN bounds and preserve input NaN (ORT parity)
(#19535)
This PR fixes https://github.com/apache/tvm/issues/19533:
- Sanitize floating tensor min/max: replace NaN with +inf/-inf before
topi max/min so bounds match ONNX "unbounded" semantics where NaN bounds
default to no constraint.
- After clamping, preserve NaNs from the input tensor on floating
dtypes.
- Extend check_correctness with equal_nan for float outputs containing
NaN.
- Add parametrized Clip opset-13 tests for NaN min/max tensor bounds.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 36 ++++++++++++++--
tests/python/relax/test_frontend_onnx.py | 49 ++++++++++++++++++++++
.../test_meta_schedule_search_strategy.py | 2 +-
3 files changed, 83 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 622e262cc4..878f976c95 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -52,12 +52,26 @@ import tvm
from tvm import TVMError, relax, tirx, topi
from tvm.ir import IRModule
from tvm.ir.supply import NameSupply
+from tvm.runtime import DataType, DataTypeCode
from tvm.tirx.generic import cast
from tvm.topi.utils import get_const_tuple
from ..common import autopad
+def _relax_dtype_is_floating_point(dtype: str) -> bool:
+ """Whether a Relax dtype string is a floating point type."""
+ try:
+ code = DataType(dtype).type_code
+ except (ValueError, TypeError, TVMError):
+ return False
+ return (
+ code == DataTypeCode.FLOAT
+ or code == DataTypeCode.BFLOAT
+ or (code >= DataTypeCode.Float8E3M4 and code <=
DataTypeCode.Float4E2M1FN)
+ )
+
+
def get_type(elem_type: str | int) -> str:
"""Converts onnx integer datatype to numpy datatype"""
# If a string was passed instead of a tensor type, it does not need
@@ -311,6 +325,7 @@ class OnnxOpConverter:
return getattr(cls, f"_impl_v{version}")
raise NotImplementedError(f"opset version {version} of {cls.__name__}
not implemented")
+
class QuantizeLinear(OnnxOpConverter):
@classmethod
def _impl_v10(cls, bb, inputs, attr, params):
@@ -379,6 +394,7 @@ class DynamicQuantizeLinear(OnnxOpConverter):
y = relax.op.quantize(x, y_scale, y_zero_point, axis=0,
out_dtype="uint8")
return relax.Tuple([y, y_scale, y_zero_point])
+
class MatMul(OnnxOpConverter):
"""Converts an onnx MatMul node into an equivalent Relax expression."""
@@ -1350,6 +1366,15 @@ class Where(OnnxOpConverter):
class Clip(OnnxOpConverter):
"""Converts an onnx Clip node into an equivalent Relax expression."""
+ @staticmethod
+ def _sanitize_nan_clip_bound(bb, bound: relax.Expr, *, for_min: bool) ->
relax.Expr:
+ """ONNX/ORT treat NaN clip bounds as unbounded; plain max/min with NaN
poisons output."""
+ dtype = bound.struct_info.dtype
+ if not _relax_dtype_is_floating_point(dtype):
+ return bound
+ repl = -_np.inf if for_min else _np.inf
+ return bb.emit(relax.op.where(relax.op.isnan(bound), relax.const(repl,
dtype), bound))
+
@classmethod
def _impl_v1(cls, bb, inputs, attr, params):
min = float(attr.get("min", -_np.inf))
@@ -1366,11 +1391,16 @@ class Clip(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
- results = inputs[0]
+ x: Any = inputs[0]
+ results = x
if inputs[1] is not None:
- results = bb.emit_te(topi.maximum, results, inputs[1])
+ lo = cls._sanitize_nan_clip_bound(bb, inputs[1], for_min=True)
+ results = bb.emit_te(topi.maximum, results, lo)
if inputs[2] is not None:
- results = bb.emit_te(topi.minimum, results, inputs[2])
+ hi = cls._sanitize_nan_clip_bound(bb, inputs[2], for_min=False)
+ results = bb.emit_te(topi.minimum, results, hi)
+ if _relax_dtype_is_floating_point(x.struct_info.dtype):
+ results = bb.emit(relax.op.where(relax.op.isnan(x), x, results))
return results
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 94b85ab95a..c46709e33d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1597,6 +1597,55 @@ def test_clip_v6(max, min):
check_correctness(model, opset=10)
[email protected](
+ "min,max",
+ [
+ pytest.param(
+ np.array(0.0, dtype=np.float32),
+ np.array(6.0, dtype=np.float32),
+ ),
+ pytest.param(
+ np.array(0.0, dtype=np.float32),
+ np.array(np.nan, dtype=np.float32),
+ ),
+ pytest.param(
+ np.array(np.nan, dtype=np.float32),
+ np.array(6.0, dtype=np.float32),
+ ),
+ pytest.param(
+ np.array(np.nan, dtype=np.float32),
+ np.array(np.nan, dtype=np.float32),
+ ),
+ ],
+)
[email protected](
+ "input",
+ [
+ np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32),
+ np.array([0.5, -3.0, 4.5, 11.0, np.nan], dtype=np.float32),
+ ],
+)
+def test_clip_v13(input, min, max):
+ # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT);
input NaN preserved.
+ clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
+ graph = helper.make_graph(
+ [clip_node],
+ "clip_v13_nan_max",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [5]),
+ helper.make_tensor_value_info("min", TensorProto.FLOAT, []),
+ helper.make_tensor_value_info("max", TensorProto.FLOAT, []),
+ ],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[5])],
+ )
+ model = helper.make_model(graph, producer_name="clip_v13_nan_max")
+ check_correctness(
+ model,
+ inputs={"input": input, "min": min, "max": max},
+ opset=13,
+ )
+
+
def test_equal():
equal_node = helper.make_node("Equal", ["a", "b"], ["output"])
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
index 370eff27c7..f9cec06aea 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
@@ -324,7 +324,7 @@ def
test_meta_schedule_evolutionary_search_fail_init_population(): # pylint: di
assert candidates is None
-def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace() #
pylint: disable = invalid-name
+def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace(): #
pylint: disable = invalid-name
# Construct an incompatible measured trace: it references block name
"other",
# which doesn't exist in Matmul. Replaying this trace should fail and be
skipped.
wrong_sch = Schedule(OtherBlock)