This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0d2eab2a99 [Cublas] Added support for bfloat16 while dispatching to 
cublas kernels (#17796)
0d2eab2a99 is described below

commit 0d2eab2a9973d09bea6d00b97a0b06a63ad0f070
Author: Annanya <[email protected]>
AuthorDate: Mon Apr 7 14:53:13 2025 -0400

    [Cublas] Added support for bfloat16 while dispatching to cublas kernels 
(#17796)
    
    In this PR I have made changes so that we can support CUBLAS
    dispatch operations for bfloat16 data type.
---
 python/tvm/relax/backend/cuda/cublas.py   |  1 +
 src/runtime/contrib/cublas/cublas.cc      |  7 +++++-
 src/runtime/contrib/cublas/cublas_utils.h |  5 ++++
 tests/python/relax/test_codegen_cublas.py | 41 +++++++++++++++++++++++++++++++
 4 files changed, 53 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/backend/cuda/cublas.py 
b/python/tvm/relax/backend/cuda/cublas.py
index 6828381e68..f8621d9b56 100644
--- a/python/tvm/relax/backend/cuda/cublas.py
+++ b/python/tvm/relax/backend/cuda/cublas.py
@@ -43,6 +43,7 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
         (lhs_dtype == "float16" and rhs_dtype == "float16")
         or (lhs_dtype == "float32" and rhs_dtype == "float32")
         or (lhs_dtype == "int8" and rhs_dtype == "int8")
+        or (lhs_dtype == "bfloat16" and rhs_dtype == "bfloat16")
     )
 
 
diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index ba01f791d9..3fbda3ac94 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -125,7 +125,8 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType 
out_dtype, bool int_s
   if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
     return TypeMatch(in_dtype, kDLInt, 8);
   } else if (TypeMatch(out_dtype, kDLFloat, 32)) {
-    return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16);
+    return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) 
||
+           TypeMatch(in_dtype, kDLBfloat, 16);
   } else {
     return false;
   }
@@ -162,6 +163,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
 
   if (TypeMatch(A->dtype, kDLFloat, 16)) {
     ab_type = CUDA_R_16F;
+  } else if (TypeMatch(A->dtype, kDLBfloat, 16)) {
+    ab_type = CUDA_R_16BF;
   } else if (TypeMatch(A->dtype, kDLInt, 8)) {
     ab_type = CUDA_R_8I;
   } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
@@ -171,6 +174,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
 
   if (TypeMatch(C->dtype, kDLFloat, 16)) {
     c_type = CUDA_R_16F;
+  } else if (TypeMatch(C->dtype, kDLBfloat, 16)) {
+    c_type = CUDA_R_16BF;
   } else if (TypeMatch(C->dtype, kDLInt, 32)) {
     c_type = CUDA_R_32I;
     compute_type = CUBLAS_COMPUTE_32I;
diff --git a/src/runtime/contrib/cublas/cublas_utils.h 
b/src/runtime/contrib/cublas/cublas_utils.h
index 387065093e..3e9ded08de 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -116,6 +116,11 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
       case 64:
         return CUDA_R_64F;
     }
+  } else if (type.code == kDLBfloat) {
+    switch (type.bits) {
+      case 16:
+        return CUDA_R_16BF;
+    }
   }
   LOG(FATAL) << "Unsupported cuda type";
 }
diff --git a/tests/python/relax/test_codegen_cublas.py 
b/tests/python/relax/test_codegen_cublas.py
index dbcb25b69d..152f04fc3c 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -393,6 +393,47 @@ def test_matmul_fp8_multiply_offload():
     tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
 
 
[email protected](ml_dtypes is None, reason="requires ml_dtypes to be 
installed")
[email protected](
+    "x_shape, y_shape, transpose_y, out_dtype",
+    [
+        ((10, 32), (64, 32), True, "float32"),
+        ((32, 16), (32, 16), True, "float32"),
+        ((2, 10, 32), (2, 64, 32), True, "float32"),
+    ],
+)
+def test_matmul_bfloat16_offload(
+    x_shape,
+    y_shape,
+    transpose_y,
+    out_dtype,
+):
+    in_dtype = "bfloat16"
+    mod = get_relax_matmul_module(
+        x_shape,
+        y_shape,
+        in_dtype,
+        out_dtype,
+        bias_shape=None,
+        transposed_y=transpose_y,
+        activation=None,
+    )
+    # Generate input data in float32 and then convert to bfloat16 using 
ml_dtypes.
+    x_float32 = np.random.uniform(low=0, high=5, 
size=x_shape).astype("float32")
+    y_float32 = np.random.uniform(low=0, high=5, 
size=y_shape).astype("float32")
+    x_bf16 = ml_dtypes.bfloat16(x_float32)
+    y_bf16 = ml_dtypes.bfloat16(y_float32)
+
+    # For the reference result, adjust y (if needed) in float32.
+    z = np.swapaxes(y_float32, -2, -1) if transpose_y else y_float32
+    args = (x_bf16, y_bf16)
+
+    out = get_result_with_relax_cublas_offload(mod, args)
+    ref_out = np.matmul(x_float32, z).astype(out_dtype)
+
+    tvm.testing.assert_allclose(out, ref_out, rtol=1e-2, atol=1e-2)
+
+
 @pytest.mark.parametrize(
     "M, N, K, out_dtype, transposed_y, partition_done",
     [

Reply via email to