dutZ1855 opened a new issue, #18590:
URL: https://github.com/apache/tvm/issues/18590

   ### Expected behavior
   
   Per ONNX `Round` operator spec:
   
   ```
   import onnx
   from onnx import defs
   print('onnx.__version__ =', onnx.__version__)
   # Print available versions for Round
   schemas = defs.get_all_schemas_with_history()
   round_s = [s for s in schemas if s.name=='Round']
   print('Round schema versions:', sorted({s.since_version for s in round_s}))
   for v in sorted({s.since_version for s in round_s}):
       s = defs.get_schema('Round', v)
       print('\n=== Round since_version', v, '===' )
       print('domain:', s.domain)
       print('doc:')
       print(s.doc)
   ```
   
   > onnx.__version__ = 1.17.0
   > Round schema versions: [11, 22]
   > 
   > === Round since_version 11 ===
   > domain: 
   > doc:
   > 
   > Round takes one input Tensor and rounds the values, element-wise, meaning
   > it finds the nearest integer for each value.
   > In case of halves, the rule is to round them to the nearest even integer.
   > If input x is integral, +0, -0, NaN,  or infinite, x itself is returned.
   > The output tensor has the same shape and type as the input.
   > 
   > Examples:
   > ```
   > round([0.9]) = [1.0]
   > round([2.5]) = [2.0]
   > round([2.3]) = [2.0]
   > round([1.5]) = [2.0]
   > round([-4.5]) = [-4.0]
   > ```
   > 
   > 
   > === Round since_version 22 ===
   > domain: 
   > doc:
   > 
   > Round takes one input Tensor and rounds the values, element-wise, meaning
   > it finds the nearest integer for each value.
   > In case of halves, the rule is to round them to the nearest even integer.
   > If input x is integral, +0, -0, NaN,  or infinite, x itself is returned.
   > The output tensor has the same shape and type as the input.
   > 
   > Examples:
   > ```
   > round([0.9]) = [1.0]
   > round([2.5]) = [2.0]
   > round([2.3]) = [2.0]
   > round([1.5]) = [2.0]
   > round([-4.5]) = [-4.0]
   > ```
   
   Therefore, for this repro (where `sigmoid(0)=0.5`):
   - `Round(0.5) == 0` (nearest-even / ties-to-even)
   
   ### Actual behavior
   
   For the following model,
   
   <img width="218" height="281" alt="Image" 
src="https://github.com/user-attachments/assets/5c2b12b7-706d-4bdb-ad52-b7daef6a3b7d";
 />
   
   With TVM (Relax, LLVM target) for this repro:
   - `Round(0.5) == 1`
   - 
   ### Environment
   
   Operating System:Ubuntu 22.04.4 LTS
   TVM version:0.23.0dev
   pytorch version:2.9.1
   ort version:1.23.2
   onnx version: 1.20.0
   openvino: 2025.4.0
   python:3.11.14
   
   ### Steps to reproduce
   
   **build a model**
   
   ```
   from __future__ import annotations
   
   import argparse
   from pathlib import Path
   
   import onnx
   from onnx import TensorProto, helper
   
   
   def make_model(path="round_sigmoid.onnx"):
       # x = 0.0 (double scalar)
       const0 = helper.make_node(
           "Constant",
           inputs=[],
           outputs=["x"],
           value=helper.make_tensor("c0", TensorProto.DOUBLE, dims=[], 
vals=[0.0]),
       )
       sig = helper.make_node("Sigmoid", inputs=["x"], outputs=["s"])
       rnd = helper.make_node("Round", inputs=["s"], outputs=["y"])
   
       y_info = helper.make_tensor_value_info("y", TensorProto.DOUBLE, [])
       graph = helper.make_graph([const0, sig, rnd], "g", inputs=[], 
outputs=[y_info])
       model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
18)])
       onnx.save(model, path)
       print("saved:", path)
   
   def main() -> int:
       ap = argparse.ArgumentParser(description="Minimal repro for 
Round(sigmoid(0)) tie-breaking.")
       ap.add_argument("--out", type=Path, 
default=Path("round_sigmoid_half.onnx"), help="Where to save the ONNX model.")
       args = ap.parse_args()
   
       out_path = args.out.resolve()
       out_path.parent.mkdir(parents=True, exist_ok=True)
       make_model(out_path.as_posix())
       return 0
   
   if __name__ == "__main__":
       raise SystemExit(main())
   ```
   **Comparison results**
   ```
   from __future__ import annotations
   
   import argparse
   import sys
   from pathlib import Path
   from typing import Any, Optional
   
   import numpy as np
   import onnx
   
   
   def _fmt(x: Any) -> str:
       if isinstance(x, np.ndarray):
           return f"ndarray(shape={x.shape}, dtype={x.dtype}, value={x})"
       return repr(x)
   
   
   def _run_torch_reference() -> Optional[np.ndarray]:
       """Reference for this specific model: Round(sigmoid(0.0))."""
       try:
           import torch  # type: ignore
       except Exception as e:
           print("[torch] not available:", e)
           return None
       x = torch.tensor(0.0, dtype=torch.float64)
       y = torch.round(torch.sigmoid(x))
       out = np.array(y.item(), dtype=np.float64)
       print("[torch] torch.__version__ =", getattr(torch, "__version__", None))
       print("[torch] y =", _fmt(out))
       return out
   
   
   def _run_ort(model_bytes: bytes) -> Optional[np.ndarray]:
       try:
           import onnxruntime as ort  # type: ignore
       except Exception as e:
           print("[ort] not available:", e)
           return None
       sess_opts = ort.SessionOptions()
       sess_opts.graph_optimization_level = 
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
       sess = ort.InferenceSession(model_bytes, sess_options=sess_opts, 
providers=["CPUExecutionProvider"])
       outs = sess.run(None, {})  # no inputs
       out = np.array(outs[0])
       print("[ort] onnxruntime.__version__ =", getattr(ort, "__version__", 
None))
       print("[ort] y =", _fmt(out))
       return out
   
   
   def _run_openvino(model_path: Path) -> Optional[np.ndarray]:
       try:
           import openvino as ov  # type: ignore
       except Exception as e:
           print("[ov] not available:", e)
           return None
       core = ov.Core()
       model = core.read_model(model_path.as_posix())
       compiled = core.compile_model(model, "CPU")
       req = compiled.create_infer_request()
       raw = req.infer({})  # no inputs
       out_port = compiled.outputs[0]
       out = np.array(raw[out_port])
       print("[ov] openvino.__version__ =", getattr(ov, "__version__", None))
       print("[ov] y =", _fmt(out))
       return out
   
   
   def _ensure_repo_tvm_python_on_syspath() -> None:
       repo_root = Path(__file__).resolve().parents[3]
       tvm_python = repo_root / "tvm" / "python"
       if tvm_python.exists():
           sys.path.insert(0, tvm_python.as_posix())
   
   
   def _tvm_to_numpy(x: Any) -> np.ndarray:
       if hasattr(x, "numpy"):
           return x.numpy()
       if isinstance(x, (int, float, bool, np.generic)):
           return np.array(x)
       return np.array(x)
   
   
   def _run_tvm(model_path: Path) -> Optional[np.ndarray]:
       _ensure_repo_tvm_python_on_syspath()
       try:
           import tvm  # type: ignore
           from tvm import relax  # type: ignore
           from tvm.relax.frontend.onnx import from_onnx  # type: ignore
       except Exception as e:
           print("[tvm] not available:", e)
           return None
   
       onnx_model = onnx.load(model_path.as_posix())
       mod = from_onnx(onnx_model, shape_dict={})
       if isinstance(mod, (list, tuple)):
           mod = mod[0]
   
       tgt = tvm.target.Target("llvm")
       pipeline = relax.pipeline.get_default_pipeline(tgt)
       with tvm.transform.PassContext(opt_level=3):
           ex = relax.build(mod, target=tgt, relax_pipeline=pipeline)
       vm = relax.VirtualMachine(ex, tvm.cpu())
   
       # no inputs
       try:
           out = vm["main"]()
       except Exception:
           vm.set_input("main")
           vm.invoke_stateful("main")
           out = vm.get_outputs("main")
   
       out_np = _tvm_to_numpy(out)
       print("[tvm] tvm.__file__ =", getattr(tvm, "__file__", None))
       print("[tvm] tvm.__version__ =", getattr(tvm, "__version__", None))
       print("[tvm] y =", _fmt(out_np))
       return out_np
   
   
   def _eq(a: Optional[np.ndarray], b: Optional[np.ndarray]) -> Optional[bool]:
       if a is None or b is None:
           return None
       try:
           return bool(np.array_equal(a, b))
       except Exception:
           return None
   
   
   def main() -> int:
       ap = argparse.ArgumentParser(description="Run ONNX model across runtimes 
and print outputs.")
       ap.add_argument("--model", type=Path, required=True, help="Path to ONNX 
model (must have no inputs).")
       ap.add_argument(
           "--no-torch",
           action="store_true",
           help="Skip torch reference printing (useful if torch not 
installed).",
       )
       args = ap.parse_args()
   
       model_path = args.model.resolve()
       if not model_path.exists():
           raise FileNotFoundError(model_path)
       model_bytes = model_path.read_bytes()
   
   
       y_torch = None if args.no_torch else _run_torch_reference()
       y_ort = _run_ort(model_bytes)
       y_tvm = _run_tvm(model_path)
       y_ov = _run_openvino(model_path)
   
       print("torch :", y_torch)
       print("tvm:",  y_tvm)
       print("ov :",  y_ov)
       print("ort :", y_ort)
       return 0
   
   
   if __name__ == "__main__":
       raise SystemExit(main())
   ```
   
   
   
   ### Triage
   
   Please refer to the list of label tags 
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the 
relevant tags and add them below in a bullet format (example below).
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to