kushanam updated this revision to Diff 531089.
kushanam added a comment.

renaming useBF16Math to hasBF16Math


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D144911/new/

https://reviews.llvm.org/D144911

Files:
  clang/include/clang/Basic/BuiltinsNVPTX.def
  llvm/include/llvm/IR/IntrinsicsNVVM.td
  llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
  llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
  llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
  llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
  llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
  llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
  llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
  llvm/lib/Target/NVPTX/NVPTXMCExpr.h
  llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
  llvm/lib/Target/NVPTX/NVPTXSubtarget.h
  llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
  llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Index: llvm/test/CodeGen/NVPTX/bf16-instructions.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -0,0 +1,142 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
+; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %}
+
+; LDST: .b8 bfloat_array[8] = {1, 2, 3, 4, 5, 6, 7, 8};
+@"bfloat_array" = addrspace(1) constant [4 x bfloat]
+                [bfloat 0xR0201, bfloat 0xR0403, bfloat 0xR0605, bfloat 0xR0807]
+  
+; CHECK-LABEL: test_fadd(
+; CHECK-DAG:  ld.param.b16    [[A:%h[0-9]+]], [test_fadd_param_0];
+; CHECK-DAG:  ld.param.b16    [[B:%h[0-9]+]], [test_fadd_param_1];
+; CHECK-NEXT: add.rn.bf16     [[R:%f[0-9]+]], [[A]], [[B]];
+; CHECK-NEXT: st.param.b16    [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fadd(bfloat %0, bfloat %1) {
+  %3 = fadd bfloat %0, %1
+  ret bfloat %3
+}
+
+; CHECK-LABEL: test_fsub(
+; CHECK-DAG:  ld.param.b16    [[A:%h[0-9]+]], [test_fsub_param_0];
+; CHECK-DAG:  ld.param.b16    [[B:%h[0-9]+]], [test_fsub_param_1];
+; CHECK-NEXT: sub.rn.bf16     [[R:%f[0-9]+]], [[A]], [[B]];
+; CHECK-NEXT: st.param.b16    [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fsub(bfloat %0, bfloat %1) {
+  %3 = fsub bfloat %0, %1
+  ret bfloat %3
+}
+
+; CHECK-LABEL: test_faddx2(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_faddx2_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_faddx2_param_1];
+; CHECK-NEXT: add.rn.bf16x2   [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK:      st.param.b32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+
+define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fadd <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fsubx2(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_fsubx2_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_fsubx2_param_1];
+; CHECK-NEXT: sub.rn.bf16x2   [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK:      st.param.b32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+
+define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fsub <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fmulx2(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_fmulx2_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_fmulx2_param_1];
+; CHECK-NEXT: mul.rn.bf16x2   [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK:      st.param.b32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+
+define <2 x bfloat> @test_fmul(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fmul <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fdiv(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_fdiv_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_fdiv_param_1];
+; CHECK-DAG:  mov.b32         {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]]
+; CHECK-DAG:  mov.b32         {[[B0:%h[0-9]+]], [[B1:%h[0-9]+]]}, [[B]]
+; CHECK-DAG:  cvt.f32.bf16     [[FA0:%f[0-9]+]], [[A0]];
+; CHECK-DAG:  cvt.f32.bf16     [[FA1:%f[0-9]+]], [[A1]];
+; CHECK-DAG:  cvt.f32.bf16     [[FB0:%f[0-9]+]], [[B0]];
+; CHECK-DAG:  cvt.f32.bf16     [[FB1:%f[0-9]+]], [[B1]];
+; CHECK-DAG:  div.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; CHECK-DAG:  div.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%h[0-9]+]], [[FR0]];
+; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%h[0-9]+]], [[FR1]];
+; CHECK-NEXT: mov.b32         [[R:%hh[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-NEXT: st.param.b32    [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fdiv <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_extract_0(
+; CHECK:      ld.param.b16    [[A:%rs[0-9]+]], [test_extract_0_param_0];
+; CHECK:      st.param.b16    [func_retval0+0], [[A]];
+; CHECK:      ret;
+
+define bfloat @test_extract_0(<2 x bfloat> %a) #0 {
+  %e = extractelement <2 x bfloat> %a, i32 0
+  ret bfloat %e
+}
+
+; CHECK-LABEL: test_extract_1(
+; CHECK:      ld.param.b16    [[A:%rs[0-9]+]], [test_extract_1_param_0+2];
+; CHECK:      st.param.b16    [func_retval0+0], [[A]];
+; CHECK:      ret;
+
+define bfloat @test_extract_1(<2 x bfloat> %a) #0 {
+  %e = extractelement <2 x bfloat> %a, i32 1
+  ret bfloat %e
+}
+
+; CHECK-LABEL: test_fpext_float(
+; CHECK:      ld.param.b16    [[A:%rs[0-9]+]], [test_fpext_float_param_0];
+; CHECK:      cvt.f32.bf16     [[R:%f[0-9]+]], [[A]];
+; CHECK:      st.param.f32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define float @test_fpext_float(bfloat %a) #0 {
+  %r = fpext bfloat %a to float
+  ret float %r
+}
+
+; CHECK-LABEL: test_fptrunc_float(
+; CHECK:      ld.param.f32    [[A:%f[0-9]+]], [test_fptrunc_float_param_0];
+; CHECK:      cvt.rn.bf16.f32  [[R:%rs[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define bfloat @test_fptrunc_float(float %a) #0 {
+  %r = fptrunc float %a to bfloat
+  ret bfloat %r
+}
+
+; CHECK-LABEL: test_fadd_imm_1(
+; CHECK:      ld.param.b16    [[B:%rs[0-9]+]], [test_fadd_imm_1_param_0];
+; CHECK:      mov.b16        [[A:%rs[0-9]+]], 0x3C00;
+; CHECK:      add.rn.f16     [[R:%rs[0-9]+]], [[B]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+define bfloat @test_fadd_imm_1(bfloat %a) #0 {
+  %r = fadd bfloat %a, 1.0
+  ret bfloat %r
+}
\ No newline at end of file
Index: llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -204,6 +204,14 @@
       return {Intrinsic::fma, FTZ_MustBeOff, true};
     case Intrinsic::nvvm_fma_rn_ftz_f16x2:
       return {Intrinsic::fma, FTZ_MustBeOn, true};
+    case Intrinsic::nvvm_fma_rn_bf16:
+      return {Intrinsic::fma, FTZ_MustBeOff, true};
+    case Intrinsic::nvvm_fma_rn_ftz_bf16:
+      return {Intrinsic::fma, FTZ_MustBeOn, true};
+    case Intrinsic::nvvm_fma_rn_bf16x2:
+      return {Intrinsic::fma, FTZ_MustBeOff, true};
+    case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
+      return {Intrinsic::fma, FTZ_MustBeOn, true};
     case Intrinsic::nvvm_fmax_d:
       return {Intrinsic::maxnum, FTZ_Any};
     case Intrinsic::nvvm_fmax_f:
Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -76,6 +76,7 @@
   inline bool hasHWROT32() const { return SmVersion >= 32; }
   bool hasImageHandles() const;
   bool hasFP16Math() const { return SmVersion >= 53; }
+  bool hasBF16Math() const { return SmVersion >= 80; }
   bool allowFP16Math() const;
   bool hasMaskOperator() const { return PTXVersion >= 71; }
   bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
+++ llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
@@ -26,7 +26,6 @@
     NoF16Math("nvptx-no-f16-math", cl::Hidden,
               cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
               cl::init(false));
-
 // Pin the vtable to this file.
 void NVPTXSubtarget::anchor() {}
 
Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXMCExpr.h
+++ llvm/lib/Target/NVPTX/NVPTXMCExpr.h
@@ -21,6 +21,7 @@
 public:
   enum VariantKind {
     VK_NVPTX_None,
+    VK_NVPTX_BFLOAT_PREC_FLOAT, // FP constant in bfloat-precision
     VK_NVPTX_HALF_PREC_FLOAT,   // FP constant in half-precision
     VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision
     VK_NVPTX_DOUBLE_PREC_FLOAT  // FP constant in double-precision
@@ -40,6 +41,11 @@
   static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt,
                                         MCContext &Ctx);
 
+  static const NVPTXFloatMCExpr *createConstantBFPHalf(const APFloat &Flt,
+                                                       MCContext &Ctx) {
+    return create(VK_NVPTX_BFLOAT_PREC_FLOAT, Flt, Ctx);
+  }
+
   static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt,
                                                         MCContext &Ctx) {
     return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx);
Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
+++ llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
@@ -34,6 +34,11 @@
     NumHex = 4;
     APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored);
     break;
+  case VK_NVPTX_BFLOAT_PREC_FLOAT:
+    OS << "0x";
+    NumHex = 4;
+    APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &Ignored);
+    break;
   case VK_NVPTX_SINGLE_PREC_FLOAT:
     OS << "0f";
     NumHex = 8;
Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -998,6 +998,18 @@
     FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, Int16Regs,
       [hasPTX<70>, hasSM<80>]>,
 
+    FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, Int16Regs,
+      [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, Int16Regs,
+      [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, Int16Regs,
+      [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs,
+      [hasPTX<70>, hasSM<80>]>,
+    FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, Int16Regs,
+      [hasPTX<70>, hasSM<80>]>,
+
     FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, Int32Regs,
       [hasPTX<42>, hasSM<53>]>,
     FMA_TUPLE<"_rn_ftz_f16x2", int_nvvm_fma_rn_ftz_f16x2, Int32Regs,
@@ -1010,11 +1022,6 @@
       [hasPTX<70>, hasSM<80>]>,
     FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
       Int32Regs, [hasPTX<70>, hasSM<80>]>,
-
-    FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX<70>, hasSM<80>]>,
-    FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs,
-      [hasPTX<70>, hasSM<80>]>,
-
     FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, Int32Regs,
       [hasPTX<70>, hasSM<80>]>,
     FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, Int32Regs,
@@ -1268,24 +1275,6 @@
 def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
           (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
 
-def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
-def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
-def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
-def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
-
-def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
-def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
-def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
-def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
-
 def CVT_tf32_f32 :
    NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
                    "cvt.rna.tf32.f32 \t$dest, $a;",
@@ -2207,10 +2196,6 @@
   : VLDU_G_ELE_V2<"v2.u16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
 defm INT_PTX_LDU_G_v2i32_ELE
   : VLDU_G_ELE_V2<"v2.u32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>;
-defm INT_PTX_LDU_G_v2f16_ELE
-  : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
-defm INT_PTX_LDU_G_v2f16x2_ELE
-  : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>;
 defm INT_PTX_LDU_G_v2f32_ELE
   : VLDU_G_ELE_V2<"v2.f32 \t{{$dst1, $dst2}}, [$src];", Float32Regs>;
 defm INT_PTX_LDU_G_v2i64_ELE
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -19,6 +19,8 @@
 
 let OperandType = "OPERAND_IMMEDIATE" in {
   def f16imm : Operand<f16>;
+  def bf16imm : Operand<bf16>;
+
 }
 
 // List of vector specific properties
@@ -154,6 +156,7 @@
 
 def useShortPtr : Predicate<"useShortPointers()">;
 def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
+def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">;
 
 // Helper class to aid conversion between ValueType and a matching RegisterClass.
 
@@ -304,6 +307,31 @@
                !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
                [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
                Requires<[useFP16Math]>;
+   def bf16rr_ftz :
+     NVPTXInst<(outs Int16Regs:$dst),
+               (ins Int16Regs:$a, Int16Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               Requires<[hasBF16Math, doF32FTZ]>;
+   def bf16rr :
+     NVPTXInst<(outs Int16Regs:$dst),
+               (ins Int16Regs:$a, Int16Regs:$b),
+               !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               Requires<[hasBF16Math]>;
+
+   def bf16x2rr_ftz :
+     NVPTXInst<(outs Int32Regs:$dst),
+               (ins Int32Regs:$a, Int32Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               Requires<[hasBF16Math, doF32FTZ]>;
+   def bf16x2rr :
+     NVPTXInst<(outs Int32Regs:$dst),
+               (ins Int32Regs:$a, Int32Regs:$b),
+               !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               Requires<[hasBF16Math]>;
 }
 
 // Template for instructions which take three FP args.  The
@@ -378,7 +406,31 @@
                !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
                [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
                Requires<[useFP16Math, allowFMA]>;
+   def bf16rr_ftz :
+     NVPTXInst<(outs Int16Regs:$dst),
+               (ins Int16Regs:$a, Int16Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
+   def bf16rr :
+     NVPTXInst<(outs Int16Regs:$dst),
+               (ins Int16Regs:$a, Int16Regs:$b),
+               !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               Requires<[hasBF16Math, allowFMA]>;
 
+   def bf16x2rr_ftz :
+     NVPTXInst<(outs Int32Regs:$dst),
+               (ins Int32Regs:$a, Int32Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+               [(set (v2bf16 Int32Regs:$dst), (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
+   def bf16x2rr :
+     NVPTXInst<(outs Int32Regs:$dst),
+               (ins Int32Regs:$a, Int32Regs:$b),
+               !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               Requires<[hasBF16Math, allowFMA]>;
    // These have strange names so we don't perturb existing mir tests.
    def _rnf64rr :
      NVPTXInst<(outs Float64Regs:$dst),
@@ -440,6 +492,30 @@
                !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
                [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
                Requires<[useFP16Math, noFMA]>;
+  def _rnbf16rr_ftz :
+     NVPTXInst<(outs Int16Regs:$dst),
+               (ins Int16Regs:$a, Int16Regs:$b),
+               !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"),
+               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               Requires<[hasBF16Math, noFMA, doF32FTZ]>;
+   def _rnbf16rr :
+     NVPTXInst<(outs Int16Regs:$dst),
+               (ins Int16Regs:$a, Int16Regs:$b),
+               !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"),
+               [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+               Requires<[hasBF16Math, noFMA]>;
+   def _rnbf16x2rr_ftz :
+     NVPTXInst<(outs Int32Regs:$dst),
+               (ins Int32Regs:$a, Int32Regs:$b),
+               !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"),
+               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               Requires<[hasBF16Math, noFMA, doF32FTZ]>;
+   def _rnbf16x2rr :
+     NVPTXInst<(outs Int32Regs:$dst),
+               (ins Int32Regs:$a, Int32Regs:$b),
+               !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
+               [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+               Requires<[hasBF16Math, noFMA]>;
 }
 
 // Template for operations which take two f32 or f64 operands.  Provides three
@@ -516,6 +592,11 @@
                 (ins Int16Regs:$src, CvtMode:$mode),
                 !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
                 FromName, ".f16 \t$dst, $src;"), []>;
+    def _bf16 :
+      NVPTXInst<(outs RC:$dst),
+                (ins Int16Regs:$src, CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
+                FromName, ".bf16 \t$dst, $src;"), []>;
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src, CvtMode:$mode),
@@ -538,6 +619,7 @@
   defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>;
   defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>;
   defm CVT_f16 : CVT_FROM_ALL<"f16", Int16Regs>;
+  defm CVT_bf16 : CVT_FROM_ALL<"bf16", Int16Regs>;
   defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>;
   defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>;
 
@@ -555,19 +637,8 @@
                                     "cvt.s64.s16 \t$dst, $src;", []>;
   def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
                                     "cvt.s64.s32 \t$dst, $src;", []>;
-
-multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
-    def _f32 :
-      NVPTXInst<(outs RC:$dst),
-                (ins Float32Regs:$src, CvtMode:$mode),
-                !strconcat("cvt${mode:base}${mode:relu}.",
-                FromName, ".f32 \t$dst, $src;"), []>,
-                Requires<[hasPTX<70>, hasSM<80>]>;
-  }
-
-  defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
-
-    multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
+  
+  multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src1, Float32Regs:$src2,  CvtMode:$mode),
@@ -641,6 +712,7 @@
 defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
 defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
 defm SELP_f16 : SELP_PATTERN<"b16", f16, Int16Regs, f16imm, fpimm>;
+defm SELP_bf16 : SELP_PATTERN<"b16", bf16, Int16Regs, bf16imm, fpimm>;
 
 defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
 defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
@@ -1005,7 +1077,9 @@
 def LOAD_CONST_F16 :
   NVPTXInst<(outs Int16Regs:$dst), (ins f16imm:$a),
             "mov.b16 \t$dst, $a;", []>;
-
+def LOAD_CONST_BF16 :
+  NVPTXInst<(outs Int16Regs:$dst), (ins bf16imm:$a),
+            "mov.b16 \t$dst, $a;", []>;
 defm FADD : F3_fma_component<"add", fadd>;
 defm FSUB : F3_fma_component<"sub", fsub>;
 defm FMUL : F3_fma_component<"mul", fmul>;
@@ -1033,6 +1107,20 @@
 def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
 def FNEG16x2     : FNEG_F16_F16X2<"neg.f16x2", v2f16, Int32Regs, True>;
 
+//
+// BF16 NEG
+//
+
+class FNEG_BF16_F16X2<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> :
+      NVPTXInst<(outs RC:$dst), (ins RC:$src),
+                !strconcat(OpcStr, " \t$dst, $src;"),
+                [(set RC:$dst, (fneg (T RC:$src)))]>,
+                Requires<[hasBF16Math, hasPTX<70>, hasSM<80>, Pred]>;
+def BFNEG16_ftz   : FNEG_BF16_F16X2<"neg.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
+def BFNEG16       : FNEG_BF16_F16X2<"neg.bf16", bf16, Int16Regs, True>;
+def BFNEG16x2_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
+def BFNEG16x2     : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>;
+
 //
 // F64 division
 //
@@ -1211,13 +1299,24 @@
                        Requires<[useFP16Math, Pred]>;
 }
 
-defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
-defm FMA16     : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
-defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
-defm FMA16x2     : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
-defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
-defm FMA32     : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
-defm FMA64     : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
+multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
+   def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
+                       [(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>,
+                       Requires<[hasBF16Math, Pred]>;
+}
+
+defm FMA16_ftz    : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
+defm FMA16        : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
+defm FMA16x2_ftz  : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
+defm FMA16x2      : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
+defm BFMA16_ftz   : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
+defm BFMA16       : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
+defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
+defm BFMA16x2     : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
+defm FMA32_ftz    : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
+defm FMA32        : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
+defm FMA64        : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
 
 // sin/cos
 def SINF:  NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
@@ -1661,6 +1760,18 @@
                 "setp${cmp:base}${cmp:ftz}.f16x2 \t$p|$q, $a, $b;",
                 []>,
                 Requires<[useFP16Math]>;
+def SETP_bf16rr :
+      NVPTXInst<(outs Int1Regs:$dst),
+                (ins Int16Regs:$a, Int16Regs:$b, CmpMode:$cmp),
+                "setp${cmp:base}${cmp:ftz}.bf16 \t$dst, $a, $b;",
+                []>, Requires<[hasBF16Math]>;
+
+def SETP_bf16x2rr :
+      NVPTXInst<(outs Int1Regs:$p, Int1Regs:$q),
+                (ins Int32Regs:$a, Int32Regs:$b, CmpMode:$cmp),
+                "setp${cmp:base}${cmp:ftz}.bf16x2 \t$p|$q, $a, $b;",
+                []>,
+                Requires<[hasBF16Math]>;
 
 
 // FIXME: This doesn't appear to be correct.  The "set" mnemonic has the form
@@ -1691,6 +1802,7 @@
 defm SET_s64 : SET<"s64", Int64Regs, i64imm>;
 defm SET_u64 : SET<"u64", Int64Regs, i64imm>;
 defm SET_f16 : SET<"f16", Int16Regs, f16imm>;
+defm SET_bf16 : SET<"bf16", Int16Regs, bf16imm>;
 defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
 defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
 
@@ -1959,6 +2071,26 @@
             (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Int16Regs:$b, Mode)>,
         Requires<[useFP16Math]>;
 
+  // bf16 -> pred
+  def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+            (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>,
+        Requires<[hasBF16Math,doF32FTZ]>;
+  def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+            (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>,
+        Requires<[hasBF16Math]>;
+  def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+            (SETP_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+        Requires<[hasBF16Math,doF32FTZ]>;
+  def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+            (SETP_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+        Requires<[hasBF16Math]>;
+  def : Pat<(i1 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+            (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, ModeFTZ)>,
+        Requires<[hasBF16Math,doF32FTZ]>;
+  def : Pat<(i1 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+            (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, Mode)>,
+        Requires<[hasBF16Math]>;
+
   // f32 -> pred
   def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)),
             (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
@@ -2004,6 +2136,26 @@
             (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Int16Regs:$b, Mode)>,
         Requires<[useFP16Math]>;
 
+  // bf16 -> i32
+  def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+            (SET_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>,
+        Requires<[hasBF16Math, doF32FTZ]>;
+  def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+            (SET_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>,
+        Requires<[hasBF16Math]>;
+  def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+            (SET_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+        Requires<[hasBF16Math, doF32FTZ]>;
+  def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+            (SET_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+        Requires<[hasBF16Math]>;
+  def : Pat<(i32 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+            (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, ModeFTZ)>,
+        Requires<[hasBF16Math, doF32FTZ]>;
+  def : Pat<(i32 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+            (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, Mode)>,
+        Requires<[hasBF16Math]>;
+
   // f32 -> i32
   def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)),
             (SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
@@ -2801,6 +2953,26 @@
 def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
           (CVT_f16_u64 Int64Regs:$a, CvtRN)>;
 
+// sint -> bf16
+def : Pat<(bf16 (sint_to_fp Int1Regs:$a)),
+          (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int16Regs:$a)),
+          (CVT_bf16_s16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int32Regs:$a)),
+          (CVT_bf16_s32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int64Regs:$a)),
+          (CVT_bf16_s64 Int64Regs:$a, CvtRN)>;
+
+// uint -> bf16
+def : Pat<(bf16 (uint_to_fp Int1Regs:$a)),
+          (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int16Regs:$a)),
+          (CVT_bf16_u16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int32Regs:$a)),
+          (CVT_bf16_u32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int64Regs:$a)),
+          (CVT_bf16_u64 Int64Regs:$a, CvtRN)>;
+
 // sint -> f32
 def : Pat<(f32 (sint_to_fp Int1Regs:$a)),
           (CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
@@ -2862,6 +3034,25 @@
 def : Pat<(i64 (fp_to_uint (f16 Int16Regs:$a))),
           (CVT_u64_f16 Int16Regs:$a, CvtRZI)>;
 
+// bf16 -> sint
+def : Pat<(i1 (fp_to_sint (bf16 Int16Regs:$a))),
+          (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_sint (bf16 Int16Regs:$a))),
+          (CVT_s16_bf16 (bf16 Int16Regs:$a), CvtRZI)>;
+def : Pat<(i32 (fp_to_sint (bf16 Int16Regs:$a))),
+          (CVT_s32_bf16 (bf16 Int16Regs:$a), CvtRZI)>;
+def : Pat<(i64 (fp_to_sint (bf16 Int16Regs:$a))),
+          (CVT_s64_bf16 Int16Regs:$a, CvtRZI)>;
+
+// bf16 -> uint
+def : Pat<(i1 (fp_to_uint (bf16 Int16Regs:$a))),
+          (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_uint (bf16 Int16Regs:$a))),
+          (CVT_u16_bf16 Int16Regs:$a, CvtRZI)>;
+def : Pat<(i32 (fp_to_uint (bf16 Int16Regs:$a))),
+          (CVT_u32_bf16 Int16Regs:$a, CvtRZI)>;
+def : Pat<(i64 (fp_to_uint (bf16 Int16Regs:$a))),
+          (CVT_u64_bf16 Int16Regs:$a, CvtRZI)>;
 // f32 -> sint
 def : Pat<(i1 (fp_to_sint Float32Regs:$a)),
           (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>;
@@ -3009,6 +3200,9 @@
 def : Pat<(select Int32Regs:$pred, (f16 Int16Regs:$a), (f16 Int16Regs:$b)),
           (SELP_f16rr Int16Regs:$a, Int16Regs:$b,
           (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
+def : Pat<(select Int32Regs:$pred, (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)),
+          (SELP_bf16rr Int16Regs:$a, Int16Regs:$b,
+          (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
 def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b),
           (SELP_f32rr Float32Regs:$a, Float32Regs:$b,
           (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
@@ -3080,6 +3274,13 @@
 def : Pat<(v2f16 (build_vector (f16 Int16Regs:$a), (f16 Int16Regs:$b))),
           (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>;
 
+def : Pat<(bf16 (extractelt (v2bf16 Int32Regs:$src), 0)),
+          (I32toI16L Int32Regs:$src)>;
+def : Pat<(bf16 (extractelt (v2bf16 Int32Regs:$src), 1)),
+          (I32toI16H Int32Regs:$src)>;
+def : Pat<(v2bf16 (build_vector (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+          (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>;
+
 // Count leading zeros
 let hasSideEffects = false in {
   def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a),
@@ -3147,10 +3348,17 @@
 def : Pat<(f16 (fpround Float32Regs:$a)),
           (CVT_f16_f32 Float32Regs:$a, CvtRN)>;
 
+// fpround f32 -> bf16
+def : Pat<(bf16 (fpround Float32Regs:$a)),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+
 // fpround f64 -> f16
 def : Pat<(f16 (fpround Float64Regs:$a)),
           (CVT_f16_f64 Float64Regs:$a, CvtRN)>;
 
+// fpround f64 -> bf16
+def : Pat<(bf16 (fpround Float64Regs:$a)),
+          (CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
 // fpround f64 -> f32
 def : Pat<(f32 (fpround Float64Regs:$a)),
           (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3162,11 +3370,20 @@
           (CVT_f32_f16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
 def : Pat<(f32 (fpextend (f16 Int16Regs:$a))),
           (CVT_f32_f16 Int16Regs:$a, CvtNONE)>;
+// fpextend bf16 -> f32
+def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
+          (CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
+def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
+          (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>;
 
 // fpextend f16 -> f64
 def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
           (CVT_f64_f16 Int16Regs:$a, CvtNONE)>;
 
+// fpextend bf16 -> f64
+def : Pat<(f64 (fpextend (bf16 Int16Regs:$a))),
+          (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>;
+
 // fpextend f32 -> f64
 def : Pat<(f64 (fpextend Float32Regs:$a)),
           (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
@@ -3181,6 +3398,8 @@
 multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
   def : Pat<(OpNode (f16 Int16Regs:$a)),
             (CVT_f16_f16 Int16Regs:$a, Mode)>;
+  def : Pat<(OpNode (bf16 Int16Regs:$a)),
+            (CVT_bf16_bf16 Int16Regs:$a, Mode)>;
   def : Pat<(OpNode Float32Regs:$a),
             (CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>;
   def : Pat<(OpNode Float32Regs:$a),
Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -149,6 +149,14 @@
   }
 }
 
+static bool Isv2f16Orv2bf16Type(MVT VT) {
+  return (VT.SimpleTy == MVT::v2f16 || VT.SimpleTy == MVT::v2bf16);
+}
+
+static bool Isf16Orbf16Type(MVT VT) {
+  return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16);
+}
+
 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
 /// into their primitive components.
@@ -199,7 +207,7 @@
       // Vectors with an even number of f16 elements will be passed to
       // us as an array of v2f16/v2bf16 elements. We must match this so we
       // stay in sync with Ins/Outs.
-      if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) {
+      if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
         EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
         NumElts /= 2;
       }
@@ -404,6 +412,11 @@
     setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
   };
 
+  auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
+                                    LegalizeAction NoBF16Action) {
+    setOperationAction(Op, VT, STI.hasBF16Math() ? Action : NoBF16Action);
+  };
+
   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
@@ -426,6 +439,16 @@
   setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
   setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
 
+  // Conversion to/from BFP16/BFP16x2 is always legal.
+  setOperationAction(ISD::SINT_TO_FP, MVT::bf16, Legal);
+  setOperationAction(ISD::FP_TO_SINT, MVT::bf16, Legal);
+  setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom);
+  setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand);
+  setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand);
+
+  setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
+  setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
   // Operations not directly supported by NVPTX.
   for (MVT VT : {MVT::f16, MVT::v2f16, MVT::f32, MVT::f64, MVT::i1, MVT::i8,
                  MVT::i16, MVT::i32, MVT::i64}) {
@@ -482,17 +505,25 @@
   // Turn FP extload into load/fpextend
   setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
   // Turn FP truncstore into trunc + store.
   // FIXME: vector types should also be expanded
   setTruncStoreAction(MVT::f32, MVT::f16, Expand);
   setTruncStoreAction(MVT::f64, MVT::f16, Expand);
+  setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
+  setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
 
   // PTX does not support load / store predicate registers
@@ -569,9 +600,9 @@
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
                        ISD::SREM, ISD::UREM});
 
-  // setcc for f16x2 needs special handling to prevent legalizer's
-  // attempt to scalarize it due to v2i1 not being legal.
-  if (STI.allowFP16Math())
+  // setcc for f16x2 and bf16x2 needs special handling to prevent
+  // legalizer's attempt to scalarize it due to v2i1 not being legal.
+  if (STI.allowFP16Math() || STI.hasBF16Math())
     setTargetDAGCombine(ISD::SETCC);
 
   // Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -583,6 +614,8 @@
   for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
     setFP16OperationAction(Op, MVT::f16, Legal, Promote);
     setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
+    setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
+    setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
   }
 
   // f16/f16x2 neg was introduced in PTX 60, SM_53.
@@ -593,19 +626,25 @@
     setOperationAction(ISD::FNEG, VT,
                        IsFP16FP16x2NegAvailable ? Legal : Expand);
 
+  setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
+  setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
   // (would be) Library functions.
 
   // These map to conversion instructions for scalar FP types.
   for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
                          ISD::FROUNDEVEN, ISD::FTRUNC}) {
+    setOperationAction(Op, MVT::bf16, Legal);
     setOperationAction(Op, MVT::f16, Legal);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
+    setOperationAction(Op, MVT::v2bf16, Expand);
   }
 
   setOperationAction(ISD::FROUND, MVT::f16, Promote);
   setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
+  setOperationAction(ISD::FROUND, MVT::bf16, Promote);
+  setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
   setOperationAction(ISD::FROUND, MVT::f32, Custom);
   setOperationAction(ISD::FROUND, MVT::f64, Custom);
 
@@ -613,6 +652,8 @@
   // 'Expand' implements FCOPYSIGN without calling an external library.
   setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
+  setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+  setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
 
@@ -622,9 +663,11 @@
   for (const auto &Op :
        {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) {
     setOperationAction(Op, MVT::f16, Promote);
+    setOperationAction(Op, MVT::bf16, Promote);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
+    setOperationAction(Op, MVT::v2bf16, Expand);
   }
   // max.f16, max.f16x2 and max.NaN are supported on sm_80+.
   auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
@@ -633,14 +676,18 @@
   };
   for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
     setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
   }
   for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
     setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
+    setFP16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
     setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
     setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
   }
 
   // No FEXP2, FLOG2.  The PTX ex2 and log2 functions are always approximate.
@@ -1258,7 +1305,7 @@
   if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
       VT.getScalarType() == MVT::i1)
     return TypeSplitVector;
-  if (VT == MVT::v2f16)
+  if (Isv2f16Orv2bf16Type(VT))
     return TypeLegal;
   return TargetLoweringBase::getPreferredVectorAction(VT);
 }
@@ -1408,7 +1455,7 @@
         sz = promoteScalarArgumentSize(sz);
       } else if (isa<PointerType>(Ty)) {
         sz = PtrVT.getSizeInBits();
-      } else if (Ty->isHalfTy())
+      } else if (Ty->isHalfTy() || Ty->isBFloatTy())
         // PTX ABI requires all scalar parameters to be at least 32
         // bits in size.  fp16 normally uses .b16 as its storage type
         // in PTX, so its size must be adjusted here, too.
@@ -2043,7 +2090,7 @@
 // generates good SASS in both cases.
 SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
                                                SelectionDAG &DAG) const {
-  if (!(Op->getValueType(0) == MVT::v2f16 &&
+  if (!(Isv2f16Orv2bf16Type(Op->getOperand(0).getValueType().getSimpleVT()) &&
         isa<ConstantFPSDNode>(Op->getOperand(0)) &&
         isa<ConstantFPSDNode>(Op->getOperand(1))))
     return Op;
@@ -2054,7 +2101,7 @@
       cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
   SDValue Const =
       DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
-  return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+  return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
 }
 
 SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
@@ -2415,7 +2462,7 @@
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // loads and have to handle it here.
-  if (Op.getValueType() == MVT::v2f16) {
+  if (Isv2f16Orv2bf16Type(Op.getValueType().getSimpleVT())) {
     LoadSDNode *Load = cast<LoadSDNode>(Op);
     EVT MemVT = Load->getMemoryVT();
     if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2460,7 +2507,7 @@
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // stores and have to handle it here.
-  if (VT == MVT::v2f16 &&
+  if (Isv2f16Orv2bf16Type(VT.getSimpleVT()) &&
       !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
                                       VT, *Store->getMemOperand()))
     return expandUnalignedStore(Store, DAG);
@@ -2547,7 +2594,7 @@
       // v8f16 is a special case. PTX doesn't have st.v8.f16
       // instruction. Instead, we split the vector into v2f16 chunks and
       // store them with st.v4.b32.
-      assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+      assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
              "Wrong type for the vector.");
       Opcode = NVPTXISD::StoreV4;
       StoreF16x2 = true;
@@ -2563,11 +2610,12 @@
       // Combine f16,f16 -> v2f16
       NumElts /= 2;
       for (unsigned i = 0; i < NumElts; ++i) {
-        SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+        SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
                                  DAG.getIntPtrConstant(i * 2, DL));
-        SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+        SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
                                  DAG.getIntPtrConstant(i * 2 + 1, DL));
-        SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
+        EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
+        SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
         Ops.push_back(V2);
       }
     } else {
@@ -2733,9 +2781,9 @@
           EVT LoadVT = EltVT;
           if (EltVT == MVT::i1)
             LoadVT = MVT::i8;
-          else if (EltVT == MVT::v2f16)
+          else if (Isv2f16Orv2bf16Type(EltVT.getSimpleVT()))
             // getLoad needs a vector type, but it can't handle
-            // vectors which contain v2f16 elements. So we must load
+            // vectors which contain v2f16 or v2bf16 elements. So we must load
             // using i32 here and then bitcast back.
             LoadVT = MVT::i32;
 
@@ -5189,7 +5237,7 @@
     // v8f16 is a special case. PTX doesn't have ld.v8.f16
     // instruction. Instead, we split the vector into v2f16 chunks and
     // load them with ld.v4.b32.
-    assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+    assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
            "Unsupported v8 vector type.");
     LoadF16x2 = true;
     Opcode = NVPTXISD::LoadV4;
Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -71,7 +71,7 @@
   bool tryTextureIntrinsic(SDNode *N);
   bool trySurfaceIntrinsic(SDNode *N);
   bool tryBFE(SDNode *N);
-  bool tryConstantFP16(SDNode *N);
+  bool tryConstantFP(SDNode *N);
   bool SelectSETP_F16X2(SDNode *N);
   bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
 
Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -500,7 +500,7 @@
     SelectAddrSpaceCast(N);
     return;
   case ISD::ConstantFP:
-    if (tryConstantFP16(N))
+    if (tryConstantFP(N))
       return;
     break;
   default:
@@ -524,15 +524,17 @@
   }
 }
 
-// There's no way to specify FP16 immediates in .f16 ops, so we have to
-// load them into an .f16 register first.
-bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
-  if (N->getValueType(0) != MVT::f16)
+// There's no way to specify FP16 and BF16 immediates in .(b)f16 ops, so we 
+// have to load them into an .(b)f16 register first.
+bool NVPTXDAGToDAGISel::tryConstantFP(SDNode *N) {
+  if (N->getValueType(0) != MVT::f16 && N->getValueType(0) != MVT::bf16)
     return false;
   SDValue Val = CurDAG->getTargetConstantFP(
-      cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16);
-  SDNode *LoadConstF16 =
-      CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val);
+      cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), N->getValueType(0));
+  SDNode *LoadConstF16 = CurDAG->getMachineNode(
+      (N->getValueType(0) == MVT::f16 ? NVPTX::LOAD_CONST_F16
+                                      : NVPTX::LOAD_CONST_BF16),
+      SDLoc(N), N->getValueType(0), Val);
   ReplaceNode(N, LoadConstF16);
   return true;
 }
@@ -1258,10 +1260,11 @@
     NumElts = EltVT.getVectorNumElements();
     EltVT = EltVT.getVectorElementType();
     // vectors of f16 are loaded/stored as multiples of v2f16 elements.
-    if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) {
-      assert(NumElts % 2 == 0 && "Vector must have even number of elements");
-      EltVT = MVT::v2f16;
-      NumElts /= 2;
+    if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
+        (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
+          assert(NumElts % 2 == 0 && "Vector must have even number of elements");
+          EltVT = N->getValueType(0);
+          NumElts /= 2;
     }
   }
 
Index: llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -272,6 +272,10 @@
       MCOp = MCOperand::createExpr(
         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
       break;
+    case Type::BFloatTyID:
+      MCOp = MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
+      break;
     case Type::FloatTyID:
       MCOp = MCOperand::createExpr(
         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
@@ -1355,8 +1359,10 @@
     }
     break;
   }
+  case Type::BFloatTyID:
   case Type::HalfTyID:
-    // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
+    // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
+    // PTX assembly.
     return "b16";
   case Type::FloatTyID:
     return "f32";
@@ -1581,7 +1587,7 @@
       } else if (PTy) {
         assert(PTySizeInBits && "Invalid pointer size");
         sz = PTySizeInBits;
-      } else if (Ty->isHalfTy())
+      } else if (Ty->isHalfTy() || Ty->isBFloatTy())
         // PTX ABI requires all scalar parameters to be at least 32
         // bits in size.  fp16 normally uses .b16 as its storage type
         // in PTX, so its size must be adjusted here, too.
Index: llvm/include/llvm/IR/IntrinsicsNVVM.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -595,16 +595,18 @@
           [IntrNoMem, IntrSpeculatable, Commutative]>;
     }
 
-    foreach variant = ["_bf16", "_nan_bf16", "_xorsign_abs_bf16",
-      "_nan_xorsign_abs_bf16"] in {
+    foreach variant = ["_bf16", "_ftz_bf16", "_nan_bf16", "_ftz_nan_bf16",
+      "_xorsign_abs_bf16", "_ftz_xorsign_abs_bf16", "_nan_xorsign_abs_bf16",
+      "_ftz_nan_xorsign_abs_bf16"] in {
       def int_nvvm_f # operation # variant :
         ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
         DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
           [IntrNoMem, IntrSpeculatable, Commutative]>;
     }
 
-    foreach variant = ["_bf16x2", "_nan_bf16x2", "_xorsign_abs_bf16x2",
-      "_nan_xorsign_abs_bf16x2"] in {
+    foreach variant = ["_bf16x2", "_ftz_bf16x2", "_nan_bf16x2",
+      "_ftz_nan_bf16x2", "_xorsign_abs_bf16x2", "_ftz_xorsign_abs_bf16x2",
+      "_nan_xorsign_abs_bf16x2", "_ftz_nan_xorsign_abs_bf16x2"]  in {
       def int_nvvm_f # operation # variant :
         ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
         DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
@@ -870,17 +872,19 @@
       [IntrNoMem, IntrSpeculatable]>;
   }
 
-  foreach variant = ["_rn_bf16", "_rn_relu_bf16"] in {
+  foreach variant = ["_rn_bf16", "_rn_ftz_bf16", "_rn_sat_bf16",
+    "_rn_ftz_sat_bf16", "_rn_relu_bf16", "_rn_ftz_relu_bf16"] in {
     def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
-      DefaultAttrsIntrinsic<[llvm_i16_ty],
-        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
+      DefaultAttrsIntrinsic<[llvm_bfloat_ty],
+        [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty],
         [IntrNoMem, IntrSpeculatable]>;
   }
 
-  foreach variant = ["_rn_bf16x2", "_rn_relu_bf16x2"] in {
+  foreach variant = ["_rn_bf16x2", "_rn_ftz_bf16x2", "_rn_sat_bf16x2",
+    "_rn_ftz_sat_bf16x2", "_rn_relu_bf16x2", "_rn_ftz_relu_bf16x2"] in {
     def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
-      DefaultAttrsIntrinsic<[llvm_i32_ty],
-        [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
+      DefaultAttrsIntrinsic<[llvm_v2bf16_ty],
+        [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty],
         [IntrNoMem, IntrSpeculatable]>;
   }
 
@@ -1232,6 +1236,11 @@
   def int_nvvm_f2h_rn : ClangBuiltin<"__nvvm_f2h_rn">,
       DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
 
+  def int_nvvm_bf2h_rn_ftz : ClangBuiltin<"__nvvm_bf2h_rn_ftz">,
+      DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+  def int_nvvm_bf2h_rn : ClangBuiltin<"__nvvm_bf2h_rn">,
+      DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+
   def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">,
        Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
   def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
Index: clang/include/clang/Basic/BuiltinsNVPTX.def
===================================================================
--- clang/include/clang/Basic/BuiltinsNVPTX.def
+++ clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -174,12 +174,16 @@
 TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
@@ -216,12 +220,16 @@
 TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to