wrongtest commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1152928725

   Thanks for the all great discussions! It is so excited that we will have a 
more powerful ability to handle all things like paddings and imperfect tiles.
   
   Since our team rely on the code path of s-tir, we are extremely interested 
in the story on s-tir. I would be very appreciated if we have some details on 
s-tir padding. I would like to use a [127, 127, 127] matmul to depict my 
questions :)
   
   ```python
   @T.prim_func
   def matmul(A: T.Buffer[(127, 127), "float32"], B: T.Buffer[(127, 127), 
"float32"], C: T.Buffer[(127, 127), "float32"]):
       for i, j, k in T.grid(127, 127, 127):
           with T.block("compute"):
               vi, vj, vk = T.axis.remap("SSR", [i, j, k])
               with T.init():
                   C[vi, vj] = 0.0
               C[vi, vj] += A[vi, vk] * B[vk, vj]
   ```
   
   In current s-tir state, we can construct padded loop and buffer using 
existing primitives by "split and then fuse" trick:
   ```python
   s = tvm.tir.Schedule(matmul)
   blk = s.get_block("compute")
   i, j, k = s.get_loops(blk)
   s.fuse(*s.split(i, factors=[4, 32]))
   s.fuse(*s.split(j, factors=[4, 32]))
   s.fuse(*s.split(k, factors=[4, 32]))
   s.transform_layout(blk, "A", lambda i,k: ((i // 32) * 32 + i % 32, (k // 32) 
* 32 + k % 32))
   s.transform_layout(blk, "B", lambda k,j: ((k // 32) * 32 + k % 32, (j // 32) 
* 32 + j % 32))
   s.transform_layout(blk, "C", lambda i,j: ((i // 32) * 32 + i % 32, (j // 32) 
* 32 + j % 32))
   ```
   We will get (if simplified)
   ```python
   @T.prim_func
   def func(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), 
"float32"], C: T.Buffer[(128, 128), "float32"]):
       for i_0_i_1_fused, j_0_j_1_fused, k_0_k_1_fused in T.grid(128, 128, 128):
           with T.block("compute"):
               vi = T.axis.spatial(127, i_0_i_1_fused)
               vj = T.axis.spatial(127, j_0_j_1_fused)
               vk = T.axis.reduce(127, k_0_k_1_fused)
               T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and 
k_0_k_1_fused < 127)
               T.reads(A[vi, vk], B[vk, vj])
               T.writes(C[vi, vj])
               with T.init():
                   C[vi, vj] = T.float32(0)
               C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   ```
   Then the only thing left is the condition for padding: 
`T.where(i_0_i_1_fused < 127 and j_0_j_1_fused < 127 and k_0_k_1_fused < 127)`. 
I believe we now get to the point on current RFC about over-computation and 
branch tradeoff. And below are some my questions ~
   
   1. What happened when change to `s.transform_layout(...,  pad_value=0)`? (if 
we want over-computations)
      - (possible behavior 1) Insert padding filling code as a producer block 
of `compute`.  
        - since the effect is immediate, maybe we do not need 
`BufferConstraint` annotations afterwards?
      - (possible behavior 2) Annotate buffers and let lowering passes to 
handle.
        - we may require `BufferConstraint` to direct lowering passes, 
      - (possible behavior 3) Pass `BufferConstraint` upwards into graph level
        -  thus assume the param buffer match the constraint, do not write edge 
values.
      
   2.  For (1.2)(1.3), it seems encode the `BufferConstraint` into the buffer 
object is not the only choice.
       - For s-tir,  fix me, at least for common cases the constraint could be 
treat to be local wrt the transformed block. What if we encode the constraint 
just into the block, as its memory access properties.
         We found previously, block memory annotations `T.reads`, `T.writes` 
(`BufferRegion`) have some limitations that they loss conditional access 
informations. Maybe we can also combine `BufferConstraint` with `BufferRegion`?
   
       - For graph level annotations, IIUC,  it uses "Tensor" typed value 
instead of "Buffer" conceptually. Maybe we still need another construction 
instead of `Buffer` with `BufferConstraint` field? 
         We could also consider instantiate graph level transformation 
explicitly. This is our solution currently: 
https://discuss.tvm.apache.org/t/introducing-ty-nnp-backend-with-end2end-tensorir-integration/11807/4.
 
   
       - Nevertheless, if finally we decide extent the buffer node structure, 
hope we can have an explicit lifetime for the `BufferConstraint` in the TIR 
lowering. Thus storage related passes afterwards do not bother, especially for 
customized passes developed by vendors.
   
   3. For the reduce axis padding, mentioned in 
https://github.com/apache/tvm-rfcs/pull/77#discussion_r894899301
       - In TIR level, since the schedule primitive should preserve the 
semantic correctness, how we prove the `k` dimension padding should only be 
zero? Especially when we do not know it is a "matmul" op generally. I think it 
is important if we want to use padded `transform_layout` in auto-schedule 
fashion applications.
   
   cc @Lunderberg @tqchen @vinx13 @Hzfengsy 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to