dutZ1855 opened a new issue, #18590:
URL: https://github.com/apache/tvm/issues/18590
### Expected behavior
Per ONNX `Round` operator spec:
```
import onnx
from onnx import defs
print('onnx.__version__ =', onnx.__version__)
# Print available versions for Round
schemas = defs.get_all_schemas_with_history()
round_s = [s for s in schemas if s.name=='Round']
print('Round schema versions:', sorted({s.since_version for s in round_s}))
for v in sorted({s.since_version for s in round_s}):
s = defs.get_schema('Round', v)
print('\n=== Round since_version', v, '===' )
print('domain:', s.domain)
print('doc:')
print(s.doc)
```
> onnx.__version__ = 1.17.0
> Round schema versions: [11, 22]
>
> === Round since_version 11 ===
> domain:
> doc:
>
> Round takes one input Tensor and rounds the values, element-wise, meaning
> it finds the nearest integer for each value.
> In case of halves, the rule is to round them to the nearest even integer.
> If input x is integral, +0, -0, NaN, or infinite, x itself is returned.
> The output tensor has the same shape and type as the input.
>
> Examples:
> ```
> round([0.9]) = [1.0]
> round([2.5]) = [2.0]
> round([2.3]) = [2.0]
> round([1.5]) = [2.0]
> round([-4.5]) = [-4.0]
> ```
>
>
> === Round since_version 22 ===
> domain:
> doc:
>
> Round takes one input Tensor and rounds the values, element-wise, meaning
> it finds the nearest integer for each value.
> In case of halves, the rule is to round them to the nearest even integer.
> If input x is integral, +0, -0, NaN, or infinite, x itself is returned.
> The output tensor has the same shape and type as the input.
>
> Examples:
> ```
> round([0.9]) = [1.0]
> round([2.5]) = [2.0]
> round([2.3]) = [2.0]
> round([1.5]) = [2.0]
> round([-4.5]) = [-4.0]
> ```
Therefore, for this repro (where `sigmoid(0)=0.5`):
- `Round(0.5) == 0` (nearest-even / ties-to-even)
### Actual behavior
For the following model,
<img width="218" height="281" alt="Image"
src="https://github.com/user-attachments/assets/5c2b12b7-706d-4bdb-ad52-b7daef6a3b7d"
/>
With TVM (Relax, LLVM target) for this repro:
- `Round(0.5) == 1`
-
### Environment
Operating System:Ubuntu 22.04.4 LTS
TVM version:0.23.0dev
pytorch version:2.9.1
ort version:1.23.2
onnx version: 1.20.0
openvino: 2025.4.0
python:3.11.14
### Steps to reproduce
**build a model**
```
from __future__ import annotations
import argparse
from pathlib import Path
import onnx
from onnx import TensorProto, helper
def make_model(path="round_sigmoid.onnx"):
# x = 0.0 (double scalar)
const0 = helper.make_node(
"Constant",
inputs=[],
outputs=["x"],
value=helper.make_tensor("c0", TensorProto.DOUBLE, dims=[],
vals=[0.0]),
)
sig = helper.make_node("Sigmoid", inputs=["x"], outputs=["s"])
rnd = helper.make_node("Round", inputs=["s"], outputs=["y"])
y_info = helper.make_tensor_value_info("y", TensorProto.DOUBLE, [])
graph = helper.make_graph([const0, sig, rnd], "g", inputs=[],
outputs=[y_info])
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("",
18)])
onnx.save(model, path)
print("saved:", path)
def main() -> int:
ap = argparse.ArgumentParser(description="Minimal repro for
Round(sigmoid(0)) tie-breaking.")
ap.add_argument("--out", type=Path,
default=Path("round_sigmoid_half.onnx"), help="Where to save the ONNX model.")
args = ap.parse_args()
out_path = args.out.resolve()
out_path.parent.mkdir(parents=True, exist_ok=True)
make_model(out_path.as_posix())
return 0
if __name__ == "__main__":
raise SystemExit(main())
```
**Comparison results**
```
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any, Optional
import numpy as np
import onnx
def _fmt(x: Any) -> str:
if isinstance(x, np.ndarray):
return f"ndarray(shape={x.shape}, dtype={x.dtype}, value={x})"
return repr(x)
def _run_torch_reference() -> Optional[np.ndarray]:
"""Reference for this specific model: Round(sigmoid(0.0))."""
try:
import torch # type: ignore
except Exception as e:
print("[torch] not available:", e)
return None
x = torch.tensor(0.0, dtype=torch.float64)
y = torch.round(torch.sigmoid(x))
out = np.array(y.item(), dtype=np.float64)
print("[torch] torch.__version__ =", getattr(torch, "__version__", None))
print("[torch] y =", _fmt(out))
return out
def _run_ort(model_bytes: bytes) -> Optional[np.ndarray]:
try:
import onnxruntime as ort # type: ignore
except Exception as e:
print("[ort] not available:", e)
return None
sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level =
ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess = ort.InferenceSession(model_bytes, sess_options=sess_opts,
providers=["CPUExecutionProvider"])
outs = sess.run(None, {}) # no inputs
out = np.array(outs[0])
print("[ort] onnxruntime.__version__ =", getattr(ort, "__version__",
None))
print("[ort] y =", _fmt(out))
return out
def _run_openvino(model_path: Path) -> Optional[np.ndarray]:
try:
import openvino as ov # type: ignore
except Exception as e:
print("[ov] not available:", e)
return None
core = ov.Core()
model = core.read_model(model_path.as_posix())
compiled = core.compile_model(model, "CPU")
req = compiled.create_infer_request()
raw = req.infer({}) # no inputs
out_port = compiled.outputs[0]
out = np.array(raw[out_port])
print("[ov] openvino.__version__ =", getattr(ov, "__version__", None))
print("[ov] y =", _fmt(out))
return out
def _ensure_repo_tvm_python_on_syspath() -> None:
repo_root = Path(__file__).resolve().parents[3]
tvm_python = repo_root / "tvm" / "python"
if tvm_python.exists():
sys.path.insert(0, tvm_python.as_posix())
def _tvm_to_numpy(x: Any) -> np.ndarray:
if hasattr(x, "numpy"):
return x.numpy()
if isinstance(x, (int, float, bool, np.generic)):
return np.array(x)
return np.array(x)
def _run_tvm(model_path: Path) -> Optional[np.ndarray]:
_ensure_repo_tvm_python_on_syspath()
try:
import tvm # type: ignore
from tvm import relax # type: ignore
from tvm.relax.frontend.onnx import from_onnx # type: ignore
except Exception as e:
print("[tvm] not available:", e)
return None
onnx_model = onnx.load(model_path.as_posix())
mod = from_onnx(onnx_model, shape_dict={})
if isinstance(mod, (list, tuple)):
mod = mod[0]
tgt = tvm.target.Target("llvm")
pipeline = relax.pipeline.get_default_pipeline(tgt)
with tvm.transform.PassContext(opt_level=3):
ex = relax.build(mod, target=tgt, relax_pipeline=pipeline)
vm = relax.VirtualMachine(ex, tvm.cpu())
# no inputs
try:
out = vm["main"]()
except Exception:
vm.set_input("main")
vm.invoke_stateful("main")
out = vm.get_outputs("main")
out_np = _tvm_to_numpy(out)
print("[tvm] tvm.__file__ =", getattr(tvm, "__file__", None))
print("[tvm] tvm.__version__ =", getattr(tvm, "__version__", None))
print("[tvm] y =", _fmt(out_np))
return out_np
def _eq(a: Optional[np.ndarray], b: Optional[np.ndarray]) -> Optional[bool]:
if a is None or b is None:
return None
try:
return bool(np.array_equal(a, b))
except Exception:
return None
def main() -> int:
ap = argparse.ArgumentParser(description="Run ONNX model across runtimes
and print outputs.")
ap.add_argument("--model", type=Path, required=True, help="Path to ONNX
model (must have no inputs).")
ap.add_argument(
"--no-torch",
action="store_true",
help="Skip torch reference printing (useful if torch not
installed).",
)
args = ap.parse_args()
model_path = args.model.resolve()
if not model_path.exists():
raise FileNotFoundError(model_path)
model_bytes = model_path.read_bytes()
y_torch = None if args.no_torch else _run_torch_reference()
y_ort = _run_ort(model_bytes)
y_tvm = _run_tvm(model_path)
y_ov = _run_openvino(model_path)
print("torch :", y_torch)
print("tvm:", y_tvm)
print("ov :", y_ov)
print("ort :", y_ort)
return 0
if __name__ == "__main__":
raise SystemExit(main())
```
### Triage
Please refer to the list of label tags
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the
relevant tags and add them below in a bullet format (example below).
* needs-triage
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]