This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 532db3392a [TIR] Fix host/device function check for build (#18199)
532db3392a is described below
commit 532db3392ae3ab954257208dff29015e9ac1e7fd
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Aug 8 14:54:06 2025 -0400
[TIR] Fix host/device function check for build (#18199)
This PR fixes a bug of deciding whether a function is host
or device function in TIR build.
Previously the decision is made based on checking whether `"cpu"`
is a substring of the target string. This check fails to work
for ROCm target, which usually comes with an `"mcpu"` attribute
that also contains `"cpu"`.
This PR fixes by checking target kind. Targets with kind `"llvm"`
or `"c"` will be treated as host functions.
---
python/tvm/tir/build.py | 16 +++++++++-------
tests/python/codegen/test_target_codegen_cuda.py | 10 ++++++----
2 files changed, 15 insertions(+), 11 deletions(-)
diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
index 3eb3648533..98e549cc9c 100644
--- a/python/tvm/tir/build.py
+++ b/python/tvm/tir/build.py
@@ -17,14 +17,14 @@
# pylint: disable=invalid-name
"""The build utils in python."""
-from typing import Union, Optional, Dict, Tuple
+from typing import Dict, Optional, Tuple, Union
import tvm
from tvm import ir
-from tvm.runtime import ndarray
-from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
+from tvm.runtime import ndarray
from tvm.target import Target
+from tvm.tir import PrimFunc
def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target,
IRModule]]:
@@ -100,10 +100,12 @@ def split_host_device_mods(mod: IRModule) ->
Tuple[IRModule, Dict[Target, IRModu
- Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch)
"""
- host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in
str(f.attrs.get("target", "cpu")))(mod)
- device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in
str(f.attrs.get("target", "cpu")))(
- mod
- )
+ def is_host_func(f):
+ target = f.attrs.get("target", tvm.target.Target("llvm"))
+ return str(target.kind) in ["llvm", "c"]
+
+ host_mod = tvm.tir.transform.Filter(is_host_func)(mod)
+ device_mod = tvm.tir.transform.Filter(lambda f: not is_host_func(f))(mod)
# TODO(syfeng): Here we use str as key since target hash is not correct
target_str2target = {}
device_func_dict = {}
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index 28dfb6b9d4..a304cb1e41 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -820,10 +820,12 @@ def test_device_host_call_same_func():
for tx in T.thread_binding(length, "threadIdx.x"):
C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from
device
- # If we set host to llvm, it will raise an error of
- # "the tir.ret should be transformed to return zero before the llvm code
generation."
- # Need to revisit this.
- target = tvm.target.Target("cuda", host="c")
+ # 1. If we set host to llvm, it will raise an error of
+ # "the tir.ret should be transformed to return zero before the llvm
code generation."
+ # Need to revisit this.
+ # 2. We set a dummy mcpu value for testing purpose,
+ # in order to avoid checking a function is host or device based on the
"cpu" substring.
+ target = tvm.target.Target({"kind": "cuda", "mcpu": "dummy_mcpu"},
host="c")
lib = tvm.compile(Module, target=target)
cuda_code = lib.mod.imported_modules[0].get_source()
assert 'extern "C" __device__ int add(int a, int b) {\n return (a +
b);\n}' in cuda_code