This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0d2eab2a99 [Cublas] Added support for bfloat16 while dispatching to
cublas kernels (#17796)
0d2eab2a99 is described below
commit 0d2eab2a9973d09bea6d00b97a0b06a63ad0f070
Author: Annanya <[email protected]>
AuthorDate: Mon Apr 7 14:53:13 2025 -0400
[Cublas] Added support for bfloat16 while dispatching to cublas kernels
(#17796)
In this PR I have made changes so that we can support CUBLAS
dispatch operations for bfloat16 data type.
---
python/tvm/relax/backend/cuda/cublas.py | 1 +
src/runtime/contrib/cublas/cublas.cc | 7 +++++-
src/runtime/contrib/cublas/cublas_utils.h | 5 ++++
tests/python/relax/test_codegen_cublas.py | 41 +++++++++++++++++++++++++++++++
4 files changed, 53 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/backend/cuda/cublas.py
b/python/tvm/relax/backend/cuda/cublas.py
index 6828381e68..f8621d9b56 100644
--- a/python/tvm/relax/backend/cuda/cublas.py
+++ b/python/tvm/relax/backend/cuda/cublas.py
@@ -43,6 +43,7 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
(lhs_dtype == "float16" and rhs_dtype == "float16")
or (lhs_dtype == "float32" and rhs_dtype == "float32")
or (lhs_dtype == "int8" and rhs_dtype == "int8")
+ or (lhs_dtype == "bfloat16" and rhs_dtype == "bfloat16")
)
diff --git a/src/runtime/contrib/cublas/cublas.cc
b/src/runtime/contrib/cublas/cublas.cc
index ba01f791d9..3fbda3ac94 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -125,7 +125,8 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType
out_dtype, bool int_s
if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
return TypeMatch(in_dtype, kDLInt, 8);
} else if (TypeMatch(out_dtype, kDLFloat, 32)) {
- return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16);
+ return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16)
||
+ TypeMatch(in_dtype, kDLBfloat, 16);
} else {
return false;
}
@@ -162,6 +163,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
if (TypeMatch(A->dtype, kDLFloat, 16)) {
ab_type = CUDA_R_16F;
+ } else if (TypeMatch(A->dtype, kDLBfloat, 16)) {
+ ab_type = CUDA_R_16BF;
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
ab_type = CUDA_R_8I;
} else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
@@ -171,6 +174,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
if (TypeMatch(C->dtype, kDLFloat, 16)) {
c_type = CUDA_R_16F;
+ } else if (TypeMatch(C->dtype, kDLBfloat, 16)) {
+ c_type = CUDA_R_16BF;
} else if (TypeMatch(C->dtype, kDLInt, 32)) {
c_type = CUDA_R_32I;
compute_type = CUBLAS_COMPUTE_32I;
diff --git a/src/runtime/contrib/cublas/cublas_utils.h
b/src/runtime/contrib/cublas/cublas_utils.h
index 387065093e..3e9ded08de 100644
--- a/src/runtime/contrib/cublas/cublas_utils.h
+++ b/src/runtime/contrib/cublas/cublas_utils.h
@@ -116,6 +116,11 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
case 64:
return CUDA_R_64F;
}
+ } else if (type.code == kDLBfloat) {
+ switch (type.bits) {
+ case 16:
+ return CUDA_R_16BF;
+ }
}
LOG(FATAL) << "Unsupported cuda type";
}
diff --git a/tests/python/relax/test_codegen_cublas.py
b/tests/python/relax/test_codegen_cublas.py
index dbcb25b69d..152f04fc3c 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -393,6 +393,47 @@ def test_matmul_fp8_multiply_offload():
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
[email protected](ml_dtypes is None, reason="requires ml_dtypes to be
installed")
[email protected](
+ "x_shape, y_shape, transpose_y, out_dtype",
+ [
+ ((10, 32), (64, 32), True, "float32"),
+ ((32, 16), (32, 16), True, "float32"),
+ ((2, 10, 32), (2, 64, 32), True, "float32"),
+ ],
+)
+def test_matmul_bfloat16_offload(
+ x_shape,
+ y_shape,
+ transpose_y,
+ out_dtype,
+):
+ in_dtype = "bfloat16"
+ mod = get_relax_matmul_module(
+ x_shape,
+ y_shape,
+ in_dtype,
+ out_dtype,
+ bias_shape=None,
+ transposed_y=transpose_y,
+ activation=None,
+ )
+ # Generate input data in float32 and then convert to bfloat16 using
ml_dtypes.
+ x_float32 = np.random.uniform(low=0, high=5,
size=x_shape).astype("float32")
+ y_float32 = np.random.uniform(low=0, high=5,
size=y_shape).astype("float32")
+ x_bf16 = ml_dtypes.bfloat16(x_float32)
+ y_bf16 = ml_dtypes.bfloat16(y_float32)
+
+ # For the reference result, adjust y (if needed) in float32.
+ z = np.swapaxes(y_float32, -2, -1) if transpose_y else y_float32
+ args = (x_bf16, y_bf16)
+
+ out = get_result_with_relax_cublas_offload(mod, args)
+ ref_out = np.matmul(x_float32, z).astype(out_dtype)
+
+ tvm.testing.assert_allclose(out, ref_out, rtol=1e-2, atol=1e-2)
+
+
@pytest.mark.parametrize(
"M, N, K, out_dtype, transposed_y, partition_done",
[