This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 532db3392a [TIR] Fix host/device function check for build (#18199)
532db3392a is described below

commit 532db3392ae3ab954257208dff29015e9ac1e7fd
Author: Ruihang Lai <[email protected]>
AuthorDate: Fri Aug 8 14:54:06 2025 -0400

    [TIR] Fix host/device function check for build (#18199)
    
    This PR fixes a bug of deciding whether a function is host
    or device function in TIR build.
    
    Previously the decision is made based on checking whether `"cpu"`
    is a substring of the target string. This check fails to work
    for ROCm target, which usually comes with an `"mcpu"` attribute
    that also contains `"cpu"`.
    
    This PR fixes by checking target kind. Targets with kind `"llvm"`
    or `"c"` will be treated as host functions.
---
 python/tvm/tir/build.py                          | 16 +++++++++-------
 tests/python/codegen/test_target_codegen_cuda.py | 10 ++++++----
 2 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
index 3eb3648533..98e549cc9c 100644
--- a/python/tvm/tir/build.py
+++ b/python/tvm/tir/build.py
@@ -17,14 +17,14 @@
 
 # pylint: disable=invalid-name
 """The build utils in python."""
-from typing import Union, Optional, Dict, Tuple
+from typing import Dict, Optional, Tuple, Union
 
 import tvm
 from tvm import ir
-from tvm.runtime import ndarray
-from tvm.tir import PrimFunc
 from tvm.ir.module import IRModule
+from tvm.runtime import ndarray
 from tvm.target import Target
+from tvm.tir import PrimFunc
 
 
 def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, 
IRModule]]:
@@ -100,10 +100,12 @@ def split_host_device_mods(mod: IRModule) -> 
Tuple[IRModule, Dict[Target, IRModu
         - Device kernel functions: use `calling_conv: 2` (kDeviceKernelLaunch)
     """
 
-    host_mod = tvm.tir.transform.Filter(lambda f: "cpu" in 
str(f.attrs.get("target", "cpu")))(mod)
-    device_mod = tvm.tir.transform.Filter(lambda f: "cpu" not in 
str(f.attrs.get("target", "cpu")))(
-        mod
-    )
+    def is_host_func(f):
+        target = f.attrs.get("target", tvm.target.Target("llvm"))
+        return str(target.kind) in ["llvm", "c"]
+
+    host_mod = tvm.tir.transform.Filter(is_host_func)(mod)
+    device_mod = tvm.tir.transform.Filter(lambda f: not is_host_func(f))(mod)
     # TODO(syfeng): Here we use str as key since target hash is not correct
     target_str2target = {}
     device_func_dict = {}
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index 28dfb6b9d4..a304cb1e41 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -820,10 +820,12 @@ def test_device_host_call_same_func():
                 for tx in T.thread_binding(length, "threadIdx.x"):
                     C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Call from 
device
 
-    # If we set host to llvm, it will raise an error of
-    # "the tir.ret should be transformed to return zero before the llvm code 
generation."
-    # Need to revisit this.
-    target = tvm.target.Target("cuda", host="c")
+    # 1. If we set host to llvm, it will raise an error of
+    #    "the tir.ret should be transformed to return zero before the llvm 
code generation."
+    #    Need to revisit this.
+    # 2. We set a dummy mcpu value for testing purpose,
+    #    in order to avoid checking a function is host or device based on the 
"cpu" substring.
+    target = tvm.target.Target({"kind": "cuda", "mcpu": "dummy_mcpu"}, 
host="c")
     lib = tvm.compile(Module, target=target)
     cuda_code = lib.mod.imported_modules[0].get_source()
     assert 'extern "C" __device__ int add(int a, int b) {\n  return (a + 
b);\n}' in cuda_code

Reply via email to