jcf94 commented on a change in pull request #8402: URL: https://github.com/apache/tvm/pull/8402#discussion_r665284491
########## File path: python/tvm/topi/cuda/batch_matmul_tensorcore.py ########## @@ -94,32 +92,35 @@ def _schedule(cfg, s, C): cfg.define_knob("vec", [1, 2, 4, 8]) # Ensure that the default parameters are applicable when autotvm is not in use - if m_dim % 32 == 0 and n_dim % 8 == 0: - cfg.define_knob("wmma_m", [32, 16, 8]) - elif m_dim % 16 == 0 and n_dim % 16 == 0: - cfg.define_knob("wmma_m", [16, 8, 32]) - elif m_dim % 8 == 0 and n_dim % 32 == 0: - cfg.define_knob("wmma_m", [8, 16, 32]) + if data_dtype in ["float16", "uint8", "int8"]: + if m_dim % 32 == 0 and n_dim % 8 == 0: + cfg.define_knob("wmma_m", [32, 16, 8]) + elif m_dim % 16 == 0 and n_dim % 16 == 0: + cfg.define_knob("wmma_m", [16, 8, 32]) + elif m_dim % 8 == 0 and n_dim % 32 == 0: + cfg.define_knob("wmma_m", [8, 16, 32]) + wmma_k = 16 + wmma_m = cfg["wmma_m"].val + if wmma_m == 16: + wmma_n = 16 + elif wmma_m == 8: + wmma_n = 32 + elif wmma_m == 32: + wmma_n = 8 + else: Review comment: Does the else branch for int4? Even through we can assume the op strategy has already done the type check, still suggest better to specify it clearly in the code, then add an extra else branch to raise warnning. The type asserts you added in other files are really good. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org