cblmemo opened a new pull request, #14413:
URL: https://github.com/apache/tvm/pull/14413

   This PR provides a workaround for a bug when the schedule primitive 
`reverse_compute_inline` creates a predicate. In the past, the block structure 
is like this:
   
   ```
   For v0 in ... {
     For v1 in ... {
       BlockRealize {
         Predicate; // <-- `reverse_compute_inline` will generate a predicate 
here
         LoopBody;
       }
     }
   }
   ```
   
   In which the predicate is related to the loop vars, like `v0` and `v1`.
   
   However, the block generated by memhammer is like this:
   
   ```
   BlockRealize {
     Predicate; // <-- `reverse_compute_inline` will generate a predicate here
     For v0 in ... {
       For v1 in ... {
         LoopBody;
       }
     }
   }
   ```
   
   This will cause a use-before-def for all loop vars and will generate an 
error tir like this:
   
   ```python
   from tvm.script import ir as I
   from tvm.script import tir as T
   
   @I.ir_module
   class Module:
       @T.prim_func
       def main(a: T.Buffer((128,), "float32"), b: T.Buffer((128,), "float32")):
           T.func_attr({"global_symbol": "main"})
           # with T.block("root"):
           b_reindex_shared_dyn = T.alloc_buffer((32, 4), scope="shared.dyn")
           for i in T.thread_binding(32, thread="threadIdx.x"):
               for j in range(4):
                   with T.block("B"):
                       v0, v1 = T.axis.remap("SS", [i, j])
                       T.reads(a[v0 * 4 + v1])
                       T.writes(b_reindex_shared_dyn[v0, v1])
                       b_reindex_shared_dyn[v0, v1] = a[v0 * 4 + v1] * 
T.float32(2)
                   ax0 = T.int32()
                   ax1 = T.int32()
                   with T.block("b_reindex_shared.dyn"):
                       v0 = T.axis.spatial(4, j)
                       T.where(0 <= ax0 and ax0 < 32 and 0 <= j + ax1 and j + 
ax1 < 4) # ax0, ax1 used here
                       T.reads(b_reindex_shared_dyn[0:32, v0])
                       T.writes(b[v0:v0 + 125])
                       T.block_attr({"auto_copy": 1})
                       for ax0, ax1 in T.grid(32, 1): # ax0, ax1 defined here
                           b[ax0 * 4 + (v0 + ax1)] = b_reindex_shared_dyn[ax0, 
v0 + ax1]
   ```
   
   For now, this PR provides a workaround for some special predicates: const 
true. By checking whether it could be proved, we could eliminate this error if 
predicates are const true, which is a common case. We still need to discuss how 
to fix this bug.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to