This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new d1ac73ca2d [CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#16888) d1ac73ca2d is described below commit d1ac73ca2d3c14dc69e47818871478e8b0f295aa Author: Ivan Sidorenko <98739392+ibsidore...@users.noreply.github.com> AuthorDate: Tue Apr 16 21:55:11 2024 +0300 [CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#16888) [CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#63) Co-authored-by: Andrey Malyshev <elvin.n...@gmail.com> --- include/tvm/runtime/data_type.h | 3 ++ python/tvm/contrib/tvmjs.py | 19 +++++++++ python/tvm/relax/backend/contrib/cublas.py | 16 ++++++- python/tvm/relax/transform/legalize_ops/qdq.py | 27 +++++++----- src/relax/backend/contrib/utils.h | 4 ++ src/relax/op/tensor/qdq.cc | 18 +++++--- src/runtime/contrib/cublas/cublas.cc | 3 ++ src/tir/op/op.cc | 2 + tests/python/relax/test_codegen_cublas.py | 59 ++++++++++++++++++++++++++ tests/python/relax/test_op_qdq.py | 37 ++++++++++++++++ 10 files changed, 169 insertions(+), 19 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index f7284ec690..a330ccbbdf 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -126,6 +126,9 @@ class DataType { code() == DataType::kE5M2Float) && bits() == 8; } + bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && bits() == 8); } + + bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && bits() == 8); } /*! \return whether type is a float16 type. */ bool is_float16() const { return is_float() && bits() == 16; } /*! \return whether type is a bfloat16 type. */ diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index 8d8bd1b051..923301a1f5 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -28,6 +28,11 @@ from typing import Iterator, Mapping, Tuple, Union import numpy as np +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + import tvm from tvm._ffi.libinfo import find_lib_path @@ -295,6 +300,20 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): arr = tvm.nd.empty(shape, dtype, device=device) assert offset + nbytes <= len(raw_data) buffer_source = raw_data[offset : offset + nbytes] + if dtype == "e4m3_float8": + if ml_dtypes is not None: + dtype = ml_dtypes.float8_e4m3fn + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert e4m3_float8 array to numpy." + ) + if dtype == "e5m2_float8": + if ml_dtypes is not None: + dtype = ml_dtypes.float8_e5m2 + else: + raise RuntimeError( + "ml_dtypes is not installed, cannot convert e5m2_float8 array to numpy." + ) if encode_format == "f32-to-bf16" and dtype == "float32": data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) arr.copyfrom(_convert_bf16_to_f32(data)) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index eecd531e74..f66001d0e8 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -28,8 +28,11 @@ from ..patterns import make_matmul_pattern from ..utils import has_leaking_intermediate_variables -def _is_supported_dtype(lhs_dtype, rhs_dtype): +def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): """Check if dtypes in the given workload are supported by cuBLAS BYOC.""" + if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8' + return out_dtype != "e5m2_float8" return ( (lhs_dtype == "float16" and rhs_dtype == "float16") or (lhs_dtype == "float32" and rhs_dtype == "float32") @@ -42,10 +45,12 @@ def _check_matmul(context: PatternCheckContext) -> bool: return False lhs = context.annotated_expr["lhs"] rhs = context.annotated_expr["rhs"] + matmul_call = context.annotated_expr["root"] lhs_dtype = lhs.struct_info.dtype rhs_dtype = rhs.struct_info.dtype - if not _is_supported_dtype(lhs_dtype, rhs_dtype): + out_dtype = matmul_call.struct_info.dtype + if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): return False lhs_shape = lhs.struct_info.shape.values @@ -62,6 +67,13 @@ def _check_matmul(context: PatternCheckContext) -> bool: if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 4 != 0: # Rows number must be multiples of 4 for IGEMM return False + elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8": + # Matrix dimensions must be multiples of 16. This requirement is missing from the cuBLAS + # docs, but it was observed during testing. + if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or rhs_shape[-1] % 16 != 0: + return False + if not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) or rhs_shape[-2] % 16 != 0: + return False lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1) rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1) diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py b/python/tvm/relax/transform/legalize_ops/qdq.py index 4f1e43d988..7484285c1e 100644 --- a/python/tvm/relax/transform/legalize_ops/qdq.py +++ b/python/tvm/relax/transform/legalize_ops/qdq.py @@ -52,7 +52,8 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr: def quantize_compute(*indices): scale_value = scale if is_const_scalar(scale) else scale[indices[axis]] zp_value = zp if is_const_scalar(zp) else zp[indices[axis]] - round_val = te.round(data[indices] / scale_value) + zp_value + scaled = data[indices] / scale_value + round_val = (te.round(scaled) if "int" in out_dtype else scaled) + zp_value return clip_cast(round_val, out_dtype) output_shape = data.shape @@ -75,15 +76,18 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr: Compute datatype: float32 Example of lowering: - qnn.dequantize(data, scale, zp, "float32") --> - sub = subtract(cast(data, "int32"), zp) - out = multiply(cast(sub, "float32"), scale) - - qnn.dequantize(data, scale, zp, "float16") --> - sub = subtract(cast(data, "int32"), zp) - mul = multiply(cast(sub, "float32"), cast(scale, "float32")) - clipped_out = clip(mul, float32(-65504.0), float32(65504.0)) - out = cast(clipped_out, "float16") + + dtype = ["int32"|"float32"] + + qnn.dequantize(data, scale, zp, "float32") --> + sub = subtract(cast(data, dtype), zp) + out = multiply(cast(sub, "float32"), scale) + + qnn.dequantize(data, scale, zp, "float16") --> + sub = subtract(cast(data, dtype), zp) + mul = multiply(cast(sub, "float32"), cast(scale, "float32")) + clipped_out = clip(mul, float32(-65504.0), float32(65504.0)) + out = cast(clipped_out, "float16") """ axis = call.attrs.axis out_dtype = call.attrs.out_dtype @@ -96,7 +100,8 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr: def dequantize_compute(*indices): scale_value = scale if is_const_scalar(scale) else scale[indices[axis]] zp_value = zp if is_const_scalar(zp) else zp[indices[axis]] - sub = te.subtract(data[indices].astype("int32"), zp_value) + dtype = "float32" if "float" in data.dtype else "int32" + sub = te.subtract(data[indices].astype(dtype), zp_value) out = te.multiply(sub, scale_value.astype("float32")) if out_dtype == "float32": return out diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index ee1240aaed..412651d3f9 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -72,6 +72,10 @@ inline std::string DType2String(const tvm::DataType dtype) { std::ostringstream os; if (dtype.is_float()) { os << "float"; + } else if (dtype.is_e4m3_float8()) { + os << "e4m3_float"; + } else if (dtype.is_e5m2_float8()) { + os << "e5m2_float"; } else if (dtype.is_int()) { os << "int"; } else if (dtype.is_uint()) { diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc index f8b0ed0ca2..0189ef9678 100644 --- a/src/relax/op/tensor/qdq.cc +++ b/src/relax/op/tensor/qdq.cc @@ -49,7 +49,9 @@ TVM_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize); StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { const auto* attrs = call->attrs.as<QuantizeAttrs>(); if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != DataType::UInt(8) && - attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16)) { + attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != DataType::UInt(16) && + attrs->out_dtype != DataType::NVFloat8E4M3() && + attrs->out_dtype != DataType::NVFloat8E5M2()) { ctx->ReportFatal(Diagnostic::Error(call) << "Unsupported output datatype attribute for operation: '" << attrs->out_dtype); @@ -73,9 +75,10 @@ StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) { } // Check datatype of zero_point param: - if (zp_sinfo->dtype != DataType::Int(8)) { + if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) { ctx->ReportFatal(Diagnostic::Error(call) - << "zero_point param datatype should be int8, but got " << zp_sinfo->dtype); + << "zero_point param datatype should be 'int8' or 'float16', but got " + << zp_sinfo->dtype); } // Check that "axis" attribute is not out of range: @@ -142,7 +145,9 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) // Check input datatype: if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype != DataType::UInt(8) && input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype != DataType::UInt(16) && - input_sinfo->dtype != DataType::Int(32)) { + input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != DataType::NVFloat8E4M3() && + input_sinfo->dtype != DataType::NVFloat8E5M2() && input_sinfo->dtype != DataType::Float(16) && + input_sinfo->dtype != DataType::Float(32)) { ctx->ReportFatal(Diagnostic::Error(call) << "Unsupported input datatype for operation: " << attrs->out_dtype); } @@ -155,9 +160,10 @@ StructInfo InferStructInfoDequantize(const Call& call, const BlockBuilder& ctx) } // Check datatype of zero_point param: - if (zp_sinfo->dtype != DataType::Int(8)) { + if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != DataType::Float(16)) { ctx->ReportFatal(Diagnostic::Error(call) - << "zero_point param datatype should be int8, but got " << zp_sinfo->dtype); + << "zero_point param datatype should be 'int8' or 'float16', but got " + << zp_sinfo->dtype); } // Check that "axis" attribute is not out of range: diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 7a867f4bae..49aa35a7e0 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -161,6 +161,9 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, ab_type = CUDA_R_16F; } else if (TypeMatch(A->dtype, kDLInt, 8)) { ab_type = CUDA_R_8I; + } else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) { + ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8)); + ab_type = CUDA_R_8F_E4M3; } if (TypeMatch(C->dtype, kDLFloat, 16)) { diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index b613639786..c79a148e4b 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -263,6 +263,7 @@ PrimExpr max_value(const DataType& dtype, Span span) { } else if (dtype.is_bfloat16()) { return FloatImm(dtype, std::numeric_limits<float>::max(), span); } else if (dtype.is_float8()) { + // according to https://arxiv.org/pdf/2209.05433.pdf if (dtype.code() == DataType::TypeCode::kE5M2Float) { return FloatImm(dtype, 57344.0, span); } else if (dtype.code() == DataType::TypeCode::kE4M3Float) { @@ -303,6 +304,7 @@ PrimExpr min_value(const DataType& dtype, Span span) { } else if (dtype.is_bfloat16()) { return FloatImm(dtype, std::numeric_limits<float>::lowest(), span); } else if (dtype.is_float8()) { + // according to https://arxiv.org/pdf/2209.05433.pdf if (dtype.code() == DataType::TypeCode::kE5M2Float) { return FloatImm(dtype, -57344.0, span); } else if (dtype.code() == DataType::TypeCode::kE4M3Float) { diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 52ad8b94b9..11247b3801 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -25,6 +25,11 @@ from tvm.relax.backend.contrib.cublas import partition_for_cublas from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R +try: + import ml_dtypes +except ImportError: + ml_dtypes = None + @pytest.fixture(autouse=True) def reset_seed(): @@ -226,6 +231,60 @@ def test_matmul_igemm_offload( tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) +@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, out_dtype", + [ + ((10, 32), (64, 32), True, "float32"), + ((32, 16), (32, 16), True, "float16"), + ((2, 10, 32), (2, 64, 32), True, "float32"), + ], +) +def test_matmul_fp8_offload( + x_shape, + y_shape, + transpose_y, + out_dtype, +): + in_dtype = "e4m3_float8" + mod = get_relax_matmul_module( + x_shape, + y_shape, + in_dtype, + out_dtype, + bias_shape=None, + transposed_y=transpose_y, + activation=None, + ) + numpytype = "float8_e4m3fn" + x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype) + y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype) + z = np.swapaxes(y, -2, -1) if transpose_y else y + args = (x, y) + + out = get_result_with_relax_cublas_offload(mod, args) + ref_out = np.matmul(x, z).astype(out_dtype) + + tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize( + "M, N, K, out_dtype, partition_done", + [ + (15, 64, 32, "float32", True), + (15, 64, 32, "e4m3_float8", True), + (15, 64, 32, "e5m2_float8", False), + (16, 32, 60, "float32", False), + (16, 30, 64, "float32", False), + ], +) +def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done): + mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype, transposed_y=True) + mod = partition_for_cublas(mod) + func_name = "relax_matmul_cublas" if partition_done else "R.matmul" + assert func_name in mod["main"].script() + + def test_cublas_partition_matmul_without_bias(): # cuBLAS does not handle 2D bias (residual input) mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32)) diff --git a/tests/python/relax/test_op_qdq.py b/tests/python/relax/test_op_qdq.py index 42391120e9..8b2d499041 100644 --- a/tests/python/relax/test_op_qdq.py +++ b/tests/python/relax/test_op_qdq.py @@ -68,5 +68,42 @@ def test_qdq_op_infer_struct_info_symbolic(): ) +def test_qdq_e4m3_float8_op_infer_struct_info_symbolic(): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((n, 3), "float32")) + dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8")) + s = relax.Var("s", R.Tensor([3], "float32")) + zp = relax.Var("zp", R.Tensor([3], "float16")) + _check_inference( + bb, + relax.op.quantize(x, s, zp, 1, "e4m3_float8"), + relax.TensorStructInfo((n, 3), "e4m3_float8"), + ) + _check_inference( + bb, + relax.op.dequantize(dx, s, zp, 1, "float32"), + relax.TensorStructInfo((n, 3), "float32"), + ) + + +def test_qdq_e5m2_float8_op_infer_struct_info_symbolic(): + dtype = "e5m2_float8" + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor((n, 3), "float32")) + dx = relax.Var("dx", R.Tensor((n, 3), dtype)) + s = relax.Var("s", R.Tensor([3], "float32")) + zp = relax.Var("zp", R.Tensor([3], "float16")) + _check_inference( + bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorStructInfo((n, 3), dtype) + ) + _check_inference( + bb, + relax.op.dequantize(dx, s, zp, 1, "float32"), + relax.TensorStructInfo((n, 3), "float32"), + ) + + if __name__ == "__main__": tvm.testing.main()