jcf94 commented on a change in pull request #8402:
URL: https://github.com/apache/tvm/pull/8402#discussion_r665284491



##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -94,32 +92,35 @@ def _schedule(cfg, s, C):
         cfg.define_knob("vec", [1, 2, 4, 8])
 
         # Ensure that the default parameters are applicable when autotvm is 
not in use
-        if m_dim % 32 == 0 and n_dim % 8 == 0:
-            cfg.define_knob("wmma_m", [32, 16, 8])
-        elif m_dim % 16 == 0 and n_dim % 16 == 0:
-            cfg.define_knob("wmma_m", [16, 8, 32])
-        elif m_dim % 8 == 0 and n_dim % 32 == 0:
-            cfg.define_knob("wmma_m", [8, 16, 32])
+        if data_dtype in ["float16", "uint8", "int8"]:
+            if m_dim % 32 == 0 and n_dim % 8 == 0:
+                cfg.define_knob("wmma_m", [32, 16, 8])
+            elif m_dim % 16 == 0 and n_dim % 16 == 0:
+                cfg.define_knob("wmma_m", [16, 8, 32])
+            elif m_dim % 8 == 0 and n_dim % 32 == 0:
+                cfg.define_knob("wmma_m", [8, 16, 32])
+            wmma_k = 16
+            wmma_m = cfg["wmma_m"].val
+            if wmma_m == 16:
+                wmma_n = 16
+            elif wmma_m == 8:
+                wmma_n = 32
+            elif wmma_m == 32:
+                wmma_n = 8
+        else:

Review comment:
       Does the else branch for int4?
   Even through we can assume the op strategy has already done the type check, 
still suggest better to specify it clearly in the code, then add an extra else 
branch to raise warnning.
   The type asserts you added in other files are really good.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to