This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2a62c72154 [FP8][Codegen] Add make_fp8 vector constructors (#17065)
2a62c72154 is described below

commit 2a62c7215419a859321460c7fb9e2da272f4d003
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Jun 5 07:45:04 2024 -0700

    [FP8][Codegen] Add make_fp8 vector constructors (#17065)
    
    * [FP8][Codegen] Add make_fp8 vector constructors.
    
    Allows vectorized fp8 loading.
    
    ---------
    
    Co-authored-by: Chris Sullivan <csulli...@octoml.ai>
---
 src/target/source/codegen_cuda.cc                  | 25 +++++++++++-----------
 src/target/source/literal/cuda_half_t.h            | 20 +++++++++++++++++
 .../python/codegen/test_target_codegen_cuda_fp8.py |  2 +-
 3 files changed, 33 insertions(+), 14 deletions(-)

diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index ecb0957611..bd28048301 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -48,21 +48,22 @@ std::string GetFP8Type(DataType type) {
   if (type.is_scalar()) {
     vec = "";
   } else if (lanes == 2) {
-    vec = "_2";
+    vec = "x2";
   } else if (lanes == 4) {
-    vec = "_4";
-  } else if (lanes == 8) {
-    vec = "_8";
+    vec = "x4";
   } else {
     LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for 
FP8";
   }
+  stream << "__nv_fp8";
+  std::string suffix;
   if (type.code() == DataType::kE4M3Float) {
-    stream << "fp8_e4" << vec << "_t";
+    suffix = "_e4m3";
   } else if (type.code() == DataType::kE5M2Float) {
-    stream << "fp8_e5" << vec << "_t";
+    suffix = "_e5m2";
   } else {
     LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
   }
+  stream << vec << suffix;
   return stream.str();
 }
 
@@ -146,12 +147,6 @@ std::string CodeGenCUDA::Finish() {
   if (enable_fp8_) {
     decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
     decl_stream << "#include <cuda_fp8.h>\n";
-    decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
-    decl_stream << "using fp8_e4_2_t = __nv_fp8x2_e4m3;\n";
-    decl_stream << "using fp8_e4_4_t = __nv_fp8x4_e4m3;\n";
-    decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
-    decl_stream << "using fp8_e5_2_t = __nv_fp8x2_e5m2;\n";
-    decl_stream << "using fp8_e5_4_t = __nv_fp8x4_e5m2;\n";
     decl_stream << "#endif\n\n";
   }
   declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
@@ -299,7 +294,11 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) 
{  // NOLINT(*)
     if (!fail) return;
   } else if (t.is_float8()) {
     enable_fp8_ = true;
-    os << GetFP8Type(t);
+    if (t.lanes() <= 4) {
+      os << GetFP8Type(t);
+    } else {
+      os << "uint" << t.lanes() / 4;
+    }
     return;
   } else if (t == DataType::Bool()) {
     os << "bool";
diff --git a/src/target/source/literal/cuda_half_t.h 
b/src/target/source/literal/cuda_half_t.h
index 27d44d9f7f..c5ecda07a4 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -431,6 +431,26 @@ struct __align__(8) half4 {
         (static_cast<__uint32_t>(lo_part.__x) | 
(static_cast<__uint32_t>(hi_part.__x) << 16));
     return result;
   }
+  __device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x, 
__nv_fp8_storage_t y) {
+      __nv_fp8x2_e5m2 result;
+      result.__x = (x) | (y << 8);
+      return result;
+  }
+  __device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a, 
__nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
+      __nv_fp8x4_e5m2 result;
+      result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
+      return result;
+  }
+  __device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x, 
__nv_fp8_storage_t y) {
+      __nv_fp8x2_e4m3 result;
+      result.__x = (x) | (y << 8);
+      return result;
+  }
+  __device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a, 
__nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
+      __nv_fp8x4_e4m3 result;
+      result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
+      return result;
+  }
   )";
     }
     stream << R"(
diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py 
b/tests/python/codegen/test_target_codegen_cuda_fp8.py
index 5566ae2434..adcb05839b 100644
--- a/tests/python/codegen/test_target_codegen_cuda_fp8.py
+++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py
@@ -64,7 +64,7 @@ def test_e4m3_conversions():
     fadd = tvm.build(sch.mod, target=target)
 
     cuda_src = fadd.imported_modules[0].get_source()
-    assert "fp8_e4_t" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found in 
generated CUDA"
+    assert "__nv_fp8_e4m3" in cuda_src, "FP8E4M3 (fp8_e4_t) datatype not found 
in generated CUDA"
 
     dev = tvm.device(target, 0)
 

Reply via email to