This revision was not accepted when it landed; it landed in state "Needs 
Review".
This revision was automatically updated to reflect the committed changes.
Closed by commit rG8c24f33158d8: [IR][BFloat] Add BFloat IR type (authored by 
stuij).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D78190/new/

https://reviews.llvm.org/D78190

Files:
  clang/lib/Sema/SemaOpenMP.cpp
  llvm/docs/BitCodeFormat.rst
  llvm/docs/LangRef.rst
  llvm/include/llvm-c/Core.h
  llvm/include/llvm/ADT/APFloat.h
  llvm/include/llvm/Bitcode/LLVMBitCodes.h
  llvm/include/llvm/IR/Constants.h
  llvm/include/llvm/IR/DataLayout.h
  llvm/include/llvm/IR/IRBuilder.h
  llvm/include/llvm/IR/Type.h
  llvm/lib/AsmParser/LLLexer.cpp
  llvm/lib/AsmParser/LLParser.cpp
  llvm/lib/Bitcode/Reader/BitcodeReader.cpp
  llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
  llvm/lib/CodeGen/MIRParser/MILexer.cpp
  llvm/lib/IR/AsmWriter.cpp
  llvm/lib/IR/Constants.cpp
  llvm/lib/IR/Core.cpp
  llvm/lib/IR/DataLayout.cpp
  llvm/lib/IR/Function.cpp
  llvm/lib/IR/LLVMContextImpl.cpp
  llvm/lib/IR/LLVMContextImpl.h
  llvm/lib/IR/Type.cpp
  llvm/lib/Support/APFloat.cpp
  llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
  llvm/lib/Target/X86/X86ISelLowering.cpp
  llvm/test/Assembler/bfloat.ll
  llvm/tools/llvm-c-test/echo.cpp

Index: llvm/tools/llvm-c-test/echo.cpp
===================================================================
--- llvm/tools/llvm-c-test/echo.cpp
+++ llvm/tools/llvm-c-test/echo.cpp
@@ -72,6 +72,8 @@
         return LLVMVoidTypeInContext(Ctx);
       case LLVMHalfTypeKind:
         return LLVMHalfTypeInContext(Ctx);
+      case LLVMBFloatTypeKind:
+        return LLVMHalfTypeInContext(Ctx);
       case LLVMFloatTypeKind:
         return LLVMFloatTypeInContext(Ctx);
       case LLVMDoubleTypeKind:
Index: llvm/test/Assembler/bfloat.ll
===================================================================
--- /dev/null
+++ llvm/test/Assembler/bfloat.ll
@@ -0,0 +1,38 @@
+; RUN: llvm-as < %s | llvm-dis | FileCheck %s --check-prefix=ASSEM-DISASS
+; RUN: opt < %s -O3 -S | FileCheck %s --check-prefix=OPT
+; RUN: verify-uselistorder %s
+; Basic smoke tests for bfloat type.
+
+define bfloat @check_bfloat(bfloat %A) {
+; ASSEM-DISASS: ret bfloat %A
+    ret bfloat %A
+}
+
+define bfloat @check_bfloat_literal() {
+; ASSEM-DISASS: ret bfloat 0xR3149
+    ret bfloat 0xR3149
+}
+
+define <4 x bfloat> @check_fixed_vector() {
+; ASSEM-DISASS: ret <4 x bfloat> %tmp
+  %tmp = fadd <4 x bfloat> undef, undef
+  ret <4 x bfloat> %tmp
+}
+
+define <vscale x 4 x bfloat> @check_vector() {
+; ASSEM-DISASS: ret <vscale x 4 x bfloat> %tmp
+  %tmp = fadd <vscale x 4 x bfloat> undef, undef
+  ret <vscale x 4 x bfloat> %tmp
+}
+
+define bfloat @check_bfloat_constprop() {
+  %tmp = fadd bfloat 0xR40C0, 0xR40C0
+; OPT: 0xR4140
+  ret bfloat %tmp
+}
+
+define float @check_bfloat_convert() {
+  %tmp = fpext bfloat 0xR4C8D to float
+; OPT: 0x4191A00000000000
+  ret float %tmp
+}
Index: llvm/lib/Target/X86/X86ISelLowering.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelLowering.cpp
+++ llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -11526,8 +11526,9 @@
   MVT LogicVT = VT;
   if (EltVT == MVT::f32 || EltVT == MVT::f64) {
     Zero = DAG.getConstantFP(0.0, DL, EltVT);
-    AllOnes = DAG.getConstantFP(
-        APFloat::getAllOnesValue(EltVT.getSizeInBits(), true), DL, EltVT);
+    APFloat AllOnesValue = APFloat::getAllOnesValue(
+        SelectionDAG::EVTToAPFloatSemantics(EltVT), EltVT.getSizeInBits());
+    AllOnes = DAG.getConstantFP(AllOnesValue, DL, EltVT);
     LogicVT =
         MVT::getVectorVT(EltVT == MVT::f64 ? MVT::i64 : MVT::i32, Mask.size());
   } else {
Index: llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
===================================================================
--- llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
+++ llvm/lib/Target/Hexagon/HexagonTargetObjectFile.cpp
@@ -323,6 +323,7 @@
   }
   case Type::FunctionTyID:
   case Type::VoidTyID:
+  case Type::BFloatTyID:
   case Type::X86_FP80TyID:
   case Type::FP128TyID:
   case Type::PPC_FP128TyID:
Index: llvm/lib/Support/APFloat.cpp
===================================================================
--- llvm/lib/Support/APFloat.cpp
+++ llvm/lib/Support/APFloat.cpp
@@ -69,6 +69,7 @@
   };
 
   static const fltSemantics semIEEEhalf = {15, -14, 11, 16};
+  static const fltSemantics semBFloat = {127, -126, 8, 16};
   static const fltSemantics semIEEEsingle = {127, -126, 24, 32};
   static const fltSemantics semIEEEdouble = {1023, -1022, 53, 64};
   static const fltSemantics semIEEEquad = {16383, -16382, 113, 128};
@@ -117,6 +118,8 @@
     switch (S) {
     case S_IEEEhalf:
       return IEEEhalf();
+    case S_BFloat:
+      return BFloat();
     case S_IEEEsingle:
       return IEEEsingle();
     case S_IEEEdouble:
@@ -135,6 +138,8 @@
   APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
     if (&Sem == &llvm::APFloat::IEEEhalf())
       return S_IEEEhalf;
+    else if (&Sem == &llvm::APFloat::BFloat())
+      return S_BFloat;
     else if (&Sem == &llvm::APFloat::IEEEsingle())
       return S_IEEEsingle;
     else if (&Sem == &llvm::APFloat::IEEEdouble())
@@ -152,6 +157,9 @@
   const fltSemantics &APFloatBase::IEEEhalf() {
     return semIEEEhalf;
   }
+  const fltSemantics &APFloatBase::BFloat() {
+    return semBFloat;
+  }
   const fltSemantics &APFloatBase::IEEEsingle() {
     return semIEEEsingle;
   }
@@ -3255,6 +3263,33 @@
                     (mysignificand & 0x7fffff)));
 }
 
+APInt IEEEFloat::convertBFloatAPFloatToAPInt() const {
+  assert(semantics == (const llvm::fltSemantics *)&semBFloat);
+  assert(partCount() == 1);
+
+  uint32_t myexponent, mysignificand;
+
+  if (isFiniteNonZero()) {
+    myexponent = exponent + 127; // bias
+    mysignificand = (uint32_t)*significandParts();
+    if (myexponent == 1 && !(mysignificand & 0x80))
+      myexponent = 0; // denormal
+  } else if (category == fcZero) {
+    myexponent = 0;
+    mysignificand = 0;
+  } else if (category == fcInfinity) {
+    myexponent = 0x1f;
+    mysignificand = 0;
+  } else {
+    assert(category == fcNaN && "Unknown category!");
+    myexponent = 0x1f;
+    mysignificand = (uint32_t)*significandParts();
+  }
+
+  return APInt(16, (((sign & 1) << 15) | ((myexponent & 0xff) << 7) |
+                    (mysignificand & 0x7f)));
+}
+
 APInt IEEEFloat::convertHalfAPFloatToAPInt() const {
   assert(semantics == (const llvm::fltSemantics*)&semIEEEhalf);
   assert(partCount()==1);
@@ -3290,6 +3325,9 @@
   if (semantics == (const llvm::fltSemantics*)&semIEEEhalf)
     return convertHalfAPFloatToAPInt();
 
+  if (semantics == (const llvm::fltSemantics *)&semBFloat)
+    return convertBFloatAPFloatToAPInt();
+
   if (semantics == (const llvm::fltSemantics*)&semIEEEsingle)
     return convertFloatAPFloatToAPInt();
 
@@ -3486,6 +3524,37 @@
   }
 }
 
+void IEEEFloat::initFromBFloatAPInt(const APInt &api) {
+  assert(api.getBitWidth() == 16);
+  uint32_t i = (uint32_t)*api.getRawData();
+  uint32_t myexponent = (i >> 7) & 0xff;
+  uint32_t mysignificand = i & 0x7f;
+
+  initialize(&semBFloat);
+  assert(partCount() == 1);
+
+  sign = i >> 15;
+  if (myexponent == 0 && mysignificand == 0) {
+    // exponent, significand meaningless
+    category = fcZero;
+  } else if (myexponent == 0xff && mysignificand == 0) {
+    // exponent, significand meaningless
+    category = fcInfinity;
+  } else if (myexponent == 0xff && mysignificand != 0) {
+    // sign, exponent, significand meaningless
+    category = fcNaN;
+    *significandParts() = mysignificand;
+  } else {
+    category = fcNormal;
+    exponent = myexponent - 127; // bias
+    *significandParts() = mysignificand;
+    if (myexponent == 0) // denormal
+      exponent = -126;
+    else
+      *significandParts() |= 0x80; // integer bit
+  }
+}
+
 void IEEEFloat::initFromHalfAPInt(const APInt &api) {
   assert(api.getBitWidth()==16);
   uint32_t i = (uint32_t)*api.getRawData();
@@ -3524,6 +3593,8 @@
 void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
   if (Sem == &semIEEEhalf)
     return initFromHalfAPInt(api);
+  if (Sem == &semBFloat)
+    return initFromBFloatAPInt(api);
   if (Sem == &semIEEEsingle)
     return initFromFloatAPInt(api);
   if (Sem == &semIEEEdouble)
@@ -4763,26 +4834,9 @@
   llvm_unreachable("Unexpected semantics");
 }
 
-APFloat APFloat::getAllOnesValue(unsigned BitWidth, bool isIEEE) {
-  if (isIEEE) {
-    switch (BitWidth) {
-    case 16:
-      return APFloat(semIEEEhalf, APInt::getAllOnesValue(BitWidth));
-    case 32:
-      return APFloat(semIEEEsingle, APInt::getAllOnesValue(BitWidth));
-    case 64:
-      return APFloat(semIEEEdouble, APInt::getAllOnesValue(BitWidth));
-    case 80:
-      return APFloat(semX87DoubleExtended, APInt::getAllOnesValue(BitWidth));
-    case 128:
-      return APFloat(semIEEEquad, APInt::getAllOnesValue(BitWidth));
-    default:
-      llvm_unreachable("Unknown floating bit width");
-    }
-  } else {
-    assert(BitWidth == 128);
-    return APFloat(semPPCDoubleDouble, APInt::getAllOnesValue(BitWidth));
-  }
+APFloat APFloat::getAllOnesValue(const fltSemantics &Semantics,
+                                 unsigned BitWidth) {
+  return APFloat(Semantics, APInt::getAllOnesValue(BitWidth));
 }
 
 void APFloat::print(raw_ostream &OS) const {
Index: llvm/lib/IR/Type.cpp
===================================================================
--- llvm/lib/IR/Type.cpp
+++ llvm/lib/IR/Type.cpp
@@ -40,6 +40,7 @@
   switch (IDNumber) {
   case VoidTyID      : return getVoidTy(C);
   case HalfTyID      : return getHalfTy(C);
+  case BFloatTyID    : return getBFloatTy(C);
   case FloatTyID     : return getFloatTy(C);
   case DoubleTyID    : return getDoubleTy(C);
   case X86_FP80TyID  : return getX86_FP80Ty(C);
@@ -112,6 +113,7 @@
 TypeSize Type::getPrimitiveSizeInBits() const {
   switch (getTypeID()) {
   case Type::HalfTyID: return TypeSize::Fixed(16);
+  case Type::BFloatTyID: return TypeSize::Fixed(16);
   case Type::FloatTyID: return TypeSize::Fixed(32);
   case Type::DoubleTyID: return TypeSize::Fixed(64);
   case Type::X86_FP80TyID: return TypeSize::Fixed(80);
@@ -142,6 +144,7 @@
     return VTy->getElementType()->getFPMantissaWidth();
   assert(isFloatingPointTy() && "Not a floating point type!");
   if (getTypeID() == HalfTyID) return 11;
+  if (getTypeID() == BFloatTyID) return 8;
   if (getTypeID() == FloatTyID) return 24;
   if (getTypeID() == DoubleTyID) return 53;
   if (getTypeID() == X86_FP80TyID) return 64;
@@ -167,6 +170,7 @@
 Type *Type::getVoidTy(LLVMContext &C) { return &C.pImpl->VoidTy; }
 Type *Type::getLabelTy(LLVMContext &C) { return &C.pImpl->LabelTy; }
 Type *Type::getHalfTy(LLVMContext &C) { return &C.pImpl->HalfTy; }
+Type *Type::getBFloatTy(LLVMContext &C) { return &C.pImpl->BFloatTy; }
 Type *Type::getFloatTy(LLVMContext &C) { return &C.pImpl->FloatTy; }
 Type *Type::getDoubleTy(LLVMContext &C) { return &C.pImpl->DoubleTy; }
 Type *Type::getMetadataTy(LLVMContext &C) { return &C.pImpl->MetadataTy; }
@@ -191,6 +195,10 @@
   return getHalfTy(C)->getPointerTo(AS);
 }
 
+PointerType *Type::getBFloatPtrTy(LLVMContext &C, unsigned AS) {
+  return getBFloatTy(C)->getPointerTo(AS);
+}
+
 PointerType *Type::getFloatPtrTy(LLVMContext &C, unsigned AS) {
   return getFloatTy(C)->getPointerTo(AS);
 }
Index: llvm/lib/IR/LLVMContextImpl.h
===================================================================
--- llvm/lib/IR/LLVMContextImpl.h
+++ llvm/lib/IR/LLVMContextImpl.h
@@ -1342,7 +1342,8 @@
   std::unique_ptr<ConstantTokenNone> TheNoneToken;
 
   // Basic type instances.
-  Type VoidTy, LabelTy, HalfTy, FloatTy, DoubleTy, MetadataTy, TokenTy;
+  Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy,
+      TokenTy;
   Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy;
   IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty;
 
Index: llvm/lib/IR/LLVMContextImpl.cpp
===================================================================
--- llvm/lib/IR/LLVMContextImpl.cpp
+++ llvm/lib/IR/LLVMContextImpl.cpp
@@ -26,6 +26,7 @@
     VoidTy(C, Type::VoidTyID),
     LabelTy(C, Type::LabelTyID),
     HalfTy(C, Type::HalfTyID),
+    BFloatTy(C, Type::BFloatTyID),
     FloatTy(C, Type::FloatTyID),
     DoubleTy(C, Type::DoubleTyID),
     MetadataTy(C, Type::MetadataTyID),
Index: llvm/lib/IR/Function.cpp
===================================================================
--- llvm/lib/IR/Function.cpp
+++ llvm/lib/IR/Function.cpp
@@ -655,6 +655,7 @@
     case Type::VoidTyID:      Result += "isVoid";   break;
     case Type::MetadataTyID:  Result += "Metadata"; break;
     case Type::HalfTyID:      Result += "f16";      break;
+    case Type::BFloatTyID:    Result += "bf16";     break;
     case Type::FloatTyID:     Result += "f32";      break;
     case Type::DoubleTyID:    Result += "f64";      break;
     case Type::X86_FP80TyID:  Result += "f80";      break;
Index: llvm/lib/IR/DataLayout.cpp
===================================================================
--- llvm/lib/IR/DataLayout.cpp
+++ llvm/lib/IR/DataLayout.cpp
@@ -162,7 +162,7 @@
     {INTEGER_ALIGN, 16, Align(2), Align(2)},   // i16
     {INTEGER_ALIGN, 32, Align(4), Align(4)},   // i32
     {INTEGER_ALIGN, 64, Align(4), Align(8)},   // i64
-    {FLOAT_ALIGN, 16, Align(2), Align(2)},     // half
+    {FLOAT_ALIGN, 16, Align(2), Align(2)},     // half, bfloat
     {FLOAT_ALIGN, 32, Align(4), Align(4)},     // float
     {FLOAT_ALIGN, 64, Align(8), Align(8)},     // double
     {FLOAT_ALIGN, 128, Align(16), Align(16)},  // ppcf128, quad, ...
@@ -732,6 +732,7 @@
     AlignType = INTEGER_ALIGN;
     break;
   case Type::HalfTyID:
+  case Type::BFloatTyID:
   case Type::FloatTyID:
   case Type::DoubleTyID:
   // PPC_FP128TyID and FP128TyID have different data contents, but the
Index: llvm/lib/IR/Core.cpp
===================================================================
--- llvm/lib/IR/Core.cpp
+++ llvm/lib/IR/Core.cpp
@@ -477,6 +477,8 @@
     return LLVMVoidTypeKind;
   case Type::HalfTyID:
     return LLVMHalfTypeKind;
+  case Type::BFloatTyID:
+    return LLVMBFloatTypeKind;
   case Type::FloatTyID:
     return LLVMFloatTypeKind;
   case Type::DoubleTyID:
@@ -595,6 +597,9 @@
 LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getHalfTy(*unwrap(C));
 }
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C) {
+  return (LLVMTypeRef) Type::getBFloatTy(*unwrap(C));
+}
 LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C) {
   return (LLVMTypeRef) Type::getFloatTy(*unwrap(C));
 }
@@ -617,6 +622,9 @@
 LLVMTypeRef LLVMHalfType(void) {
   return LLVMHalfTypeInContext(LLVMGetGlobalContext());
 }
+LLVMTypeRef LLVMBFloatType(void) {
+  return LLVMBFloatTypeInContext(LLVMGetGlobalContext());
+}
 LLVMTypeRef LLVMFloatType(void) {
   return LLVMFloatTypeInContext(LLVMGetGlobalContext());
 }
Index: llvm/lib/IR/Constants.cpp
===================================================================
--- llvm/lib/IR/Constants.cpp
+++ llvm/lib/IR/Constants.cpp
@@ -332,6 +332,9 @@
   case Type::HalfTyID:
     return ConstantFP::get(Ty->getContext(),
                            APFloat::getZero(APFloat::IEEEhalf()));
+  case Type::BFloatTyID:
+    return ConstantFP::get(Ty->getContext(),
+                           APFloat::getZero(APFloat::BFloat()));
   case Type::FloatTyID:
     return ConstantFP::get(Ty->getContext(),
                            APFloat::getZero(APFloat::IEEEsingle()));
@@ -386,8 +389,8 @@
                             APInt::getAllOnesValue(ITy->getBitWidth()));
 
   if (Ty->isFloatingPointTy()) {
-    APFloat FL = APFloat::getAllOnesValue(Ty->getPrimitiveSizeInBits(),
-                                          !Ty->isPPC_FP128Ty());
+    APFloat FL = APFloat::getAllOnesValue(Ty->getFltSemantics(),
+                                          Ty->getPrimitiveSizeInBits());
     return ConstantFP::get(Ty->getContext(), FL);
   }
 
@@ -763,6 +766,8 @@
 static const fltSemantics *TypeToFloatSemantics(Type *Ty) {
   if (Ty->isHalfTy())
     return &APFloat::IEEEhalf();
+  if (Ty->isBFloatTy())
+    return &APFloat::BFloat();
   if (Ty->isFloatTy())
     return &APFloat::IEEEsingle();
   if (Ty->isDoubleTy())
@@ -880,6 +885,8 @@
     Type *Ty;
     if (&V.getSemantics() == &APFloat::IEEEhalf())
       Ty = Type::getHalfTy(Context);
+    else if (&V.getSemantics() == &APFloat::BFloat())
+      Ty = Type::getBFloatTy(Context);
     else if (&V.getSemantics() == &APFloat::IEEEsingle())
       Ty = Type::getFloatTy(Context);
     else if (&V.getSemantics() == &APFloat::IEEEdouble())
@@ -1029,7 +1036,7 @@
       Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
     else
       return nullptr;
-  return SequentialTy::getFP(V[0]->getContext(), Elts);
+  return SequentialTy::getFP(V[0]->getType(), Elts);
 }
 
 template <typename SequenceTy>
@@ -1048,7 +1055,7 @@
     else if (CI->getType()->isIntegerTy(64))
       return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
   } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
-    if (CFP->getType()->isHalfTy())
+    if (CFP->getType()->isHalfTy() || CFP->getType()->isBFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
     else if (CFP->getType()->isFloatTy())
       return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
@@ -1421,6 +1428,12 @@
     Val2.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &losesInfo);
     return !losesInfo;
   }
+  case Type::BFloatTyID: {
+    if (&Val2.getSemantics() == &APFloat::BFloat())
+      return true;
+    Val2.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &losesInfo);
+    return !losesInfo;
+  }
   case Type::FloatTyID: {
     if (&Val2.getSemantics() == &APFloat::IEEEsingle())
       return true;
@@ -1429,6 +1442,7 @@
   }
   case Type::DoubleTyID: {
     if (&Val2.getSemantics() == &APFloat::IEEEhalf() ||
+        &Val2.getSemantics() == &APFloat::BFloat() ||
         &Val2.getSemantics() == &APFloat::IEEEsingle() ||
         &Val2.getSemantics() == &APFloat::IEEEdouble())
       return true;
@@ -1437,16 +1451,19 @@
   }
   case Type::X86_FP80TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::x87DoubleExtended();
   case Type::FP128TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::IEEEquad();
   case Type::PPC_FP128TyID:
     return &Val2.getSemantics() == &APFloat::IEEEhalf() ||
+           &Val2.getSemantics() == &APFloat::BFloat() ||
            &Val2.getSemantics() == &APFloat::IEEEsingle() ||
            &Val2.getSemantics() == &APFloat::IEEEdouble() ||
            &Val2.getSemantics() == &APFloat::PPCDoubleDouble();
@@ -2562,7 +2579,8 @@
 }
 
 bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
-  if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+  if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() || Ty->isDoubleTy())
+    return true;
   if (auto *IT = dyn_cast<IntegerType>(Ty)) {
     switch (IT->getBitWidth()) {
     case 8:
@@ -2680,26 +2698,29 @@
   Next = nullptr;
 }
 
-/// getFP() constructors - Return a constant with array type with an element
-/// count and element type of float with precision matching the number of
-/// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint16_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getHalfTy(Context), Elts.size());
+/// getFP() constructors - Return a constant of array type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'.  The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
+  assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+         "Element type is not a 16-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 2), Ty);
 }
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint32_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
+  assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 4), Ty);
 }
-Constant *ConstantDataArray::getFP(LLVMContext &Context,
-                                   ArrayRef<uint64_t> Elts) {
-  Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
+Constant *ConstantDataArray::getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
+  assert(ElementType->isDoubleTy() &&
+         "Element type is not a 64-bit float type");
+  Type *Ty = ArrayType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
@@ -2751,26 +2772,32 @@
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
 
-/// getFP() constructors - Return a constant with vector type with an element
-/// count and element type of float with the precision matching the number of
-/// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
-/// double for 64bits) Note that this can return a ConstantAggregateZero
-/// object.
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+/// getFP() constructors - Return a constant of vector type with a float
+/// element type taken from argument `ElementType', and count taken from
+/// argument `Elts'.  The amount of bits of the contained type must match the
+/// number of bits of the type contained in the passed in ArrayRef.
+/// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+/// that this can return a ConstantAggregateZero object.
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint16_t> Elts) {
-  Type *Ty = VectorType::get(Type::getHalfTy(Context), Elts.size());
+  assert((ElementType->isHalfTy() || ElementType->isBFloatTy()) &&
+         "Element type is not a 16-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 2), Ty);
 }
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint32_t> Elts) {
-  Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
+  assert(ElementType->isFloatTy() && "Element type is not a 32-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 4), Ty);
 }
-Constant *ConstantDataVector::getFP(LLVMContext &Context,
+Constant *ConstantDataVector::getFP(Type *ElementType,
                                     ArrayRef<uint64_t> Elts) {
-  Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
+  assert(ElementType->isDoubleTy() &&
+         "Element type is not a 64-bit float type");
+  Type *Ty = VectorType::get(ElementType, Elts.size());
   const char *Data = reinterpret_cast<const char *>(Elts.data());
   return getImpl(StringRef(Data, Elts.size() * 8), Ty);
 }
@@ -2800,17 +2827,22 @@
     if (CFP->getType()->isHalfTy()) {
       SmallVector<uint16_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
+    }
+    if (CFP->getType()->isBFloatTy()) {
+      SmallVector<uint16_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getType(), Elts);
     }
     if (CFP->getType()->isFloatTy()) {
       SmallVector<uint32_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
     }
     if (CFP->getType()->isDoubleTy()) {
       SmallVector<uint64_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-      return getFP(V->getContext(), Elts);
+      return getFP(V->getType(), Elts);
     }
   }
   return ConstantVector::getSplat({NumElts, false}, V);
@@ -2875,6 +2907,10 @@
     auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
     return APFloat(APFloat::IEEEhalf(), APInt(16, EltVal));
   }
+  case Type::BFloatTyID: {
+    auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+    return APFloat(APFloat::BFloat(), APInt(16, EltVal));
+  }
   case Type::FloatTyID: {
     auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
     return APFloat(APFloat::IEEEsingle(), APInt(32, EltVal));
@@ -2899,8 +2935,8 @@
 }
 
 Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
-  if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
-      getElementType()->isDoubleTy())
+  if (getElementType()->isHalfTy() || getElementType()->isBFloatTy() ||
+      getElementType()->isFloatTy() || getElementType()->isDoubleTy())
     return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
 
   return ConstantInt::get(getElementType(), getElementAsInteger(Elt));
Index: llvm/lib/IR/AsmWriter.cpp
===================================================================
--- llvm/lib/IR/AsmWriter.cpp
+++ llvm/lib/IR/AsmWriter.cpp
@@ -588,6 +588,7 @@
   switch (Ty->getTypeID()) {
   case Type::VoidTyID:      OS << "void"; return;
   case Type::HalfTyID:      OS << "half"; return;
+  case Type::BFloatTyID:    OS << "bfloat"; return;
   case Type::FloatTyID:     OS << "float"; return;
   case Type::DoubleTyID:    OS << "double"; return;
   case Type::X86_FP80TyID:  OS << "x86_fp80"; return;
@@ -1379,7 +1380,7 @@
       return;
     }
 
-    // Either half, or some form of long double.
+    // Either half, bfloat or some form of long double.
     // These appear as a magic letter identifying the type, then a
     // fixed number of hex digits.
     Out << "0x";
@@ -1407,6 +1408,10 @@
       Out << 'H';
       Out << format_hex_no_prefix(API.getZExtValue(), 4,
                                   /*Upper=*/true);
+    } else if (&APF.getSemantics() == &APFloat::BFloat()) {
+      Out << 'R';
+      Out << format_hex_no_prefix(API.getZExtValue(), 4,
+                                  /*Upper=*/true);
     } else
       llvm_unreachable("Unsupported floating point type");
     return;
Index: llvm/lib/CodeGen/MIRParser/MILexer.cpp
===================================================================
--- llvm/lib/CodeGen/MIRParser/MILexer.cpp
+++ llvm/lib/CodeGen/MIRParser/MILexer.cpp
@@ -534,7 +534,7 @@
 }
 
 static bool isValidHexFloatingPointPrefix(char C) {
-  return C == 'H' || C == 'K' || C == 'L' || C == 'M';
+  return C == 'H' || C == 'K' || C == 'L' || C == 'M' || C == 'R';
 }
 
 static Cursor lexFloatingPointLiteral(Cursor Range, Cursor C, MIToken &Token) {
Index: llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
===================================================================
--- llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -881,6 +881,7 @@
     switch (T->getTypeID()) {
     case Type::VoidTyID:      Code = bitc::TYPE_CODE_VOID;      break;
     case Type::HalfTyID:      Code = bitc::TYPE_CODE_HALF;      break;
+    case Type::BFloatTyID:    Code = bitc::TYPE_CODE_BFLOAT;    break;
     case Type::FloatTyID:     Code = bitc::TYPE_CODE_FLOAT;     break;
     case Type::DoubleTyID:    Code = bitc::TYPE_CODE_DOUBLE;    break;
     case Type::X86_FP80TyID:  Code = bitc::TYPE_CODE_X86_FP80;  break;
@@ -2387,7 +2388,8 @@
     } else if (const ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
       Code = bitc::CST_CODE_FLOAT;
       Type *Ty = CFP->getType();
-      if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) {
+      if (Ty->isHalfTy() || Ty->isBFloatTy() || Ty->isFloatTy() ||
+          Ty->isDoubleTy()) {
         Record.push_back(CFP->getValueAPF().bitcastToAPInt().getZExtValue());
       } else if (Ty->isX86_FP80Ty()) {
         // api needed to prevent premature destruction
Index: llvm/lib/Bitcode/Reader/BitcodeReader.cpp
===================================================================
--- llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -1720,6 +1720,9 @@
     case bitc::TYPE_CODE_HALF:     // HALF
       ResultTy = Type::getHalfTy(Context);
       break;
+    case bitc::TYPE_CODE_BFLOAT:    // BFLOAT
+      ResultTy = Type::getBFloatTy(Context);
+      break;
     case bitc::TYPE_CODE_FLOAT:     // FLOAT
       ResultTy = Type::getFloatTy(Context);
       break;
@@ -2429,6 +2432,9 @@
       if (CurTy->isHalfTy())
         V = ConstantFP::get(Context, APFloat(APFloat::IEEEhalf(),
                                              APInt(16, (uint16_t)Record[0])));
+      else if (CurTy->isBFloatTy())
+        V = ConstantFP::get(Context, APFloat(APFloat::BFloat(),
+                                             APInt(16, (uint32_t)Record[0])));
       else if (CurTy->isFloatTy())
         V = ConstantFP::get(Context, APFloat(APFloat::IEEEsingle(),
                                              APInt(32, (uint32_t)Record[0])));
@@ -2526,21 +2532,27 @@
       } else if (EltTy->isHalfTy()) {
         SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
+        else
+          V = ConstantDataArray::getFP(EltTy, Elts);
+      } else if (EltTy->isBFloatTy()) {
+        SmallVector<uint16_t, 16> Elts(Record.begin(), Record.end());
+        if (isa<VectorType>(CurTy))
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isFloatTy()) {
         SmallVector<uint32_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else if (EltTy->isDoubleTy()) {
         SmallVector<uint64_t, 16> Elts(Record.begin(), Record.end());
         if (isa<VectorType>(CurTy))
-          V = ConstantDataVector::getFP(Context, Elts);
+          V = ConstantDataVector::getFP(EltTy, Elts);
         else
-          V = ConstantDataArray::getFP(Context, Elts);
+          V = ConstantDataArray::getFP(EltTy, Elts);
       } else {
         return error("Invalid type for value");
       }
Index: llvm/lib/AsmParser/LLParser.cpp
===================================================================
--- llvm/lib/AsmParser/LLParser.cpp
+++ llvm/lib/AsmParser/LLParser.cpp
@@ -5247,13 +5247,16 @@
         !ConstantFP::isValueValidForType(Ty, ID.APFloatVal))
       return Error(ID.Loc, "floating point constant invalid for type");
 
-    // The lexer has no type info, so builds all half, float, and double FP
-    // constants as double.  Fix this here.  Long double does not need this.
+    // The lexer has no type info, so builds all half, bfloat, float, and double
+    // FP constants as double.  Fix this here.  Long double does not need this.
     if (&ID.APFloatVal.getSemantics() == &APFloat::IEEEdouble()) {
       bool Ignored;
       if (Ty->isHalfTy())
         ID.APFloatVal.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven,
                               &Ignored);
+      else if (Ty->isBFloatTy())
+        ID.APFloatVal.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
+                              &Ignored);
       else if (Ty->isFloatTy())
         ID.APFloatVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
                               &Ignored);
Index: llvm/lib/AsmParser/LLLexer.cpp
===================================================================
--- llvm/lib/AsmParser/LLLexer.cpp
+++ llvm/lib/AsmParser/LLLexer.cpp
@@ -820,6 +820,7 @@
 
   TYPEKEYWORD("void",      Type::getVoidTy(Context));
   TYPEKEYWORD("half",      Type::getHalfTy(Context));
+  TYPEKEYWORD("bfloat",    Type::getBFloatTy(Context));
   TYPEKEYWORD("float",     Type::getFloatTy(Context));
   TYPEKEYWORD("double",    Type::getDoubleTy(Context));
   TYPEKEYWORD("x86_fp80",  Type::getX86_FP80Ty(Context));
@@ -985,11 +986,13 @@
 ///    HexFP128Constant  0xL[0-9A-Fa-f]+
 ///    HexPPC128Constant 0xM[0-9A-Fa-f]+
 ///    HexHalfConstant   0xH[0-9A-Fa-f]+
+///    HexBFloatConstant 0xR[0-9A-Fa-f]+
 lltok::Kind LLLexer::Lex0x() {
   CurPtr = TokStart + 2;
 
   char Kind;
-  if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H') {
+  if ((CurPtr[0] >= 'K' && CurPtr[0] <= 'M') || CurPtr[0] == 'H' ||
+      CurPtr[0] == 'R') {
     Kind = *CurPtr++;
   } else {
     Kind = 'J';
@@ -1007,7 +1010,7 @@
   if (Kind == 'J') {
     // HexFPConstant - Floating point constant represented in IEEE format as a
     // hexadecimal number for when exponential notation is not precise enough.
-    // Half, Float, and double only.
+    // Half, BFloat, Float, and double only.
     APFloatVal = APFloat(APFloat::IEEEdouble(),
                          APInt(64, HexIntToVal(TokStart + 2, CurPtr)));
     return lltok::APFloat;
@@ -1035,6 +1038,11 @@
     APFloatVal = APFloat(APFloat::IEEEhalf(),
                          APInt(16,HexIntToVal(TokStart+3, CurPtr)));
     return lltok::APFloat;
+  case 'R':
+    // Brain floating point
+    APFloatVal = APFloat(APFloat::BFloat(),
+                         APInt(16, HexIntToVal(TokStart + 3, CurPtr)));
+    return lltok::APFloat;
   }
 }
 
Index: llvm/include/llvm/IR/Type.h
===================================================================
--- llvm/include/llvm/IR/Type.h
+++ llvm/include/llvm/IR/Type.h
@@ -54,27 +54,28 @@
   ///
   enum TypeID {
     // PrimitiveTypes - make sure LastPrimitiveTyID stays up to date.
-    VoidTyID = 0,  ///<  0: type with no size
-    HalfTyID,      ///<  1: 16-bit floating point type
-    FloatTyID,     ///<  2: 32-bit floating point type
-    DoubleTyID,    ///<  3: 64-bit floating point type
-    X86_FP80TyID,  ///<  4: 80-bit floating point type (X87)
-    FP128TyID,     ///<  5: 128-bit floating point type (112-bit mantissa)
-    PPC_FP128TyID, ///<  6: 128-bit floating point type (two 64-bits, PowerPC)
-    LabelTyID,     ///<  7: Labels
-    MetadataTyID,  ///<  8: Metadata
-    X86_MMXTyID,   ///<  9: MMX vectors (64 bits, X86 specific)
-    TokenTyID,     ///< 10: Tokens
+    VoidTyID = 0,    ///<  0: type with no size
+    HalfTyID,        ///<  1: 16-bit floating point type
+    BFloatTyID,      ///<  2: 16-bit floating point type (7-bit significand)
+    FloatTyID,       ///<  3: 32-bit floating point type
+    DoubleTyID,      ///<  4: 64-bit floating point type
+    X86_FP80TyID,    ///<  5: 80-bit floating point type (X87)
+    FP128TyID,       ///<  6: 128-bit floating point type (112-bit significand)
+    PPC_FP128TyID,   ///<  7: 128-bit floating point type (two 64-bits, PowerPC)
+    LabelTyID,       ///<  8: Labels
+    MetadataTyID,    ///<  9: Metadata
+    X86_MMXTyID,     ///< 10: MMX vectors (64 bits, X86 specific)
+    TokenTyID,       ///< 11: Tokens
 
     // Derived types... see DerivedTypes.h file.
     // Make sure FirstDerivedTyID stays up to date!
-    IntegerTyID,       ///< 11: Arbitrary bit width integers
-    FunctionTyID,      ///< 12: Functions
-    StructTyID,        ///< 13: Structures
-    ArrayTyID,         ///< 14: Arrays
-    PointerTyID,       ///< 15: Pointers
-    FixedVectorTyID,   ///< 16: Fixed width SIMD vector type
-    ScalableVectorTyID ///< 17: Scalable SIMD vector type
+    IntegerTyID,       ///< 12: Arbitrary bit width integers
+    FunctionTyID,      ///< 13: Functions
+    StructTyID,        ///< 14: Structures
+    ArrayTyID,         ///< 15: Arrays
+    PointerTyID,       ///< 16: Pointers
+    FixedVectorTyID,   ///< 17: Fixed width SIMD vector type
+    ScalableVectorTyID ///< 18: Scalable SIMD vector type
   };
 
 private:
@@ -140,6 +141,9 @@
   /// Return true if this is 'half', a 16-bit IEEE fp type.
   bool isHalfTy() const { return getTypeID() == HalfTyID; }
 
+  /// Return true if this is 'bfloat', a 16-bit bfloat type.
+  bool isBFloatTy() const { return getTypeID() == BFloatTyID; }
+
   /// Return true if this is 'float', a 32-bit IEEE fp type.
   bool isFloatTy() const { return getTypeID() == FloatTyID; }
 
@@ -157,8 +161,8 @@
 
   /// Return true if this is one of the six floating-point types
   bool isFloatingPointTy() const {
-    return getTypeID() == HalfTyID || getTypeID() == FloatTyID ||
-           getTypeID() == DoubleTyID ||
+    return getTypeID() == HalfTyID || getTypeID() == BFloatTyID ||
+           getTypeID() == FloatTyID || getTypeID() == DoubleTyID ||
            getTypeID() == X86_FP80TyID || getTypeID() == FP128TyID ||
            getTypeID() == PPC_FP128TyID;
   }
@@ -166,6 +170,7 @@
   const fltSemantics &getFltSemantics() const {
     switch (getTypeID()) {
     case HalfTyID: return APFloat::IEEEhalf();
+    case BFloatTyID: return APFloat::BFloat();
     case FloatTyID: return APFloat::IEEEsingle();
     case DoubleTyID: return APFloat::IEEEdouble();
     case X86_FP80TyID: return APFloat::x87DoubleExtended();
@@ -387,6 +392,7 @@
   static Type *getVoidTy(LLVMContext &C);
   static Type *getLabelTy(LLVMContext &C);
   static Type *getHalfTy(LLVMContext &C);
+  static Type *getBFloatTy(LLVMContext &C);
   static Type *getFloatTy(LLVMContext &C);
   static Type *getDoubleTy(LLVMContext &C);
   static Type *getMetadataTy(LLVMContext &C);
@@ -422,6 +428,7 @@
   // types as pointee.
   //
   static PointerType *getHalfPtrTy(LLVMContext &C, unsigned AS = 0);
+  static PointerType *getBFloatPtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getFloatPtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getDoublePtrTy(LLVMContext &C, unsigned AS = 0);
   static PointerType *getX86_FP80PtrTy(LLVMContext &C, unsigned AS = 0);
Index: llvm/include/llvm/IR/IRBuilder.h
===================================================================
--- llvm/include/llvm/IR/IRBuilder.h
+++ llvm/include/llvm/IR/IRBuilder.h
@@ -477,6 +477,11 @@
     return Type::getHalfTy(Context);
   }
 
+  /// Fetch the type representing a 16-bit brain floating point value.
+  Type *getBFloatTy() {
+    return Type::getBFloatTy(Context);
+  }
+
   /// Fetch the type representing a 32-bit floating point value.
   Type *getFloatTy() {
     return Type::getFloatTy(Context);
Index: llvm/include/llvm/IR/DataLayout.h
===================================================================
--- llvm/include/llvm/IR/DataLayout.h
+++ llvm/include/llvm/IR/DataLayout.h
@@ -651,6 +651,7 @@
   case Type::IntegerTyID:
     return TypeSize::Fixed(Ty->getIntegerBitWidth());
   case Type::HalfTyID:
+  case Type::BFloatTyID:
     return TypeSize::Fixed(16);
   case Type::FloatTyID:
     return TypeSize::Fixed(32);
Index: llvm/include/llvm/IR/Constants.h
===================================================================
--- llvm/include/llvm/IR/Constants.h
+++ llvm/include/llvm/IR/Constants.h
@@ -721,14 +721,15 @@
     return getImpl(Data, Ty);
   }
 
-  /// getFP() constructors - Return a constant with array type with an element
-  /// count and element type of float with precision matching the number of
-  /// bits in the ArrayRef passed in. (i.e. half for 16bits, float for 32bits,
-  /// double for 64bits) Note that this can return a ConstantAggregateZero
-  /// object.
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+  /// getFP() constructors - Return a constant of array type with a float
+  /// element type taken from argument `ElementType', and count taken from
+  /// argument `Elts'.  The amount of bits of the contained type must match the
+  /// number of bits of the type contained in the passed in ArrayRef.
+  /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+  /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
 
   /// This method constructs a CDS and initializes it with a text string.
   /// The default behavior (AddNull==true) causes a null terminator to
@@ -780,14 +781,15 @@
   static Constant *get(LLVMContext &Context, ArrayRef<float> Elts);
   static Constant *get(LLVMContext &Context, ArrayRef<double> Elts);
 
-  /// getFP() constructors - Return a constant with vector type with an element
-  /// count and element type of float with the precision matching the number of
-  /// bits in the ArrayRef passed in.  (i.e. half for 16bits, float for 32bits,
-  /// double for 64bits) Note that this can return a ConstantAggregateZero
-  /// object.
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint16_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint32_t> Elts);
-  static Constant *getFP(LLVMContext &Context, ArrayRef<uint64_t> Elts);
+  /// getFP() constructors - Return a constant of vector type with a float
+  /// element type taken from argument `ElementType', and count taken from
+  /// argument `Elts'.  The amount of bits of the contained type must match the
+  /// number of bits of the type contained in the passed in ArrayRef.
+  /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+  /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts);
+  static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts);
 
   /// Return a ConstantVector with the specified constant in each element.
   /// The specified constant has to be a of a compatible type (i8/i16/
Index: llvm/include/llvm/Bitcode/LLVMBitCodes.h
===================================================================
--- llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -166,7 +166,9 @@
 
   TYPE_CODE_FUNCTION = 21, // FUNCTION: [vararg, retty, paramty x N]
 
-  TYPE_CODE_TOKEN = 22 // TOKEN
+  TYPE_CODE_TOKEN = 22, // TOKEN
+
+  TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT
 };
 
 enum OperandBundleTagCode {
Index: llvm/include/llvm/ADT/APFloat.h
===================================================================
--- llvm/include/llvm/ADT/APFloat.h
+++ llvm/include/llvm/ADT/APFloat.h
@@ -151,6 +151,7 @@
   /// @{
   enum Semantics {
     S_IEEEhalf,
+    S_BFloat,
     S_IEEEsingle,
     S_IEEEdouble,
     S_x87DoubleExtended,
@@ -162,6 +163,7 @@
   static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem);
 
   static const fltSemantics &IEEEhalf() LLVM_READNONE;
+  static const fltSemantics &BFloat() LLVM_READNONE;
   static const fltSemantics &IEEEsingle() LLVM_READNONE;
   static const fltSemantics &IEEEdouble() LLVM_READNONE;
   static const fltSemantics &IEEEquad() LLVM_READNONE;
@@ -541,6 +543,7 @@
   /// @}
 
   APInt convertHalfAPFloatToAPInt() const;
+  APInt convertBFloatAPFloatToAPInt() const;
   APInt convertFloatAPFloatToAPInt() const;
   APInt convertDoubleAPFloatToAPInt() const;
   APInt convertQuadrupleAPFloatToAPInt() const;
@@ -548,6 +551,7 @@
   APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
   void initFromAPInt(const fltSemantics *Sem, const APInt &api);
   void initFromHalfAPInt(const APInt &api);
+  void initFromBFloatAPInt(const APInt &api);
   void initFromFloatAPInt(const APInt &api);
   void initFromDoubleAPInt(const APInt &api);
   void initFromQuadrupleAPInt(const APInt &api);
@@ -954,9 +958,10 @@
 
   /// Returns a float which is bitcasted from an all one value int.
   ///
+  /// \param Semantics - type float semantics
   /// \param BitWidth - Select float type
-  /// \param isIEEE   - If 128 bit number, select between PPC and IEEE
-  static APFloat getAllOnesValue(unsigned BitWidth, bool isIEEE = false);
+  static APFloat getAllOnesValue(const fltSemantics &Semantics,
+                                 unsigned BitWidth);
 
   /// Used to insert APFloat objects, or objects that contain APFloat objects,
   /// into FoldingSets.
Index: llvm/include/llvm-c/Core.h
===================================================================
--- llvm/include/llvm-c/Core.h
+++ llvm/include/llvm-c/Core.h
@@ -146,6 +146,7 @@
 typedef enum {
   LLVMVoidTypeKind,        /**< type with no size */
   LLVMHalfTypeKind,        /**< 16 bit floating point type */
+  LLVMBFloatTypeKind,      /**< 16 bit brain floating point type */
   LLVMFloatTypeKind,       /**< 32 bit floating point type */
   LLVMDoubleTypeKind,      /**< 64 bit floating point type */
   LLVMX86_FP80TypeKind,    /**< 80 bit floating point type (X87) */
@@ -1164,6 +1165,11 @@
 LLVMTypeRef LLVMHalfTypeInContext(LLVMContextRef C);
 
 /**
+ * Obtain a 16-bit brain floating point type from a context.
+ */
+LLVMTypeRef LLVMBFloatTypeInContext(LLVMContextRef C);
+
+/**
  * Obtain a 32-bit floating point type from a context.
  */
 LLVMTypeRef LLVMFloatTypeInContext(LLVMContextRef C);
@@ -1195,6 +1201,7 @@
  * These map to the functions in this group of the same name.
  */
 LLVMTypeRef LLVMHalfType(void);
+LLVMTypeRef LLVMBFloatType(void);
 LLVMTypeRef LLVMFloatType(void);
 LLVMTypeRef LLVMDoubleType(void);
 LLVMTypeRef LLVMX86FP80Type(void);
Index: llvm/docs/LangRef.rst
===================================================================
--- llvm/docs/LangRef.rst
+++ llvm/docs/LangRef.rst
@@ -2963,6 +2963,12 @@
    * - ``half``
      - 16-bit floating-point value
 
+   * - ``bfloat``
+     - 16-bit "brain" floating-point value (7-bit significand).  Provides the
+       same number of exponent bits as ``float``, so that it matches its dynamic
+       range, but with greatly reduced precision.  Used in Intel's AVX-512 BF16
+       extensions and Arm's ARMv8.6-A extensions, among others.
+
    * - ``float``
      - 32-bit floating-point value
 
@@ -2970,7 +2976,7 @@
      - 64-bit floating-point value
 
    * - ``fp128``
-     - 128-bit floating-point value (112-bit mantissa)
+     - 128-bit floating-point value (112-bit significand)
 
    * - ``x86_fp80``
      -  80-bit floating-point value (X87)
@@ -3303,20 +3309,20 @@
 values are represented in their IEEE hexadecimal format so that assembly
 and disassembly do not cause any bits to change in the constants.
 
-When using the hexadecimal form, constants of types half, float, and
-double are represented using the 16-digit form shown above (which
-matches the IEEE754 representation for double); half and float values
-must, however, be exactly representable as IEEE 754 half and single
-precision, respectively. Hexadecimal format is always used for long
-double, and there are three forms of long double. The 80-bit format used
-by x86 is represented as ``0xK`` followed by 20 hexadecimal digits. The
-128-bit format used by PowerPC (two adjacent doubles) is represented by
-``0xM`` followed by 32 hexadecimal digits. The IEEE 128-bit format is
-represented by ``0xL`` followed by 32 hexadecimal digits. Long doubles
-will only work if they match the long double format on your target.
-The IEEE 16-bit format (half precision) is represented by ``0xH``
-followed by 4 hexadecimal digits. All hexadecimal formats are big-endian
-(sign bit at the left).
+When using the hexadecimal form, constants of types bfloat, half, float, and
+double are represented using the 16-digit form shown above (which matches the
+IEEE754 representation for double); bfloat, half and float values must, however,
+be exactly representable as bfloat, IEEE 754 half, and IEEE 754 single
+precision respectively. Hexadecimal format is always used for long double, and
+there are three forms of long double. The 80-bit format used by x86 is
+represented as ``0xK`` followed by 20 hexadecimal digits. The 128-bit format
+used by PowerPC (two adjacent doubles) is represented by ``0xM`` followed by 32
+hexadecimal digits. The IEEE 128-bit format is represented by ``0xL`` followed
+by 32 hexadecimal digits. Long doubles will only work if they match the long
+double format on your target.  The IEEE 16-bit format (half precision) is
+represented by ``0xH`` followed by 4 hexadecimal digits. The bfloat 16-bit
+format is represented by ``0xR`` followed by 4 hexadecimal digits. All
+hexadecimal formats are big-endian (sign bit at the left).
 
 There are no constants of type x86_mmx.
 
Index: llvm/docs/BitCodeFormat.rst
===================================================================
--- llvm/docs/BitCodeFormat.rst
+++ llvm/docs/BitCodeFormat.rst
@@ -1107,6 +1107,14 @@
 The ``HALF`` record (code 10) adds a ``half`` (16-bit floating point) type to
 the type table.
 
+TYPE_CODE_BFLOAT Record
+^^^^^^^^^^^^^^^^^^^^^
+
+``[BFLOAT]``
+
+The ``BFLOAT`` record (code 23) adds a ``bfloat`` (16-bit brain floating point)
+type to the type table.
+
 TYPE_CODE_FLOAT Record
 ^^^^^^^^^^^^^^^^^^^^^^
 
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -14936,9 +14936,9 @@
         if (auto *ComplexTy = OrigType->getAs<ComplexType>())
           Type = ComplexTy->getElementType();
         if (Type->isRealFloatingType()) {
-          llvm::APFloat InitValue =
-              llvm::APFloat::getAllOnesValue(Context.getTypeSize(Type),
-                                             /*isIEEE=*/true);
+          llvm::APFloat InitValue = llvm::APFloat::getAllOnesValue(
+              Context.getFloatTypeSemantics(Type),
+              Context.getTypeSize(Type));
           Init = FloatingLiteral::Create(Context, InitValue, /*isexact=*/true,
                                          Type, ELoc);
         } else if (Type->isScalarType()) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to