This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d1ac73ca2d [CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#16888)
d1ac73ca2d is described below

commit d1ac73ca2d3c14dc69e47818871478e8b0f295aa
Author: Ivan Sidorenko <98739392+ibsidore...@users.noreply.github.com>
AuthorDate: Tue Apr 16 21:55:11 2024 +0300

    [CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#16888)
    
    [CUBLAS][FP8] Support e4m3 gemm in cuBLAS BYOC (#63)
    
    Co-authored-by: Andrey Malyshev <elvin.n...@gmail.com>
---
 include/tvm/runtime/data_type.h                |  3 ++
 python/tvm/contrib/tvmjs.py                    | 19 +++++++++
 python/tvm/relax/backend/contrib/cublas.py     | 16 ++++++-
 python/tvm/relax/transform/legalize_ops/qdq.py | 27 +++++++-----
 src/relax/backend/contrib/utils.h              |  4 ++
 src/relax/op/tensor/qdq.cc                     | 18 +++++---
 src/runtime/contrib/cublas/cublas.cc           |  3 ++
 src/tir/op/op.cc                               |  2 +
 tests/python/relax/test_codegen_cublas.py      | 59 ++++++++++++++++++++++++++
 tests/python/relax/test_op_qdq.py              | 37 ++++++++++++++++
 10 files changed, 169 insertions(+), 19 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index f7284ec690..a330ccbbdf 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -126,6 +126,9 @@ class DataType {
             code() == DataType::kE5M2Float) &&
            bits() == 8;
   }
+  bool is_e4m3_float8() const { return (code() == DataType::kE4M3Float && 
bits() == 8); }
+
+  bool is_e5m2_float8() const { return (code() == DataType::kE5M2Float && 
bits() == 8); }
   /*! \return whether type is a float16 type. */
   bool is_float16() const { return is_float() && bits() == 16; }
   /*! \return whether type is a bfloat16 type. */
diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py
index 8d8bd1b051..923301a1f5 100644
--- a/python/tvm/contrib/tvmjs.py
+++ b/python/tvm/contrib/tvmjs.py
@@ -28,6 +28,11 @@ from typing import Iterator, Mapping, Tuple, Union
 
 import numpy as np
 
+try:
+    import ml_dtypes
+except ImportError:
+    ml_dtypes = None
+
 import tvm
 from tvm._ffi.libinfo import find_lib_path
 
@@ -295,6 +300,20 @@ def load_ndarray_cache(cachepath: str, device: 
tvm.runtime.Device):
             arr = tvm.nd.empty(shape, dtype, device=device)
             assert offset + nbytes <= len(raw_data)
             buffer_source = raw_data[offset : offset + nbytes]
+            if dtype == "e4m3_float8":
+                if ml_dtypes is not None:
+                    dtype = ml_dtypes.float8_e4m3fn
+                else:
+                    raise RuntimeError(
+                        "ml_dtypes is not installed, cannot convert 
e4m3_float8 array to numpy."
+                    )
+            if dtype == "e5m2_float8":
+                if ml_dtypes is not None:
+                    dtype = ml_dtypes.float8_e5m2
+                else:
+                    raise RuntimeError(
+                        "ml_dtypes is not installed, cannot convert 
e5m2_float8 array to numpy."
+                    )
             if encode_format == "f32-to-bf16" and dtype == "float32":
                 data = np.frombuffer(buffer_source, 
dtype="uint16").reshape(shape)
                 arr.copyfrom(_convert_bf16_to_f32(data))
diff --git a/python/tvm/relax/backend/contrib/cublas.py 
b/python/tvm/relax/backend/contrib/cublas.py
index eecd531e74..f66001d0e8 100644
--- a/python/tvm/relax/backend/contrib/cublas.py
+++ b/python/tvm/relax/backend/contrib/cublas.py
@@ -28,8 +28,11 @@ from ..patterns import make_matmul_pattern
 from ..utils import has_leaking_intermediate_variables
 
 
-def _is_supported_dtype(lhs_dtype, rhs_dtype):
+def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
     """Check if dtypes in the given workload are supported by cuBLAS BYOC."""
+    if lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+        # The output cannot be 'e5m2_float8' if inputs are 'e4m3_float8'
+        return out_dtype != "e5m2_float8"
     return (
         (lhs_dtype == "float16" and rhs_dtype == "float16")
         or (lhs_dtype == "float32" and rhs_dtype == "float32")
@@ -42,10 +45,12 @@ def _check_matmul(context: PatternCheckContext) -> bool:
         return False
     lhs = context.annotated_expr["lhs"]
     rhs = context.annotated_expr["rhs"]
+    matmul_call = context.annotated_expr["root"]
 
     lhs_dtype = lhs.struct_info.dtype
     rhs_dtype = rhs.struct_info.dtype
-    if not _is_supported_dtype(lhs_dtype, rhs_dtype):
+    out_dtype = matmul_call.struct_info.dtype
+    if not _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
         return False
 
     lhs_shape = lhs.struct_info.shape.values
@@ -62,6 +67,13 @@ def _check_matmul(context: PatternCheckContext) -> bool:
         if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or 
rhs_shape[-1] % 4 != 0:
             # Rows number must be multiples of 4 for IGEMM
             return False
+    elif lhs_dtype == "e4m3_float8" and rhs_dtype == "e4m3_float8":
+        # Matrix dimensions must be multiples of 16. This requirement is 
missing from the cuBLAS
+        # docs, but it was observed during testing.
+        if not isinstance(rhs_shape[-1], (tvm.tir.expr.IntImm, int)) or 
rhs_shape[-1] % 16 != 0:
+            return False
+        if not isinstance(rhs_shape[-2], (tvm.tir.expr.IntImm, int)) or 
rhs_shape[-2] % 16 != 0:
+            return False
 
     lhs_batches = reduce(operator.mul, lhs_shape[:-2], 1)
     rhs_batches = reduce(operator.mul, rhs_shape[:-2], 1)
diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py 
b/python/tvm/relax/transform/legalize_ops/qdq.py
index 4f1e43d988..7484285c1e 100644
--- a/python/tvm/relax/transform/legalize_ops/qdq.py
+++ b/python/tvm/relax/transform/legalize_ops/qdq.py
@@ -52,7 +52,8 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr:
         def quantize_compute(*indices):
             scale_value = scale if is_const_scalar(scale) else 
scale[indices[axis]]
             zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
-            round_val = te.round(data[indices] / scale_value) + zp_value
+            scaled = data[indices] / scale_value
+            round_val = (te.round(scaled) if "int" in out_dtype else scaled) + 
zp_value
             return clip_cast(round_val, out_dtype)
 
         output_shape = data.shape
@@ -75,15 +76,18 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
     Compute datatype: float32
 
     Example of lowering:
-    qnn.dequantize(data, scale, zp, "float32") -->
-        sub = subtract(cast(data, "int32"), zp)
-        out = multiply(cast(sub, "float32"), scale)
-
-    qnn.dequantize(data, scale, zp, "float16") -->
-        sub = subtract(cast(data, "int32"), zp)
-        mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
-        clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
-        out = cast(clipped_out, "float16")
+
+        dtype = ["int32"|"float32"]
+
+        qnn.dequantize(data, scale, zp, "float32") -->
+            sub = subtract(cast(data, dtype), zp)
+            out = multiply(cast(sub, "float32"), scale)
+
+        qnn.dequantize(data, scale, zp, "float16") -->
+            sub = subtract(cast(data, dtype), zp)
+            mul = multiply(cast(sub, "float32"), cast(scale, "float32"))
+            clipped_out = clip(mul, float32(-65504.0), float32(65504.0))
+            out = cast(clipped_out, "float16")
     """
     axis = call.attrs.axis
     out_dtype = call.attrs.out_dtype
@@ -96,7 +100,8 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
         def dequantize_compute(*indices):
             scale_value = scale if is_const_scalar(scale) else 
scale[indices[axis]]
             zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
-            sub = te.subtract(data[indices].astype("int32"), zp_value)
+            dtype = "float32" if "float" in data.dtype else "int32"
+            sub = te.subtract(data[indices].astype(dtype), zp_value)
             out = te.multiply(sub, scale_value.astype("float32"))
             if out_dtype == "float32":
                 return out
diff --git a/src/relax/backend/contrib/utils.h 
b/src/relax/backend/contrib/utils.h
index ee1240aaed..412651d3f9 100644
--- a/src/relax/backend/contrib/utils.h
+++ b/src/relax/backend/contrib/utils.h
@@ -72,6 +72,10 @@ inline std::string DType2String(const tvm::DataType dtype) {
   std::ostringstream os;
   if (dtype.is_float()) {
     os << "float";
+  } else if (dtype.is_e4m3_float8()) {
+    os << "e4m3_float";
+  } else if (dtype.is_e5m2_float8()) {
+    os << "e5m2_float";
   } else if (dtype.is_int()) {
     os << "int";
   } else if (dtype.is_uint()) {
diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc
index f8b0ed0ca2..0189ef9678 100644
--- a/src/relax/op/tensor/qdq.cc
+++ b/src/relax/op/tensor/qdq.cc
@@ -49,7 +49,9 @@ 
TVM_REGISTER_GLOBAL("relax.op.quantize").set_body_typed(quantize);
 StructInfo InferStructInfoQuantize(const Call& call, const BlockBuilder& ctx) {
   const auto* attrs = call->attrs.as<QuantizeAttrs>();
   if (attrs->out_dtype != DataType::Int(8) && attrs->out_dtype != 
DataType::UInt(8) &&
-      attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != 
DataType::UInt(16)) {
+      attrs->out_dtype != DataType::Int(16) && attrs->out_dtype != 
DataType::UInt(16) &&
+      attrs->out_dtype != DataType::NVFloat8E4M3() &&
+      attrs->out_dtype != DataType::NVFloat8E5M2()) {
     ctx->ReportFatal(Diagnostic::Error(call)
                      << "Unsupported output datatype attribute for operation: 
'"
                      << attrs->out_dtype);
@@ -73,9 +75,10 @@ StructInfo InferStructInfoQuantize(const Call& call, const 
BlockBuilder& ctx) {
   }
 
   // Check datatype of zero_point param:
-  if (zp_sinfo->dtype != DataType::Int(8)) {
+  if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != 
DataType::Float(16)) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << "zero_point param datatype should be int8, but got " 
<< zp_sinfo->dtype);
+                     << "zero_point param datatype should be 'int8' or 
'float16', but got "
+                     << zp_sinfo->dtype);
   }
 
   // Check that "axis" attribute is not out of range:
@@ -142,7 +145,9 @@ StructInfo InferStructInfoDequantize(const Call& call, 
const BlockBuilder& ctx)
   // Check input datatype:
   if (input_sinfo->dtype != DataType::Int(8) && input_sinfo->dtype != 
DataType::UInt(8) &&
       input_sinfo->dtype != DataType::Int(16) && input_sinfo->dtype != 
DataType::UInt(16) &&
-      input_sinfo->dtype != DataType::Int(32)) {
+      input_sinfo->dtype != DataType::Int(32) && input_sinfo->dtype != 
DataType::NVFloat8E4M3() &&
+      input_sinfo->dtype != DataType::NVFloat8E5M2() && input_sinfo->dtype != 
DataType::Float(16) &&
+      input_sinfo->dtype != DataType::Float(32)) {
     ctx->ReportFatal(Diagnostic::Error(call)
                      << "Unsupported input datatype for operation: " << 
attrs->out_dtype);
   }
@@ -155,9 +160,10 @@ StructInfo InferStructInfoDequantize(const Call& call, 
const BlockBuilder& ctx)
   }
 
   // Check datatype of zero_point param:
-  if (zp_sinfo->dtype != DataType::Int(8)) {
+  if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != 
DataType::Float(16)) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << "zero_point param datatype should be int8, but got " 
<< zp_sinfo->dtype);
+                     << "zero_point param datatype should be 'int8' or 
'float16', but got "
+                     << zp_sinfo->dtype);
   }
 
   // Check that "axis" attribute is not out of range:
diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index 7a867f4bae..49aa35a7e0 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -161,6 +161,9 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
     ab_type = CUDA_R_16F;
   } else if (TypeMatch(A->dtype, kDLInt, 8)) {
     ab_type = CUDA_R_8I;
+  } else if (TypeMatch(A->dtype, DataType::TypeCode::kE4M3Float, 8)) {
+    ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kE4M3Float, 8));
+    ab_type = CUDA_R_8F_E4M3;
   }
 
   if (TypeMatch(C->dtype, kDLFloat, 16)) {
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index b613639786..c79a148e4b 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -263,6 +263,7 @@ PrimExpr max_value(const DataType& dtype, Span span) {
   } else if (dtype.is_bfloat16()) {
     return FloatImm(dtype, std::numeric_limits<float>::max(), span);
   } else if (dtype.is_float8()) {
+    // according to https://arxiv.org/pdf/2209.05433.pdf
     if (dtype.code() == DataType::TypeCode::kE5M2Float) {
       return FloatImm(dtype, 57344.0, span);
     } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
@@ -303,6 +304,7 @@ PrimExpr min_value(const DataType& dtype, Span span) {
   } else if (dtype.is_bfloat16()) {
     return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
   } else if (dtype.is_float8()) {
+    // according to https://arxiv.org/pdf/2209.05433.pdf
     if (dtype.code() == DataType::TypeCode::kE5M2Float) {
       return FloatImm(dtype, -57344.0, span);
     } else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
diff --git a/tests/python/relax/test_codegen_cublas.py 
b/tests/python/relax/test_codegen_cublas.py
index 52ad8b94b9..11247b3801 100644
--- a/tests/python/relax/test_codegen_cublas.py
+++ b/tests/python/relax/test_codegen_cublas.py
@@ -25,6 +25,11 @@ from tvm.relax.backend.contrib.cublas import 
partition_for_cublas
 from tvm.relax.testing import get_relax_matmul_module
 from tvm.script import relax as R
 
+try:
+    import ml_dtypes
+except ImportError:
+    ml_dtypes = None
+
 
 @pytest.fixture(autouse=True)
 def reset_seed():
@@ -226,6 +231,60 @@ def test_matmul_igemm_offload(
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be 
installed")
+@pytest.mark.parametrize(
+    "x_shape, y_shape, transpose_y, out_dtype",
+    [
+        ((10, 32), (64, 32), True, "float32"),
+        ((32, 16), (32, 16), True, "float16"),
+        ((2, 10, 32), (2, 64, 32), True, "float32"),
+    ],
+)
+def test_matmul_fp8_offload(
+    x_shape,
+    y_shape,
+    transpose_y,
+    out_dtype,
+):
+    in_dtype = "e4m3_float8"
+    mod = get_relax_matmul_module(
+        x_shape,
+        y_shape,
+        in_dtype,
+        out_dtype,
+        bias_shape=None,
+        transposed_y=transpose_y,
+        activation=None,
+    )
+    numpytype = "float8_e4m3fn"
+    x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
+    y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
+    z = np.swapaxes(y, -2, -1) if transpose_y else y
+    args = (x, y)
+
+    out = get_result_with_relax_cublas_offload(mod, args)
+    ref_out = np.matmul(x, z).astype(out_dtype)
+
+    tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3)
+
+
+@pytest.mark.parametrize(
+    "M, N, K, out_dtype, partition_done",
+    [
+        (15, 64, 32, "float32", True),
+        (15, 64, 32, "e4m3_float8", True),
+        (15, 64, 32, "e5m2_float8", False),
+        (16, 32, 60, "float32", False),
+        (16, 30, 64, "float32", False),
+    ],
+)
+def test_cublas_partition_fp8_matmul(M, N, K, out_dtype, partition_done):
+    mod = get_relax_matmul_module((M, K), (N, K), "e4m3_float8", out_dtype, 
transposed_y=True)
+    mod = partition_for_cublas(mod)
+    func_name = "relax_matmul_cublas" if partition_done else "R.matmul"
+    assert func_name in mod["main"].script()
+
+
 def test_cublas_partition_matmul_without_bias():
     # cuBLAS does not handle 2D bias (residual input)
     mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", 
bias_shape=(16, 32))
diff --git a/tests/python/relax/test_op_qdq.py 
b/tests/python/relax/test_op_qdq.py
index 42391120e9..8b2d499041 100644
--- a/tests/python/relax/test_op_qdq.py
+++ b/tests/python/relax/test_op_qdq.py
@@ -68,5 +68,42 @@ def test_qdq_op_infer_struct_info_symbolic():
     )
 
 
+def test_qdq_e4m3_float8_op_infer_struct_info_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor((n, 3), "float32"))
+    dx = relax.Var("dx", R.Tensor((n, 3), "e4m3_float8"))
+    s = relax.Var("s", R.Tensor([3], "float32"))
+    zp = relax.Var("zp", R.Tensor([3], "float16"))
+    _check_inference(
+        bb,
+        relax.op.quantize(x, s, zp, 1, "e4m3_float8"),
+        relax.TensorStructInfo((n, 3), "e4m3_float8"),
+    )
+    _check_inference(
+        bb,
+        relax.op.dequantize(dx, s, zp, 1, "float32"),
+        relax.TensorStructInfo((n, 3), "float32"),
+    )
+
+
+def test_qdq_e5m2_float8_op_infer_struct_info_symbolic():
+    dtype = "e5m2_float8"
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor((n, 3), "float32"))
+    dx = relax.Var("dx", R.Tensor((n, 3), dtype))
+    s = relax.Var("s", R.Tensor([3], "float32"))
+    zp = relax.Var("zp", R.Tensor([3], "float16"))
+    _check_inference(
+        bb, relax.op.quantize(x, s, zp, 1, dtype), relax.TensorStructInfo((n, 
3), dtype)
+    )
+    _check_inference(
+        bb,
+        relax.op.dequantize(dx, s, zp, 1, "float32"),
+        relax.TensorStructInfo((n, 3), "float32"),
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to