yzh119 opened a new pull request, #14448:
URL: https://github.com/apache/tvm/pull/14448

   # Motivation
   Currently the `reorder` primitive only changes the loops, and block iterable 
variables order would not be changed.
   `transform_block_layout` can change the block iterable variables, but it 
requires the loops outside the given block to have no branches, which limited 
its usage.
   
   This schedule primitive changes the block iterable variable order directly, 
with API like:
   ```python
   def reorder_block_iter_var(self, block: BlockRV, new_order: List[int]) -> 
None:
           """Reorder the itervars inside a given block.
           Parameters
           ----------
           block : BlockRV
               The block to be transformed. 
           new_order : List[int]
               The new block itervar order.
           """
   ```
   where the `new_order` is a permutation of [0, 1, ..., n-1] if n is the 
number of itervars in the block.
   
   # Example
   
   Suppose we need to change the block itervar order in block "C":
   ```python
   @T.prim_func
   def matmul(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), 
"float32"], C: T.Buffer[(128, 128), "float32"]) -> None:
       for i, j, k in T.grid(128, 128, 128):
           with T.block("C"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   C[vi, vj] = 0.0
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
   ```
   
   after applying:
   ```python
   sch = tir.Schedule(matmul, debug_mask="all")
   C = sch.get_block("C")
   sch.reorder_block_iter_var(C, [2, 1, 0])
   ```
   
   the block itervar order would be changed to `vk, vj, vi`.
   ```python
   @T.prim_func
   def matmul_after_reorder_block_iter_var(A: T.Buffer[(128, 128), "float32"], 
B: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]):
       for i, j, k in T.grid(128, 128, 128):
           with T.block("C"):
               vk, vj, vi = T.axis.remap("RSS", [k, j, i])
               T.reads(A[vi, vk], B[vj, vk])
               T.writes(C[vi, vj])
               with T.init():
                   C[vi, vj] = T.float32(0)
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
   ```
   
   cc @junrushao @vinx13 @Hzfengsy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to