LeiWang1999 opened a new pull request, #14449:
URL: https://github.com/apache/tvm/pull/14449

   In current cuda codegen, the `PrintMMAAssembly`  instantiates the final asm 
code by replacing the A, B... in the template. However, this matching method is 
incorrect in some cases, as replacing a single letter can lead to confusion in 
the results.
   
   
https://github.com/apache/tvm/blob/49e6695586d07c33c84097d2b0f58c79c2abd51e/src/target/source/ptx.cc#L564-L571
   
   for example, the case I have meet:
   
   ```
     {
       __asm__ __volatile__(
         "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16"
         "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n"
         :  "=r"(((unsigned *)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[0]), 
"=r"(((unsigned *)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[1])
         : "r"(((unsigned *)(AC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + 
(ii_2 * 8)))[0]), "r"(((unsigned *)(AC_warp + ((ii_2 * 16) + (jj_2 * 
8))_shared_warp + (ii_2 * 8)))[1]), "r"(((unsigned *)(AC_warp + ((ii_2 * 16) + 
(jj_2 * 8))_shared_warp + (ii_2 * 8)))[2]), "r"(((unsigned *)(AC_warp + ((ii_2 
* 16) + (jj_2 * 8))_shared_warp + (ii_2 * 8)))[3]), "r"(((unsigned *)(BC_warp + 
((ii_2 * 16) + (jj_2 * 8))_shared_warp + (jj_2 * 8)))[0]), "r"(((unsigned 
*)(BC_warp + ((ii_2 * 16) + (jj_2 * 8))_shared_warp + (jj_2 * 8)))[1]), 
"r"(((unsigned *)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[0]), "r"(((unsigned 
*)(C_warp + ((ii_2 * 16) + (jj_2 * 8))))[1]));
     }
   ```
   
   the source ptx and offset actually should be 
   
   ```
             {
               __asm__ __volatile__(
                   "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16"
                   "{%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n"
                   : "=r"(((unsigned *)(C_warp + ((ii_2 * 64) + (jj_2 * 
8))))[0]), "=r"(((unsigned *)(C_warp + ((ii_2 * 64) + (jj_2 * 8))))[1])
                   : "r"(((unsigned *)(A_shared_warp + (ii_2 * 8)))[0]), 
"r"(((unsigned *)(A_shared_warp + (ii_2 * 8)))[1]), "r"(((unsigned 
*)(A_shared_warp + (ii_2 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (ii_2 * 
8)))[3]), "r"(((unsigned *)(B_shared_warp + (jj_2 * 8)))[0]), "r"(((unsigned 
*)(B_shared_warp + (jj_2 * 8)))[1]), "r"(((unsigned *)(C_warp + ((ii_2 * 64) + 
(jj_2 * 8))))[0]), "r"(((unsigned *)(C_warp + ((ii_2 * 64) + (jj_2 * 8))))[1]));
             }
   ``` 
   
   By replacing the pattern "A" with "{A}", "B" with "{B}" .. , we can have a 
simple fix.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to