This is an automated email from the ASF dual-hosted git repository. csullivan pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 2a62c72154 [FP8][Codegen] Add make_fp8 vector constructors (#17065) 2a62c72154 is described below commit 2a62c7215419a859321460c7fb9e2da272f4d003 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Wed Jun 5 07:45:04 2024 -0700 [FP8][Codegen] Add make_fp8 vector constructors (#17065) * [FP8][Codegen] Add make_fp8 vector constructors. Allows vectorized fp8 loading. --------- Co-authored-by: Chris Sullivan <csulli...@octoml.ai> --- src/target/source/codegen_cuda.cc | 25 +++++++++++----------- src/target/source/literal/cuda_half_t.h | 20 +++++++++++++++++ .../python/codegen/test_target_codegen_cuda_fp8.py | 2 +- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index ecb0957611..bd28048301 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -48,21 +48,22 @@ std::string GetFP8Type(DataType type) { if (type.is_scalar()) { vec = ""; } else if (lanes == 2) { - vec = "_2"; + vec = "x2"; } else if (lanes == 4) { - vec = "_4"; - } else if (lanes == 8) { - vec = "_8"; + vec = "x4"; } else { LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; } + stream << "__nv_fp8"; + std::string suffix; if (type.code() == DataType::kE4M3Float) { - stream << "fp8_e4" << vec << "_t"; + suffix = "_e4m3"; } else if (type.code() == DataType::kE5M2Float) { - stream << "fp8_e5" << vec << "_t"; + suffix = "_e5m2"; } else { LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; } + stream << vec << suffix; return stream.str(); } @@ -146,12 +147,6 @@ std::string CodeGenCUDA::Finish() { if (enable_fp8_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n"; decl_stream << "#include <cuda_fp8.h>\n"; - decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n"; - decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n"; - decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n"; - decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n"; - decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n"; - decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n"; decl_stream << "#endif\n\n"; } declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); @@ -299,7 +294,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*) if (!fail) return; } else if (t.is_float8()) { enable_fp8_ = true; - os << GetFP8Type(t); + if (t.lanes() <= 4) { + os << GetFP8Type(t); + } else { + os << "uint" << t.lanes() / 4; + } return; } else if (t == DataType::Bool()) { os << "bool"; diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index 27d44d9f7f..c5ecda07a4 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -431,6 +431,26 @@ struct __align__(8) half4 { (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); return result; } + __device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { + __nv_fp8x2_e5m2 result; + result.__x = (x) | (y << 8); + return result; + } + __device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { + __nv_fp8x4_e5m2 result; + result.__x = (a) | (b << 8) | (c << 16) | (d << 24); + return result; + } + __device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { + __nv_fp8x2_e4m3 result; + result.__x = (x) | (y << 8); + return result; + } + __device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { + __nv_fp8x4_e4m3 result; + result.__x = (a) | (b << 8) | (c << 16) | (d << 24); + return result; + } )"; } stream << R"( diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 5566ae2434..adcb05839b 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -64,7 +64,7 @@ def test_e4m3_conversions(): fadd = tvm.build(sch.mod, target=target) cuda_src = fadd.imported_modules[0].get_source() - assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" + assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in generated CUDA" dev = tvm.device(target, 0)