wyc-ruiker commented on a change in pull request #8402:
URL: https://github.com/apache/tvm/pull/8402#discussion_r665319855



##########
File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py
##########
@@ -94,32 +92,35 @@ def _schedule(cfg, s, C):
         cfg.define_knob("vec", [1, 2, 4, 8])
 
         # Ensure that the default parameters are applicable when autotvm is 
not in use
-        if m_dim % 32 == 0 and n_dim % 8 == 0:
-            cfg.define_knob("wmma_m", [32, 16, 8])
-        elif m_dim % 16 == 0 and n_dim % 16 == 0:
-            cfg.define_knob("wmma_m", [16, 8, 32])
-        elif m_dim % 8 == 0 and n_dim % 32 == 0:
-            cfg.define_knob("wmma_m", [8, 16, 32])
+        if data_dtype in ["float16", "uint8", "int8"]:
+            if m_dim % 32 == 0 and n_dim % 8 == 0:
+                cfg.define_knob("wmma_m", [32, 16, 8])
+            elif m_dim % 16 == 0 and n_dim % 16 == 0:
+                cfg.define_knob("wmma_m", [16, 8, 32])
+            elif m_dim % 8 == 0 and n_dim % 32 == 0:
+                cfg.define_knob("wmma_m", [8, 16, 32])
+            wmma_k = 16
+            wmma_m = cfg["wmma_m"].val
+            if wmma_m == 16:
+                wmma_n = 16
+            elif wmma_m == 8:
+                wmma_n = 32
+            elif wmma_m == 32:
+                wmma_n = 8
+        else:

Review comment:
       done.

##########
File path: python/tvm/topi/cuda/dense_tensorcore.py
##########
@@ -127,33 +131,36 @@ def _schedule_dense_tensorcore(cfg, s, C):
     cfg.define_knob("offsetCS", [0, 8])
     cfg.define_knob("vec", [1, 2, 4, 8])
 
-    # Ensure that the default parameters are applicable when autotvm is not in 
use
-    if batch % 32 == 0 and out_dim % 8 == 0:
-        cfg.define_knob("wmma_m", [32, 16, 8])
-    elif batch % 16 == 0 and out_dim % 16 == 0:
-        cfg.define_knob("wmma_m", [16, 8, 32])
-    elif batch % 8 == 0 and out_dim % 32 == 0:
-        cfg.define_knob("wmma_m", [8, 16, 32])
+    if data_dtype in ["float16", "int8", "uint8"]:
+        # Ensure that the default parameters are applicable when autotvm is 
not in use
+        if batch % 32 == 0 and out_dim % 8 == 0:
+            cfg.define_knob("wmma_m", [32, 16, 8])
+        elif batch % 16 == 0 and out_dim % 16 == 0:
+            cfg.define_knob("wmma_m", [16, 8, 32])
+        elif batch % 8 == 0 and out_dim % 32 == 0:
+            cfg.define_knob("wmma_m", [8, 16, 32])
+        wmma_k = 16
+        wmma_m = cfg["wmma_m"].val
+        if wmma_m == 16:
+            wmma_n = 16
+        elif wmma_m == 8:
+            wmma_n = 32
+        elif wmma_m == 32:
+            wmma_n = 8
+    else:

Review comment:
       done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to