This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 40af75b61f [Fix][TIR] UnifyThreadBinding creating unit loop with annotation (#14588) 40af75b61f is described below commit 40af75b61ff7111b479e447714db63225609cbb5 Author: Ruihang Lai <ruiha...@cs.cmu.edu> AuthorDate: Tue Apr 11 20:18:35 2023 -0400 [Fix][TIR] UnifyThreadBinding creating unit loop with annotation (#14588) This PR fixes a behavior of the UnifyThreadBinding pass which (at one place) assumes a return value is always a ForNode, which is not right. To be more specific, when a thread-binding loop has an annotation, the current behavior is assuming that the post-recursive-mutation value is also a ForNode, and apply the previous annotation directly to the new loop. However, the post-recursive-mutation value is also possibly not a ForNode. In this case, the current behavior is incorrect. This PR creates a new unit-length loop in this case to preserve the annotation. Thanks Bohan for catching this issue. Co-authored-by: Bohan Hou <spectromet...@gmail.com> --- src/tir/transforms/unify_thread_binding.cc | 17 ++++++++++++--- .../test_tir_transform_unify_thread_binding.py | 25 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/unify_thread_binding.cc b/src/tir/transforms/unify_thread_binding.cc index da725f7f8e..09b0970dd3 100644 --- a/src/tir/transforms/unify_thread_binding.cc +++ b/src/tir/transforms/unify_thread_binding.cc @@ -64,9 +64,20 @@ class ThreadBindingUnifier : public StmtExprMutator { if (annotations.empty()) { return stmt; } - For new_loop = Downcast<For>(stmt); - new_loop.CopyOnWrite()->annotations = std::move(annotations); - return std::move(new_loop); + if (const auto* loop = stmt.as<ForNode>()) { + For new_loop = GetRef<For>(loop); + new_loop.CopyOnWrite()->annotations = std::move(annotations); + return std::move(new_loop); + } else { + // Create a new unit loop with the annotation. + DataType dtype = op->loop_var->dtype; + return For(/*loop_var=*/Var("var", dtype), // + /*min=*/IntImm(dtype, 0), // + /*extent=*/IntImm(dtype, 1), // + /*kind=*/ForKind::kSerial, stmt, // + /*thread_binding=*/NullOpt, // + /*annotation=*/std::move(annotations)); + } } template <typename Node> diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index e489298741..c49ea5e60f 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -286,6 +286,31 @@ def test_implicit_block(): _check(element_wise_implicit_block, unified_element_wise_implicit_block) +def test_inner_binding_with_annotation(): + @T.prim_func + def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")): + for bx in T.thread_binding(32, "blockIdx.x"): + for tx in T.thread_binding(2, "threadIdx.x", annotations={"my_annotation": 1}): + with T.block("block"): + v = T.axis.spatial(64, bx * 2 + tx) + B[v] = A[v] + + @T.prim_func + def unified_inner_binding_with_annotation( + A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32") + ): + for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"): + for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"): + for var in T.serial(1, annotations={"my_annotation": 1}): + with T.block("block"): + v = T.axis.spatial(64, blockIdx_x * 2 + threadIdx_x) + T.reads(A[v]) + T.writes(B[v]) + B[v] = A[v] + + _check(inner_binding_with_annotation, unified_inner_binding_with_annotation) + + def test_lower_te(): a = te.placeholder((32, 2, 2)) b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0)