This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 40af75b61f [Fix][TIR] UnifyThreadBinding creating unit loop with 
annotation (#14588)
40af75b61f is described below

commit 40af75b61ff7111b479e447714db63225609cbb5
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Tue Apr 11 20:18:35 2023 -0400

    [Fix][TIR] UnifyThreadBinding creating unit loop with annotation (#14588)
    
    This PR fixes a behavior of the UnifyThreadBinding pass which (at one
    place) assumes a return value is always a ForNode, which is not right.
    
    To be more specific, when a thread-binding loop has an annotation,
    the current behavior is assuming that the post-recursive-mutation value
    is also a ForNode, and apply the previous annotation directly to the new
    loop. However, the post-recursive-mutation value is also possibly not a
    ForNode. In this case, the current behavior is incorrect.
    
    This PR creates a new unit-length loop in this case to preserve the
    annotation.
    
    Thanks Bohan for catching this issue.
    
    Co-authored-by: Bohan Hou <spectromet...@gmail.com>
---
 src/tir/transforms/unify_thread_binding.cc         | 17 ++++++++++++---
 .../test_tir_transform_unify_thread_binding.py     | 25 ++++++++++++++++++++++
 2 files changed, 39 insertions(+), 3 deletions(-)

diff --git a/src/tir/transforms/unify_thread_binding.cc 
b/src/tir/transforms/unify_thread_binding.cc
index da725f7f8e..09b0970dd3 100644
--- a/src/tir/transforms/unify_thread_binding.cc
+++ b/src/tir/transforms/unify_thread_binding.cc
@@ -64,9 +64,20 @@ class ThreadBindingUnifier : public StmtExprMutator {
     if (annotations.empty()) {
       return stmt;
     }
-    For new_loop = Downcast<For>(stmt);
-    new_loop.CopyOnWrite()->annotations = std::move(annotations);
-    return std::move(new_loop);
+    if (const auto* loop = stmt.as<ForNode>()) {
+      For new_loop = GetRef<For>(loop);
+      new_loop.CopyOnWrite()->annotations = std::move(annotations);
+      return std::move(new_loop);
+    } else {
+      // Create a new unit loop with the annotation.
+      DataType dtype = op->loop_var->dtype;
+      return For(/*loop_var=*/Var("var", dtype),   //
+                 /*min=*/IntImm(dtype, 0),         //
+                 /*extent=*/IntImm(dtype, 1),      //
+                 /*kind=*/ForKind::kSerial, stmt,  //
+                 /*thread_binding=*/NullOpt,       //
+                 /*annotation=*/std::move(annotations));
+    }
   }
 
   template <typename Node>
diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py 
b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
index e489298741..c49ea5e60f 100644
--- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py
+++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py
@@ -286,6 +286,31 @@ def test_implicit_block():
     _check(element_wise_implicit_block, unified_element_wise_implicit_block)
 
 
+def test_inner_binding_with_annotation():
+    @T.prim_func
+    def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: 
T.Buffer((64,), "float32")):
+        for bx in T.thread_binding(32, "blockIdx.x"):
+            for tx in T.thread_binding(2, "threadIdx.x", 
annotations={"my_annotation": 1}):
+                with T.block("block"):
+                    v = T.axis.spatial(64, bx * 2 + tx)
+                    B[v] = A[v]
+
+    @T.prim_func
+    def unified_inner_binding_with_annotation(
+        A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")
+    ):
+        for blockIdx_x in T.thread_binding(32, thread="blockIdx.x"):
+            for threadIdx_x in T.thread_binding(2, thread="threadIdx.x"):
+                for var in T.serial(1, annotations={"my_annotation": 1}):
+                    with T.block("block"):
+                        v = T.axis.spatial(64, blockIdx_x * 2 + threadIdx_x)
+                        T.reads(A[v])
+                        T.writes(B[v])
+                        B[v] = A[v]
+
+    _check(inner_binding_with_annotation, 
unified_inner_binding_with_annotation)
+
+
 def test_lower_te():
     a = te.placeholder((32, 2, 2))
     b = te.compute((32, 2, 2), lambda i, j, k: a[i, j, k] * 2.0)

Reply via email to