ibsidorenko commented on code in PR #16895:
URL: https://github.com/apache/tvm/pull/16895#discussion_r1568429222


##########
python/tvm/relax/backend/contrib/cublas.py:
##########
@@ -68,11 +69,30 @@ def _check_matmul(context: PatternCheckContext) -> bool:
             # Rows number must be multiples of 4 for IGEMM
             return False
     elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
-        # Matrix dimensions must be multiples of 16. This requirement is 
missing from the cuBLAS
-        # docs, but it was observed during testing.
-        if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or 
rhs_shape[-1] % 16 != 0:
+        matmul_rhs_var = matmul_call.args[1]
+        rhs_transposed = False
+        if matmul_rhs_var in context.matched_bindings:
+            matmul_rhs_call = context.matched_bindings[matmul_rhs_var]
+            assert (
+                isinstance(matmul_rhs_call, tvm.relax.Call)
+                and matmul_rhs_call.op.name == "relax.permute_dims"
+            )

Review Comment:
   I am Ok, thank you! Just a nit question: do we need here assert for the case 
when rhs_call is something but not `permute_dims`? Just to leave rhs_transposed 
== False and return False in the next IF (without crash):
   ```
           if not rhs_transposed:
               # cuBLAS FP8 operations require rhs being transposed
               return False
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to