kushanam updated this revision to Diff 522290. kushanam added a comment. Adressing review changes and removing bf16 registers
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D149976/new/ https://reviews.llvm.org/D149976 Files: clang/include/clang/Basic/BuiltinsNVPTX.def llvm/include/llvm/IR/IntrinsicsNVVM.td llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp llvm/lib/Target/NVPTX/NVPTXInstrInfo.td llvm/lib/Target/NVPTX/NVPTXIntrinsics.td llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp llvm/lib/Target/NVPTX/NVPTXMCExpr.h llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp llvm/lib/Target/NVPTX/NVPTXSubtarget.h llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp llvm/test/CodeGen/NVPTX/bf16-instructions.ll
Index: llvm/test/CodeGen/NVPTX/bf16-instructions.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/NVPTX/bf16-instructions.ll @@ -0,0 +1,88 @@ +; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s +; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %} + + +; CHECK-LABEL: test_fadd( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fadd_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fadd_param_1]; +; CHECK-NEXT: add.rn.bf16 [[R:%f[0-9]+]], [[A]], [[B]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; + +define bfloat @test_fadd(bfloat %0, bfloat %1) { + %3 = fadd bfloat %0, %1 + ret bfloat %3 +} + +; CHECK-LABEL: test_fsub( +; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fsub_param_0]; +; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fsub_param_1]; +; CHECK-NEXT: sub.rn.bf16 [[R:%f[0-9]+]], [[A]], [[B]]; +; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; + +define bfloat @test_fsub(bfloat %0, bfloat %1) { + %3 = fsub bfloat %0, %1 + ret bfloat %3 +} + +; CHECK-LABEL: test_faddx2( +; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_faddx2_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_faddx2_param_1]; +; CHECK-NEXT: add.rn.bf16x2 [[R:%f[0-9]+]], [[A]], [[B]]; + +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; + +define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fadd <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} + +; CHECK-LABEL: test_fsubx2( +; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_fsubx2_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_fsubx2_param_1]; +; CHECK-NEXT: sub.rn.bf16x2 [[R:%f[0-9]+]], [[A]], [[B]]; + +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; + +define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fsub <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} + +; CHECK-LABEL: test_fmulx2( +; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_fmulx2_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_fmulx2_param_1]; +; CHECK-NEXT: mul.rn.bf16x2 [[R:%f[0-9]+]], [[A]], [[B]]; + +; CHECK: st.param.b32 [func_retval0+0], [[R]]; +; CHECK: ret; + +define <2 x bfloat> @test_fmul(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fmul <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} + +; CHECK-LABEL: test_fdiv( +; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_fdiv_param_0]; +; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_fdiv_param_1]; +; CHECK-DAG: mov.b32 {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]] +; CHECK-DAG: mov.b32 {[[B0:%h[0-9]+]], [[B1:%h[0-9]+]]}, [[B]] +; CHECK-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]]; +; CHECK-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]; +; CHECK-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]]; +; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]]; +; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]]; +; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]]; +; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%h[0-9]+]], [[FR0]]; +; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%h[0-9]+]], [[FR1]]; +; CHECK-NEXT: mov.b32 [[R:%hh[0-9]+]], {[[R0]], [[R1]]} +; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]]; +; CHECK-NEXT: ret; + +define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 { + %r = fdiv <2 x bfloat> %a, %b + ret <2 x bfloat> %r +} Index: llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -204,6 +204,14 @@ return {Intrinsic::fma, FTZ_MustBeOff, true}; case Intrinsic::nvvm_fma_rn_ftz_f16x2: return {Intrinsic::fma, FTZ_MustBeOn, true}; + case Intrinsic::nvvm_fma_rn_bf16: + return {Intrinsic::fma, FTZ_MustBeOff, true}; + case Intrinsic::nvvm_fma_rn_ftz_bf16: + return {Intrinsic::fma, FTZ_MustBeOn, true}; + case Intrinsic::nvvm_fma_rn_bf16x2: + return {Intrinsic::fma, FTZ_MustBeOff, true}; + case Intrinsic::nvvm_fma_rn_ftz_bf16x2: + return {Intrinsic::fma, FTZ_MustBeOn, true}; case Intrinsic::nvvm_fmax_d: return {Intrinsic::maxnum, FTZ_Any}; case Intrinsic::nvvm_fmax_f: Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.h =================================================================== --- llvm/lib/Target/NVPTX/NVPTXSubtarget.h +++ llvm/lib/Target/NVPTX/NVPTXSubtarget.h @@ -76,7 +76,9 @@ inline bool hasHWROT32() const { return SmVersion >= 32; } bool hasImageHandles() const; bool hasFP16Math() const { return SmVersion >= 53; } + bool hasBF16Math() const { return SmVersion >= 80; } bool allowFP16Math() const; + bool allowBF16Math() const; bool hasMaskOperator() const { return PTXVersion >= 71; } bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; } unsigned int getSmVersion() const { return SmVersion; } Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp +++ llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp @@ -26,7 +26,10 @@ NoF16Math("nvptx-no-f16-math", cl::Hidden, cl::desc("NVPTX Specific: Disable generation of f16 math ops."), cl::init(false)); - +static cl::opt<bool> + NoBF16Math("nvptx-no-bf16-math", cl::Hidden, + cl::desc("NVPTX Specific: Disable generation of bf16 math ops."), + cl::init(false)); // Pin the vtable to this file. void NVPTXSubtarget::anchor() {} @@ -65,3 +68,7 @@ bool NVPTXSubtarget::allowFP16Math() const { return hasFP16Math() && NoF16Math == false; } + +bool NVPTXSubtarget::allowBF16Math() const { + return hasBF16Math() && NoBF16Math == false; +} \ No newline at end of file Index: llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -60,8 +60,10 @@ def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>; def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4), VRFrame32, VRFrameLocal32)>; def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>; -def Float16Regs : NVPTXRegClass<[f16,bf16], 16, (add (sequence "H%u", 0, 4))>; -def Float16x2Regs : NVPTXRegClass<[v2f16,v2bf16], 32, (add (sequence "HH%u", 0, 4))>; +def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>; +def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>; +def BFloat16Regs : NVPTXRegClass<[bf16], 16, (add (sequence "H%u", 0, 4))>; +def BFloat16x2Regs : NVPTXRegClass<[v2bf16], 32, (add (sequence "HH%u", 0, 4))>; def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>; def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>; def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>; Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.h =================================================================== --- llvm/lib/Target/NVPTX/NVPTXMCExpr.h +++ llvm/lib/Target/NVPTX/NVPTXMCExpr.h @@ -21,6 +21,7 @@ public: enum VariantKind { VK_NVPTX_None, + VK_NVPTX_BFLOAT_PREC_FLOAT, // FP constant in bfloat-precision VK_NVPTX_HALF_PREC_FLOAT, // FP constant in half-precision VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision @@ -40,6 +41,11 @@ static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt, MCContext &Ctx); + static const NVPTXFloatMCExpr *createConstantBFPHalf(const APFloat &Flt, + MCContext &Ctx) { + return create(VK_NVPTX_BFLOAT_PREC_FLOAT, Flt, Ctx); + } + static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt, MCContext &Ctx) { return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx); Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp +++ llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp @@ -34,6 +34,11 @@ NumHex = 4; APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored); break; + case VK_NVPTX_BFLOAT_PREC_FLOAT: + OS << "0x"; + NumHex = 4; + APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &Ignored); + break; case VK_NVPTX_SINGLE_PREC_FLOAT: OS << "0f"; NumHex = 8; Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -973,6 +973,18 @@ FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, Float16Regs, [hasPTX70, hasSM80]>, + FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, BFloat16Regs, [hasPTX70, hasSM80]>, + FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, BFloat16Regs, + [hasPTX70, hasSM80]>, + FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, BFloat16Regs, + [hasPTX70, hasSM80]>, + FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, BFloat16Regs, + [hasPTX70, hasSM80]>, + FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, BFloat16Regs, + [hasPTX70, hasSM80]>, + FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, BFloat16Regs, + [hasPTX70, hasSM80]>, + FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, Float16x2Regs, [hasPTX42, hasSM53]>, FMA_TUPLE<"_rn_ftz_f16x2", int_nvvm_fma_rn_ftz_f16x2, Float16x2Regs, @@ -986,13 +998,9 @@ FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2, Float16x2Regs, [hasPTX70, hasSM80]>, - FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX70, hasSM80]>, - FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs, - [hasPTX70, hasSM80]>, - - FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, Int32Regs, + FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, BFloat16x2Regs, [hasPTX70, hasSM80]>, - FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, Int32Regs, + FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, BFloat16x2Regs, [hasPTX70, hasSM80]> ] in { def P.Variant : @@ -1243,24 +1251,6 @@ def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b), (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>; -def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>; -def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>; -def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>; -def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b), - (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>; - -def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRN)>; -def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>; -def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>; -def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a), - (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>; - def CVT_tf32_f32 : NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a), "cvt.rna.tf32.f32 \t$dest, $a;", @@ -2136,6 +2126,8 @@ defm INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64 \t$result, [$src];", Int64Regs>; defm INT_PTX_LDU_GLOBAL_f16 : LDU_G<"b16 \t$result, [$src];", Float16Regs>; defm INT_PTX_LDU_GLOBAL_f16x2 : LDU_G<"b32 \t$result, [$src];", Float16x2Regs>; +defm INT_PTX_LDU_GLOBAL_bf16 : LDU_G<"b16 \t$result, [$src];", BFloat16Regs>; +defm INT_PTX_LDU_GLOBAL_bf16x2 : LDU_G<"b32 \t$result, [$src];", BFloat16x2Regs>; defm INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32 \t$result, [$src];", Float32Regs>; defm INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64 \t$result, [$src];", Float64Regs>; defm INT_PTX_LDU_GLOBAL_p32 : LDU_G<"u32 \t$result, [$src];", Int32Regs>; @@ -2190,6 +2182,10 @@ : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Float16Regs>; defm INT_PTX_LDU_G_v2f16x2_ELE : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Float16x2Regs>; +defm INT_PTX_LDU_G_v2bf16_ELE + : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", BFloat16Regs>; +defm INT_PTX_LDU_G_v2bf16x2_ELE + : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", BFloat16x2Regs>; defm INT_PTX_LDU_G_v2f32_ELE : VLDU_G_ELE_V2<"v2.f32 \t{{$dst1, $dst2}}, [$src];", Float32Regs>; defm INT_PTX_LDU_G_v2i64_ELE @@ -2253,6 +2249,10 @@ : LDG_G<"b16 \t$result, [$src];", Float16Regs>; defm INT_PTX_LDG_GLOBAL_f16x2 : LDG_G<"b32 \t$result, [$src];", Float16x2Regs>; +defm INT_PTX_LDG_GLOBAL_bf16 + : LDG_G<"b16 \t$result, [$src];", BFloat16Regs>; +defm INT_PTX_LDG_GLOBAL_bf16x2 + : LDG_G<"b32 \t$result, [$src];", BFloat16x2Regs>; defm INT_PTX_LDG_GLOBAL_f32 : LDG_G<"f32 \t$result, [$src];", Float32Regs>; defm INT_PTX_LDG_GLOBAL_f64 Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td =================================================================== --- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -19,6 +19,8 @@ let OperandType = "OPERAND_IMMEDIATE" in { def f16imm : Operand<f16>; + def bf16imm : Operand<bf16>; + } // List of vector specific properties @@ -172,6 +174,7 @@ def useShortPtr : Predicate<"useShortPointers()">; def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; +def useBFP16Math: Predicate<"Subtarget->allowBF16Math()">; // Helper class to aid conversion between ValueType and a matching RegisterClass. @@ -184,8 +187,8 @@ !eq(name, "i64"): Int64Regs, !eq(name, "f16"): Float16Regs, !eq(name, "v2f16"): Float16x2Regs, - !eq(name, "bf16"): Float16Regs, - !eq(name, "v2bf16"): Float16x2Regs, + !eq(name, "bf16"): BFloat16Regs, + !eq(name, "v2bf16"): BFloat16x2Regs, !eq(name, "f32"): Float32Regs, !eq(name, "f64"): Float64Regs, !eq(name, "ai32"): Int32ArgRegs, @@ -322,6 +325,31 @@ !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"), [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math]>; + def bf16rr_ftz : + NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16Regs:$a, BFloat16Regs:$b), + !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"), + [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>, + Requires<[useBFP16Math, doF32FTZ]>; + def bf16rr : + NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16Regs:$a, BFloat16Regs:$b), + !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"), + [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>, + Requires<[useBFP16Math]>; + + def bf16x2rr_ftz : + NVPTXInst<(outs BFloat16x2Regs:$dst), + (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b), + !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"), + [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>, + Requires<[useBFP16Math, doF32FTZ]>; + def bf16x2rr : + NVPTXInst<(outs BFloat16x2Regs:$dst), + (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b), + !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"), + [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>, + Requires<[useBFP16Math]>; } // Template for instructions which take three FP args. The @@ -396,7 +424,31 @@ !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"), [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math, allowFMA]>; - + def bf16rr_ftz : + NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16Regs:$a, BFloat16Regs:$b), + !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"), + [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>, + Requires<[useBFP16Math, allowFMA, doF32FTZ]>; + def bf16rr : + NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16Regs:$a, BFloat16Regs:$b), + !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"), + [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>, + Requires<[useBFP16Math, allowFMA]>; + + def bf16x2rr_ftz : + NVPTXInst<(outs BFloat16x2Regs:$dst), + (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b), + !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"), + [(set (v2bf16 BFloat16x2Regs:$dst), (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>, + Requires<[useBFP16Math, allowFMA, doF32FTZ]>; + def bf16x2rr : + NVPTXInst<(outs BFloat16x2Regs:$dst), + (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b), + !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"), + [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>, + Requires<[useBFP16Math, allowFMA]>; // These have strange names so we don't perturb existing mir tests. def _rnf64rr : NVPTXInst<(outs Float64Regs:$dst), @@ -458,6 +510,30 @@ !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"), [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>, Requires<[useFP16Math, noFMA]>; + def _rnbf16rr_ftz : + NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16Regs:$a, BFloat16Regs:$b), + !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"), + [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>, + Requires<[useBFP16Math, noFMA, doF32FTZ]>; + def _rnbf16rr : + NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16Regs:$a, BFloat16Regs:$b), + !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"), + [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>, + Requires<[useBFP16Math, noFMA]>; + def _rnbf16x2rr_ftz : + NVPTXInst<(outs BFloat16x2Regs:$dst), + (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b), + !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"), + [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>, + Requires<[useBFP16Math, noFMA, doF32FTZ]>; + def _rnbf16x2rr : + NVPTXInst<(outs BFloat16x2Regs:$dst), + (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b), + !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"), + [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>, + Requires<[useBFP16Math, noFMA]>; } // Template for operations which take two f32 or f64 operands. Provides three @@ -534,6 +610,11 @@ (ins Float16Regs:$src, CvtMode:$mode), !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", FromName, ".f16 \t$dst, $src;"), []>; + def _bf16 : + NVPTXInst<(outs RC:$dst), + (ins BFloat16Regs:$src, CvtMode:$mode), + !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.", + FromName, ".bf16 \t$dst, $src;"), []>; def _f32 : NVPTXInst<(outs RC:$dst), (ins Float32Regs:$src, CvtMode:$mode), @@ -556,6 +637,7 @@ defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>; defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>; defm CVT_f16 : CVT_FROM_ALL<"f16", Float16Regs>; + defm CVT_bf16 : CVT_FROM_ALL<"bf16", BFloat16Regs>; defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>; defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>; @@ -574,18 +656,7 @@ def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src), "cvt.s64.s32 \t$dst, $src;", []>; -multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> { - def _f32 : - NVPTXInst<(outs RC:$dst), - (ins Float32Regs:$src, CvtMode:$mode), - !strconcat("cvt${mode:base}${mode:relu}.", - FromName, ".f32 \t$dst, $src;"), []>, - Requires<[hasPTX70, hasSM80]>; - } - - defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>; - - multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> { + multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> { def _f32 : NVPTXInst<(outs RC:$dst), (ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode), @@ -594,7 +665,7 @@ Requires<[hasPTX70, hasSM80]>; } - defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>; + defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", BFloat16x2Regs>; defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>; } @@ -659,7 +730,7 @@ defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>; defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>; defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>; - +defm SELP_bf16 : SELP_PATTERN<"b16", bf16, BFloat16Regs, bf16imm, fpimm>; defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>; defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>; @@ -1023,7 +1094,9 @@ def LOAD_CONST_F16 : NVPTXInst<(outs Float16Regs:$dst), (ins f16imm:$a), "mov.b16 \t$dst, $a;", []>; - +def LOAD_CONST_BF16 : + NVPTXInst<(outs BFloat16Regs:$dst), (ins bf16imm:$a), + "mov.b16 \t$dst, $a;", []>; defm FADD : F3_fma_component<"add", fadd>; defm FSUB : F3_fma_component<"sub", fsub>; defm FMUL : F3_fma_component<"mul", fmul>; @@ -1051,6 +1124,20 @@ def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>; def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", v2f16, Float16x2Regs, True>; +// +// BF16 NEG +// + +class FNEG_BF16_F16X2<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> : + NVPTXInst<(outs RC:$dst), (ins RC:$src), + !strconcat(OpcStr, " \t$dst, $src;"), + [(set RC:$dst, (fneg (T RC:$src)))]>, + Requires<[useFP16Math, hasPTX70, hasSM80, Pred]>; +def BFNEG16_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16", bf16, BFloat16Regs, doF32FTZ>; +def BFNEG16 : FNEG_BF16_F16X2<"neg.bf16", bf16, BFloat16Regs, True>; +def BFNEG16x2_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16x2", v2bf16, BFloat16x2Regs, doF32FTZ>; +def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, BFloat16x2Regs, True>; + // // F64 division // @@ -1229,10 +1316,21 @@ Requires<[useFP16Math, Pred]>; } +multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> { + def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), + !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), + [(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>, + Requires<[useBFP16Math, Pred]>; +} + defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Float16Regs, doF32FTZ>; defm FMA16 : FMA_F16<"fma.rn.f16", f16, Float16Regs, True>; defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>; defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Float16x2Regs, True>; +defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, BFloat16Regs, doF32FTZ>; +defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, BFloat16Regs, True>; +defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, BFloat16x2Regs, doF32FTZ>; +defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, BFloat16x2Regs, True>; defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>; defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>; defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>; @@ -1679,6 +1777,18 @@ "setp${cmp:base}${cmp:ftz}.f16x2 \t$p|$q, $a, $b;", []>, Requires<[useFP16Math]>; +def SETP_bf16rr : + NVPTXInst<(outs Int1Regs:$dst), + (ins BFloat16Regs:$a, BFloat16Regs:$b, CmpMode:$cmp), + "setp${cmp:base}${cmp:ftz}.bf16 \t$dst, $a, $b;", + []>, Requires<[useBFP16Math]>; + +def SETP_bf16x2rr : + NVPTXInst<(outs Int1Regs:$p, Int1Regs:$q), + (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b, CmpMode:$cmp), + "setp${cmp:base}${cmp:ftz}.bf16x2 \t$p|$q, $a, $b;", + []>, + Requires<[useBFP16Math]>; // FIXME: This doesn't appear to be correct. The "set" mnemonic has the form @@ -1709,6 +1819,7 @@ defm SET_s64 : SET<"s64", Int64Regs, i64imm>; defm SET_u64 : SET<"u64", Int64Regs, i64imm>; defm SET_f16 : SET<"f16", Float16Regs, f16imm>; +defm SET_bf16 : SET<"bf16", BFloat16Regs, bf16imm>; defm SET_f32 : SET<"f32", Float32Regs, f32imm>; defm SET_f64 : SET<"f64", Float64Regs, f64imm>; @@ -1781,6 +1892,8 @@ def FMOV16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$src), // We have to use .b16 here as there's no mov.f16. "mov.b16 \t$dst, $src;", []>; + def BFMOV16rr : NVPTXInst<(outs BFloat16Regs:$dst), (ins BFloat16Regs:$src), + "mov.b16 \t$dst, $src;", []>; def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), "mov.f32 \t$dst, $src;", []>; def FMOV64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src), @@ -1963,7 +2076,27 @@ (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, Requires<[useFP16Math]>; - // f32 -> pred + // bf16 -> pred + def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))), + (SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>, + Requires<[useBFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))), + (SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, Mode)>, + Requires<[useBFP16Math]>; + def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)), + (SETP_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>, + Requires<[useBFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)), + (SETP_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>, + Requires<[useBFP16Math]>; + def : Pat<(i1 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))), + (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, ModeFTZ)>, + Requires<[useBFP16Math,doF32FTZ]>; + def : Pat<(i1 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))), + (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, Mode)>, + Requires<[useBFP16Math]>; + + //f32 -> pred def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)), (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>, Requires<[doF32FTZ]>; @@ -2007,6 +2140,26 @@ def : Pat<(i32 (OpNode fpimm:$a, (f16 Float16Regs:$b))), (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>, Requires<[useFP16Math]>; + + // bf16 -> i32 + def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))), + (SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>, + Requires<[useBFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))), + (SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, Mode)>, + Requires<[useBFP16Math]>; + def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)), + (SET_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>, + Requires<[useBFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)), + (SET_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>, + Requires<[useBFP16Math]>; + def : Pat<(i32 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))), + (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, ModeFTZ)>, + Requires<[useBFP16Math, doF32FTZ]>; + def : Pat<(i32 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))), + (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, Mode)>, + Requires<[useBFP16Math]>; // f32 -> i32 def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)), @@ -2296,10 +2449,14 @@ def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">; def LoadParamMemF16 : LoadParamMemInst<Float16Regs, ".b16">; def LoadParamMemF16x2 : LoadParamMemInst<Float16x2Regs, ".b32">; +def LoadParamMemBF16 : LoadParamMemInst<BFloat16Regs, ".b16">; +def LoadParamMemBF16x2 : LoadParamMemInst<BFloat16x2Regs, ".b32">; def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">; def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">; def LoadParamMemV2F16 : LoadParamV2MemInst<Float16Regs, ".b16">; def LoadParamMemV2F16x2: LoadParamV2MemInst<Float16x2Regs, ".b32">; +def LoadParamMemV2BF16 : LoadParamV2MemInst<BFloat16Regs, ".b16">; +def LoadParamMemV2BF16x2: LoadParamV2MemInst<BFloat16x2Regs, ".b32">; def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">; def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">; def LoadParamMemV4F16 : LoadParamV4MemInst<Float16Regs, ".b16">; @@ -2322,6 +2479,10 @@ def StoreParamF16 : StoreParamInst<Float16Regs, ".b16">; def StoreParamF16x2 : StoreParamInst<Float16x2Regs, ".b32">; + +def StoreParamBF16 : StoreParamInst<BFloat16Regs, ".b16">; +def StoreParamBF16x2 : StoreParamInst<BFloat16x2Regs, ".b32">; + def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">; def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">; def StoreParamV2F16 : StoreParamV2Inst<Float16Regs, ".b16">; @@ -2348,6 +2509,8 @@ def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">; def StoreRetvalF16 : StoreRetvalInst<Float16Regs, ".b16">; def StoreRetvalF16x2 : StoreRetvalInst<Float16x2Regs, ".b32">; +def StoreRetvalBF16 : StoreRetvalInst<BFloat16Regs, ".b16">; +def StoreRetvalBF16x2 : StoreRetvalInst<BFloat16x2Regs, ".b32">; def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">; def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">; def StoreRetvalV2F16 : StoreRetvalV2Inst<Float16Regs, ".b16">; @@ -2450,6 +2613,7 @@ def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">; def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">; def MoveParamF16 : MoveParamInst<f16, Float16Regs, ".f16">; +def MoveParamBF16 : MoveParamInst<bf16, BFloat16Regs, ".bf16">; class PseudoUseParamInst<NVPTXRegClass regclass> : NVPTXInst<(outs), (ins regclass:$src), @@ -2473,11 +2637,11 @@ def ProxyRegI32 : ProxyRegInst<"b32", i32, Int32Regs>; def ProxyRegI64 : ProxyRegInst<"b64", i64, Int64Regs>; def ProxyRegF16 : ProxyRegInst<"b16", f16, Float16Regs>; - def ProxyRegBF16 : ProxyRegInst<"b16", bf16, Float16Regs>; + def ProxyRegBF16 : ProxyRegInst<"b16", bf16, BFloat16Regs>; def ProxyRegF32 : ProxyRegInst<"f32", f32, Float32Regs>; def ProxyRegF64 : ProxyRegInst<"f64", f64, Float64Regs>; def ProxyRegF16x2 : ProxyRegInst<"b32", v2f16, Float16x2Regs>; - def ProxyRegBF16x2 : ProxyRegInst<"b32", v2bf16, Float16x2Regs>; + def ProxyRegBF16x2 : ProxyRegInst<"b32", v2bf16, BFloat16x2Regs>; } // @@ -2578,7 +2742,9 @@ defm ST_i32 : ST<Int32Regs>; defm ST_i64 : ST<Int64Regs>; defm ST_f16 : ST<Float16Regs>; + defm ST_bf16 : ST<BFloat16Regs>; defm ST_f16x2 : ST<Float16x2Regs>; + defm ST_bf16x2 : ST<BFloat16x2Regs>; defm ST_f32 : ST<Float32Regs>; defm ST_f64 : ST<Float64Regs>; } @@ -2667,6 +2833,8 @@ defm LDV_i64 : LD_VEC<Int64Regs>; defm LDV_f16 : LD_VEC<Float16Regs>; defm LDV_f16x2 : LD_VEC<Float16x2Regs>; + defm LDV_bf16 : LD_VEC<BFloat16Regs>; + defm LDV_bf16x2 : LD_VEC<BFloat16x2Regs>; defm LDV_f32 : LD_VEC<Float32Regs>; defm LDV_f64 : LD_VEC<Float64Regs>; } @@ -2762,6 +2930,8 @@ defm STV_i64 : ST_VEC<Int64Regs>; defm STV_f16 : ST_VEC<Float16Regs>; defm STV_f16x2 : ST_VEC<Float16x2Regs>; + defm STV_bf16 : ST_VEC<BFloat16Regs>; + defm STV_bf16x2 : ST_VEC<BFloat16x2Regs>; defm STV_f32 : ST_VEC<Float32Regs>; defm STV_f64 : ST_VEC<Float64Regs>; } @@ -2816,6 +2986,26 @@ def : Pat<(f16 (uint_to_fp Int64Regs:$a)), (CVT_f16_u64 Int64Regs:$a, CvtRN)>; +// sint -> bf16 +def : Pat<(bf16 (sint_to_fp Int1Regs:$a)), + (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(bf16 (sint_to_fp Int16Regs:$a)), + (CVT_bf16_s16 Int16Regs:$a, CvtRN)>; +def : Pat<(bf16 (sint_to_fp Int32Regs:$a)), + (CVT_bf16_s32 Int32Regs:$a, CvtRN)>; +def : Pat<(bf16 (sint_to_fp Int64Regs:$a)), + (CVT_bf16_s64 Int64Regs:$a, CvtRN)>; + +// uint -> bf16 +def : Pat<(bf16 (uint_to_fp Int1Regs:$a)), + (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; +def : Pat<(bf16 (uint_to_fp Int16Regs:$a)), + (CVT_bf16_u16 Int16Regs:$a, CvtRN)>; +def : Pat<(bf16 (uint_to_fp Int32Regs:$a)), + (CVT_bf16_u32 Int32Regs:$a, CvtRN)>; +def : Pat<(bf16 (uint_to_fp Int64Regs:$a)), + (CVT_bf16_u64 Int64Regs:$a, CvtRN)>; + // sint -> f32 def : Pat<(f32 (sint_to_fp Int1Regs:$a)), (CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>; @@ -2877,6 +3067,25 @@ def : Pat<(i64 (fp_to_uint (f16 Float16Regs:$a))), (CVT_u64_f16 Float16Regs:$a, CvtRZI)>; +// bf16 -> sint +def : Pat<(i1 (fp_to_sint (bf16 BFloat16Regs:$a))), + (SETP_b16ri (BITCONVERT_16_BF2I BFloat16Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_sint (bf16 BFloat16Regs:$a))), + (CVT_s16_bf16 (bf16 BFloat16Regs:$a), CvtRZI)>; +def : Pat<(i32 (fp_to_sint (bf16 BFloat16Regs:$a))), + (CVT_s32_bf16 (bf16 BFloat16Regs:$a), CvtRZI)>; +def : Pat<(i64 (fp_to_sint (bf16 BFloat16Regs:$a))), + (CVT_s64_bf16 BFloat16Regs:$a, CvtRZI)>; + +// bf16 -> uint +def : Pat<(i1 (fp_to_uint (bf16 BFloat16Regs:$a))), + (SETP_b16ri (BITCONVERT_16_BF2I BFloat16Regs:$a), 0, CmpEQ)>; +def : Pat<(i16 (fp_to_uint (bf16 BFloat16Regs:$a))), + (CVT_u16_bf16 BFloat16Regs:$a, CvtRZI)>; +def : Pat<(i32 (fp_to_uint (bf16 BFloat16Regs:$a))), + (CVT_u32_bf16 BFloat16Regs:$a, CvtRZI)>; +def : Pat<(i64 (fp_to_uint (bf16 BFloat16Regs:$a))), + (CVT_u64_bf16 BFloat16Regs:$a, CvtRZI)>; // f32 -> sint def : Pat<(i1 (fp_to_sint Float32Regs:$a)), (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>; @@ -3024,6 +3233,9 @@ def : Pat<(select Int32Regs:$pred, (f16 Float16Regs:$a), (f16 Float16Regs:$b)), (SELP_f16rr Float16Regs:$a, Float16Regs:$b, (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; +def : Pat<(select Int32Regs:$pred, (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)), + (SELP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, + (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b), (SELP_f32rr Float32Regs:$a, Float32Regs:$b, (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>; @@ -3124,6 +3336,42 @@ (ins Int32Regs:$src), "mov.b32 \t{{$lo, $hi}}, $src;", []>; + def BF16x2toBF16_0 : NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16x2Regs:$src), + "{{ .reg .b16 \t%tmp_hi;\n\t" + " mov.b32 \t{$dst, %tmp_hi}, $src; }}", + [(set BFloat16Regs:$dst, + (extractelt (v2bf16 BFloat16x2Regs:$src), 0))]>; + def BF16x2toBF16_1 : NVPTXInst<(outs BFloat16Regs:$dst), + (ins BFloat16x2Regs:$src), + "{{ .reg .b16 \t%tmp_lo;\n\t" + " mov.b32 \t{%tmp_lo, $dst}, $src; }}", + [(set BFloat16Regs:$dst, + (extractelt (v2bf16 BFloat16x2Regs:$src), 1))]>; + + // // Coalesce two bf16 registers into bf16x2 + // def BuildBF16x2 : NVPTXInst<(outs BFloat16x2Regs:$dst), + // (ins BFloat16Regs:$a, BFloat16Regs:$b), + // "mov.b32 \t$dst, {{$a, $b}};", + // [(set (v2bf16 BFloat16x2Regs:$dst), + // (build_vector (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>; + + // // Directly initializing underlying the b32 register is one less SASS + // // instruction than than vector-packing move. + // def BuildBF16x2i : NVPTXInst<(outs BFloat16x2Regs:$dst), (ins i32imm:$src), + // "mov.b32 \t$dst, $src;", + // []>; + + // // Split f16x2 into two f16 registers. + // def SplitBF16x2 : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi), + // (ins BFloat16x2Regs:$src), + // "mov.b32 \t{{$lo, $hi}}, $src;", + // []>; + // // Split an i32 into two f16 + // def SplitI32toBF16x2 : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi), + // (ins Int32Regs:$src), + // "mov.b32 \t{{$lo, $hi}}, $src;", + // []>; } // Count leading zeros @@ -3193,10 +3441,17 @@ def : Pat<(f16 (fpround Float32Regs:$a)), (CVT_f16_f32 Float32Regs:$a, CvtRN)>; +// fpround f32 -> bf16 +def : Pat<(bf16 (fpround Float32Regs:$a)), + (CVT_bf16_f32 Float32Regs:$a, CvtRN)>; + // fpround f64 -> f16 def : Pat<(f16 (fpround Float64Regs:$a)), (CVT_f16_f64 Float64Regs:$a, CvtRN)>; +// fpround f64 -> bf16 +def : Pat<(bf16 (fpround Float64Regs:$a)), + (CVT_bf16_f64 Float64Regs:$a, CvtRN)>; // fpround f64 -> f32 def : Pat<(f32 (fpround Float64Regs:$a)), (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>; @@ -3208,11 +3463,20 @@ (CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; def : Pat<(f32 (fpextend (f16 Float16Regs:$a))), (CVT_f32_f16 Float16Regs:$a, CvtNONE)>; +// fpextend bf16 -> f32 +def : Pat<(f32 (fpextend (bf16 BFloat16Regs:$a))), + (CVT_f32_bf16 BFloat16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; +def : Pat<(f32 (fpextend (bf16 BFloat16Regs:$a))), + (CVT_f32_bf16 BFloat16Regs:$a, CvtNONE)>; // fpextend f16 -> f64 def : Pat<(f64 (fpextend (f16 Float16Regs:$a))), (CVT_f64_f16 Float16Regs:$a, CvtNONE)>; +// fpextend bf16 -> f64 +def : Pat<(f64 (fpextend (bf16 BFloat16Regs:$a))), + (CVT_f64_bf16 BFloat16Regs:$a, CvtNONE)>; + // fpextend f32 -> f64 def : Pat<(f64 (fpextend Float32Regs:$a)), (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>; @@ -3227,6 +3491,8 @@ multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> { def : Pat<(OpNode (f16 Float16Regs:$a)), (CVT_f16_f16 Float16Regs:$a, Mode)>; + def : Pat<(OpNode (bf16 BFloat16Regs:$a)), + (CVT_bf16_bf16 BFloat16Regs:$a, Mode)>; def : Pat<(OpNode Float32Regs:$a), (CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(OpNode Float32Regs:$a), Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -143,6 +143,26 @@ } } +static bool Isv2f16Orv2bf16Type(MVT VT) { + switch (VT.SimpleTy) { + default: + return false; + case MVT::v2f16: + case MVT::v2bf16: + return true; + } +} + +static bool Isf16Orbf16Type(MVT VT) { + switch (VT.SimpleTy) { + default: + return false; + case MVT::f16: + case MVT::bf16: + return true; + } +} + /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive /// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors /// into their primitive components. @@ -193,7 +213,7 @@ // Vectors with an even number of f16 elements will be passed to // us as an array of v2f16/v2bf16 elements. We must match this so we // stay in sync with Ins/Outs. - if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) { + if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) { EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16; NumElts /= 2; } @@ -398,6 +418,11 @@ setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action); }; + auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action, + LegalizeAction NoBF16Action) { + setOperationAction(Op, VT, STI.allowBF16Math() ? Action : NoBF16Action); + }; + addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass); addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass); addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass); @@ -406,8 +431,6 @@ addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass); addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass); - addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass); - addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass); // Conversion to/from FP16/FP16x2 is always legal. setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal); @@ -420,6 +443,16 @@ setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote); setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand); + // Conversion to/from BFP16/BFP16x2 is always legal. + setOperationAction(ISD::SINT_TO_FP, MVT::bf16, Legal); + setOperationAction(ISD::FP_TO_SINT, MVT::bf16, Legal); + setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom); + setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom); + setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand); + setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand); + + setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote); + setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand); // Operations not directly supported by NVPTX. for (MVT VT : {MVT::f16, MVT::v2f16, MVT::f32, MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::i32, MVT::i64}) { @@ -476,17 +509,25 @@ // Turn FP extload into load/fpextend setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand); + setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand); setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand); // Turn FP truncstore into trunc + store. // FIXME: vector types should also be expanded setTruncStoreAction(MVT::f32, MVT::f16, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); + setTruncStoreAction(MVT::f32, MVT::bf16, Expand); + setTruncStoreAction(MVT::f64, MVT::bf16, Expand); setTruncStoreAction(MVT::f64, MVT::f32, Expand); // PTX does not support load / store predicate registers @@ -563,9 +604,9 @@ setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM}); - // setcc for f16x2 needs special handling to prevent legalizer's - // attempt to scalarize it due to v2i1 not being legal. - if (STI.allowFP16Math()) + // setcc for f16x2 and bf16x2 needs special handling to prevent + // legalizer's attempt to scalarize it due to v2i1 not being legal. + if (STI.allowFP16Math() || STI.allowBF16Math()) setTargetDAGCombine(ISD::SETCC); // Promote fp16 arithmetic if fp16 hardware isn't available or the @@ -579,6 +620,11 @@ setFP16OperationAction(Op, MVT::v2f16, Legal, Expand); } + for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) { + setBF16OperationAction(Op, MVT::bf16, Legal, Promote); + setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand); + } + // f16/f16x2 neg was introduced in PTX 60, SM_53. const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 && STI.getPTXVersion() >= 60 && @@ -587,19 +633,29 @@ setOperationAction(ISD::FNEG, VT, IsFP16FP16x2NegAvailable ? Legal : Expand); + const bool IsBFP16FP16x2NegAvailable = STI.getSmVersion() >= 80 && + STI.getPTXVersion() >= 70 && + STI.allowBF16Math(); + for (const auto &VT : {MVT::bf16, MVT::v2bf16}) + setOperationAction(ISD::FNEG, VT, + IsBFP16FP16x2NegAvailable ? Legal : Expand); // (would be) Library functions. // These map to conversion instructions for scalar FP types. for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT, ISD::FROUNDEVEN, ISD::FTRUNC}) { + setOperationAction(Op, MVT::bf16, Legal); setOperationAction(Op, MVT::f16, Legal); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); + setOperationAction(Op, MVT::v2bf16, Expand); } setOperationAction(ISD::FROUND, MVT::f16, Promote); setOperationAction(ISD::FROUND, MVT::v2f16, Expand); + setOperationAction(ISD::FROUND, MVT::bf16, Promote); + setOperationAction(ISD::FROUND, MVT::v2bf16, Expand); setOperationAction(ISD::FROUND, MVT::f32, Custom); setOperationAction(ISD::FROUND, MVT::f64, Custom); @@ -607,6 +663,8 @@ // 'Expand' implements FCOPYSIGN without calling an external library. setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand); setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand); @@ -616,9 +674,11 @@ for (const auto &Op : {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) { setOperationAction(Op, MVT::f16, Promote); + setOperationAction(Op, MVT::bf16, Promote); setOperationAction(Op, MVT::f32, Legal); setOperationAction(Op, MVT::f64, Legal); setOperationAction(Op, MVT::v2f16, Expand); + setOperationAction(Op, MVT::v2bf16, Expand); } // max.f16, max.f16x2 and max.NaN are supported on sm_80+. auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) { @@ -636,6 +696,12 @@ setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand)); setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand); } + for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) { + setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote); + setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); + setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand); + setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand); + } // No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate. // No FPOW or FREM in PTX. @@ -1252,7 +1318,7 @@ if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 && VT.getScalarType() == MVT::i1) return TypeSplitVector; - if (VT == MVT::v2f16) + if (Isv2f16Orv2bf16Type(VT)) return TypeLegal; return TargetLoweringBase::getPreferredVectorAction(VT); } @@ -1402,7 +1468,7 @@ sz = promoteScalarArgumentSize(sz); } else if (isa<PointerType>(Ty)) { sz = PtrVT.getSizeInBits(); - } else if (Ty->isHalfTy()) + } else if (Ty->isHalfTy() || Ty->isBFloatTy()) // PTX ABI requires all scalar parameters to be at least 32 // bits in size. fp16 normally uses .b16 as its storage type // in PTX, so its size must be adjusted here, too. @@ -2037,7 +2103,7 @@ // generates good SASS in both cases. SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { - if (!(Op->getValueType(0) == MVT::v2f16 && + if (!(Isv2f16Orv2bf16Type(Op->getOperand(0).getValueType().getSimpleVT()) && isa<ConstantFPSDNode>(Op->getOperand(0)) && isa<ConstantFPSDNode>(Op->getOperand(1)))) return Op; @@ -2048,7 +2114,7 @@ cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt(); SDValue Const = DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32); - return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const); + return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const); } SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, @@ -2409,7 +2475,7 @@ // v2f16 is legal, so we can't rely on legalizer to handle unaligned // loads and have to handle it here. - if (Op.getValueType() == MVT::v2f16) { + if (Isv2f16Orv2bf16Type(Op.getValueType().getSimpleVT())) { LoadSDNode *Load = cast<LoadSDNode>(Op); EVT MemVT = Load->getMemoryVT(); if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), @@ -2454,7 +2520,7 @@ // v2f16 is legal, so we can't rely on legalizer to handle unaligned // stores and have to handle it here. - if (VT == MVT::v2f16 && + if ((Isv2f16Orv2bf16Type(VT.getSimpleVT())) && !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(), VT, *Store->getMemOperand())) return expandUnalignedStore(Store, DAG); @@ -2541,7 +2607,7 @@ // v8f16 is a special case. PTX doesn't have st.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // store them with st.v4.b32. - assert((EltVT == MVT::f16 || EltVT == MVT::bf16) && + assert((Isf16Orbf16Type(EltVT.getSimpleVT())) && "Wrong type for the vector."); Opcode = NVPTXISD::StoreV4; StoreF16x2 = true; @@ -2557,11 +2623,12 @@ // Combine f16,f16 -> v2f16 NumElts /= 2; for (unsigned i = 0; i < NumElts; ++i) { - SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, DAG.getIntPtrConstant(i * 2, DL)); - SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val, + SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, DAG.getIntPtrConstant(i * 2 + 1, DL)); - SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1); + EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2); + SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1); Ops.push_back(V2); } } else { @@ -2733,9 +2800,9 @@ EVT LoadVT = EltVT; if (EltVT == MVT::i1) LoadVT = MVT::i8; - else if (EltVT == MVT::v2f16) + else if (Isv2f16Orv2bf16Type(EltVT.getSimpleVT())) // getLoad needs a vector type, but it can't handle - // vectors which contain v2f16 elements. So we must load + // vectors which contain v2f16 or v2bf16 elements. So we must load // using i32 here and then bitcast back. LoadVT = MVT::i32; @@ -5171,7 +5238,7 @@ // v8f16 is a special case. PTX doesn't have ld.v8.f16 // instruction. Instead, we split the vector into v2f16 chunks and // load them with ld.v4.b32. - assert((EltVT == MVT::f16 || EltVT == MVT::bf16) && + assert(Isf16Orbf16Type(EltVT.getSimpleVT()) && "Unsupported v8 vector type."); LoadF16x2 = true; Opcode = NVPTXISD::LoadV4; Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -72,6 +72,7 @@ bool trySurfaceIntrinsic(SDNode *N); bool tryBFE(SDNode *N); bool tryConstantFP16(SDNode *N); + bool tryConstantBF16(SDNode *N); bool SelectSETP_F16X2(SDNode *N); bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N); Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -537,6 +537,16 @@ return true; } +bool NVPTXDAGToDAGISel::tryConstantBF16(SDNode *N) { + if (N->getValueType(0) != MVT::bf16) + return false; + SDValue Val = CurDAG->getTargetConstantFP( + cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::bf16); + SDNode *LoadConstBF16 = + CurDAG->getMachineNode(NVPTX::LOAD_CONST_BF16, SDLoc(N), MVT::bf16, Val); + ReplaceNode(N, LoadConstBF16); + return true; +} // Map ISD:CONDCODE value to appropriate CmpMode expected by // NVPTXInstPrinter::printCmpMode() static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) { @@ -1288,6 +1298,10 @@ assert(NumElts % 2 == 0 && "Vector must have even number of elements"); EltVT = MVT::v2f16; NumElts /= 2; + } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) { + assert(NumElts % 2 == 0 && "Vector must have even number of elements"); + EltVT = MVT::v2bf16; + NumElts /= 2; } } Index: llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp =================================================================== --- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -267,6 +267,10 @@ MCOp = MCOperand::createExpr( NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext)); break; + case Type::BFloatTyID: + MCOp = MCOperand::createExpr( + NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext)); + break; case Type::FloatTyID: MCOp = MCOperand::createExpr( NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext)); @@ -1353,8 +1357,10 @@ } break; } + case Type::BFloatTyID: case Type::HalfTyID: - // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly. + // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53 + // PTX assembly. return "b16"; case Type::FloatTyID: return "f32"; @@ -1588,7 +1594,7 @@ } else if (PTy) { assert(PTySizeInBits && "Invalid pointer size"); sz = PTySizeInBits; - } else if (Ty->isHalfTy()) + } else if (Ty->isHalfTy() || Ty->isBFloatTy()) // PTX ABI requires all scalar parameters to be at least 32 // bits in size. fp16 normally uses .b16 as its storage type // in PTX, so its size must be adjusted here, too. Index: llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp =================================================================== --- llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -61,9 +61,11 @@ OS << "%fd"; break; case 7: + case 9: OS << "%h"; break; case 8: + case 10: OS << "%hh"; break; } Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -597,16 +597,18 @@ [IntrNoMem, IntrSpeculatable, Commutative]>; } - foreach variant = ["_bf16", "_nan_bf16", "_xorsign_abs_bf16", - "_nan_xorsign_abs_bf16"] in { + foreach variant = ["_bf16", "_ftz_bf16", "_nan_bf16", "_ftz_nan_bf16", + "_xorsign_abs_bf16", "_ftz_xorsign_abs_bf16", "_nan_xorsign_abs_bf16", + "_ftz_nan_xorsign_abs_bf16"] in { def int_nvvm_f # operation # variant : ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>, DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty], [IntrNoMem, IntrSpeculatable, Commutative]>; } - foreach variant = ["_bf16x2", "_nan_bf16x2", "_xorsign_abs_bf16x2", - "_nan_xorsign_abs_bf16x2"] in { + foreach variant = ["_bf16x2", "_ftz_bf16x2", "_nan_bf16x2", + "_ftz_nan_bf16x2", "_xorsign_abs_bf16x2", "_ftz_xorsign_abs_bf16x2", + "_nan_xorsign_abs_bf16x2", "_ftz_nan_xorsign_abs_bf16x2"] in { def int_nvvm_f # operation # variant : ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>, DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], @@ -874,17 +876,19 @@ [IntrNoMem, IntrSpeculatable]>; } - foreach variant = ["_rn_bf16", "_rn_relu_bf16"] in { + foreach variant = ["_rn_bf16", "_rn_ftz_bf16", "_rn_sat_bf16", + "_rn_ftz_sat_bf16", "_rn_relu_bf16", "_rn_ftz_relu_bf16"] in { def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>, - DefaultAttrsIntrinsic<[llvm_i16_ty], - [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty], + DefaultAttrsIntrinsic<[llvm_bfloat_ty], + [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>; } - foreach variant = ["_rn_bf16x2", "_rn_relu_bf16x2"] in { + foreach variant = ["_rn_bf16x2", "_rn_ftz_bf16x2", "_rn_sat_bf16x2", + "_rn_ftz_sat_bf16x2", "_rn_relu_bf16x2", "_rn_ftz_relu_bf16x2"] in { def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>, - DefaultAttrsIntrinsic<[llvm_i32_ty], - [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], + DefaultAttrsIntrinsic<[llvm_v2bf16_ty], + [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty], [IntrNoMem, IntrSpeculatable]>; } @@ -1236,6 +1240,11 @@ def int_nvvm_f2h_rn : ClangBuiltin<"__nvvm_f2h_rn">, DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>; + def int_nvvm_bf2h_rn_ftz : ClangBuiltin<"__nvvm_bf2h_rn_ftz">, + DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>; + def int_nvvm_bf2h_rn : ClangBuiltin<"__nvvm_bf2h_rn">, + DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>; + def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">, Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>; def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">, Index: clang/include/clang/Basic/BuiltinsNVPTX.def =================================================================== --- clang/include/clang/Basic/BuiltinsNVPTX.def +++ clang/include/clang/Basic/BuiltinsNVPTX.def @@ -145,12 +145,16 @@ TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "", @@ -187,12 +191,16 @@ TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) +TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70)) TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "", AND(SM_86, PTX72)) TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits