This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c9d87ef54f [Relax][Bugfix] Annotate ComputePrimValue output as host 
function (#17032)
c9d87ef54f is described below

commit c9d87ef54fbba29b16a0a8420fb61c669808a256
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Tue May 28 19:49:20 2024 -0500

    [Relax][Bugfix] Annotate ComputePrimValue output as host function (#17032)
    
    The `ComputePrimValue` transform is used to compute the value of
    symbolic expressions that may appear within a Relax function.  For
    example, to compute a boolean condition used for a `relax::If` node.
    These functions are used for small host-side computations, prior to
    launching a device kernel.
    
    This commit updates `ComputePrimValue` to annotate the generated
    `PrimFunc` with `tir::attr::kIsHostFunc`.  This annotation is required
    for correct behavior in `tvm.dlight.ApplyDefaultSchedule`, to avoid
    erroneous scheduling of this function for the GPU, and for
    `tir::transform::BindTarget`, to ensure that the function is compiled
    for execution on the host.
    
    Co-authored-by: Chris Sullivan <csulli...@octo.ai>
---
 src/relax/transform/compute_prim_value.cc               | 3 ++-
 tests/python/relax/test_transform_compute_prim_value.py | 3 +++
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/compute_prim_value.cc 
b/src/relax/transform/compute_prim_value.cc
index 9fe2a3a06f..716550ba04 100644
--- a/src/relax/transform/compute_prim_value.cc
+++ b/src/relax/transform/compute_prim_value.cc
@@ -45,7 +45,8 @@ class PrimValueComputeInjector : public ExprMutator {
     auto param_vars = tir::UndefinedVars(node->value);
     tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), 
{node->value}));
 
-    tir::PrimFunc func(param_vars, body, PrimType(ret_dtype));
+    tir::PrimFunc func(param_vars, body, PrimType(ret_dtype), {},
+                       DictAttrs({{tir::attr::kIsHostFunc, Bool(true)}}));
     func = tir::RenewDefs(func);
 
     auto callee = builder_->AddFunction(func, "compute_symbolic_expr");
diff --git a/tests/python/relax/test_transform_compute_prim_value.py 
b/tests/python/relax/test_transform_compute_prim_value.py
index 9fee35414d..5d9caf2d36 100644
--- a/tests/python/relax/test_transform_compute_prim_value.py
+++ b/tests/python/relax/test_transform_compute_prim_value.py
@@ -44,6 +44,7 @@ class TestPrimValueInAssertCondition(BaseCompare):
 
         @T.prim_func(private=True)
         def compute_symbolic_expr(N: T.int64) -> T.bool:
+            T.func_attr({"tir.is_host_func": True})
             T.ret(N % 16 == 0)
 
 
@@ -73,6 +74,7 @@ class TestPrimValueInBranchCondition(BaseCompare):
 
         @T.prim_func(private=True)
         def compute_symbolic_expr(N: T.int64) -> T.bool:
+            T.func_attr({"tir.is_host_func": True})
             T.ret(N % 16 == 0)
 
 
@@ -97,6 +99,7 @@ class TestPrimValueInPureFunction(BaseCompare):
 
         @T.prim_func(private=True)
         def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64:
+            T.func_attr({"tir.is_host_func": True})
             T.ret(N * M)
 
 

Reply via email to