https://github.com/sergey-kozub updated https://github.com/llvm/llvm-project/pull/102969
>From 7db2478f59c5a4f46df040ed4799da815b28bc43 Mon Sep 17 00:00:00 2001 From: Sergey Kozub <sko...@nvidia.com> Date: Mon, 12 Aug 2024 12:52:01 -0700 Subject: [PATCH] [NVPTX] Add conversion intrinsics from/to fp8 types (e4m3, e5m2) --- clang/include/clang/Basic/BuiltinsNVPTX.def | 15 ++++ clang/test/CodeGen/builtins-nvptx.c | 36 +++++++++ llvm/include/llvm/IR/IntrinsicsNVVM.td | 27 +++++++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 31 ++++++++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 27 +++++++ llvm/test/CodeGen/NVPTX/convert-sm89.ll | 86 +++++++++++++++++++++ 6 files changed, 222 insertions(+) create mode 100644 llvm/test/CodeGen/NVPTX/convert-sm89.ll diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def index 504314d8d96e91..c11970c279c4bb 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.def +++ b/clang/include/clang/Basic/BuiltinsNVPTX.def @@ -584,6 +584,21 @@ TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70)) TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70)) +TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn, "sff", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_ff_to_e4m3x2_rn_relu, "sff", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn, "sff", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_ff_to_e5m2x2_rn_relu, "sff", "", AND(SM_89,PTX81)) + +TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn, "sV2h", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_f16x2_to_e4m3x2_rn_relu, "sV2h", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn, "sV2h", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_f16x2_to_e5m2x2_rn_relu, "sV2h", "", AND(SM_89,PTX81)) + +TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_e4m3x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn, "V2hs", "", AND(SM_89,PTX81)) +TARGET_BUILTIN(__nvvm_e5m2x2_to_f16x2_rn_relu, "V2hs", "", AND(SM_89,PTX81)) + // Bitcast BUILTIN(__nvvm_bitcast_f2i, "if", "") diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 75b9d6d1fe1902..20399b73e63757 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -22,6 +22,9 @@ // RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_86 -target-feature +ptx72 \ // RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \ // RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 -check-prefix=LP64 %s +// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown -target-cpu sm_89 -target-feature +ptx81 \ +// RUN: -fcuda-is-device -emit-llvm -o - -x cuda %s \ +// RUN: | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX81_SM89 %s #define __device__ __attribute__((device)) #define __global__ __attribute__((global)) @@ -968,6 +971,39 @@ __device__ void nvvm_cvt_sm80() { // CHECK: ret void } +// CHECK-LABEL: nvvm_cvt_sm89 +__device__ void nvvm_cvt_sm89() { +#if __CUDA_ARCH__ >= 890 + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e4m3x2_rn(1.0f, 1.0f); + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e4m3x2_rn_relu(1.0f, 1.0f); + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e5m2x2_rn(1.0f, 1.0f); + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float 1.000000e+00, float 1.000000e+00) + __nvvm_ff_to_e5m2x2_rn_relu(1.0f, 1.0f); + + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>) + __nvvm_f16x2_to_e4m3x2_rn({1.0f16, 1.0f16}); + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>) + __nvvm_f16x2_to_e4m3x2_rn_relu({1.0f16, 1.0f16}); + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> <half 0xH3C00, half 0xH3C00>) + __nvvm_f16x2_to_e5m2x2_rn({1.0f16, 1.0f16}); + // CHECK_PTX81_SM89: call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> <half 0xH3C00, half 0xH3C00>) + __nvvm_f16x2_to_e5m2x2_rn_relu({1.0f16, 1.0f16}); + + // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 18504) + __nvvm_e4m3x2_to_f16x2_rn(0x4848); + // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 18504) + __nvvm_e4m3x2_to_f16x2_rn_relu(0x4848); + // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 19532) + __nvvm_e5m2x2_to_f16x2_rn(0x4c4c); + // CHECK_PTX81_SM89: call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 19532) + __nvvm_e5m2x2_to_f16x2_rn_relu(0x4c4c); +#endif + // CHECK: ret void +} + #define NAN32 0x7FBFFFFF #define NAN16 (__bf16)0x7FBF #define BF16 (__bf16)0.1f diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 7caada24dad564..42dcf08cc65cac 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1296,6 +1296,33 @@ let TargetPrefix = "nvvm" in { def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">, Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_ff_to_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_ff_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn_relu">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_ff_to_e5m2x2_rn : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_ff_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_ff_to_e5m2x2_rn_relu">, + Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; + + def int_nvvm_f16x2_to_e4m3x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn">, + Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f16x2_to_e4m3x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e4m3x2_rn_relu">, + Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f16x2_to_e5m2x2_rn : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn">, + Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_f16x2_to_e5m2x2_rn_relu : ClangBuiltin<"__nvvm_f16x2_to_e5m2x2_rn_relu">, + Intrinsic<[llvm_i16_ty], [llvm_v2f16_ty], [IntrNoMem, IntrNoCallback]>; + + def int_nvvm_e4m3x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn">, + Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_e4m3x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e4m3x2_to_f16x2_rn_relu">, + Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_e5m2x2_to_f16x2_rn : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn">, + Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>; + def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">, + Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>; + // // Bitcast // diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index d75dc8781f7802..48d6caeebb46f5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -722,6 +722,37 @@ let hasSideEffects = false in { defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Int32Regs>; defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>; + + // FP8 conversions. + multiclass CVT_TO_F8X2<string F8Name> { + def _f32 : + NVPTXInst<(outs Int16Regs:$dst), + (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode), + !strconcat("cvt${mode:base}.satfinite${mode:relu}.", + F8Name, "x2.f32 \t$dst, $src1, $src2;"), []>, + Requires<[hasPTX<81>, hasSM<89>]>; + def _f16x2 : + NVPTXInst<(outs Int16Regs:$dst), + (ins Int32Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}.satfinite${mode:relu}.", + F8Name, "x2.f16x2 \t$dst, $src;"), []>, + Requires<[hasPTX<81>, hasSM<89>]>; + } + + defm CVT_e4m3x2 : CVT_TO_F8X2<"e4m3">; + defm CVT_e5m2x2 : CVT_TO_F8X2<"e5m2">; + + multiclass CVT_FROM_F8X2<string F8Name> { + def x2 : + NVPTXInst<(outs Int32Regs:$dst), + (ins Int16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:relu}.f16x2.", + F8Name, "x2 \t$dst, $src;"), []>, + Requires<[hasPTX<81>, hasSM<89>]>; + } + + defm CVT_f16x2_e4m3 : CVT_FROM_F8X2<"e4m3">; + defm CVT_f16x2_e5m2 : CVT_FROM_F8X2<"e5m2">; } //----------------------------------- diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 887951b55fb3b7..a9116e15c3671b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1524,6 +1524,33 @@ def : Pat<(int_nvvm_f2h_rn_ftz Float32Regs:$a), def : Pat<(int_nvvm_f2h_rn Float32Regs:$a), (CVT_f16_f32 Float32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_ff_to_e4m3x2_rn Float32Regs:$a, Float32Regs:$b), + (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; +def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu Float32Regs:$a, Float32Regs:$b), + (CVT_e4m3x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; +def : Pat<(int_nvvm_ff_to_e5m2x2_rn Float32Regs:$a, Float32Regs:$b), + (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; +def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu Float32Regs:$a, Float32Regs:$b), + (CVT_e5m2x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; + +def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn Int32Regs:$a), + (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu Int32Regs:$a), + (CVT_e4m3x2_f16x2 Int32Regs:$a, CvtRN_RELU)>; +def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn Int32Regs:$a), + (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu Int32Regs:$a), + (CVT_e5m2x2_f16x2 Int32Regs:$a, CvtRN_RELU)>; + +def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn Int16Regs:$a), + (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu Int16Regs:$a), + (CVT_f16x2_e4m3x2 Int16Regs:$a, CvtRN_RELU)>; +def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a), + (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN)>; +def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a), + (CVT_f16x2_e5m2x2 Int16Regs:$a, CvtRN_RELU)>; + // // Bitcast // diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll new file mode 100644 index 00000000000000..5a1a640afc1cf9 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll @@ -0,0 +1,86 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | FileCheck %s +; RUN: %if ptxas-12.1 %{ llc < %s -march=nvptx64 -mcpu=sm_89 -mattr=+ptx81 | %ptxas-verify -arch=sm_89 %} + +; CHECK-LABEL: cvt_rn_e4m3x2_f32 +define i16 @cvt_rn_e4m3x2_f32(float %f1, float %f2) { +; CHECK: cvt.rn.satfinite.e4m3x2.f32 + %val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %f1, float %f2); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_relu_e4m3x2_f32 +define i16 @cvt_rn_relu_e4m3x2_f32(float %f1, float %f2) { +; CHECK: cvt.rn.satfinite.relu.e4m3x2.f32 + %val = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %f1, float %f2); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_e5m2x2_f32 +define i16 @cvt_rn_e5m2x2_f32(float %f1, float %f2) { +; CHECK: cvt.rn.satfinite.e5m2x2.f32 + %val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %f1, float %f2); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_relu_e5m2x2_f32 +define i16 @cvt_rn_relu_e5m2x2_f32(float %f1, float %f2) { +; CHECK: cvt.rn.satfinite.relu.e5m2x2.f32 + %val = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %f1, float %f2); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_e4m3x2_f16x2 +define i16 @cvt_rn_e4m3x2_f16x2(<2 x half> %in) { +; CHECK: cvt.rn.satfinite.e4m3x2.f16x2 + %val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %in); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_relu_e4m3x2_f16x2 +define i16 @cvt_rn_relu_e4m3x2_f16x2(<2 x half> %in) { +; CHECK: cvt.rn.satfinite.relu.e4m3x2.f16x2 + %val = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %in); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_e5m2x2_f16x2 +define i16 @cvt_rn_e5m2x2_f16x2(<2 x half> %in) { +; CHECK: cvt.rn.satfinite.e5m2x2.f16x2 + %val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %in); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_relu_e5m2x2_f16x2 +define i16 @cvt_rn_relu_e5m2x2_f16x2(<2 x half> %in) { +; CHECK: cvt.rn.satfinite.relu.e5m2x2.f16x2 + %val = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %in); + ret i16 %val +} + +; CHECK-LABEL: cvt_rn_f16x2_e4m3x2 +define <2 x half> @cvt_rn_f16x2_e4m3x2(i16 %in) { +; CHECK: cvt.rn.f16x2.e4m3x2 + %val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %in); + ret <2 x half> %val +} + +; CHECK-LABEL: cvt_rn_relu_f16x2_e4m3x2 +define <2 x half> @cvt_rn_relu_f16x2_e4m3x2(i16 %in) { +; CHECK: cvt.rn.relu.f16x2.e4m3x2 + %val = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %in); + ret <2 x half> %val +} + +; CHECK-LABEL: cvt_rn_f16x2_e5m2x2 +define <2 x half> @cvt_rn_f16x2_e5m2x2(i16 %in) { +; CHECK: cvt.rn.f16x2.e5m2x2 + %val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %in); + ret <2 x half> %val +} + +; CHECK-LABEL: cvt_rn_relu_f16x2_e5m2x2 +define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) { +; CHECK: cvt.rn.relu.f16x2.e5m2x2 + %val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in); + ret <2 x half> %val +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits