kushanam updated this revision to Diff 523106.
kushanam added a comment.

adding cland directives and removing bf16 registers

Depends on D144911 <https://reviews.llvm.org/D144911>

Differential Revision: https://reviews.llvm.org/D144911


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D149976/new/

https://reviews.llvm.org/D149976

Files:
  clang/include/clang/Basic/BuiltinsNVPTX.def
  llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
  llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
  llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
  llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
  llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
  llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp

Index: llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -29,14 +29,13 @@
 std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
   if (RC == &NVPTX::Float32RegsRegClass)
     return ".f32";
-  if (RC == &NVPTX::Float16RegsRegClass || RC == &NVPTX::BFloat16RegsRegClass)
+  if (RC == &NVPTX::Float16RegsRegClass)
     // Ideally fp16 registers should be .f16, but this syntax is only
     // supported on sm_53+. On the other hand, .b16 registers are
     // accepted for all supported fp16 instructions on all GPU
     // variants, so we can use them instead.
     return ".b16";
-  if (RC == &NVPTX::Float16x2RegsRegClass ||
-      RC == &NVPTX::BFloat16x2RegsRegClass)
+  if (RC == &NVPTX::Float16x2RegsRegClass)
     return ".b32";
   if (RC == &NVPTX::Float64RegsRegClass)
     return ".f64";
@@ -74,10 +73,9 @@
 std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
   if (RC == &NVPTX::Float32RegsRegClass)
     return "%f";
-  if (RC == &NVPTX::Float16RegsRegClass || RC == &NVPTX::BFloat16RegsRegClass)
+  if (RC == &NVPTX::Float16RegsRegClass)
     return "%h";
-  if (RC == &NVPTX::Float16x2RegsRegClass ||
-      RC == &NVPTX::BFloat16x2RegsRegClass)
+  if (RC == &NVPTX::Float16x2RegsRegClass)
     return "%hh";
   if (RC == &NVPTX::Float64RegsRegClass)
     return "%fd";
Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -998,9 +998,6 @@
     FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
       Float16x2Regs, [hasPTX70, hasSM80]>,
 
-    // FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, BFloat16Regs,
-    //   [hasPTX70, hasSM80]>,
-
     FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, BFloat16x2Regs,
       [hasPTX70, hasSM80]>,
     FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, BFloat16x2Regs,
@@ -1254,24 +1251,6 @@
 def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
           (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
 
-// def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
-//           (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
-// def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
-//           (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
-// def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
-//           (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
-// def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
-//           (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
-
-// def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
-//           (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
-// def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
-//           (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
-// def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
-//           (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
-// def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
-//           (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
-
 def CVT_tf32_f32 :
    NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
                    "cvt.rna.tf32.f32 \t$dest, $a;",
@@ -1387,11 +1366,6 @@
 def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
           (BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN))>;
 
-// def : Pat<(int_nvvm_bf2h_rn_ftz Float32Regs:$a),
-//           (BITCONVERT_16_BF2I (CVT_bf16_f32 Float32Regs:$a, CvtRN_FTZ))>;
-// def : Pat<(int_nvvm_f2h_rn BFloat16Regs:$a),
-//           (BITCONVERT_16_BF2I (CVT_bf16_f32 BFloat16Regs:$a, CvtRN))>;
-
 //
 // Bitcast
 //
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -656,15 +656,6 @@
   def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
                                     "cvt.s64.s32 \t$dst, $src;", []>;
 
-multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
-    def _f32 :
-      NVPTXInst<(outs RC:$dst),
-                (ins Float32Regs:$src, CvtMode:$mode),
-                !strconcat("cvt${mode:base}${mode:relu}.",
-                FromName, ".f32 \t$dst, $src;"), []>,
-                Requires<[hasPTX70, hasSM80]>;
-  }
-
   multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
     def _f32 :
       NVPTXInst<(outs RC:$dst),
@@ -753,12 +744,6 @@
               "selp.b32 \t$dst, $a, $b, $p;",
               [(set Float16x2Regs:$dst,
                     (select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>;
-def SELP_bf16x2rr :
-    NVPTXInst<(outs BFloat16x2Regs:$dst),
-              (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b, Int1Regs:$p),
-              "selp.b32 \t$dst, $a, $b, $p;",
-              [(set BFloat16x2Regs:$dst,
-                    (select Int1Regs:$p, (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>;
 
 //-----------------------------------
 // Test Instructions
@@ -2091,7 +2076,7 @@
             (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
         Requires<[useFP16Math]>;
 
-  //bf16 -> pred
+  // bf16 -> pred
   def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
             (SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
         Requires<[useBFP16Math,doF32FTZ]>;
@@ -2156,7 +2141,7 @@
             (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
         Requires<[useFP16Math]>;
   
-    // bf16 -> i32
+  // bf16 -> i32
   def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
             (SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
         Requires<[useBFP16Math, doF32FTZ]>;
@@ -2707,9 +2692,7 @@
   defm LD_i32 : LD<Int32Regs>;
   defm LD_i64 : LD<Int64Regs>;
   defm LD_f16 : LD<Float16Regs>;
-  defm LD_bf16 : LD<BFloat16Regs>;
   defm LD_f16x2 : LD<Float16x2Regs>;
-  defm LD_bf16x2 : LD<BFloat16x2Regs>;
   defm LD_f32 : LD<Float32Regs>;
   defm LD_f64 : LD<Float64Regs>;
 }
@@ -3366,29 +3349,29 @@
                                [(set BFloat16Regs:$dst,
                                  (extractelt (v2bf16 BFloat16x2Regs:$src), 1))]>;
 
-  // Coalesce two bf16 registers into bf16x2
-  def BuildBF16x2 : NVPTXInst<(outs BFloat16x2Regs:$dst),
-                             (ins BFloat16Regs:$a, BFloat16Regs:$b),
-                             "mov.b32 \t$dst, {{$a, $b}};",
-                             [(set (v2bf16 BFloat16x2Regs:$dst),
-                               (build_vector (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>;
-
-  // Directly initializing underlying the b32 register is one less SASS
-  // instruction than than vector-packing move.
-  def BuildBF16x2i : NVPTXInst<(outs BFloat16x2Regs:$dst), (ins i32imm:$src),
-                              "mov.b32 \t$dst, $src;",
-                              []>;
-
-  // Split f16x2 into two f16 registers.
-  def SplitBF16x2  : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
-                              (ins BFloat16x2Regs:$src),
-                              "mov.b32 \t{{$lo, $hi}}, $src;",
-                              []>;
-  // Split an i32 into two f16
-  def SplitI32toBF16x2  : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
-                                   (ins Int32Regs:$src),
-                                   "mov.b32 \t{{$lo, $hi}}, $src;",
-                                   []>;
+  // // Coalesce two bf16 registers into bf16x2
+  // def BuildBF16x2 : NVPTXInst<(outs BFloat16x2Regs:$dst),
+  //                            (ins BFloat16Regs:$a, BFloat16Regs:$b),
+  //                            "mov.b32 \t$dst, {{$a, $b}};",
+  //                            [(set (v2bf16 BFloat16x2Regs:$dst),
+  //                              (build_vector (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>;
+
+  // // Directly initializing underlying the b32 register is one less SASS
+  // // instruction than than vector-packing move.
+  // def BuildBF16x2i : NVPTXInst<(outs BFloat16x2Regs:$dst), (ins i32imm:$src),
+  //                             "mov.b32 \t$dst, $src;",
+  //                             []>;
+
+  // // Split f16x2 into two f16 registers.
+  // def SplitBF16x2  : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
+  //                             (ins BFloat16x2Regs:$src),
+  //                             "mov.b32 \t{{$lo, $hi}}, $src;",
+  //                             []>;
+  // // Split an i32 into two f16
+  // def SplitI32toBF16x2  : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
+  //                                  (ins Int32Regs:$src),
+  //                                  "mov.b32 \t{{$lo, $hi}}, $src;",
+  //                                  []>;
 }
 
 // Count leading zeros
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -56,11 +56,6 @@
                                                : NVPTX::BITCONVERT_16_I2F);
   } else if (DestRC == &NVPTX::Float16x2RegsRegClass) {
     Op = NVPTX::IMOV32rr;
-  } else if (DestRC == &NVPTX::BFloat16RegsRegClass) {
-    Op = (SrcRC == &NVPTX::BFloat16RegsRegClass ? NVPTX::BFMOV16rr
-                                                : NVPTX::BITCONVERT_16_I2BF);
-  } else if (DestRC == &NVPTX::BFloat16x2RegsRegClass) {
-    Op = NVPTX::IMOV32rr;
   } else if (DestRC == &NVPTX::Float32RegsRegClass) {
     Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
                                                : NVPTX::BITCONVERT_32_I2F);
Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -143,6 +143,26 @@
   }
 }
 
+static bool Isv2f16Orv2bf16Type(MVT VT) {
+  switch (VT.SimpleTy) {
+  default:
+    return false;
+  case MVT::v2f16:
+  case MVT::v2bf16:
+    return true;
+  }
+}
+
+static bool Isf16Orbf16Type(MVT VT) {
+  switch (VT.SimpleTy) {
+  default:
+    return false;
+  case MVT::f16:
+  case MVT::bf16:
+    return true;
+  }
+}
+
 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
 /// into their primitive components.
@@ -193,7 +213,7 @@
       // Vectors with an even number of f16 elements will be passed to
       // us as an array of v2f16/v2bf16 elements. We must match this so we
       // stay in sync with Ins/Outs.
-      if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) {
+      if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
         EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
         NumElts /= 2;
       }
@@ -411,8 +431,6 @@
   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
   addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
   addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
-  addRegisterClass(MVT::bf16, &NVPTX::BFloat16RegsRegClass);
-  addRegisterClass(MVT::v2bf16, &NVPTX::BFloat16x2RegsRegClass);
 
   // Conversion to/from FP16/FP16x2 is always legal.
   setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
@@ -586,7 +604,7 @@
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
                        ISD::SREM, ISD::UREM});
 
-  // setcc for f16x2 and bf16x2 needs special handling to prevent 
+  // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
   if (STI.allowFP16Math() || STI.allowBF16Math())
     setTargetDAGCombine(ISD::SETCC);
@@ -616,8 +634,8 @@
                        IsFP16FP16x2NegAvailable ? Legal : Expand);
 
   const bool IsBFP16FP16x2NegAvailable = STI.getSmVersion() >= 80 &&
-                                        STI.getPTXVersion() >= 70 &&
-                                        STI.allowBF16Math();
+                                         STI.getPTXVersion() >= 70 &&
+                                         STI.allowBF16Math();
   for (const auto &VT : {MVT::bf16, MVT::v2bf16})
     setOperationAction(ISD::FNEG, VT,
                        IsBFP16FP16x2NegAvailable ? Legal : Expand);
@@ -631,6 +649,7 @@
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
+    setOperationAction(Op, MVT::v2bf16, Expand);
   }
 
   setOperationAction(ISD::FROUND, MVT::f16, Promote);
@@ -680,12 +699,10 @@
   for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
     setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
     setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
-  }
-  for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
     setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
-    setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
     setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
   }
+
   // No FEXP2, FLOG2.  The PTX ex2 and log2 functions are always approximate.
   // No FPOW or FREM in PTX.
 
@@ -1301,7 +1318,7 @@
   if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
       VT.getScalarType() == MVT::i1)
     return TypeSplitVector;
-  if (VT == MVT::v2f16 || VT == MVT::v2bf16)
+  if (Isv2f16Orv2bf16Type(VT))
     return TypeLegal;
   return TargetLoweringBase::getPreferredVectorAction(VT);
 }
@@ -2086,8 +2103,7 @@
 // generates good SASS in both cases.
 SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
                                                SelectionDAG &DAG) const {
-  if (!((Op->getValueType(0) == MVT::v2f16 ||
-         Op->getValueType(0) == MVT::v2bf16) &&
+  if (!(Isv2f16Orv2bf16Type(Op->getOperand(0).getValueType().getSimpleVT()) &&
         isa<ConstantFPSDNode>(Op->getOperand(0)) &&
         isa<ConstantFPSDNode>(Op->getOperand(1))))
     return Op;
@@ -2098,9 +2114,7 @@
       cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
   SDValue Const =
       DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
-  return Op->getValueType(0) == MVT::v2bf16
-             ? DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2bf16, Const)
-             : DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+  return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
 }
 
 SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
@@ -2461,7 +2475,7 @@
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // loads and have to handle it here.
-  if (Op.getValueType() == MVT::v2f16 || Op.getValueType() == MVT::v2bf16) {
+  if (Isv2f16Orv2bf16Type(Op.getValueType().getSimpleVT())) {
     LoadSDNode *Load = cast<LoadSDNode>(Op);
     EVT MemVT = Load->getMemoryVT();
     if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2506,7 +2520,7 @@
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // stores and have to handle it here.
-  if ((VT == MVT::v2f16 || VT == MVT::v2bf16) &&
+  if ((Isv2f16Orv2bf16Type(VT.getSimpleVT())) &&
       !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
                                       VT, *Store->getMemOperand()))
     return expandUnalignedStore(Store, DAG);
@@ -2593,7 +2607,7 @@
       // v8f16 is a special case. PTX doesn't have st.v8.f16
       // instruction. Instead, we split the vector into v2f16 chunks and
       // store them with st.v4.b32.
-      assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+      assert((Isf16Orbf16Type(EltVT.getSimpleVT())) &&
              "Wrong type for the vector.");
       Opcode = NVPTXISD::StoreV4;
       StoreF16x2 = true;
@@ -2608,24 +2622,14 @@
     if (StoreF16x2) {
       // Combine f16,f16 -> v2f16
       NumElts /= 2;
-      if (EltVT == MVT::f16) {
-        for (unsigned i = 0; i < NumElts; ++i) {
-          SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
-                                   DAG.getIntPtrConstant(i * 2, DL));
-          SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
-                                   DAG.getIntPtrConstant(i * 2 + 1, DL));
-          SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
-          Ops.push_back(V2);
-        }
-      } else {
-        for (unsigned i = 0; i < NumElts; ++i) {
-          SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::bf16, Val,
-                                   DAG.getIntPtrConstant(i * 2, DL));
-          SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::bf16, Val,
-                                   DAG.getIntPtrConstant(i * 2 + 1, DL));
-          SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2bf16, E0, E1);
-          Ops.push_back(V2);
-        }
+      for (unsigned i = 0; i < NumElts; ++i) {
+        SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
+                                 DAG.getIntPtrConstant(i * 2, DL));
+        SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
+                                 DAG.getIntPtrConstant(i * 2 + 1, DL));
+        EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
+        SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
+        Ops.push_back(V2);
       }
     } else {
       // Then the split values
@@ -2796,7 +2800,7 @@
           EVT LoadVT = EltVT;
           if (EltVT == MVT::i1)
             LoadVT = MVT::i8;
-          else if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16)
+          else if (Isv2f16Orv2bf16Type(EltVT.getSimpleVT()))
             // getLoad needs a vector type, but it can't handle
             // vectors which contain v2f16 or v2bf16 elements. So we must load
             // using i32 here and then bitcast back.
@@ -5234,7 +5238,7 @@
     // v8f16 is a special case. PTX doesn't have ld.v8.f16
     // instruction. Instead, we split the vector into v2f16 chunks and
     // load them with ld.v4.b32.
-    assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+    assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
            "Unsupported v8 vector type.");
     LoadF16x2 = true;
     Opcode = NVPTXISD::LoadV4;
Index: llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -312,10 +312,6 @@
       Ret = (7 << 28);
     } else if (RC == &NVPTX::Float16x2RegsRegClass) {
       Ret = (8 << 28);
-    } else if (RC == &NVPTX::BFloat16RegsRegClass) {
-      Ret = (9 << 28);
-    } else if (RC == &NVPTX::BFloat16x2RegsRegClass) {
-      Ret = (10 << 28);
     } else {
       report_fatal_error("Bad register class");
     }
Index: clang/include/clang/Basic/BuiltinsNVPTX.def
===================================================================
--- clang/include/clang/Basic/BuiltinsNVPTX.def
+++ clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -145,12 +145,16 @@
 TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
@@ -187,12 +191,16 @@
 TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to