https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/80908
>From 1488b4e54982be4d3f5bc7f35617effcab52be48 Mon Sep 17 00:00:00 2001 From: Shilei Tian <i...@tianshilei.me> Date: Wed, 14 Feb 2024 09:41:00 -0500 Subject: [PATCH] [RFC][WIP][AMDGPU] Use `bf16` instead of `i16` for bfloat Currently it looks like we generally use `i16` to represent `bf16` in those tablegen files. I'm not sure of the reason behind it. My wild guess is the type `bf16` was not available when we enabled the support. This patch is trying to use `bf16` directly in those tablegen files, aiming at fixing #79369. Of course for #79369 a workaround can be to treat all `INT16` variants as `BFloat` in `getOpFltSemantics`, but it doesn't look good IMHO. Since I'm fairly new to AMDGPU backend, I'd appreciate it if you can point out where I don't understand correctly. --- clang/lib/CodeGen/CGBuiltin.cpp | 4 - .../builtins-amdgcn-dl-insts-gfx11.cl | 5 +- llvm/include/llvm/IR/IntrinsicsAMDGPU.td | 8 +- .../AMDGPU/AsmParser/AMDGPUAsmParser.cpp | 92 +++++++++++++++++++ .../AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp | 57 ++++++++++++ .../AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h | 2 + .../MCTargetDesc/AMDGPUMCCodeEmitter.cpp | 39 ++++++++ llvm/lib/Target/AMDGPU/SIDefines.h | 7 ++ llvm/lib/Target/AMDGPU/SIInstrInfo.cpp | 15 +++ llvm/lib/Target/AMDGPU/SIInstrInfo.td | 58 ++++++------ llvm/lib/Target/AMDGPU/SIRegisterInfo.td | 21 ++++- .../Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp | 54 +++++++++++ llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h | 16 ++++ llvm/lib/Target/AMDGPU/VOP3Instructions.td | 2 +- .../AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll | 39 ++++---- llvm/test/MC/AMDGPU/bf16_imm.s | 8 ++ llvm/test/MC/Disassembler/AMDGPU/bf16_imm.txt | 8 ++ 17 files changed, 379 insertions(+), 56 deletions(-) create mode 100644 llvm/test/MC/AMDGPU/bf16_imm.s create mode 100644 llvm/test/MC/Disassembler/AMDGPU/bf16_imm.txt diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index ee0b7504769622..9bc60466d09be6 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -5912,8 +5912,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, } } - assert(ArgValue->getType()->canLosslesslyBitCastTo(PTy) && - "Must be able to losslessly bit cast to param"); // Cast vector type (e.g., v256i32) to x86_amx, this only happen // in amx intrinsics. if (PTy->isX86_AMXTy()) @@ -5943,8 +5941,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, } } - assert(V->getType()->canLosslesslyBitCastTo(RetTy) && - "Must be able to losslessly bit cast result type"); // Cast x86_amx to vector type (e.g., v256i32), this only happen // in amx intrinsics. if (V->getType()->isX86_AMXTy()) diff --git a/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl b/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl index dc7069decaaa61..7688dfa55a78e3 100644 --- a/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl +++ b/clang/test/CodeGenOpenCL/builtins-amdgcn-dl-insts-gfx11.cl @@ -11,7 +11,10 @@ typedef unsigned short __attribute__((ext_vector_type(2))) ushort2; // CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 false) // CHECK: call float @llvm.amdgcn.fdot2(<2 x half> %v2hA, <2 x half> %v2hB, float %fC, i1 true) // CHECK: call half @llvm.amdgcn.fdot2.f16.f16(<2 x half> %v2hA, <2 x half> %v2hB, half %hC) -// CHECK: call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, i16 %sC) +// CHECK: [[s1:%[0-9]+]] = bitcast <2 x i16> %v2ssA to <2 x bfloat> +// CHECK-NEXT: [[s2:%[0-9]+]] = bitcast <2 x i16> %v2ssB to <2 x bfloat> +// CHECK-NEXT: [[s3:%[0-9]+]] = bitcast i16 %sC to bfloat +// CHECK-NEXT: [[d:%[0-9]+]] = tail call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> [[s1]], <2 x bfloat> [[s2]], bfloat [[s3]]) // CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 false) // CHECK: call float @llvm.amdgcn.fdot2.f32.bf16(<2 x i16> %v2ssA, <2 x i16> %v2ssB, float %fC, i1 true) // CHECK: call i32 @llvm.amdgcn.udot4(i32 %uiA, i32 %uiB, i32 %uiC, i1 false) diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td index 202fa4e8f4ea81..6795fb7aa0edb8 100644 --- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td +++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td @@ -2819,11 +2819,11 @@ def int_amdgcn_fdot2_f16_f16 : def int_amdgcn_fdot2_bf16_bf16 : ClangBuiltin<"__builtin_amdgcn_fdot2_bf16_bf16">, DefaultAttrsIntrinsic< - [llvm_i16_ty], // %r + [llvm_bfloat_ty], // %r [ - llvm_v2i16_ty, // %a - llvm_v2i16_ty, // %b - llvm_i16_ty // %c + llvm_v2bf16_ty, // %a + llvm_v2bf16_ty, // %b + llvm_bfloat_ty // %c ], [IntrNoMem, IntrSpeculatable] >; diff --git a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp index 79ad6ddf7861fc..883b30562e911b 100644 --- a/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp +++ b/llvm/lib/Target/AMDGPU/AsmParser/AMDGPUAsmParser.cpp @@ -475,6 +475,8 @@ class AMDGPUOperand : public MCParsedAsmOperand { bool isSSrcF64() const { return isSCSrc_b64() || isLiteralImm(MVT::f64); } + bool isSSrc_bf16() const { return isSCSrcB16() || isLiteralImm(MVT::bf16); } + bool isSSrc_f16() const { return isSCSrcB16() || isLiteralImm(MVT::f16); } bool isSSrcV2F16() const { @@ -541,22 +543,40 @@ class AMDGPUOperand : public MCParsedAsmOperand { return isRegOrInlineNoMods(AMDGPU::VS_64RegClassID, MVT::f64); } + bool isVCSrcTBF16() const { + return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::bf16); + } + bool isVCSrcTF16() const { return isRegOrInlineNoMods(AMDGPU::VS_16RegClassID, MVT::f16); } + bool isVCSrcTBF16_Lo128() const { + return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::bf16); + } + bool isVCSrcTF16_Lo128() const { return isRegOrInlineNoMods(AMDGPU::VS_16_Lo128RegClassID, MVT::f16); } + bool isVCSrcFake16BF16_Lo128() const { + return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::bf16); + } + bool isVCSrcFake16F16_Lo128() const { return isRegOrInlineNoMods(AMDGPU::VS_32_Lo128RegClassID, MVT::f16); } + bool isVCSrc_bf16() const { + return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::bf16); + } + bool isVCSrc_f16() const { return isRegOrInlineNoMods(AMDGPU::VS_32RegClassID, MVT::f16); } + bool isVCSrc_v2bf16() const { return isVCSrc_bf16(); } + bool isVCSrc_v2f16() const { return isVCSrc_f16(); } bool isVSrc_b32() const { @@ -597,18 +617,34 @@ class AMDGPUOperand : public MCParsedAsmOperand { bool isVSrc_f64() const { return isVCSrcF64() || isLiteralImm(MVT::f64); } + bool isVSrcT_bf16() const { return isVCSrcTBF16() || isLiteralImm(MVT::bf16); } + bool isVSrcT_f16() const { return isVCSrcTF16() || isLiteralImm(MVT::f16); } + bool isVSrcT_bf16_Lo128() const { + return isVCSrcTBF16_Lo128() || isLiteralImm(MVT::bf16); + } + bool isVSrcT_f16_Lo128() const { return isVCSrcTF16_Lo128() || isLiteralImm(MVT::f16); } + bool isVSrcFake16_bf16_Lo128() const { + return isVCSrcFake16BF16_Lo128() || isLiteralImm(MVT::bf16); + } + bool isVSrcFake16_f16_Lo128() const { return isVCSrcFake16F16_Lo128() || isLiteralImm(MVT::f16); } + bool isVSrc_bf16() const { return isVCSrc_bf16() || isLiteralImm(MVT::bf16); } + bool isVSrc_f16() const { return isVCSrc_f16() || isLiteralImm(MVT::f16); } + bool isVSrc_v2bf16() const { + return isVSrc_bf16() || isLiteralImm(MVT::v2bf16); + } + bool isVSrc_v2f16() const { return isVSrc_f16() || isLiteralImm(MVT::v2f16); } bool isVISrcB32() const { @@ -635,6 +671,10 @@ class AMDGPUOperand : public MCParsedAsmOperand { return isVISrcF16() || isVISrcB32(); } + bool isVISrc_64_bf16() const { + return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::bf16); + } + bool isVISrc_64_f16() const { return isRegOrInlineNoMods(AMDGPU::VReg_64RegClassID, MVT::f16); } @@ -803,6 +843,10 @@ class AMDGPUOperand : public MCParsedAsmOperand { return isAISrc_128F16() || isAISrc_128_b32(); } + bool isVISrc_128_bf16() const { + return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::bf16); + } + bool isVISrc_128_f16() const { return isRegOrInlineNoMods(AMDGPU::VReg_128RegClassID, MVT::f16); } @@ -1890,6 +1934,14 @@ static const fltSemantics *getOpFltSemantics(uint8_t OperandType) { case AMDGPU::OPERAND_REG_IMM_V2FP16: case AMDGPU::OPERAND_KIMM16: return &APFloat::IEEEhalf(); + case AMDGPU::OPERAND_REG_IMM_BF16: + case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED: + case AMDGPU::OPERAND_REG_INLINE_C_BF16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: + case AMDGPU::OPERAND_REG_IMM_V2BF16: + return &APFloat::BFloat(); default: llvm_unreachable("unsupported fp type"); } @@ -2186,17 +2238,24 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo case AMDGPU::OPERAND_REG_INLINE_AC_INT32: case AMDGPU::OPERAND_REG_INLINE_AC_FP32: case AMDGPU::OPERAND_REG_IMM_INT16: + case AMDGPU::OPERAND_REG_IMM_BF16: case AMDGPU::OPERAND_REG_IMM_FP16: + case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED: case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED: case AMDGPU::OPERAND_REG_INLINE_C_INT16: + case AMDGPU::OPERAND_REG_INLINE_C_BF16: case AMDGPU::OPERAND_REG_INLINE_C_FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2INT16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_INLINE_AC_INT16: + case AMDGPU::OPERAND_REG_INLINE_AC_BF16: case AMDGPU::OPERAND_REG_INLINE_AC_FP16: case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: case AMDGPU::OPERAND_REG_IMM_V2INT16: + case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_IMM_V2FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP32: case AMDGPU::OPERAND_REG_IMM_V2FP32: @@ -2240,6 +2299,7 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo case AMDGPU::OPERAND_REG_INLINE_AC_INT32: case AMDGPU::OPERAND_REG_INLINE_AC_FP32: case AMDGPU::OPERAND_REG_IMM_V2INT16: + case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_IMM_V2FP16: case AMDGPU::OPERAND_REG_IMM_V2FP32: case AMDGPU::OPERAND_REG_INLINE_C_V2FP32: @@ -2295,6 +2355,22 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo setImmKindLiteral(); return; + case AMDGPU::OPERAND_REG_IMM_BF16: + case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED: + case AMDGPU::OPERAND_REG_INLINE_C_BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_BF16: + if (isSafeTruncation(Val, 16) && + AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val), + AsmParser->hasInv2PiInlineImm())) { + Inst.addOperand(MCOperand::createImm(Val)); + setImmKindConst(); + return; + } + + Inst.addOperand(MCOperand::createImm(Val & 0xffff)); + setImmKindLiteral(); + return; + case AMDGPU::OPERAND_REG_INLINE_C_V2INT16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: @@ -2306,6 +2382,17 @@ void AMDGPUOperand::addLiteralImmOperand(MCInst &Inst, int64_t Val, bool ApplyMo Inst.addOperand(MCOperand::createImm(Val)); return; } + + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: { + assert(isSafeTruncation(Val, 16)); + assert(AMDGPU::isInlinableLiteralBF16(static_cast<int16_t>(Val), + AsmParser->hasInv2PiInlineImm())); + + Inst.addOperand(MCOperand::createImm(Val)); + return; + } + case AMDGPU::OPERAND_KIMM32: Inst.addOperand(MCOperand::createImm(Literal.getLoBits(32).getZExtValue())); setImmKindMandatoryLiteral(); @@ -3429,6 +3516,11 @@ bool AMDGPUAsmParser::isInlineConstant(const MCInst &Inst, OperandType == AMDGPU::OPERAND_REG_IMM_V2FP16) return AMDGPU::isInlinableLiteralV2F16(Val); + if (OperandType == AMDGPU::OPERAND_REG_INLINE_C_V2BF16 || + OperandType == AMDGPU::OPERAND_REG_INLINE_AC_V2BF16 || + OperandType == AMDGPU::OPERAND_REG_IMM_V2BF16) + return AMDGPU::isInlinableLiteralV2BF16(Val); + return AMDGPU::isInlinableLiteral16(Val, hasInv2PiInlineImm()); } default: diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp index 4ab3aa5a0240ad..a45fea6701f35a 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.cpp @@ -488,6 +488,47 @@ static bool printImmediateFloat16(uint32_t Imm, const MCSubtargetInfo &STI, return true; } +static bool printImmediateBFloat16(uint32_t Imm, const MCSubtargetInfo &STI, + raw_ostream &O) { + if (Imm == 0x3F80) + O << "1.0"; + else if (Imm == 0xBF80) + O << "-1.0"; + else if (Imm == 0x3F00) + O << "0.5"; + else if (Imm == 0xBF00) + O << "-0.5"; + else if (Imm == 0x4000) + O << "2.0"; + else if (Imm == 0xC000) + O << "-2.0"; + else if (Imm == 0x4080) + O << "4.0"; + else if (Imm == 0xC080) + O << "-4.0"; + else if (Imm == 0x3E22 && STI.hasFeature(AMDGPU::FeatureInv2PiInlineImm)) + O << "0.15915494"; + else + return false; + + return true; +} + +void AMDGPUInstPrinter::printImmediateBF16(uint32_t Imm, + const MCSubtargetInfo &STI, + raw_ostream &O) { + int16_t SImm = static_cast<int16_t>(Imm); + if (isInlinableIntLiteral(SImm)) { + O << SImm; + return; + } + + if (printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O)) + return; + + O << formatHex(static_cast<uint64_t>(Imm)); +} + void AMDGPUInstPrinter::printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI, raw_ostream &O) { @@ -528,6 +569,13 @@ void AMDGPUInstPrinter::printImmediateV216(uint32_t Imm, uint8_t OpType, printImmediateFloat16(static_cast<uint16_t>(Imm), STI, O)) return; break; + case AMDGPU::OPERAND_REG_IMM_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: + if (isUInt<16>(Imm) && + printImmediateBFloat16(static_cast<uint16_t>(Imm), STI, O)) + return; + break; default: llvm_unreachable("bad operand type"); } @@ -799,11 +847,20 @@ void AMDGPUInstPrinter::printRegularOperand(const MCInst *MI, unsigned OpNo, case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED: printImmediate16(Op.getImm(), STI, O); break; + case AMDGPU::OPERAND_REG_INLINE_C_BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_BF16: + case AMDGPU::OPERAND_REG_IMM_BF16: + case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED: + printImmediateBF16(Op.getImm(), STI, O); + break; case AMDGPU::OPERAND_REG_IMM_V2INT16: + case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_IMM_V2FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2INT16: case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: printImmediateV216(Op.getImm(), OpTy, STI, O); break; diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h index e91ff86b219a0c..15ecbf2e5e5918 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUInstPrinter.h @@ -88,6 +88,8 @@ class AMDGPUInstPrinter : public MCInstPrinter { raw_ostream &O); void printImmediate16(uint32_t Imm, const MCSubtargetInfo &STI, raw_ostream &O); + void printImmediateBF16(uint32_t Imm, const MCSubtargetInfo &STI, + raw_ostream &O); void printImmediateV216(uint32_t Imm, uint8_t OpType, const MCSubtargetInfo &STI, raw_ostream &O); bool printImmediateFloat32(uint32_t Imm, const MCSubtargetInfo &STI, diff --git a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp index 11f5e456e8d348..e51bb40132f96e 100644 --- a/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp +++ b/llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCCodeEmitter.cpp @@ -157,6 +157,27 @@ static uint32_t getLit16Encoding(uint16_t Val, const MCSubtargetInfo &STI) { return 255; } +static uint32_t getLitBF16Encoding(uint16_t Val) { + uint16_t IntImm = getIntInlineImmEncoding(static_cast<int16_t>(Val)); + if (IntImm != 0) + return IntImm; + + // clang-format off + switch (Val) { + case 0x3F00: return 240; // 0.5 + case 0xBF00: return 241; // -0.5 + case 0x3F80: return 242; // 1.0 + case 0xBF80: return 243; // -1.0 + case 0x4000: return 244; // 2.0 + case 0xC000: return 245; // -2.0 + case 0x4080: return 246; // 4.0 + case 0xC080: return 247; // -4.0 + case 0x3E22: return 248; // 1.0 / (2.0 * pi) + default: return 255; + } + // clang-format on +} + static uint32_t getLit32Encoding(uint32_t Val, const MCSubtargetInfo &STI) { uint32_t IntImm = getIntInlineImmEncoding(static_cast<int32_t>(Val)); if (IntImm != 0) @@ -276,6 +297,7 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO, case AMDGPU::OPERAND_REG_INLINE_C_INT16: case AMDGPU::OPERAND_REG_INLINE_AC_INT16: return getLit16IntEncoding(static_cast<uint16_t>(Imm), STI); + case AMDGPU::OPERAND_REG_IMM_FP16: case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED: case AMDGPU::OPERAND_REG_INLINE_C_FP16: @@ -283,16 +305,33 @@ AMDGPUMCCodeEmitter::getLitEncoding(const MCOperand &MO, // FIXME Is this correct? What do inline immediates do on SI for f16 src // which does not have f16 support? return getLit16Encoding(static_cast<uint16_t>(Imm), STI); + + case AMDGPU::OPERAND_REG_IMM_BF16: + case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED: + case AMDGPU::OPERAND_REG_INLINE_C_BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_BF16: + // We don't actually need to check Inv2Pi here because BF16 instructions can + // only be emitted for targets that already support the feature. + return getLitBF16Encoding(static_cast<uint16_t>(Imm)); + case AMDGPU::OPERAND_REG_IMM_V2INT16: case AMDGPU::OPERAND_REG_INLINE_C_V2INT16: case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: return AMDGPU::getInlineEncodingV2I16(static_cast<uint32_t>(Imm)) .value_or(255); + case AMDGPU::OPERAND_REG_IMM_V2FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: return AMDGPU::getInlineEncodingV2F16(static_cast<uint32_t>(Imm)) .value_or(255); + + case AMDGPU::OPERAND_REG_IMM_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: + return AMDGPU::getInlineEncodingV2BF16(static_cast<uint32_t>(Imm)) + .value_or(255); + case AMDGPU::OPERAND_KIMM32: case AMDGPU::OPERAND_KIMM16: return MO.getImm(); diff --git a/llvm/lib/Target/AMDGPU/SIDefines.h b/llvm/lib/Target/AMDGPU/SIDefines.h index ca6728cf3ddc67..bfbd5b13ccc404 100644 --- a/llvm/lib/Target/AMDGPU/SIDefines.h +++ b/llvm/lib/Target/AMDGPU/SIDefines.h @@ -202,9 +202,12 @@ enum OperandType : unsigned { OPERAND_REG_IMM_INT16, OPERAND_REG_IMM_FP32, OPERAND_REG_IMM_FP64, + OPERAND_REG_IMM_BF16, OPERAND_REG_IMM_FP16, + OPERAND_REG_IMM_BF16_DEFERRED, OPERAND_REG_IMM_FP16_DEFERRED, OPERAND_REG_IMM_FP32_DEFERRED, + OPERAND_REG_IMM_V2BF16, OPERAND_REG_IMM_V2FP16, OPERAND_REG_IMM_V2INT16, OPERAND_REG_IMM_V2INT32, @@ -214,10 +217,12 @@ enum OperandType : unsigned { OPERAND_REG_INLINE_C_INT16, OPERAND_REG_INLINE_C_INT32, OPERAND_REG_INLINE_C_INT64, + OPERAND_REG_INLINE_C_BF16, OPERAND_REG_INLINE_C_FP16, OPERAND_REG_INLINE_C_FP32, OPERAND_REG_INLINE_C_FP64, OPERAND_REG_INLINE_C_V2INT16, + OPERAND_REG_INLINE_C_V2BF16, OPERAND_REG_INLINE_C_V2FP16, OPERAND_REG_INLINE_C_V2INT32, OPERAND_REG_INLINE_C_V2FP32, @@ -232,10 +237,12 @@ enum OperandType : unsigned { /// Operands with an AccVGPR register or inline constant OPERAND_REG_INLINE_AC_INT16, OPERAND_REG_INLINE_AC_INT32, + OPERAND_REG_INLINE_AC_BF16, OPERAND_REG_INLINE_AC_FP16, OPERAND_REG_INLINE_AC_FP32, OPERAND_REG_INLINE_AC_FP64, OPERAND_REG_INLINE_AC_V2INT16, + OPERAND_REG_INLINE_AC_V2BF16, OPERAND_REG_INLINE_AC_V2FP16, OPERAND_REG_INLINE_AC_V2INT32, OPERAND_REG_INLINE_AC_V2FP32, diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index f5ec831234f2f9..b7236ddbb64880 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -4185,6 +4185,10 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO, case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: return AMDGPU::isInlinableLiteralV2F16(Imm); + case AMDGPU::OPERAND_REG_IMM_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: + return AMDGPU::isInlinableLiteralV2BF16(Imm); case AMDGPU::OPERAND_REG_IMM_FP16: case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED: case AMDGPU::OPERAND_REG_INLINE_C_FP16: @@ -4201,6 +4205,17 @@ bool SIInstrInfo::isInlineConstant(const MachineOperand &MO, return false; } + case AMDGPU::OPERAND_REG_IMM_BF16: + case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED: + case AMDGPU::OPERAND_REG_INLINE_C_BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_BF16: { + if (isInt<16>(Imm) || isUInt<16>(Imm)) { + int16_t Trunc = static_cast<int16_t>(Imm); + return ST.has16BitInsts() && + AMDGPU::isInlinableLiteralBF16(Trunc, ST.hasInv2PiInlineImm()); + } + return false; + } case AMDGPU::OPERAND_KIMM32: case AMDGPU::OPERAND_KIMM16: return false; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td index 22599773d562cb..b0daec4a350eb3 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td @@ -1497,20 +1497,17 @@ class getVOPSrc0ForVT<ValueType VT, bit IsTrue16, bit IsFake16 = 1> { RegisterOperand ret = !if(VT.isFP, !if(!eq(VT.Size, 64), - VSrc_f64, - !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), - !if(IsTrue16, - !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), - VSrc_f16 - ), - !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), - VSrc_v2f16, - !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), - AVSrc_64, - VSrc_f32 + VSrc_f64, + !if(!eq(VT.Value, f16.Value), + !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16), + !if(!eq(VT.Value, bf16.Value), + !if(IsTrue16, !if(IsFake16, VSrcFake16_bf16_Lo128, VSrcT_bf16_Lo128), VSrc_bf16), + !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), + !if(!eq(VT.Value, v2f16.Value), VSrc_v2f16, VSrc_v2bf16), + !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32) + ) ) - ) - ) + ) ), !if(!eq(VT.Size, 64), VSrc_b64, @@ -1569,16 +1566,20 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> { !if(!eq(VT.Value, i1.Value), SSrc_i1, !if(VT.isFP, - !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), - !if(IsTrue16, VSrcT_f16, VSrc_f16), - !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), - VSrc_v2f16, - !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), - AVSrc_64, - VSrc_f32 - ) - ) - ), + !if(!eq(VT.Value, f16.Value), + !if(IsTrue16, VSrcT_f16, VSrc_f16), + !if(!eq(VT.Value, bf16.Value), + !if(IsTrue16, VSrcT_bf16, VSrc_bf16), + !if(!eq(VT.Value, v2f16.Value), + VSrc_v2f16, + !if(!eq(VT.Value, v2bf16.Value), + VSrc_v2bf16, + !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), + AVSrc_64, VSrc_f32) + ) + ) + ) + ), !if(!eq(VT.Value, i16.Value), !if(IsTrue16, VSrcT_b16, VSrc_b16), !if(!eq(VT.Value, v2i16.Value), @@ -1597,8 +1598,13 @@ class getVOP3DPPSrcForVT<ValueType VT> { RegisterOperand ret = !if (!eq(VT.Value, i1.Value), SSrc_i1, !if (VT.isFP, - !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16, - !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)), + !if(!eq(VT.Value, f16.Value), VCSrc_f16, + !if(!eq(VT.Value, bf16.Value), VCSrc_bf16, + !if(!eq(VT.Value, v2f16.Value), VCSrc_v2f16, + !if(!eq(VT.Value, v2bf16.Value), VCSrc_v2bf16, VCSrc_f32) + ) + ) + ), !if (!eq(VT.Value, i16.Value), VCSrc_b16, !if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16, VCSrc_b32)))); @@ -2528,7 +2534,7 @@ def VOP_V2I16_F32_F32 : VOPProfile <[v2i16, f32, f32, untyped]>; def VOP_V2I16_I32_I32 : VOPProfile <[v2i16, i32, i32, untyped]>; def VOP_F16_V2F16_V2F16_F16 : VOPProfile <[f16, v2f16, v2f16, f16]>; -def VOP_I16_V2I16_V2I16_I16 : VOPProfile <[i16, v2i16, v2i16, i16]>; +def VOP_BF16_V2BF16_V2BF16_BF16: VOPProfile <[bf16, v2bf16, v2bf16, bf16]>; def VOP_F32_V2I16_V2I16_F32 : VOPProfile <[f32, v2i16, v2i16, f32]>; def VOP_F32_V2F16_V2F16_V2F16 : VOPProfile <[f32, v2f16, v2f16, v2f16]>; diff --git a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td index d4a1e8d185a1d5..b846d999fc9cfc 100644 --- a/llvm/lib/Target/AMDGPU/SIRegisterInfo.td +++ b/llvm/lib/Target/AMDGPU/SIRegisterInfo.td @@ -1066,7 +1066,7 @@ multiclass AVRegClass<int numRegs, list<ValueType> regTypes, // Define the regular class. def "" : VRegClassBase<numRegs, regTypes, (add vregList, aregList)>; - // Define 2-aligned variant + // Define 2-aligned variant def _Align2 : VRegClassBase<numRegs, regTypes, (add (decimate vregList, 2), (decimate aregList, 2))> { @@ -1119,6 +1119,7 @@ class SrcRegOrImm9<RegisterClass regClass, string opWidth, string operandType, } def SSrc_b16 : SrcRegOrImm9 <SReg_32, "OPW32", "OPERAND_REG_IMM_INT16", 16>; +def SSrc_bf16: SrcRegOrImm9 <SReg_32, "OPW32", "OPERAND_REG_IMM_BF16", 16>; def SSrc_f16 : SrcRegOrImm9 <SReg_32, "OPW32", "OPERAND_REG_IMM_FP16", 16>; def SSrc_b32 : SrcRegOrImm9 <SReg_32, "OPW32", "OPERAND_REG_IMM_INT32", 32>; def SSrc_f32 : SrcRegOrImm9 <SReg_32, "OPW32", "OPERAND_REG_IMM_FP32", 32>; @@ -1153,6 +1154,7 @@ def SCSrc_b64 : SrcRegOrImm9 <SReg_64, "OPW64", "OPERAND_REG_INLINE_C_INT64", 64 // The current and temporary future default used case for VOP3. def VSrc_b16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_INT16", 16>; +def VSrc_bf16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_BF16", 16>; def VSrc_f16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_FP16", 16>; // True16 VOP3 operands. @@ -1160,6 +1162,10 @@ def VSrcT_b16 : RegOrImmOperand <VS_16, "OPERAND_REG_IMM_INT16"> { let EncoderMethod = "getMachineOpValueT16"; let DecoderMethod = "decodeOperand_VSrcT16"; } +def VSrcT_bf16 : RegOrImmOperand <VS_16, "OPERAND_REG_IMM_BF16"> { + let EncoderMethod = "getMachineOpValueT16"; + let DecoderMethod = "decodeOperand_VSrcT16"; +} def VSrcT_f16 : RegOrImmOperand <VS_16, "OPERAND_REG_IMM_FP16"> { let EncoderMethod = "getMachineOpValueT16"; let DecoderMethod = "decodeOperand_VSrcT16"; @@ -1170,6 +1176,10 @@ def VSrcT_b16_Lo128 : RegOrImmOperand <VS_16_Lo128, "OPERAND_REG_IMM_INT16"> { let EncoderMethod = "getMachineOpValueT16Lo128"; let DecoderMethod = "decodeOperand_VSrcT16_Lo128"; } +def VSrcT_bf16_Lo128 : RegOrImmOperand <VS_16_Lo128, "OPERAND_REG_IMM_BF16"> { + let EncoderMethod = "getMachineOpValueT16Lo128"; + let DecoderMethod = "decodeOperand_VSrcT16_Lo128"; +} def VSrcT_f16_Lo128 : RegOrImmOperand <VS_16_Lo128, "OPERAND_REG_IMM_FP16"> { let EncoderMethod = "getMachineOpValueT16Lo128"; let DecoderMethod = "decodeOperand_VSrcT16_Lo128"; @@ -1178,11 +1188,13 @@ def VSrcT_f16_Lo128 : RegOrImmOperand <VS_16_Lo128, "OPERAND_REG_IMM_FP16"> { // The current and temporary future default used case for fake VOP1/2/C. // For VOP1,2,C True16 instructions. _Lo128 use first 128 32-bit VGPRs only. def VSrcFake16_b16_Lo128 : SrcRegOrImm9 <VS_32_Lo128, "OPW16", "OPERAND_REG_IMM_INT16", 16>; +def VSrcFake16_bf16_Lo128 : SrcRegOrImm9 <VS_32_Lo128, "OPW16", "OPERAND_REG_IMM_BF16", 16>; def VSrcFake16_f16_Lo128 : SrcRegOrImm9 <VS_32_Lo128, "OPW16", "OPERAND_REG_IMM_FP16", 16>; def VSrc_b32 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_INT32", 32>; def VSrc_f32 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_FP32", 32>; def VSrc_v2b16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_V2INT16", 32>; +def VSrc_v2bf16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_V2BF16", 16>; def VSrc_v2f16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_IMM_V2FP16", 16>; def VSrc_b64 : SrcRegOrImm9 <VS_64, "OPW64", "OPERAND_REG_IMM_INT64", 64>; def VSrc_f64 : SrcRegOrImm9 <VS_64, "OPW64", "OPERAND_REG_IMM_FP64", 64> { @@ -1196,9 +1208,12 @@ def VSrc_v2f32 : SrcRegOrImm9 <VS_64, "OPW64", "OPERAND_REG_IMM_V2FP32", 32>; // with FMAMK/FMAAK //===----------------------------------------------------------------------===// +def VSrc_bf16_Deferred : SrcRegOrImmDeferred9<VS_32, "OPW16", "OPERAND_REG_IMM_BF16_DEFERRED", 16>; def VSrc_f16_Deferred : SrcRegOrImmDeferred9<VS_32, "OPW16", "OPERAND_REG_IMM_FP16_DEFERRED", 16>; def VSrc_f32_Deferred : SrcRegOrImmDeferred9<VS_32, "OPW32", "OPERAND_REG_IMM_FP32_DEFERRED", 32>; +def VSrcFake16_bf16_Lo128_Deferred + : SrcRegOrImmDeferred9<VS_32_Lo128, "OPW16", "OPERAND_REG_IMM_BF16_DEFERRED", 16>; def VSrcFake16_f16_Lo128_Deferred : SrcRegOrImmDeferred9<VS_32_Lo128, "OPW16", "OPERAND_REG_IMM_FP16_DEFERRED", 16>; @@ -1258,19 +1273,23 @@ def ARegSrc_32 : AVOperand<AGPR_32, "decodeSrcA9", "OPW32">; //===----------------------------------------------------------------------===// def VCSrc_b16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_INT16", 16>; +def VCSrc_bf16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_BF16", 16>; def VCSrc_f16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_FP16", 16>; def VCSrc_b32 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_INT32", 32>; def VCSrc_f32 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_FP32", 32>; def VCSrc_v2b16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_V2INT16", 32>; +def VCSrc_v2bf16: SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_V2BF16", 16>; def VCSrc_v2f16 : SrcRegOrImm9 <VS_32, "OPW32", "OPERAND_REG_INLINE_C_V2FP16", 16>; //===----------------------------------------------------------------------===// // VISrc_* Operands with a VGPR or an inline constant //===----------------------------------------------------------------------===// +def VISrc_64_bf16 : SrcRegOrImm9 <VReg_64, "OPW64", "OPERAND_REG_INLINE_C_BF16", 16>; def VISrc_64_f16 : SrcRegOrImm9 <VReg_64, "OPW64", "OPERAND_REG_INLINE_C_FP16", 16>; def VISrc_64_b32 : SrcRegOrImm9 <VReg_64, "OPW64", "OPERAND_REG_INLINE_C_INT32", 32>; def VISrc_64_f64 : SrcRegOrImm9 <VReg_64, "OPW64", "OPERAND_REG_INLINE_C_FP64", 64>; +def VISrc_128_bf16 : SrcRegOrImm9 <VReg_128, "OPW128", "OPERAND_REG_INLINE_C_BF16", 16>; def VISrc_128_f16 : SrcRegOrImm9 <VReg_128, "OPW128", "OPERAND_REG_INLINE_C_FP16", 16>; def VISrc_128_b32 : SrcRegOrImm9 <VReg_128, "OPW128", "OPERAND_REG_INLINE_C_INT32", 32>; def VISrc_128_f32 : SrcRegOrImm9 <VReg_128, "OPW128", "OPERAND_REG_INLINE_C_FP32", 32>; diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp index 800dfcf3076dd3..dacdf7b5cd9a8e 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -2652,6 +2652,23 @@ bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi) { (Val == 0x3e22f983 && HasInv2Pi); } +bool isInlinableLiteralBF16(int16_t Literal, bool HasInv2Pi) { + if (!HasInv2Pi) + return false; + if (isInlinableIntLiteral(Literal)) + return true; + uint16_t Val = static_cast<uint16_t>(Literal); + return Val == 0x3F00 || // 0.5 + Val == 0xBF00 || // -0.5 + Val == 0x3F80 || // 1.0 + Val == 0xBF80 || // -1.0 + Val == 0x4000 || // 2.0 + Val == 0xC000 || // -2.0 + Val == 0x4080 || // 4.0 + Val == 0xC080 || // -4.0 + Val == 0x3E22; // 1.0 / (2.0 * pi) +} + bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi) { if (!HasInv2Pi) return false; @@ -2730,6 +2747,34 @@ std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal) { return getInlineEncodingV216(false, Literal); } +// Encoding of the literal as an inline constant for a V_PK_*_BF16 instruction +// or nullopt. +std::optional<unsigned> getInlineEncodingV2BF16(uint32_t Literal) { + int32_t Signed = static_cast<int32_t>(Literal); + if (Signed >= 0 && Signed <= 64) + return 128 + Signed; + + if (Signed >= -16 && Signed <= -1) + return 192 + std::abs(Signed); + + // clang-format off + switch (Literal) { + case 0x3F00: return 240; // 0.5 + case 0xBF00: return 241; // -0.5 + case 0x3F80: return 242; // 1.0 + case 0xBF80: return 243; // -1.0 + case 0x4000: return 244; // 2.0 + case 0xC000: return 245; // -2.0 + case 0x4080: return 246; // 4.0 + case 0xC080: return 247; // -4.0 + case 0x3E22: return 248; // 1.0 / (2.0 * pi) + default: break; + } + // clang-format on + + return std::nullopt; +} + // Encoding of the literal as an inline constant for a V_PK_*_F16 instruction // or nullopt. std::optional<unsigned> getInlineEncodingV2F16(uint32_t Literal) { @@ -2747,6 +2792,10 @@ bool isInlinableLiteralV216(uint32_t Literal, uint8_t OpType) { case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: return getInlineEncodingV216(true, Literal).has_value(); + case AMDGPU::OPERAND_REG_IMM_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: + return isInlinableLiteralV2BF16(Literal); default: llvm_unreachable("bad packed operand type"); } @@ -2757,6 +2806,11 @@ bool isInlinableLiteralV2I16(uint32_t Literal) { return getInlineEncodingV2I16(Literal).has_value(); } +// Whether the given literal can be inlined for a V_PK_*_BF16 instruction. +bool isInlinableLiteralV2BF16(uint32_t Literal) { + return getInlineEncodingV2BF16(Literal).has_value(); +} + // Whether the given literal can be inlined for a V_PK_*_F16 instruction. bool isInlinableLiteralV2F16(uint32_t Literal) { return getInlineEncodingV2F16(Literal).has_value(); diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h index b56025f55519a5..f35e7744528290 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h @@ -1335,17 +1335,24 @@ inline unsigned getOperandSize(const MCOperandInfo &OpInfo) { return 8; case AMDGPU::OPERAND_REG_IMM_INT16: + case AMDGPU::OPERAND_REG_IMM_BF16: case AMDGPU::OPERAND_REG_IMM_FP16: + case AMDGPU::OPERAND_REG_IMM_BF16_DEFERRED: case AMDGPU::OPERAND_REG_IMM_FP16_DEFERRED: case AMDGPU::OPERAND_REG_INLINE_C_INT16: + case AMDGPU::OPERAND_REG_INLINE_C_BF16: case AMDGPU::OPERAND_REG_INLINE_C_FP16: case AMDGPU::OPERAND_REG_INLINE_C_V2INT16: + case AMDGPU::OPERAND_REG_INLINE_C_V2BF16: case AMDGPU::OPERAND_REG_INLINE_C_V2FP16: case AMDGPU::OPERAND_REG_INLINE_AC_INT16: + case AMDGPU::OPERAND_REG_INLINE_AC_BF16: case AMDGPU::OPERAND_REG_INLINE_AC_FP16: case AMDGPU::OPERAND_REG_INLINE_AC_V2INT16: + case AMDGPU::OPERAND_REG_INLINE_AC_V2BF16: case AMDGPU::OPERAND_REG_INLINE_AC_V2FP16: case AMDGPU::OPERAND_REG_IMM_V2INT16: + case AMDGPU::OPERAND_REG_IMM_V2BF16: case AMDGPU::OPERAND_REG_IMM_V2FP16: return 2; @@ -1373,12 +1380,18 @@ bool isInlinableLiteral64(int64_t Literal, bool HasInv2Pi); LLVM_READNONE bool isInlinableLiteral32(int32_t Literal, bool HasInv2Pi); +LLVM_READNONE +bool isInlinableLiteralBF16(int16_t Literal, bool HasInv2Pi); + LLVM_READNONE bool isInlinableLiteral16(int16_t Literal, bool HasInv2Pi); LLVM_READNONE std::optional<unsigned> getInlineEncodingV2I16(uint32_t Literal); +LLVM_READNONE +std::optional<unsigned> getInlineEncodingV2BF16(uint32_t Literal); + LLVM_READNONE std::optional<unsigned> getInlineEncodingV2F16(uint32_t Literal); @@ -1388,6 +1401,9 @@ bool isInlinableLiteralV216(uint32_t Literal, uint8_t OpType); LLVM_READNONE bool isInlinableLiteralV2I16(uint32_t Literal); +LLVM_READNONE +bool isInlinableLiteralV2BF16(uint32_t Literal); + LLVM_READNONE bool isInlinableLiteralV2F16(uint32_t Literal); diff --git a/llvm/lib/Target/AMDGPU/VOP3Instructions.td b/llvm/lib/Target/AMDGPU/VOP3Instructions.td index 8d965d3b9041d5..35cffa22f45929 100644 --- a/llvm/lib/Target/AMDGPU/VOP3Instructions.td +++ b/llvm/lib/Target/AMDGPU/VOP3Instructions.td @@ -904,7 +904,7 @@ let SubtargetPredicate = isGFX12Plus, ReadsModeReg = 0 in { let SubtargetPredicate = HasDot9Insts, IsDOT=1 in { defm V_DOT2_F16_F16 : VOP3Inst<"v_dot2_f16_f16", VOP3_DOT_Profile<VOP_F16_V2F16_V2F16_F16>, int_amdgcn_fdot2_f16_f16>; - defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_I16_V2I16_V2I16_I16>, int_amdgcn_fdot2_bf16_bf16>; + defm V_DOT2_BF16_BF16 : VOP3Inst<"v_dot2_bf16_bf16", VOP3_DOT_Profile<VOP_BF16_V2BF16_V2BF16_BF16>, int_amdgcn_fdot2_bf16_bf16>; } class VOP_Pseudo_Scalar<RegisterClass Dst, RegisterOperand SrcOp, diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll index 645e00b6a4a819..ca06a57be19ccd 100644 --- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll +++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.fdot2.bf16.bf16.ll @@ -1,8 +1,9 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py ; RUN: llc -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11,SDAG-GFX11 -; RUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11,GISEL-GFX11 +; FIXME: GlobalIsel doesn't support BF16 for now. +; xUN: llc -global-isel -mtriple=amdgcn -mcpu=gfx1100 -verify-machineinstrs < %s | FileCheck %s --check-prefixes=GFX11,GISEL-GFX11 -declare i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a, <2 x i16> %b, i16 %c) +declare bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a, <2 x bfloat> %b, bfloat %c) define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16( ; GFX11-LABEL: test_llvm_amdgcn_fdot2_bf16_bf16: @@ -24,11 +25,11 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16( ptr addrspace(1) %b, ptr addrspace(1) %c) { entry: - %a.val = load <2 x i16>, ptr addrspace(1) %a - %b.val = load <2 x i16>, ptr addrspace(1) %b - %c.val = load i16, ptr addrspace(1) %c - %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a.val, <2 x i16> %b.val, i16 %c.val) - store i16 %r.val, ptr addrspace(1) %r + %a.val = load <2 x bfloat>, ptr addrspace(1) %a + %b.val = load <2 x bfloat>, ptr addrspace(1) %b + %c.val = load bfloat, ptr addrspace(1) %c + %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a.val, <2 x bfloat> %b.val, bfloat %c.val) + store bfloat %r.val, ptr addrspace(1) %r ret void } @@ -61,14 +62,14 @@ define amdgpu_kernel void @test_llvm_amdgcn_fdot2_bf16_bf16_dpp( ptr addrspace(5) %b, ptr addrspace(5) %c) { entry: - %a.val = load <2 x i16>, ptr addrspace(5) %a - %b.val = load <2 x i16>, ptr addrspace(5) %b - %c.val = load i16, ptr addrspace(5) %c - %a.val.i32 = bitcast <2 x i16> %a.val to i32 + %a.val = load <2 x bfloat>, ptr addrspace(5) %a + %b.val = load <2 x bfloat>, ptr addrspace(5) %b + %c.val = load bfloat, ptr addrspace(5) %c + %a.val.i32 = bitcast <2 x bfloat> %a.val to i32 %dpp = call i32 @llvm.amdgcn.update.dpp.i32(i32 %a.val.i32, i32 %a.val.i32, i32 1, i32 15, i32 15, i1 1) - %a.val.dpp.v2i16 = bitcast i32 %dpp to <2 x i16> - %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a.val.dpp.v2i16, <2 x i16> %b.val, i16 %c.val) - store i16 %r.val, ptr addrspace(5) %r + %a.val.dpp.v2bfloat = bitcast i32 %dpp to <2 x bfloat> + %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a.val.dpp.v2bfloat, <2 x bfloat> %b.val, bfloat %c.val) + store bfloat %r.val, ptr addrspace(5) %r ret void } @@ -79,17 +80,17 @@ define amdgpu_ps void @test_llvm_amdgcn_fdot2_bf16_bf16_sis( ; GFX11: ; %bb.0: ; %entry ; GFX11-NEXT: v_mov_b32_e32 v2, s1 ; GFX11-NEXT: s_delay_alu instid0(VALU_DEP_1) -; GFX11-NEXT: v_dot2_bf16_bf16 v2, s0, 0x10001, v2 +; GFX11-NEXT: v_dot2_bf16_bf16 v2, s0, 0x3f803f80, v2 ; GFX11-NEXT: global_store_b16 v[0:1], v2, off ; GFX11-NEXT: s_nop 0 ; GFX11-NEXT: s_sendmsg sendmsg(MSG_DEALLOC_VGPRS) ; GFX11-NEXT: s_endpgm ptr addrspace(1) %r, - <2 x i16> inreg %a, - i16 inreg %c) { + <2 x bfloat> inreg %a, + bfloat inreg %c) { entry: - %r.val = call i16 @llvm.amdgcn.fdot2.bf16.bf16(<2 x i16> %a, <2 x i16> <i16 1, i16 1>, i16 %c) - store i16 %r.val, ptr addrspace(1) %r + %r.val = call bfloat @llvm.amdgcn.fdot2.bf16.bf16(<2 x bfloat> %a, <2 x bfloat> <bfloat 1.0, bfloat 1.0>, bfloat %c) + store bfloat %r.val, ptr addrspace(1) %r ret void } diff --git a/llvm/test/MC/AMDGPU/bf16_imm.s b/llvm/test/MC/AMDGPU/bf16_imm.s new file mode 100644 index 00000000000000..5ad1171b21f327 --- /dev/null +++ b/llvm/test/MC/AMDGPU/bf16_imm.s @@ -0,0 +1,8 @@ +// RUN: llvm-mc -arch=amdgcn -mcpu=gfx1100 -show-encoding %s | FileCheck %s +// RUN: llvm-mc -arch=amdgcn -mcpu=gfx1200 -show-encoding %s | FileCheck %s + +v_dot2_bf16_bf16 v5, v1, v2, 100.0 +// CHECK: v_dot2_bf16_bf16 v5, v1, v2, 0x42c8 ; encoding: [0x05,0x00,0x67,0xd6,0x01,0x05,0xfe,0x03,0xc8,0x42,0x00,0x00] + +v_dot2_bf16_bf16 v5, v1, v2, 1.0 +// CHECK: v_dot2_bf16_bf16 v5, v1, v2, 1.0 ; encoding: [0x05,0x00,0x67,0xd6,0x01,0x05,0xca,0x03] diff --git a/llvm/test/MC/Disassembler/AMDGPU/bf16_imm.txt b/llvm/test/MC/Disassembler/AMDGPU/bf16_imm.txt new file mode 100644 index 00000000000000..1cca2987fad869 --- /dev/null +++ b/llvm/test/MC/Disassembler/AMDGPU/bf16_imm.txt @@ -0,0 +1,8 @@ +# RUN: llvm-mc -triple=amdgcn -mcpu=gfx1100 -disassemble -show-encoding < %s | FileCheck %s +# RUN: llvm-mc -triple=amdgcn -mcpu=gfx1200 -disassemble -show-encoding < %s | FileCheck %s + +# CHECK: v_dot2_bf16_bf16 v5, v1, v2, 0x42c8 +0x05,0x00,0x67,0xd6,0x01,0x05,0xfe,0x03,0xc8,0x42,0x00,0x00 + +# CHECK: v_dot2_bf16_bf16 v5, v1, v2, 0x3c00 +0x05,0x00,0x67,0xd6,0x01,0x05,0xca,0x03 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits