This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7db7a91402 [Relax][PyTorch] Fix PyTorch Dynamo frontend for Darwin
compatibility (#18619)
7db7a91402 is described below
commit 7db7a914021fb210cdb3045cc94ba70a484ec669
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Dec 28 22:13:16 2025 +0800
[Relax][PyTorch] Fix PyTorch Dynamo frontend for Darwin compatibility
(#18619)
## Why
The llvm_target() function reads `/proc/cpuinfo` which only exists on
`Linux`, causing tests to fail on `macOS` with FileNotFoundError.
## How
- Add cross-platform CPU feature detection in llvm_target() using
platform.system() and sysctl for macOS
- Update tests
---
python/tvm/relax/frontend/torch/dynamo.py | 44 ++++++++++++++++++++++++++++--
tests/python/relax/test_frontend_dynamo.py | 14 +++++-----
2 files changed, 48 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/dynamo.py
b/python/tvm/relax/frontend/torch/dynamo.py
index 8837d96835..8dc9e2a55a 100644
--- a/python/tvm/relax/frontend/torch/dynamo.py
+++ b/python/tvm/relax/frontend/torch/dynamo.py
@@ -19,6 +19,7 @@
# pylint: disable=import-outside-toplevel, unused-argument, use-list-literal
# mypy: ignore-errors
"""PyTorch Dynamo backend of Relax."""
+
import functools
from typing import Optional
@@ -202,6 +203,43 @@ def dynamo_capture_subgraphs(model, *params, **kwargs) ->
tvm.IRModule:
@functools.lru_cache(None)
def llvm_target():
- if "avx512" in open("/proc/cpuinfo").read():
- return "llvm -mcpu=skylake-avx512"
- return "llvm -mcpu=core-avx2"
+ import platform
+ import subprocess
+
+ AVX512_TARGET = "llvm -mcpu=skylake-avx512"
+ AVX2_TARGET = "llvm -mcpu=core-avx2"
+ DEFAULT_TARGET = "llvm"
+
+ system = platform.system()
+
+ if system == "Linux":
+ try:
+ with open("/proc/cpuinfo") as f:
+ cpuinfo = f.read()
+ if "avx512" in cpuinfo:
+ return AVX512_TARGET
+ return AVX2_TARGET
+ except FileNotFoundError:
+ pass
+ elif system == "Darwin":
+ try:
+ result = subprocess.run(
+ ["sysctl", "-n", "machdep.cpu.features"],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+ if result.returncode == 0:
+ cpu_features = result.stdout.lower()
+ if "avx512" in cpu_features:
+ return AVX512_TARGET
+ if "avx2" in cpu_features:
+ return AVX2_TARGET
+ except (FileNotFoundError, subprocess.SubprocessError):
+ pass
+
+ if platform.machine() == "arm64":
+ return DEFAULT_TARGET
+
+ # Default fallback
+ return DEFAULT_TARGET
diff --git a/tests/python/relax/test_frontend_dynamo.py
b/tests/python/relax/test_frontend_dynamo.py
index 70619714dd..b3eac1d427 100644
--- a/tests/python/relax/test_frontend_dynamo.py
+++ b/tests/python/relax/test_frontend_dynamo.py
@@ -206,10 +206,10 @@ def test_subgraph_capture():
# block 0
with R.dataflow():
lv: R.Tensor((10,), dtype="float32") = R.sin(inp_0)
- lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1,
"float32"))
+ lv1: R.Tensor((10,), dtype="float32") = R.add(lv, R.const(1.0,
"float32"))
lv2: R.Tensor((10,), dtype="float32") = R.divide(inp_0, lv1)
lv3: R.Tensor((), dtype="float32") = R.sum(inp_1, axis=None,
keepdims=False)
- lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1,
"float32"))
+ lv4: R.Tensor((), dtype="bool") = R.less(lv3, R.const(1.0,
"float32"))
gv: R.Tuple(R.Tensor((10,), dtype="float32"), R.Tensor((),
dtype="bool")) = (
lv2,
lv4,
@@ -219,14 +219,14 @@ def test_subgraph_capture():
@R.function
def subgraph_1(
- inp_01: R.Tensor((10,), dtype="float32"), inp_11: R.Tensor((10,),
dtype="float32")
+ inp_0: R.Tensor((10,), dtype="float32"), inp_1: R.Tensor((10,),
dtype="float32")
) -> R.Tensor((10,), dtype="float32"):
# block 0
with R.dataflow():
- lv5: R.Tensor((10,), dtype="float32") = R.multiply(inp_01,
inp_11)
- gv1: R.Tensor((10,), dtype="float32") = lv5
- R.output(gv1)
- return gv1
+ lv: R.Tensor((10,), dtype="float32") = R.multiply(inp_0, inp_1)
+ gv: R.Tensor((10,), dtype="float32") = lv
+ R.output(gv)
+ return gv
mod = dynamo_capture_subgraphs(Input2, torch.randn(10), torch.ones(10))
tvm.ir.assert_structural_equal(mod, Expected2)