This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b1918c74fd [Fix][Relax]: ONNX Clip NaN bounds and preserve input NaN 
(ORT parity) (#19535)
b1918c74fd is described below

commit b1918c74fd2516ed6e9c1438b6e942e4af4e2452
Author: ConvolutedDog <[email protected]>
AuthorDate: Tue May 12 19:06:06 2026 +0800

    [Fix][Relax]: ONNX Clip NaN bounds and preserve input NaN (ORT parity) 
(#19535)
    
    This PR fixes https://github.com/apache/tvm/issues/19533:
    - Sanitize floating tensor min/max: replace NaN with +inf/-inf before
    topi max/min so bounds match ONNX "unbounded" semantics where NaN bounds
    default to no constraint.
    - After clamping, preserve NaNs from the input tensor on floating
    dtypes.
    - Extend check_correctness with equal_nan for float outputs containing
    NaN.
    - Add parametrized Clip opset-13 tests for NaN min/max tensor bounds.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py    | 36 ++++++++++++++--
 tests/python/relax/test_frontend_onnx.py           | 49 ++++++++++++++++++++++
 .../test_meta_schedule_search_strategy.py          |  2 +-
 3 files changed, 83 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 622e262cc4..878f976c95 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -52,12 +52,26 @@ import tvm
 from tvm import TVMError, relax, tirx, topi
 from tvm.ir import IRModule
 from tvm.ir.supply import NameSupply
+from tvm.runtime import DataType, DataTypeCode
 from tvm.tirx.generic import cast
 from tvm.topi.utils import get_const_tuple
 
 from ..common import autopad
 
 
+def _relax_dtype_is_floating_point(dtype: str) -> bool:
+    """Whether a Relax dtype string is a floating point type."""
+    try:
+        code = DataType(dtype).type_code
+    except (ValueError, TypeError, TVMError):
+        return False
+    return (
+        code == DataTypeCode.FLOAT
+        or code == DataTypeCode.BFLOAT
+        or (code >= DataTypeCode.Float8E3M4 and code <= 
DataTypeCode.Float4E2M1FN)
+    )
+
+
 def get_type(elem_type: str | int) -> str:
     """Converts onnx integer datatype to numpy datatype"""
     # If a string was passed instead of a tensor type, it does not need
@@ -311,6 +325,7 @@ class OnnxOpConverter:
             return getattr(cls, f"_impl_v{version}")
         raise NotImplementedError(f"opset version {version} of {cls.__name__} 
not implemented")
 
+
 class QuantizeLinear(OnnxOpConverter):
     @classmethod
     def _impl_v10(cls, bb, inputs, attr, params):
@@ -379,6 +394,7 @@ class DynamicQuantizeLinear(OnnxOpConverter):
         y = relax.op.quantize(x, y_scale, y_zero_point, axis=0, 
out_dtype="uint8")
         return relax.Tuple([y, y_scale, y_zero_point])
 
+
 class MatMul(OnnxOpConverter):
     """Converts an onnx MatMul node into an equivalent Relax expression."""
 
@@ -1350,6 +1366,15 @@ class Where(OnnxOpConverter):
 class Clip(OnnxOpConverter):
     """Converts an onnx Clip node into an equivalent Relax expression."""
 
+    @staticmethod
+    def _sanitize_nan_clip_bound(bb, bound: relax.Expr, *, for_min: bool) -> 
relax.Expr:
+        """ONNX/ORT treat NaN clip bounds as unbounded; plain max/min with NaN 
poisons output."""
+        dtype = bound.struct_info.dtype
+        if not _relax_dtype_is_floating_point(dtype):
+            return bound
+        repl = -_np.inf if for_min else _np.inf
+        return bb.emit(relax.op.where(relax.op.isnan(bound), relax.const(repl, 
dtype), bound))
+
     @classmethod
     def _impl_v1(cls, bb, inputs, attr, params):
         min = float(attr.get("min", -_np.inf))
@@ -1366,11 +1391,16 @@ class Clip(OnnxOpConverter):
 
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
-        results = inputs[0]
+        x: Any = inputs[0]
+        results = x
         if inputs[1] is not None:
-            results = bb.emit_te(topi.maximum, results, inputs[1])
+            lo = cls._sanitize_nan_clip_bound(bb, inputs[1], for_min=True)
+            results = bb.emit_te(topi.maximum, results, lo)
         if inputs[2] is not None:
-            results = bb.emit_te(topi.minimum, results, inputs[2])
+            hi = cls._sanitize_nan_clip_bound(bb, inputs[2], for_min=False)
+            results = bb.emit_te(topi.minimum, results, hi)
+        if _relax_dtype_is_floating_point(x.struct_info.dtype):
+            results = bb.emit(relax.op.where(relax.op.isnan(x), x, results))
         return results
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 94b85ab95a..c46709e33d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1597,6 +1597,55 @@ def test_clip_v6(max, min):
     check_correctness(model, opset=10)
 
 
[email protected](
+    "min,max",
+    [
+        pytest.param(
+            np.array(0.0, dtype=np.float32),
+            np.array(6.0, dtype=np.float32),
+        ),
+        pytest.param(
+            np.array(0.0, dtype=np.float32),
+            np.array(np.nan, dtype=np.float32),
+        ),
+        pytest.param(
+            np.array(np.nan, dtype=np.float32),
+            np.array(6.0, dtype=np.float32),
+        ),
+        pytest.param(
+            np.array(np.nan, dtype=np.float32),
+            np.array(np.nan, dtype=np.float32),
+        ),
+    ],
+)
[email protected](
+    "input",
+    [
+        np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32),
+        np.array([0.5, -3.0, 4.5, 11.0, np.nan], dtype=np.float32),
+    ],
+)
+def test_clip_v13(input, min, max):
+    # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); 
input NaN preserved.
+    clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
+    graph = helper.make_graph(
+        [clip_node],
+        "clip_v13_nan_max",
+        inputs=[
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, [5]),
+            helper.make_tensor_value_info("min", TensorProto.FLOAT, []),
+            helper.make_tensor_value_info("max", TensorProto.FLOAT, []),
+        ],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[5])],
+    )
+    model = helper.make_model(graph, producer_name="clip_v13_nan_max")
+    check_correctness(
+        model,
+        inputs={"input": input, "min": min, "max": max},
+        opset=13,
+    )
+
+
 def test_equal():
     equal_node = helper.make_node("Equal", ["a", "b"], ["output"])
 
diff --git 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
index 370eff27c7..f9cec06aea 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
@@ -324,7 +324,7 @@ def 
test_meta_schedule_evolutionary_search_fail_init_population():  # pylint: di
     assert candidates is None
 
 
-def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace()  # 
pylint: disable = invalid-name
+def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace():  # 
pylint: disable = invalid-name
     # Construct an incompatible measured trace: it references block name 
"other",
     # which doesn't exist in Matmul. Replaying this trace should fail and be 
skipped.
     wrong_sch = Schedule(OtherBlock)

Reply via email to