This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b89f9537c [Codegen][CUDA] Fix codegen of cast among vector bfloat16, 
fp8 and fp4 (#17741)
6b89f9537c is described below

commit 6b89f9537c0270a626e67f431e19edf3c9d3b20e
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Mar 12 07:53:16 2025 -0400

    [Codegen][CUDA] Fix codegen of cast among vector bfloat16, fp8 and fp4 
(#17741)
    
    This PR fixes the CUDA code generation for fp8 (also fp4) and bfloat16.
    We added a few vector data conversion util functions.
---
 src/target/source/codegen_cuda.cc                  |  58 +++++---
 src/target/source/literal/cuda_half_t.h            | 151 +++++++++++++++------
 .../python/codegen/test_target_codegen_cuda_fp8.py |  49 +++++--
 3 files changed, 186 insertions(+), 72 deletions(-)

diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 34023e0bb7..a97e66d346 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -194,7 +194,7 @@ std::string CodeGenCUDA::Finish() {
     decl_stream << "#include <cuda_fp4.h>\n";
     decl_stream << "#endif\n\n";
   }
-  declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_, 
enable_fp4_);
+  declare_vector_type_extensions(decl_stream, enable_fp16_, enable_bf16_, 
enable_fp8_, enable_fp4_);
 
   if (enable_warp_shuffle_) {
     decl_stream << _cuda_warp_intrinsic_util;
@@ -331,8 +331,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) 
{  // NOLINT(*)
     if (t.is_scalar()) {
       os << "nv_bfloat16";
     } else if (lanes <= 8) {
-      ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
-      os << "uint" << lanes / 2;
+      ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type";
+      if (lanes <= 4) {
+        os << "nv_bfloat16" << lanes;
+      } else {
+        os << "uint" << lanes / 2;
+      }
     } else {
       fail = true;
     }
@@ -575,7 +579,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, 
DataType t, int i,
       os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << 
access[i % 2];
     }
   } else if (t.is_bfloat16()) {
-    os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << 
access[i % 2];
+    if (t.lanes() <= 4) {
+      os << vec << "." << access[i];
+    } else {
+      os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" << 
access[i % 2];
+    }
   } else if (t.lanes() > 4 && t.lanes() <= 8) {
     std::string type_name;
     if (t.bits() == 16) {
@@ -630,8 +638,12 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& 
vec, DataType t, int i,
     }
 
   } else if (t.is_bfloat16()) {
-    stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" 
<< access[i % 2]
-           << " = " << value << ";\n";
+    if (t.lanes() <= 4) {
+      stream << vec << "." << access[i] << " = " << value << ";\n";
+    } else {
+      stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << 
")))->" << access[i % 2]
+             << " = " << value << ";\n";
+    }
   } else if (t.lanes() > 4 && t.lanes() <= 8) {
     std::string type_name;
     if (t.bits() == 16) {
@@ -736,9 +748,13 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, 
std::ostream& os) {
       target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() == 
DataType::kFloat8_e4m3fn ||
       from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() == 
DataType::kFloat4_e2m1fn) {
     std::ostringstream val;
-    val << "(";
-    PrintType(target_ty, val);
-    val << ")(" << PrintExpr(op->value) << ")";
+    if (target_ty.code() == DataType::kBFloat && target_ty.lanes() == 2) {
+      val << "cast_to_nv_bfloat162(" << PrintExpr(op->value) << ")";
+    } else {
+      val << "(";
+      PrintType(target_ty, val);
+      val << ")(" << PrintExpr(op->value) << ")";
+    }
     os << val.str();
     return;
   }
@@ -1384,9 +1400,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, 
std::ostream& os) {  // NO
     std::string v = PrintExpr(op->value);
     PrintVecConstructor(op->dtype, os);
     os << '(';
-    for (int i = 0; i < lanes / 2; ++i) {
-      if (i != 0) os << ", ";
-      os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
+    if (lanes > 4) {
+      for (int i = 0; i < lanes / 2; ++i) {
+        if (i != 0) os << ", ";
+        os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
+      }
+    } else {
+      for (int i = 0; i < lanes; ++i) {
+        if (i != 0) os << ", ";
+        os << v;
+      }
     }
     os << ')';
     return;
@@ -1660,15 +1683,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int 
i, const std::string& val
       PrintVecConstructor(t, os);
       os << '(';
     }
-    if (i % 2 == 0) {
-      os << "__pack_nv_bfloat162(" << value;
+    if (i == t.lanes() - 1) {
+      os << value << ")";
     } else {
-      os << "," << value << ")";
-      if (i != t.lanes() - 1) {
-        os << ",";
-      } else {
-        os << ")";
-      }
+      os << value << ",";
     }
     return;
   }
diff --git a/src/target/source/literal/cuda_half_t.h 
b/src/target/source/literal/cuda_half_t.h
index 86f2219fe8..b095f5b8cf 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -385,52 +385,70 @@ static constexpr const char* _cuda_warp_intrinsic_util = 
R"(
 
 )";
 
-void declare_vector_type_extensions(std::ostringstream& stream, bool 
enable_fp16, bool enable_fp8,
-                                    bool enable_fp4) {
-  if (enable_fp16 || enable_fp8 || enable_fp4) {
+void declare_vector_type_extensions(std::ostringstream& stream, bool 
enable_fp16, bool enable_bf16,
+                                    bool enable_fp8, bool enable_fp4) {
+  if (enable_fp16 || enable_bf16) {
     stream << R"(
-struct __align__(8) half4 {
-  __half x, y, z, w;
-  __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)), 
w(__half(0)) {}
-  __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x), 
y(y), z(z), w(w) {}
+#include <type_traits>
+template <typename T, typename TVec2>
+struct __align__(8) half4_bfloat164 {
+  T x, y, z, w;
+  __host__ __device__ half4_bfloat164() : x(T(0)), y(T(0)), z(T(0)), w(T(0)) {}
+  __host__ __device__ half4_bfloat164(T x, T y, T z, T w) : x(x), y(y), z(z), 
w(w) {}
 )";
     if (enable_fp8) {
       stream << R"(
-  __host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
-    __nv_fp8x2_e4m3 lo_part, hi_part;
-    lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
-    hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 
0xFFFF);
-    __half2 lo_half2 = static_cast<__half2>(lo_part);
-    __half2 hi_half2 = static_cast<__half2>(hi_part);
-    x = reinterpret_cast<__half*>(&lo_half2)[0];
-    y = reinterpret_cast<__half*>(&lo_half2)[1];
-    z = reinterpret_cast<__half*>(&hi_half2)[0];
-    w = reinterpret_cast<__half*>(&hi_half2)[1];
+  __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) {
+    if constexpr (std::is_same_v<T, __half>) {
+      __nv_fp8x2_e4m3 lo_part, hi_part;
+      lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
+      hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 
0xFFFF);
+      TVec2 lo_half2 = static_cast<TVec2>(lo_part);
+      TVec2 hi_half2 = static_cast<TVec2>(hi_part);
+      x = reinterpret_cast<T*>(&lo_half2)[0];
+      y = reinterpret_cast<T*>(&lo_half2)[1];
+      z = reinterpret_cast<T*>(&hi_half2)[0];
+      w = reinterpret_cast<T*>(&hi_half2)[1];
+    } else {
+      __nv_fp8_storage_t elem0_raw = static_cast<__nv_fp8_storage_t>(fp8x4.__x 
& 0xFF);
+      __nv_fp8_storage_t elem1_raw = 
static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 8) & 0xFF);
+      __nv_fp8_storage_t elem2_raw = 
static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 16) & 0xFF);
+      __nv_fp8_storage_t elem3_raw = 
static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 24) & 0xFF);
+      __nv_fp8_e4m3 elem0, elem1, elem2, elem3;
+      elem0.__x = elem0_raw;
+      elem1.__x = elem1_raw;
+      elem2.__x = elem2_raw;
+      elem3.__x = elem3_raw;
+      x = T(elem0);
+      y = T(elem1);
+      z = T(elem2);
+      w = T(elem3);
+    }
   }
   __host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
     __nv_fp8x4_e4m3 result;
-    __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
-    __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+    TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
+    TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
     __nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
     result.__x =
         (static_cast<__uint32_t>(lo_part.__x) | 
(static_cast<__uint32_t>(hi_part.__x) << 16));
     return result;
   }
-  __host__ __device__ explicit half4(const __nv_fp8x4_e5m2& fp8x4) {
-    __nv_fp8x2_e5m2 lo_part, hi_part;
-    lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
-    hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 
0xFFFF);
-    __half2 lo_half2 = static_cast<__half2>(lo_part);
-    __half2 hi_half2 = static_cast<__half2>(hi_part);
-    x = reinterpret_cast<__half*>(&lo_half2)[0];
-    y = reinterpret_cast<__half*>(&lo_half2)[1];
-    z = reinterpret_cast<__half*>(&hi_half2)[0];
-    w = reinterpret_cast<__half*>(&hi_half2)[1];
+  __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e5m2& fp8x4) {
+      __nv_fp8x2_e5m2 lo_part, hi_part;
+      lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
+      hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 
0xFFFF);
+      TVec2 lo_half2 = static_cast<TVec2>(lo_part);
+      TVec2 hi_half2 = static_cast<TVec2>(hi_part);
+      x = reinterpret_cast<T*>(&lo_half2)[0];
+      y = reinterpret_cast<T*>(&lo_half2)[1];
+      z = reinterpret_cast<T*>(&hi_half2)[0];
+      w = reinterpret_cast<T*>(&hi_half2)[1];
   }
   __host__ __device__ explicit operator __nv_fp8x4_e5m2() const {
     __nv_fp8x4_e5m2 result;
-    __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
-    __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+    TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
+    TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
     __nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2);
     result.__x =
         (static_cast<__uint32_t>(lo_part.__x) | 
(static_cast<__uint32_t>(hi_part.__x) << 16));
@@ -460,31 +478,70 @@ struct __align__(8) half4 {
     }
     if (enable_fp4) {
       stream << R"(
-  __host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) {
-    __nv_fp4x2_storage_t lo_part, hi_part;
-    lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
-    hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
-    __half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
-    __half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
-    x = reinterpret_cast<__half*>(&lo_half2)[0];
-    y = reinterpret_cast<__half*>(&lo_half2)[1];
-    z = reinterpret_cast<__half*>(&hi_half2)[0];
-    w = reinterpret_cast<__half*>(&hi_half2)[1];
+  __host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) {
+    if constexpr (std::is_same_v<T, __half>) {
+      __nv_fp4x2_storage_t lo_part = 
static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
+      __nv_fp4x2_storage_t hi_part = 
static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
+      TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
+      TVec2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
+      x = reinterpret_cast<T*>(&lo_half2)[0];
+      y = reinterpret_cast<T*>(&lo_half2)[1];
+      z = reinterpret_cast<T*>(&hi_half2)[0];
+      w = reinterpret_cast<T*>(&hi_half2)[1];
+    } else {
+      __nv_fp4_e2m1 elem0, elem1, elem2, elem3;
+      elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x4.__x & 0xF);
+      elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 4) & 0xF);
+      elem2.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 8) & 0xF);
+      elem3.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 12) & 0xF);
+      x = T(elem0);
+      y = T(elem1);
+      z = T(elem2);
+      w = T(elem3);
+    }
   }
   __host__ __device__ explicit operator __nv_fp4x4_e2m1() const {
-    __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
-    __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+    TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
+    TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
     return __nv_fp4x4_e2m1(lo_half2, hi_half2);
   }
   )";
     }
     stream << R"(
 };
+)";
+  }
+  if (enable_fp16) {
+    stream << R"(
+using half4 = half4_bfloat164<__half, __half2>;
 __host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
     return half4(x, y, z, w);
 }
 )";
   }
+  if (enable_bf16) {
+    stream << R"(
+using nv_bfloat164 = half4_bfloat164<nv_bfloat16, nv_bfloat162>;
+__host__ __device__ nv_bfloat164 make_nv_bfloat164(nv_bfloat16 x, nv_bfloat16 
y, nv_bfloat16 z, nv_bfloat16 w) {
+    return nv_bfloat164(x, y, z, w);
+}
+__host__ __device__ nv_bfloat162 make_nv_bfloat162(nv_bfloat16 x, nv_bfloat16 
y) {
+    return nv_bfloat162(x, y);
+}
+)";
+    if (enable_fp8) {
+      stream << R"(
+__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& 
fp8x2) {
+    __nv_fp8_e4m3 elem0, elem1;
+    elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF);
+    elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF);
+    nv_bfloat16 x = nv_bfloat16(elem0);
+    nv_bfloat16 y = nv_bfloat16(elem1);
+    return nv_bfloat162(x, y);
+}
+)";
+    }
+  }
   if (enable_fp4) {
     stream << R"(
 __device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 
y) {
@@ -497,6 +554,14 @@ __device__ __nv_fp4x4_e2m1 
make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 b
   result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) | 
(static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) | 
(static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) | 
(static_cast<__nv_fp4x4_storage_t>(d.__x) << 12);
   return result;
 }
+__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp4x2_e2m1& 
fp4x2) {
+    __nv_fp4_e2m1 elem0, elem1;
+    elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x2.__x & 0xF);
+    elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x2.__x >> 4) & 0xF);
+    nv_bfloat16 x = nv_bfloat16(elem0);
+    nv_bfloat16 y = nv_bfloat16(elem1);
+    return nv_bfloat162(x, y);
+}
 )";
   }
 }
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py 
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
index 5e0a4c3000..7b3a20463b 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp8.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import sys
+from itertools import product
 from typing import List, Tuple
 
 import numpy as np
@@ -26,13 +26,9 @@ import tvm.testing
 from tvm import DataType, DataTypeCode, IRModule
 from tvm import dlight as dl
 from tvm import relax, te, tir, topi
-from tvm.relax.frontend import nn
-from tvm.runtime import NDArray
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tir as T
-from tvm.target import Target
-from tvm.topi.utils import get_const_tuple
 
 try:
     import ml_dtypes
@@ -67,7 +63,7 @@ def test_e4m3_conversions():
     sch.bind(tx, "threadIdx.x")
 
     target = "cuda"
-    fadd = tvm.compile(sch.mod, target=target)
+    fadd = tvm.tir.build(sch.mod, target=target)
 
     cuda_src = fadd.imported_modules[0].get_source()
     assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found 
in generated CUDA"
@@ -179,7 +175,7 @@ def test_e4m3_vector_conversions(native_dtype, 
promoted_dtype):
     sch.bind(tx, "threadIdx.x")
 
     target = "cuda"
-    fadd = tvm.compile(sch.mod, target=target)
+    fadd = tvm.tir.build(sch.mod, target=target)
     cuda_src = fadd.imported_modules[0].get_source()
     dev = tvm.device(target, 0)
 
@@ -700,7 +696,7 @@ class BaseFP8E4M3QuantScaleOnly:
         def print_cuda(target, mod, name=None):
             if name:
                 mod = mod[name]
-            f = tvm.compile(mod, target=target)
+            f = tvm.tir.build(mod, target=target)
             cuda_src = f.imported_modules[0].get_source()
             print(cuda_src)
 
@@ -963,6 +959,41 @@ def test_moe_gemv_shfl_down_illegal_instr():
     vm["main"](x, indptr, weight, scale)
 
 
[email protected]_cuda_compute_version(8, 9)
+def test_fp8_fp16_bf16_vectorize_arith():
+    for vec_length, dtype in product([2, 4], ["float16", "bfloat16"]):
+
+        @T.prim_func
+        def func_vectorize(
+            A: T.Buffer((128,), "float8_e4m3fn"),
+            B: T.Buffer((128,), dtype),
+            C: T.Buffer((128,), dtype),
+        ) -> None:
+            for i in T.serial(128):
+                with T.block("compute"):
+                    vi = T.axis.remap("S", [i])
+                    C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0)
+
+        sch = tir.Schedule(func_vectorize)
+        (l,) = sch.get_loops(sch.get_block("compute"))
+        lo, li = sch.split(l, [None, vec_length])
+        sch.bind(lo, "threadIdx.x")
+        sch.vectorize(li)
+
+        device = tvm.cuda()
+        target = tvm.target.Target.from_device(device)
+        f = tir.build(sch.mod, target=target)
+
+        a_np = np.random.rand(128).astype("float8_e4m3fn")
+        b_np = np.random.rand(128).astype(dtype)
+        c_np = (a_np.astype(dtype) * b_np) + 3
+        a_tvm = tvm.nd.array(a_np, device=device)
+        b_tvm = tvm.nd.array(b_np, device=device)
+        c_tvm = tvm.nd.empty((128,), dtype=dtype, device=device)
+        f(a_tvm, b_tvm, c_tvm)
+        c_tvm = c_tvm.numpy()
+        np.testing.assert_allclose(c_tvm, c_np, atol=1e-3, rtol=1e-3)
+
+
 if __name__ == "__main__":
-    # test_half_broadcast(6)
     tvm.testing.main()

Reply via email to