This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6b89f9537c [Codegen][CUDA] Fix codegen of cast among vector bfloat16,
fp8 and fp4 (#17741)
6b89f9537c is described below
commit 6b89f9537c0270a626e67f431e19edf3c9d3b20e
Author: Ruihang Lai <[email protected]>
AuthorDate: Wed Mar 12 07:53:16 2025 -0400
[Codegen][CUDA] Fix codegen of cast among vector bfloat16, fp8 and fp4
(#17741)
This PR fixes the CUDA code generation for fp8 (also fp4) and bfloat16.
We added a few vector data conversion util functions.
---
src/target/source/codegen_cuda.cc | 58 +++++---
src/target/source/literal/cuda_half_t.h | 151 +++++++++++++++------
.../python/codegen/test_target_codegen_cuda_fp8.py | 49 +++++--
3 files changed, 186 insertions(+), 72 deletions(-)
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 34023e0bb7..a97e66d346 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -194,7 +194,7 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#include <cuda_fp4.h>\n";
decl_stream << "#endif\n\n";
}
- declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_,
enable_fp4_);
+ declare_vector_type_extensions(decl_stream, enable_fp16_, enable_bf16_,
enable_fp8_, enable_fp4_);
if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
@@ -331,8 +331,12 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os)
{ // NOLINT(*)
if (t.is_scalar()) {
os << "nv_bfloat16";
} else if (lanes <= 8) {
- ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
- os << "uint" << lanes / 2;
+ ICHECK_EQ(lanes % 2, 0) << "only support even lane for bfloat16 type";
+ if (lanes <= 4) {
+ os << "nv_bfloat16" << lanes;
+ } else {
+ os << "uint" << lanes / 2;
+ }
} else {
fail = true;
}
@@ -575,7 +579,11 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec,
DataType t, int i,
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" <<
access[i % 2];
}
} else if (t.is_bfloat16()) {
- os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" <<
access[i % 2];
+ if (t.lanes() <= 4) {
+ os << vec << "." << access[i];
+ } else {
+ os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" <<
access[i % 2];
+ }
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
@@ -630,8 +638,12 @@ void CodeGenCUDA::PrintVecElemStore(const std::string&
vec, DataType t, int i,
}
} else if (t.is_bfloat16()) {
- stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2]
- << " = " << value << ";\n";
+ if (t.lanes() <= 4) {
+ stream << vec << "." << access[i] << " = " << value << ";\n";
+ } else {
+ stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] <<
")))->" << access[i % 2]
+ << " = " << value << ";\n";
+ }
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
@@ -736,9 +748,13 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op,
std::ostream& os) {
target_ty.code() == DataType::kFloat4_e2m1fn || from_ty.code() ==
DataType::kFloat8_e4m3fn ||
from_ty.code() == DataType::kFloat8_e5m2 || from_ty.code() ==
DataType::kFloat4_e2m1fn) {
std::ostringstream val;
- val << "(";
- PrintType(target_ty, val);
- val << ")(" << PrintExpr(op->value) << ")";
+ if (target_ty.code() == DataType::kBFloat && target_ty.lanes() == 2) {
+ val << "cast_to_nv_bfloat162(" << PrintExpr(op->value) << ")";
+ } else {
+ val << "(";
+ PrintType(target_ty, val);
+ val << ")(" << PrintExpr(op->value) << ")";
+ }
os << val.str();
return;
}
@@ -1384,9 +1400,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
- for (int i = 0; i < lanes / 2; ++i) {
- if (i != 0) os << ", ";
- os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
+ if (lanes > 4) {
+ for (int i = 0; i < lanes / 2; ++i) {
+ if (i != 0) os << ", ";
+ os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
+ }
+ } else {
+ for (int i = 0; i < lanes; ++i) {
+ if (i != 0) os << ", ";
+ os << v;
+ }
}
os << ')';
return;
@@ -1660,15 +1683,10 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int
i, const std::string& val
PrintVecConstructor(t, os);
os << '(';
}
- if (i % 2 == 0) {
- os << "__pack_nv_bfloat162(" << value;
+ if (i == t.lanes() - 1) {
+ os << value << ")";
} else {
- os << "," << value << ")";
- if (i != t.lanes() - 1) {
- os << ",";
- } else {
- os << ")";
- }
+ os << value << ",";
}
return;
}
diff --git a/src/target/source/literal/cuda_half_t.h
b/src/target/source/literal/cuda_half_t.h
index 86f2219fe8..b095f5b8cf 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -385,52 +385,70 @@ static constexpr const char* _cuda_warp_intrinsic_util =
R"(
)";
-void declare_vector_type_extensions(std::ostringstream& stream, bool
enable_fp16, bool enable_fp8,
- bool enable_fp4) {
- if (enable_fp16 || enable_fp8 || enable_fp4) {
+void declare_vector_type_extensions(std::ostringstream& stream, bool
enable_fp16, bool enable_bf16,
+ bool enable_fp8, bool enable_fp4) {
+ if (enable_fp16 || enable_bf16) {
stream << R"(
-struct __align__(8) half4 {
- __half x, y, z, w;
- __host__ __device__ half4() : x(__half(0)), y(__half(0)), z(__half(0)),
w(__half(0)) {}
- __host__ __device__ half4(__half x, __half y, __half z, __half w) : x(x),
y(y), z(z), w(w) {}
+#include <type_traits>
+template <typename T, typename TVec2>
+struct __align__(8) half4_bfloat164 {
+ T x, y, z, w;
+ __host__ __device__ half4_bfloat164() : x(T(0)), y(T(0)), z(T(0)), w(T(0)) {}
+ __host__ __device__ half4_bfloat164(T x, T y, T z, T w) : x(x), y(y), z(z),
w(w) {}
)";
if (enable_fp8) {
stream << R"(
- __host__ __device__ explicit half4(const __nv_fp8x4_e4m3& fp8x4) {
- __nv_fp8x2_e4m3 lo_part, hi_part;
- lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
- hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) &
0xFFFF);
- __half2 lo_half2 = static_cast<__half2>(lo_part);
- __half2 hi_half2 = static_cast<__half2>(hi_part);
- x = reinterpret_cast<__half*>(&lo_half2)[0];
- y = reinterpret_cast<__half*>(&lo_half2)[1];
- z = reinterpret_cast<__half*>(&hi_half2)[0];
- w = reinterpret_cast<__half*>(&hi_half2)[1];
+ __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) {
+ if constexpr (std::is_same_v<T, __half>) {
+ __nv_fp8x2_e4m3 lo_part, hi_part;
+ lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
+ hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) &
0xFFFF);
+ TVec2 lo_half2 = static_cast<TVec2>(lo_part);
+ TVec2 hi_half2 = static_cast<TVec2>(hi_part);
+ x = reinterpret_cast<T*>(&lo_half2)[0];
+ y = reinterpret_cast<T*>(&lo_half2)[1];
+ z = reinterpret_cast<T*>(&hi_half2)[0];
+ w = reinterpret_cast<T*>(&hi_half2)[1];
+ } else {
+ __nv_fp8_storage_t elem0_raw = static_cast<__nv_fp8_storage_t>(fp8x4.__x
& 0xFF);
+ __nv_fp8_storage_t elem1_raw =
static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 8) & 0xFF);
+ __nv_fp8_storage_t elem2_raw =
static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 16) & 0xFF);
+ __nv_fp8_storage_t elem3_raw =
static_cast<__nv_fp8_storage_t>((fp8x4.__x >> 24) & 0xFF);
+ __nv_fp8_e4m3 elem0, elem1, elem2, elem3;
+ elem0.__x = elem0_raw;
+ elem1.__x = elem1_raw;
+ elem2.__x = elem2_raw;
+ elem3.__x = elem3_raw;
+ x = T(elem0);
+ y = T(elem1);
+ z = T(elem2);
+ w = T(elem3);
+ }
}
__host__ __device__ explicit operator __nv_fp8x4_e4m3() const {
__nv_fp8x4_e4m3 result;
- __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
- __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+ TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
+ TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
__nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) |
(static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
- __host__ __device__ explicit half4(const __nv_fp8x4_e5m2& fp8x4) {
- __nv_fp8x2_e5m2 lo_part, hi_part;
- lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
- hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) &
0xFFFF);
- __half2 lo_half2 = static_cast<__half2>(lo_part);
- __half2 hi_half2 = static_cast<__half2>(hi_part);
- x = reinterpret_cast<__half*>(&lo_half2)[0];
- y = reinterpret_cast<__half*>(&lo_half2)[1];
- z = reinterpret_cast<__half*>(&hi_half2)[0];
- w = reinterpret_cast<__half*>(&hi_half2)[1];
+ __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e5m2& fp8x4) {
+ __nv_fp8x2_e5m2 lo_part, hi_part;
+ lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
+ hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) &
0xFFFF);
+ TVec2 lo_half2 = static_cast<TVec2>(lo_part);
+ TVec2 hi_half2 = static_cast<TVec2>(hi_part);
+ x = reinterpret_cast<T*>(&lo_half2)[0];
+ y = reinterpret_cast<T*>(&lo_half2)[1];
+ z = reinterpret_cast<T*>(&hi_half2)[0];
+ w = reinterpret_cast<T*>(&hi_half2)[1];
}
__host__ __device__ explicit operator __nv_fp8x4_e5m2() const {
__nv_fp8x4_e5m2 result;
- __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
- __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+ TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
+ TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
__nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2);
result.__x =
(static_cast<__uint32_t>(lo_part.__x) |
(static_cast<__uint32_t>(hi_part.__x) << 16));
@@ -460,31 +478,70 @@ struct __align__(8) half4 {
}
if (enable_fp4) {
stream << R"(
- __host__ __device__ explicit half4(const __nv_fp4x4_e2m1& fp4x4) {
- __nv_fp4x2_storage_t lo_part, hi_part;
- lo_part = static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
- hi_part = static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
- __half2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
- __half2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
- x = reinterpret_cast<__half*>(&lo_half2)[0];
- y = reinterpret_cast<__half*>(&lo_half2)[1];
- z = reinterpret_cast<__half*>(&hi_half2)[0];
- w = reinterpret_cast<__half*>(&hi_half2)[1];
+ __host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) {
+ if constexpr (std::is_same_v<T, __half>) {
+ __nv_fp4x2_storage_t lo_part =
static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
+ __nv_fp4x2_storage_t hi_part =
static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
+ TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
+ TVec2 hi_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(hi_part, __NV_E2M1));
+ x = reinterpret_cast<T*>(&lo_half2)[0];
+ y = reinterpret_cast<T*>(&lo_half2)[1];
+ z = reinterpret_cast<T*>(&hi_half2)[0];
+ w = reinterpret_cast<T*>(&hi_half2)[1];
+ } else {
+ __nv_fp4_e2m1 elem0, elem1, elem2, elem3;
+ elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x4.__x & 0xF);
+ elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 4) & 0xF);
+ elem2.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 8) & 0xF);
+ elem3.__x = static_cast<__nv_fp4_storage_t>((fp4x4.__x >> 12) & 0xF);
+ x = T(elem0);
+ y = T(elem1);
+ z = T(elem2);
+ w = T(elem3);
+ }
}
__host__ __device__ explicit operator __nv_fp4x4_e2m1() const {
- __half2 lo_half2 = *reinterpret_cast<const __half2*>(&x);
- __half2 hi_half2 = *reinterpret_cast<const __half2*>(&z);
+ TVec2 lo_half2 = *reinterpret_cast<const TVec2*>(&x);
+ TVec2 hi_half2 = *reinterpret_cast<const TVec2*>(&z);
return __nv_fp4x4_e2m1(lo_half2, hi_half2);
}
)";
}
stream << R"(
};
+)";
+ }
+ if (enable_fp16) {
+ stream << R"(
+using half4 = half4_bfloat164<__half, __half2>;
__host__ __device__ half4 make_half4(__half x, __half y, __half z, __half w) {
return half4(x, y, z, w);
}
)";
}
+ if (enable_bf16) {
+ stream << R"(
+using nv_bfloat164 = half4_bfloat164<nv_bfloat16, nv_bfloat162>;
+__host__ __device__ nv_bfloat164 make_nv_bfloat164(nv_bfloat16 x, nv_bfloat16
y, nv_bfloat16 z, nv_bfloat16 w) {
+ return nv_bfloat164(x, y, z, w);
+}
+__host__ __device__ nv_bfloat162 make_nv_bfloat162(nv_bfloat16 x, nv_bfloat16
y) {
+ return nv_bfloat162(x, y);
+}
+)";
+ if (enable_fp8) {
+ stream << R"(
+__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3&
fp8x2) {
+ __nv_fp8_e4m3 elem0, elem1;
+ elem0.__x = static_cast<__nv_fp8_storage_t>(fp8x2.__x & 0xFF);
+ elem1.__x = static_cast<__nv_fp8_storage_t>((fp8x2.__x >> 8) & 0xFF);
+ nv_bfloat16 x = nv_bfloat16(elem0);
+ nv_bfloat16 y = nv_bfloat16(elem1);
+ return nv_bfloat162(x, y);
+}
+)";
+ }
+ }
if (enable_fp4) {
stream << R"(
__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1
y) {
@@ -497,6 +554,14 @@ __device__ __nv_fp4x4_e2m1
make___nv_fp4x4_e2m1(__nv_fp4_e2m1 a, __nv_fp4_e2m1 b
result.__x = (static_cast<__nv_fp4x4_storage_t>(a.__x)) |
(static_cast<__nv_fp4x4_storage_t>(b.__x) << 4) |
(static_cast<__nv_fp4x4_storage_t>(c.__x) << 8) |
(static_cast<__nv_fp4x4_storage_t>(d.__x) << 12);
return result;
}
+__host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp4x2_e2m1&
fp4x2) {
+ __nv_fp4_e2m1 elem0, elem1;
+ elem0.__x = static_cast<__nv_fp4_storage_t>(fp4x2.__x & 0xF);
+ elem1.__x = static_cast<__nv_fp4_storage_t>((fp4x2.__x >> 4) & 0xF);
+ nv_bfloat16 x = nv_bfloat16(elem0);
+ nv_bfloat16 y = nv_bfloat16(elem1);
+ return nv_bfloat162(x, y);
+}
)";
}
}
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
index 5e0a4c3000..7b3a20463b 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp8.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-import sys
+from itertools import product
from typing import List, Tuple
import numpy as np
@@ -26,13 +26,9 @@ import tvm.testing
from tvm import DataType, DataTypeCode, IRModule
from tvm import dlight as dl
from tvm import relax, te, tir, topi
-from tvm.relax.frontend import nn
-from tvm.runtime import NDArray
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
-from tvm.target import Target
-from tvm.topi.utils import get_const_tuple
try:
import ml_dtypes
@@ -67,7 +63,7 @@ def test_e4m3_conversions():
sch.bind(tx, "threadIdx.x")
target = "cuda"
- fadd = tvm.compile(sch.mod, target=target)
+ fadd = tvm.tir.build(sch.mod, target=target)
cuda_src = fadd.imported_modules[0].get_source()
assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found
in generated CUDA"
@@ -179,7 +175,7 @@ def test_e4m3_vector_conversions(native_dtype,
promoted_dtype):
sch.bind(tx, "threadIdx.x")
target = "cuda"
- fadd = tvm.compile(sch.mod, target=target)
+ fadd = tvm.tir.build(sch.mod, target=target)
cuda_src = fadd.imported_modules[0].get_source()
dev = tvm.device(target, 0)
@@ -700,7 +696,7 @@ class BaseFP8E4M3QuantScaleOnly:
def print_cuda(target, mod, name=None):
if name:
mod = mod[name]
- f = tvm.compile(mod, target=target)
+ f = tvm.tir.build(mod, target=target)
cuda_src = f.imported_modules[0].get_source()
print(cuda_src)
@@ -963,6 +959,41 @@ def test_moe_gemv_shfl_down_illegal_instr():
vm["main"](x, indptr, weight, scale)
[email protected]_cuda_compute_version(8, 9)
+def test_fp8_fp16_bf16_vectorize_arith():
+ for vec_length, dtype in product([2, 4], ["float16", "bfloat16"]):
+
+ @T.prim_func
+ def func_vectorize(
+ A: T.Buffer((128,), "float8_e4m3fn"),
+ B: T.Buffer((128,), dtype),
+ C: T.Buffer((128,), dtype),
+ ) -> None:
+ for i in T.serial(128):
+ with T.block("compute"):
+ vi = T.axis.remap("S", [i])
+ C[vi] = (A[vi].astype(dtype) * B[vi]) + T.bfloat16(3.0)
+
+ sch = tir.Schedule(func_vectorize)
+ (l,) = sch.get_loops(sch.get_block("compute"))
+ lo, li = sch.split(l, [None, vec_length])
+ sch.bind(lo, "threadIdx.x")
+ sch.vectorize(li)
+
+ device = tvm.cuda()
+ target = tvm.target.Target.from_device(device)
+ f = tir.build(sch.mod, target=target)
+
+ a_np = np.random.rand(128).astype("float8_e4m3fn")
+ b_np = np.random.rand(128).astype(dtype)
+ c_np = (a_np.astype(dtype) * b_np) + 3
+ a_tvm = tvm.nd.array(a_np, device=device)
+ b_tvm = tvm.nd.array(b_np, device=device)
+ c_tvm = tvm.nd.empty((128,), dtype=dtype, device=device)
+ f(a_tvm, b_tvm, c_tvm)
+ c_tvm = c_tvm.numpy()
+ np.testing.assert_allclose(c_tvm, c_np, atol=1e-3, rtol=1e-3)
+
+
if __name__ == "__main__":
- # test_half_broadcast(6)
tvm.testing.main()