This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new c9d87ef54f [Relax][Bugfix] Annotate ComputePrimValue output as host function (#17032) c9d87ef54f is described below commit c9d87ef54fbba29b16a0a8420fb61c669808a256 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Tue May 28 19:49:20 2024 -0500 [Relax][Bugfix] Annotate ComputePrimValue output as host function (#17032) The `ComputePrimValue` transform is used to compute the value of symbolic expressions that may appear within a Relax function. For example, to compute a boolean condition used for a `relax::If` node. These functions are used for small host-side computations, prior to launching a device kernel. This commit updates `ComputePrimValue` to annotate the generated `PrimFunc` with `tir::attr::kIsHostFunc`. This annotation is required for correct behavior in `tvm.dlight.ApplyDefaultSchedule`, to avoid erroneous scheduling of this function for the GPU, and for `tir::transform::BindTarget`, to ensure that the function is compiled for execution on the host. Co-authored-by: Chris Sullivan <csulli...@octo.ai> --- src/relax/transform/compute_prim_value.cc | 3 ++- tests/python/relax/test_transform_compute_prim_value.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 9fe2a3a06f..716550ba04 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -45,7 +45,8 @@ class PrimValueComputeInjector : public ExprMutator { auto param_vars = tir::UndefinedVars(node->value); tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), {node->value})); - tir::PrimFunc func(param_vars, body, PrimType(ret_dtype)); + tir::PrimFunc func(param_vars, body, PrimType(ret_dtype), {}, + DictAttrs({{tir::attr::kIsHostFunc, Bool(true)}})); func = tir::RenewDefs(func); auto callee = builder_->AddFunction(func, "compute_symbolic_expr"); diff --git a/tests/python/relax/test_transform_compute_prim_value.py b/tests/python/relax/test_transform_compute_prim_value.py index 9fee35414d..5d9caf2d36 100644 --- a/tests/python/relax/test_transform_compute_prim_value.py +++ b/tests/python/relax/test_transform_compute_prim_value.py @@ -44,6 +44,7 @@ class TestPrimValueInAssertCondition(BaseCompare): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64) -> T.bool: + T.func_attr({"tir.is_host_func": True}) T.ret(N % 16 == 0) @@ -73,6 +74,7 @@ class TestPrimValueInBranchCondition(BaseCompare): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64) -> T.bool: + T.func_attr({"tir.is_host_func": True}) T.ret(N % 16 == 0) @@ -97,6 +99,7 @@ class TestPrimValueInPureFunction(BaseCompare): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64: + T.func_attr({"tir.is_host_func": True}) T.ret(N * M)