jchlanda updated this revision to Diff 406322.
jchlanda added a comment.

Set correct SM and PTX version.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D118977/new/

https://reviews.llvm.org/D118977

Files:
  clang/include/clang/Basic/BuiltinsNVPTX.def
  clang/test/CodeGen/builtins-nvptx-native-half-type.c
  clang/test/CodeGen/builtins-nvptx.c
  llvm/include/llvm/IR/IntrinsicsNVVM.td
  llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
  llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
  llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
  llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-instcombine.ll
  llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
  llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
  llvm/test/CodeGen/NVPTX/math-intrins.ll

Index: llvm/test/CodeGen/NVPTX/math-intrins.ll
===================================================================
--- llvm/test/CodeGen/NVPTX/math-intrins.ll
+++ llvm/test/CodeGen/NVPTX/math-intrins.ll
@@ -30,6 +30,15 @@
 declare float @llvm.fma.f32(float, float, float) #0
 declare double @llvm.fma.f64(double, double, double) #0
 
+declare half @llvm.nvvm.fma.rn.f16(half, half, half)
+declare half @llvm.nvvm.fma.rn.ftz.f16(half, half, half)
+declare half @llvm.nvvm.fma.rn.sat.f16(half, half, half)
+declare half @llvm.nvvm.fma.rn.ftz.sat.f16(half, half, half)
+declare <2 x half> @llvm.nvvm.fma.rn.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fma.rn.sat.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2(<2 x half>, <2 x half>, <2 x half>)
+
 ; ---- ceil ----
 
 ; CHECK-LABEL: ceil_float
@@ -328,5 +337,69 @@
   ret double %x
 }
 
+; CHECK-LABEL: fma_rn_f16
+define half @fma_rn_f16(half %0, half %1, half %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.f16
+  %res = call half @llvm.nvvm.fma.rn.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_f16
+define half @fma_rn_ftz_f16(half %0, half %1, half %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.ftz.f16
+  %res = call half @llvm.nvvm.fma.rn.ftz.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_sat_f16
+define half @fma_rn_sat_f16(half %0, half %1, half %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.sat.f16
+  %res = call half @llvm.nvvm.fma.rn.sat.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_sat_f16
+define half @fma_rn_ftz_sat_f16(half %0, half %1, half %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.ftz.sat.f16
+  %res = call half @llvm.nvvm.fma.rn.ftz.sat.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_f16x2
+define <2 x half> @fma_rn_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.f16x2
+  %res = call <2 x half> @llvm.nvvm.fma.rn.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_f16x2
+define <2 x half> @fma_rn_ftz_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.ftz.f16x2
+  %res = call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_sat_f16x2
+define <2 x half> @fma_rn_sat_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.sat.f16x2
+  %res = call <2 x half> @llvm.nvvm.fma.rn.sat.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_sat_f16x2
+define <2 x half> @fma_rn_ftz_sat_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.ftz.sat.f16x2
+  %res = call <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
 attributes #0 = { nounwind readnone }
 attributes #1 = { "denormal-fp-math-f32" = "preserve-sign" }
Index: llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
===================================================================
--- llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
+++ llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
@@ -36,6 +36,7 @@
 
 ; CHECK-LABEL: fmin_xorsign_abs_f16
 define half @fmin_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmin.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -43,6 +44,7 @@
 
 ; CHECK-LABEL: fmin_ftz_xorsign_abs_f16
 define half @fmin_ftz_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmin.ftz.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -50,6 +52,7 @@
 
 ; CHECK-LABEL: fmin_nan_xorsign_abs_f16
 define half @fmin_nan_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmin.nan.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -57,6 +60,7 @@
 
 ; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f16
 define half @fmin_ftz_nan_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.NaN.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -64,6 +68,7 @@
 
 ; CHECK-LABEL: fmin_xorsign_abs_f16x2
 define <2 x half> @fmin_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -71,6 +76,7 @@
 
 ; CHECK-LABEL: fmin_ftz_xorsign_abs_f16x2
 define <2 x half> @fmin_ftz_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -78,6 +84,7 @@
 
 ; CHECK-LABEL: fmin_nan_xorsign_abs_f16x2
 define <2 x half> @fmin_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -85,6 +92,7 @@
 
 ; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f16x2
 define <2 x half> @fmin_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.NaN.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -92,6 +100,7 @@
 
 ; CHECK-LABEL: fmin_xorsign_abs_bf16
 define i16 @fmin_xorsign_abs_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.xorsign.abs.bf16
   %res = call i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -99,6 +108,7 @@
 
 ; CHECK-LABEL: fmin_nan_xorsign_abs_bf16
 define i16 @fmin_nan_xorsign_abs_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.xorsign.abs.bf16
   %res = call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -106,6 +116,7 @@
 
 ; CHECK-LABEL: fmin_xorsign_abs_bf16x2
 define i32 @fmin_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.xorsign.abs.bf16x2
   %res = call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32 %0, i32 %1)
   ret i32 %res
@@ -113,6 +124,7 @@
 
 ; CHECK-LABEL: fmin_nan_xorsign_abs_bf16x2
 define i32 @fmin_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.xorsign.abs.bf16x2
   %res = call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32 %0, i32 %1)
   ret i32 %res
@@ -120,6 +132,7 @@
 
 ; CHECK-LABEL: fmin_xorsign_abs_f
 define float @fmin_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.xorsign.abs.f
   %res = call float @llvm.nvvm.fmin.xorsign.abs.f(float %0, float %1)
   ret float %res
@@ -127,6 +140,7 @@
 
 ; CHECK-LABEL: fmin_ftz_xorsign_abs_f
 define float @fmin_ftz_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.xorsign.abs.f
   %res = call float @llvm.nvvm.fmin.ftz.xorsign.abs.f(float %0, float %1)
   ret float %res
@@ -134,6 +148,7 @@
 
 ; CHECK-LABEL: fmin_nan_xorsign_abs_f
 define float @fmin_nan_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.xorsign.abs.f
   %res = call float @llvm.nvvm.fmin.nan.xorsign.abs.f(float %0, float %1)
   ret float %res
@@ -141,6 +156,7 @@
 
 ; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f
 define float @fmin_ftz_nan_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.NaN.xorsign.abs.f
   %res = call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f(float %0, float %1)
   ret float %res
@@ -148,6 +164,7 @@
 
 ; CHECK-LABEL: fmax_xorsign_abs_f16
 define half @fmax_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmax.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -155,6 +172,7 @@
 
 ; CHECK-LABEL: fmax_ftz_xorsign_abs_f16
 define half @fmax_ftz_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmax.ftz.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -162,6 +180,7 @@
 
 ; CHECK-LABEL: fmax_nan_xorsign_abs_f16
 define half @fmax_nan_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmax.nan.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -169,6 +188,7 @@
 
 ; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f16
 define half @fmax_ftz_nan_xorsign_abs_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.NaN.xorsign.abs.f16
   %res = call half @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16(half %0, half %1)
   ret half %res
@@ -176,6 +196,7 @@
 
 ; CHECK-LABEL: fmax_xorsign_abs_f16x2
 define <2 x half> @fmax_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -183,6 +204,7 @@
 
 ; CHECK-LABEL: fmax_ftz_xorsign_abs_f16x2
 define <2 x half> @fmax_ftz_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -190,6 +212,7 @@
 
 ; CHECK-LABEL: fmax_nan_xorsign_abs_f16x2
 define <2 x half> @fmax_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -197,6 +220,7 @@
 
 ; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f16x2
 define <2 x half> @fmax_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.NaN.xorsign.abs.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -204,6 +228,7 @@
 
 ; CHECK-LABEL: fmax_xorsign_abs_bf16
 define i16 @fmax_xorsign_abs_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.xorsign.abs.bf16
   %res = call i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -211,6 +236,7 @@
 
 ; CHECK-LABEL: fmax_nan_xorsign_abs_bf16
 define i16 @fmax_nan_xorsign_abs_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.xorsign.abs.bf16
   %res = call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -218,6 +244,7 @@
 
 ; CHECK-LABEL: fmax_xorsign_abs_bf16x2
 define i32 @fmax_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.xorsign.abs.bf16x2
   %res = call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32 %0, i32 %1)
   ret i32 %res
@@ -225,6 +252,7 @@
 
 ; CHECK-LABEL: fmax_nan_xorsign_abs_bf16x2
 define i32 @fmax_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.xorsign.abs.bf16x2
   %res = call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32 %0, i32 %1)
   ret i32 %res
@@ -232,6 +260,7 @@
 
 ; CHECK-LABEL: fmax_xorsign_abs_f
 define float @fmax_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.xorsign.abs.f
   %res = call float @llvm.nvvm.fmax.xorsign.abs.f(float %0, float %1)
   ret float %res
@@ -239,6 +268,7 @@
 
 ; CHECK-LABEL: fmax_ftz_xorsign_abs_f
 define float @fmax_ftz_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.xorsign.abs.f
   %res = call float @llvm.nvvm.fmax.ftz.xorsign.abs.f(float %0, float %1)
   ret float %res
@@ -246,6 +276,7 @@
 
 ; CHECK-LABEL: fmax_nan_xorsign_abs_f
 define float @fmax_nan_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.xorsign.abs.f
   %res = call float @llvm.nvvm.fmax.nan.xorsign.abs.f(float %0, float %1)
   ret float %res
@@ -253,6 +284,7 @@
 
 ; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f
 define float @fmax_ftz_nan_xorsign_abs_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.NaN.xorsign.abs.f
   %res = call float @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f(float %0, float %1)
   ret float %res
Index: llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
===================================================================
--- llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
+++ llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
@@ -35,8 +35,18 @@
 declare i32 @llvm.nvvm.fmax.bf16x2(i32, i32)
 declare i32 @llvm.nvvm.fmax.nan.bf16x2(i32, i32)
 
+declare half @llvm.nvvm.fma.rn.relu.f16(half, half, half)
+declare half @llvm.nvvm.fma.rn.ftz.relu.f16(half, half, half)
+declare <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare i16 @llvm.nvvm.fma.rn.bf16(i16, i16, i16)
+declare i16 @llvm.nvvm.fma.rn.relu.bf16(i16, i16, i16)
+declare i32 @llvm.nvvm.fma.rn.bf16x2(i32, i32, i32)
+declare i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32, i32, i32)
+
 ; CHECK-LABEL: abs_bf16
 define i16 @abs_bf16(i16 %0) {
+  ; CHECK-NOT: call
   ; CHECK: abs.bf16
   %res = call i16 @llvm.nvvm.abs.bf16(i16 %0);
   ret i16 %res
@@ -44,6 +54,7 @@
 
 ; CHECK-LABEL: abs_bf16x2
 define i32 @abs_bf16x2(i32 %0) {
+  ; CHECK-NOT: call
   ; CHECK: abs.bf16x2
   %res = call i32 @llvm.nvvm.abs.bf16x2(i32 %0);
   ret i32 %res
@@ -51,6 +62,7 @@
 
 ; CHECK-LABEL: neg_bf16
 define i16 @neg_bf16(i16 %0) {
+  ; CHECK-NOT: call
   ; CHECK: neg.bf16
   %res = call i16 @llvm.nvvm.neg.bf16(i16 %0);
   ret i16 %res
@@ -58,6 +70,7 @@
 
 ; CHECK-LABEL: neg_bf16x2
 define i32 @neg_bf16x2(i32 %0) {
+  ; CHECK-NOT: call
   ; CHECK: neg.bf16x2
   %res = call i32 @llvm.nvvm.neg.bf16x2(i32 %0);
   ret i32 %res
@@ -65,6 +78,7 @@
 
 ; CHECK-LABEL: fmin_nan_f
 define float @fmin_nan_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.f32
   %res = call float @llvm.nvvm.fmin.nan.f(float %0, float %1);
   ret float %res
@@ -72,6 +86,7 @@
 
 ; CHECK-LABEL: fmin_ftz_nan_f
 define float @fmin_ftz_nan_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.NaN.f32
   %res = call float @llvm.nvvm.fmin.ftz.nan.f(float %0, float %1);
   ret float %res
@@ -79,6 +94,7 @@
 
 ; CHECK-LABEL: fmin_f16
 define half @fmin_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.f16
   %res = call half @llvm.nvvm.fmin.f16(half %0, half %1)
   ret half %res
@@ -86,6 +102,7 @@
 
 ; CHECK-LABEL: fmin_ftz_f16
 define half @fmin_ftz_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.f16
   %res = call half @llvm.nvvm.fmin.ftz.f16(half %0, half %1)
   ret half %res
@@ -93,6 +110,7 @@
 
 ; CHECK-LABEL: fmin_nan_f16
 define half @fmin_nan_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.f16
   %res = call half @llvm.nvvm.fmin.nan.f16(half %0, half %1)
   ret half %res
@@ -100,6 +118,7 @@
 
 ; CHECK-LABEL: fmin_ftz_nan_f16
 define half @fmin_ftz_nan_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.NaN.f16
   %res = call half @llvm.nvvm.fmin.ftz.nan.f16(half %0, half %1)
   ret half %res
@@ -107,6 +126,7 @@
 
 ; CHECK-LABEL: fmin_f16x2
 define <2 x half> @fmin_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -114,6 +134,7 @@
 
 ; CHECK-LABEL: fmin_ftz_f16x2
 define <2 x half> @fmin_ftz_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -121,6 +142,7 @@
 
 ; CHECK-LABEL: fmin_nan_f16x2
 define <2 x half> @fmin_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -128,6 +150,7 @@
 
 ; CHECK-LABEL: fmin_ftz_nan_f16x2
 define <2 x half> @fmin_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.ftz.NaN.f16x2
   %res = call <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -135,6 +158,7 @@
 
 ; CHECK-LABEL: fmin_bf16
 define i16 @fmin_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.bf16
   %res = call i16 @llvm.nvvm.fmin.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -142,6 +166,7 @@
 
 ; CHECK-LABEL: fmin_nan_bf16
 define i16 @fmin_nan_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.bf16
   %res = call i16 @llvm.nvvm.fmin.nan.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -149,6 +174,7 @@
 
 ; CHECK-LABEL: fmin_bf16x2
 define i32 @fmin_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.bf16x2
   %res = call i32 @llvm.nvvm.fmin.bf16x2(i32 %0, i32 %1)
   ret i32 %res
@@ -156,6 +182,7 @@
 
 ; CHECK-LABEL: fmin_nan_bf16x2
 define i32 @fmin_nan_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: min.NaN.bf16x2
   %res = call i32 @llvm.nvvm.fmin.nan.bf16x2(i32 %0, i32 %1)
   ret i32 %res
@@ -163,6 +190,7 @@
 
 ; CHECK-LABEL: fmax_nan_f
 define float @fmax_nan_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.f32
   %res = call float @llvm.nvvm.fmax.nan.f(float %0, float %1);
   ret float %res
@@ -170,6 +198,7 @@
 
 ; CHECK-LABEL: fmax_ftz_nan_f
 define float @fmax_ftz_nan_f(float %0, float %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.NaN.f32
   %res = call float @llvm.nvvm.fmax.ftz.nan.f(float %0, float %1);
   ret float %res
@@ -177,6 +206,7 @@
 
 ; CHECK-LABEL: fmax_f16
 define half @fmax_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.f16
   %res = call half @llvm.nvvm.fmax.f16(half %0, half %1)
   ret half %res
@@ -184,6 +214,7 @@
 
 ; CHECK-LABEL: fmax_ftz_f16
 define half @fmax_ftz_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.f16
   %res = call half @llvm.nvvm.fmax.ftz.f16(half %0, half %1)
   ret half %res
@@ -191,6 +222,7 @@
 
 ; CHECK-LABEL: fmax_nan_f16
 define half @fmax_nan_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.f16
   %res = call half @llvm.nvvm.fmax.nan.f16(half %0, half %1)
   ret half %res
@@ -198,6 +230,7 @@
 
 ; CHECK-LABEL: fmax_ftz_nan_f16
 define half @fmax_ftz_nan_f16(half %0, half %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.NaN.f16
   %res = call half @llvm.nvvm.fmax.ftz.nan.f16(half %0, half %1)
   ret half %res
@@ -205,6 +238,7 @@
 
 ; CHECK-LABEL: fmax_f16x2
 define <2 x half> @fmax_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -212,6 +246,7 @@
 
 ; CHECK-LABEL: fmax_ftz_f16x2
 define <2 x half> @fmax_ftz_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -219,6 +254,7 @@
 
 ; CHECK-LABEL: fmax_nan_f16x2
 define <2 x half> @fmax_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -226,6 +262,7 @@
 
 ; CHECK-LABEL: fmax_ftz_nan_f16x2
 define <2 x half> @fmax_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.ftz.NaN.f16x2
   %res = call <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half> %0, <2 x half> %1)
   ret <2 x half> %res
@@ -233,6 +270,7 @@
 
 ; CHECK-LABEL: fmax_bf16
 define i16 @fmax_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.bf16
   %res = call i16 @llvm.nvvm.fmax.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -240,6 +278,7 @@
 
 ; CHECK-LABEL: fmax_nan_bf16
 define i16 @fmax_nan_bf16(i16 %0, i16 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.bf16
   %res = call i16 @llvm.nvvm.fmax.nan.bf16(i16 %0, i16 %1)
   ret i16 %res
@@ -247,6 +286,7 @@
 
 ; CHECK-LABEL: fmax_bf16x2
 define i32 @fmax_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.bf16x2
   %res = call i32 @llvm.nvvm.fmax.bf16x2(i32 %0, i32 %1)
   ret i32 %res
@@ -254,7 +294,72 @@
 
 ; CHECK-LABEL: fmax_nan_bf16x2
 define i32 @fmax_nan_bf16x2(i32 %0, i32 %1) {
+  ; CHECK-NOT: call
   ; CHECK: max.NaN.bf16x2
   %res = call i32 @llvm.nvvm.fmax.nan.bf16x2(i32 %0, i32 %1)
   ret i32 %res
 }
+
+; CHECK-LABEL: fma_rn_relu_f16
+define half @fma_rn_relu_f16(half %0, half %1, half %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.relu.f16
+  %res = call half @llvm.nvvm.fma.rn.relu.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_relu_f16
+define half @fma_rn_ftz_relu_f16(half %0, half %1, half %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.ftz.relu.f16
+  %res = call half @llvm.nvvm.fma.rn.ftz.relu.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_relu_f16x2
+define <2 x half> @fma_rn_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.relu.f16x2
+  %res = call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_relu_f16x2
+define <2 x half> @fma_rn_ftz_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.ftz.relu.f16x2
+  %res = call <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_bf16
+define i16 @fma_rn_bf16(i16 %0, i16 %1, i16 %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.bf16
+  %res = call i16 @llvm.nvvm.fma.rn.bf16(i16 %0, i16 %1, i16 %2)
+  ret i16 %res
+}
+
+; CHECK-LABEL: fma_rn_relu_bf16
+define i16 @fma_rn_relu_bf16(i16 %0, i16 %1, i16 %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.relu.bf16
+  %res = call i16 @llvm.nvvm.fma.rn.relu.bf16(i16 %0, i16 %1, i16 %2)
+  ret i16 %res
+}
+
+; CHECK-LABEL: fma_rn_bf16x2
+define i32 @fma_rn_bf16x2(i32 %0, i32 %1, i32 %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.bf16x2
+  %res = call i32 @llvm.nvvm.fma.rn.bf16x2(i32 %0, i32 %1, i32 %2)
+  ret i32 %res
+}
+
+; CHECK-LABEL: fma_rn_relu_bf16x2
+define i32 @fma_rn_relu_bf16x2(i32 %0, i32 %1, i32 %2) {
+  ; CHECK-NOT: call
+  ; CHECK: fma.rn.relu.bf16x2
+  %res = call i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32 %0, i32 %1, i32 %2)
+  ret i32 %res
+}
Index: llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-instcombine.ll
===================================================================
--- llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-instcombine.ll
+++ llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-instcombine.ll
@@ -24,6 +24,12 @@
 declare <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half>, <2 x half>)
 declare <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half>, <2 x half>)
 
+; f16 and f16x2 fma are available since ptx 4.2 and sm_53.
+declare half @llvm.nvvm.fma.rn.f16(half, half, half)
+declare half @llvm.nvvm.fma.rn.ftz.f16(half, half, half)
+declare <2 x half> @llvm.nvvm.fma.rn.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half>, <2 x half>, <2 x half>)
+
 ; CHECK-LABEL: fmin_f16
 define half @fmin_f16(half %0, half %1) {
   ; CHECK-NOT: @llvm.nvvm.fmin.f16
@@ -264,5 +270,53 @@
   ret <2 x half> %res
 }
 
+; CHECK-LABEL: fma_rn_f16
+define half @fma_rn_f16(half %0, half %1, half %2) {
+  ; CHECK-NOT: @llvm.nvvm.fma.rn.f16
+  ; CHECK: @llvm.fma.f16
+  %res = call half @llvm.nvvm.fma.rn.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_f16_no_attr
+define half @fma_rn_ftz_f16_no_attr(half %0, half %1, half %2) {
+  ; CHECK-NOT: @llvm.fma.f16
+  ; CHECK: @llvm.nvvm.fma.rn.ftz.f16
+  %res = call half @llvm.nvvm.fma.rn.ftz.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_f16
+define half @fma_rn_ftz_f16(half %0, half %1, half %2) #0 {
+  ; CHECK-NOT: @llvm.nvvm.fma.rn.ftz.f16
+  ; CHECK: @llvm.fma.f16
+  %res = call half @llvm.nvvm.fma.rn.ftz.f16(half %0, half %1, half %2)
+  ret half %res
+}
+
+; CHECK-LABEL: fma_rn_f16x2
+define <2 x half> @fma_rn_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: @llvm.nvvm.fma.rn.f16x2
+  ; CHECK: @llvm.fma.v2f16
+  %res = call <2 x half> @llvm.nvvm.fma.rn.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_f16x2
+define <2 x half> @fma_rn_ftz_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) #0 {
+  ; CHECK-NOT: @llvm.nvvm.fma.rn.ftz.f16x2
+  ; CHECK: @llvm.fma.v2f16
+  %res = call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_f16x2_no_attr
+define <2 x half> @fma_rn_ftz_f16x2_no_attr(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+  ; CHECK-NOT: @llvm.fma.v2f16
+  ; CHECK: @llvm.nvvm.fma.rn.ftz.f16x2
+  %res = call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+  ret <2 x half> %res
+}
+
 attributes #0 = { "denormal-fp-math"="preserve-sign" }
 attributes #1 = { "denormal-fp-math-f32"="preserve-sign" }
Index: llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -195,6 +195,14 @@
       return {Intrinsic::fma, FTZ_MustBeOff};
     case Intrinsic::nvvm_fma_rn_ftz_f:
       return {Intrinsic::fma, FTZ_MustBeOn};
+    case Intrinsic::nvvm_fma_rn_f16:
+      return {Intrinsic::fma, FTZ_MustBeOff, true};
+    case Intrinsic::nvvm_fma_rn_ftz_f16:
+      return {Intrinsic::fma, FTZ_MustBeOn, true};
+    case Intrinsic::nvvm_fma_rn_f16x2:
+      return {Intrinsic::fma, FTZ_MustBeOff, true};
+    case Intrinsic::nvvm_fma_rn_ftz_f16x2:
+      return {Intrinsic::fma, FTZ_MustBeOn, true};
     case Intrinsic::nvvm_fmax_d:
       return {Intrinsic::maxnum, FTZ_Any};
     case Intrinsic::nvvm_fmax_f:
Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -568,12 +568,13 @@
 
 class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
   NVPTXRegClass s0_regclass, NVPTXRegClass s1_regclass,
-  NVPTXRegClass s2_regclass, Intrinsic IntOP>
+  NVPTXRegClass s2_regclass, Intrinsic IntOP, list<Predicate> Preds = []>
             : NVPTXInst<(outs t_regclass:$dst),
               (ins s0_regclass:$src0, s1_regclass:$src1, s2_regclass:$src2),
             OpcStr,
         [(set t_regclass:$dst,
-          (IntOP s0_regclass:$src0, s1_regclass:$src1, s2_regclass:$src2))]>;
+          (IntOP s0_regclass:$src0, s1_regclass:$src1, s2_regclass:$src2))]>,
+          Requires<Preds>;
 
 //
 // MISC
@@ -648,6 +649,7 @@
 //
 // Min Max f16, f16x2, bf16, bf16x2
 //
+
 class MIN_MAX_TUPLE<string V, Intrinsic I, NVPTXRegClass RC,
                     list<Predicate> Preds = [hasPTX70, hasSM80]> {
   string Variant = V;
@@ -931,35 +933,66 @@
 // Fma
 //
 
-def INT_NVVM_FMA_RN_FTZ_F
-  : F_MATH_3<"fma.rn.ftz.f32 \t$dst, $src0, $src1, $src2;", Float32Regs,
-    Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rn_ftz_f>;
-def INT_NVVM_FMA_RN_F : F_MATH_3<"fma.rn.f32 \t$dst, $src0, $src1, $src2;",
-  Float32Regs, Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rn_f>;
-def INT_NVVM_FMA_RZ_FTZ_F
-  : F_MATH_3<"fma.rz.ftz.f32 \t$dst, $src0, $src1, $src2;", Float32Regs,
-    Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rz_ftz_f>;
-def INT_NVVM_FMA_RZ_F : F_MATH_3<"fma.rz.f32 \t$dst, $src0, $src1, $src2;",
-  Float32Regs, Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rz_f>;
-def INT_NVVM_FMA_RM_FTZ_F
-  : F_MATH_3<"fma.rm.ftz.f32 \t$dst, $src0, $src1, $src2;", Float32Regs,
-    Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rm_ftz_f>;
-def INT_NVVM_FMA_RM_F : F_MATH_3<"fma.rm.f32 \t$dst, $src0, $src1, $src2;",
-  Float32Regs, Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rm_f>;
-def INT_NVVM_FMA_RP_FTZ_F
-  : F_MATH_3<"fma.rp.ftz.f32 \t$dst, $src0, $src1, $src2;", Float32Regs,
-    Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rp_ftz_f>;
-def INT_NVVM_FMA_RP_F : F_MATH_3<"fma.rp.f32 \t$dst, $src0, $src1, $src2;",
-  Float32Regs, Float32Regs, Float32Regs, Float32Regs, int_nvvm_fma_rp_f>;
-
-def INT_NVVM_FMA_RN_D : F_MATH_3<"fma.rn.f64 \t$dst, $src0, $src1, $src2;",
-  Float64Regs, Float64Regs, Float64Regs, Float64Regs, int_nvvm_fma_rn_d>;
-def INT_NVVM_FMA_RZ_D : F_MATH_3<"fma.rz.f64 \t$dst, $src0, $src1, $src2;",
-  Float64Regs, Float64Regs, Float64Regs, Float64Regs, int_nvvm_fma_rz_d>;
-def INT_NVVM_FMA_RM_D : F_MATH_3<"fma.rm.f64 \t$dst, $src0, $src1, $src2;",
-  Float64Regs, Float64Regs, Float64Regs, Float64Regs, int_nvvm_fma_rm_d>;
-def INT_NVVM_FMA_RP_D : F_MATH_3<"fma.rp.f64 \t$dst, $src0, $src1, $src2;",
-  Float64Regs, Float64Regs, Float64Regs, Float64Regs, int_nvvm_fma_rp_d>;
+class FMA_TUPLE<string V, Intrinsic I, NVPTXRegClass RC,
+                list<Predicate> Preds = [hasPTX70, hasSM80]> {
+  string Variant = V;
+  Intrinsic Intr = I;
+  NVPTXRegClass RegClass = RC;
+  list<Predicate> Predicates = Preds;
+}
+
+multiclass FMA_INST {
+  foreach P = [
+    FMA_TUPLE<"_rn_f64", int_nvvm_fma_rn_d, Float64Regs, []>,
+    FMA_TUPLE<"_rz_f64", int_nvvm_fma_rz_d, Float64Regs, []>,
+    FMA_TUPLE<"_rm_f64", int_nvvm_fma_rm_d, Float64Regs, []>,
+    FMA_TUPLE<"_rp_f64", int_nvvm_fma_rp_d, Float64Regs, []>,
+
+    FMA_TUPLE<"_rn_ftz_f32", int_nvvm_fma_rn_ftz_f, Float32Regs, []>,
+    FMA_TUPLE<"_rn_f32", int_nvvm_fma_rn_f, Float32Regs, []>,
+    FMA_TUPLE<"_rz_ftz_f32", int_nvvm_fma_rz_ftz_f, Float32Regs, []>,
+    FMA_TUPLE<"_rz_f32", int_nvvm_fma_rz_f, Float32Regs, []>,
+    FMA_TUPLE<"_rm_f32", int_nvvm_fma_rm_f, Float32Regs, []>,
+    FMA_TUPLE<"_rm_ftz_f32", int_nvvm_fma_rm_ftz_f, Float32Regs, []>,
+    FMA_TUPLE<"_rp_f32", int_nvvm_fma_rp_f, Float32Regs, []>,
+    FMA_TUPLE<"_rp_ftz_f32", int_nvvm_fma_rp_ftz_f, Float32Regs, []>,
+
+    FMA_TUPLE<"_rn_f16", int_nvvm_fma_rn_f16, Float16Regs, [hasPTX42, hasSM53]>,
+    FMA_TUPLE<"_rn_ftz_f16", int_nvvm_fma_rn_ftz_f16, Float16Regs,
+      [hasPTX42, hasSM53]>,
+    FMA_TUPLE<"_rn_sat_f16", int_nvvm_fma_rn_sat_f16, Float16Regs,
+      [hasPTX42, hasSM53]>,
+    FMA_TUPLE<"_rn_ftz_sat_f16", int_nvvm_fma_rn_ftz_sat_f16, Float16Regs,
+      [hasPTX42, hasSM53]>,
+    FMA_TUPLE<"_rn_relu_f16", int_nvvm_fma_rn_relu_f16, Float16Regs>,
+    FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, Float16Regs>,
+
+    FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, Float16x2Regs,
+      [hasPTX42, hasSM53]>,
+    FMA_TUPLE<"_rn_ftz_f16x2", int_nvvm_fma_rn_ftz_f16x2, Float16x2Regs,
+      [hasPTX42, hasSM53]>,
+    FMA_TUPLE<"_rn_sat_f16x2", int_nvvm_fma_rn_sat_f16x2, Float16x2Regs,
+      [hasPTX42, hasSM53]>,
+    FMA_TUPLE<"_rn_ftz_sat_f16x2", int_nvvm_fma_rn_ftz_sat_f16x2,
+      Float16x2Regs>,
+    FMA_TUPLE<"_rn_relu_f16x2", int_nvvm_fma_rn_relu_f16x2, Float16x2Regs>,
+    FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
+      Float16x2Regs>,
+
+    FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs>,
+    FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs>,
+
+    FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, Int32Regs>,
+    FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, Int32Regs>
+  ] in {
+    def P.Variant :
+      F_MATH_3<!strconcat("fma",
+        !subst("_", ".", P.Variant), " \t$dst, $src0, $src1, $src2;"),
+        P.RegClass, P.RegClass, P.RegClass, P.RegClass, P.Intr, P.Predicates>;
+  }
+}
+
+defm INT_NVVM_FMA : FMA_INST;
 
 //
 // Rcp
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -145,6 +145,7 @@
 def True : Predicate<"true">;
 
 def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
+def hasPTX42 : Predicate<"Subtarget->getPTXVersion() >= 42">;
 def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
 def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
 def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
@@ -155,6 +156,7 @@
 def hasPTX72 : Predicate<"Subtarget->getPTXVersion() >= 72">;
 
 def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
+def hasSM53 : Predicate<"Subtarget->getSmVersion() >= 53">;
 def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;
 def hasSM72 : Predicate<"Subtarget->getSmVersion() >= 72">;
 def hasSM75 : Predicate<"Subtarget->getSmVersion() >= 75">;
Index: llvm/include/llvm/IR/IntrinsicsNVVM.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -854,47 +854,50 @@
 // Fma
 //
 
-  def int_nvvm_fma_rn_ftz_f : GCCBuiltin<"__nvvm_fma_rn_ftz_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
-        [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rn_f : GCCBuiltin<"__nvvm_fma_rn_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
-        [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rz_ftz_f : GCCBuiltin<"__nvvm_fma_rz_ftz_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
-        [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rz_f : GCCBuiltin<"__nvvm_fma_rz_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
-        [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rm_ftz_f : GCCBuiltin<"__nvvm_fma_rm_ftz_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
-        [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rm_f : GCCBuiltin<"__nvvm_fma_rm_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
-        [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rp_ftz_f : GCCBuiltin<"__nvvm_fma_rp_ftz_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
-        [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rp_f : GCCBuiltin<"__nvvm_fma_rp_f">,
-      DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty, llvm_float_ty],
+  foreach variant = ["_rn_f16", "_rn_ftz_f16", "_rn_sat_f16",
+    "_rn_ftz_sat_f16", "_rn_relu_f16", "_rn_ftz_relu_f16"] in {
+    def int_nvvm_fma # variant : GCCBuiltin<!strconcat("__nvvm_fma", variant)>,
+        DefaultAttrsIntrinsic<[llvm_half_ty],
+          [llvm_half_ty, llvm_half_ty, llvm_half_ty],
+          [IntrNoMem, IntrSpeculatable]>;
+  }
+
+  foreach variant = ["_rn_f16x2", "_rn_ftz_f16x2", "_rn_sat_f16x2",
+    "_rn_ftz_sat_f16x2", "_rn_relu_f16x2", "_rn_ftz_relu_f16x2"] in {
+    def int_nvvm_fma # variant : GCCBuiltin<!strconcat("__nvvm_fma", variant)>,
+      DefaultAttrsIntrinsic<[llvm_v2f16_ty],
+        [llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
         [IntrNoMem, IntrSpeculatable]>;
+  }
 
-  def int_nvvm_fma_rn_d : GCCBuiltin<"__nvvm_fma_rn_d">,
-      DefaultAttrsIntrinsic<[llvm_double_ty],
-        [llvm_double_ty, llvm_double_ty, llvm_double_ty],
+  foreach variant = ["_rn_bf16", "_rn_relu_bf16"] in {
+    def int_nvvm_fma # variant : GCCBuiltin<!strconcat("__nvvm_fma", variant)>,
+      DefaultAttrsIntrinsic<[llvm_i16_ty],
+        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
         [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rz_d : GCCBuiltin<"__nvvm_fma_rz_d">,
-      DefaultAttrsIntrinsic<[llvm_double_ty],
-        [llvm_double_ty, llvm_double_ty, llvm_double_ty],
+  }
+
+  foreach variant = ["_rn_bf16x2", "_rn_relu_bf16x2"] in {
+    def int_nvvm_fma # variant : GCCBuiltin<!strconcat("__nvvm_fma", variant)>,
+      DefaultAttrsIntrinsic<[llvm_i32_ty],
+        [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
         [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rm_d : GCCBuiltin<"__nvvm_fma_rm_d">,
-      DefaultAttrsIntrinsic<[llvm_double_ty],
-        [llvm_double_ty, llvm_double_ty, llvm_double_ty],
+  }
+
+  foreach variant = ["_rn_ftz_f", "_rn_f", "_rz_ftz_f", "_rz_f", "_rm_ftz_f",
+    "_rm_f", "_rp_ftz_f", "_rp_f"] in {
+    def int_nvvm_fma # variant : GCCBuiltin<!strconcat("__nvvm_fma", variant)>,
+      DefaultAttrsIntrinsic<[llvm_float_ty],
+        [llvm_float_ty, llvm_float_ty, llvm_float_ty],
         [IntrNoMem, IntrSpeculatable]>;
-  def int_nvvm_fma_rp_d : GCCBuiltin<"__nvvm_fma_rp_d">,
+  }
+
+  foreach variant = ["_rn_d", "_rz_d", "_rm_d", "_rp_d"] in {
+    def int_nvvm_fma # variant : GCCBuiltin<!strconcat("__nvvm_fma", variant)>,
       DefaultAttrsIntrinsic<[llvm_double_ty],
         [llvm_double_ty, llvm_double_ty, llvm_double_ty],
         [IntrNoMem, IntrSpeculatable]>;
+  }
 
 //
 // Rcp
Index: clang/test/CodeGen/builtins-nvptx.c
===================================================================
--- clang/test/CodeGen/builtins-nvptx.c
+++ clang/test/CodeGen/builtins-nvptx.c
@@ -866,6 +866,22 @@
 #endif
   // CHECK: ret void
 }
+
+// CHECK-LABEL: nvvm_fma_bf16_bf16x2_sm80
+__device__ void nvvm_fma_bf16_bf16x2_sm80() {
+#if __CUDA_ARCH__ >= 800
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.bf16
+  __nvvm_fma_rn_bf16(0x1234, 0x7FBF, 0x1234);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.relu.bf16
+  __nvvm_fma_rn_relu_bf16(0x1234, 0x7FBF, 0x1234);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.bf16x2
+  __nvvm_fma_rn_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.relu.bf16x2
+  __nvvm_fma_rn_relu_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
+#endif
+  // CHECK: ret void
+}
+
 // CHECK-LABEL: nvvm_min_max_sm86
 __device__ void nvvm_min_max_sm86() {
 #if __CUDA_ARCH__ >= 860
Index: clang/test/CodeGen/builtins-nvptx-native-half-type.c
===================================================================
--- clang/test/CodeGen/builtins-nvptx-native-half-type.c
+++ clang/test/CodeGen/builtins-nvptx-native-half-type.c
@@ -20,6 +20,16 @@
 // RUN:   -fnative-half-type -S -emit-llvm -o - -x cuda %s \
 // RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX72_SM86 %s
 
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx-unknown-unknown -target-cpu \
+// RUN:   sm_53 -target-feature +ptx42 -fcuda-is-device -fnative-half-type -S \
+// RUN:   -emit-llvm -o - -x cuda %s \
+// RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX42_SM53 %s
+
+// RUN: %clang_cc1 -ffp-contract=off -triple nvptx64-unknown-unknown \
+// RUN:   -target-cpu sm_53 -target-feature +ptx42 -fcuda-is-device \
+// RUN:   -fnative-half-type -S -emit-llvm -o - -x cuda %s \
+// RUN:   | FileCheck -check-prefix=CHECK -check-prefix=CHECK_PTX42_SM53 %s
+
 #define __device__ __attribute__((device))
 
 // CHECK-LABEL: nvvm_min_max_sm80
@@ -62,6 +72,52 @@
   // CHECK: ret void
 }
 
+// CHECK-LABEL: nvvm_fma_f16_f16x2_sm80
+__device__ void nvvm_fma_f16_f16x2_sm80() {
+#if __CUDA_ARCH__ >= 800
+  // CHECK_PTX70_SM80: call half @llvm.nvvm.fma.rn.relu.f16
+  __nvvm_fma_rn_relu_f16(0.1f16, 0.1f16, 0.1f16);
+  // CHECK_PTX70_SM80: call half @llvm.nvvm.fma.rn.ftz.relu.f16
+  __nvvm_fma_rn_ftz_relu_f16(0.1f16, 0.1f16, 0.1f16);
+
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2
+  __nvvm_fma_rn_relu_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
+                           {0.1f16, 0.7f16});
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2
+  __nvvm_fma_rn_ftz_relu_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
+                               {0.1f16, 0.7f16});
+#endif
+  // CHECK: ret void
+}
+
+// CHECK-LABEL: nvvm_fma_f16_f16x2_sm53
+__device__ void nvvm_fma_f16_f16x2_sm53() {
+#if __CUDA_ARCH__ >= 530
+  // CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.f16
+  __nvvm_fma_rn_f16(0.1f16, 0.1f16, 0.1f16);
+  // CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.ftz.f16
+  __nvvm_fma_rn_ftz_f16(0.1f16, 0.1f16, 0.1f16);
+  // CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.sat.f16
+  __nvvm_fma_rn_sat_f16(0.1f16, 0.1f16, 0.1f16);
+  // CHECK_PTX42_SM53: call half @llvm.nvvm.fma.rn.ftz.sat.f16
+  __nvvm_fma_rn_ftz_sat_f16(0.1f16, 0.1f16, 0.1f16);
+
+  // CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.f16x2
+  __nvvm_fma_rn_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
+                      {0.1f16, 0.7f16});
+  // CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.ftz.f16x2
+  __nvvm_fma_rn_ftz_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
+                          {0.1f16, 0.7f16});
+  // CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.sat.f16x2
+  __nvvm_fma_rn_sat_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
+                          {0.1f16, 0.7f16});
+  // CHECK_PTX42_SM53: call <2 x half> @llvm.nvvm.fma.rn.ftz.sat.f16x2
+  __nvvm_fma_rn_ftz_sat_f16x2({0.1f16, 0.7f16}, {0.1f16, 0.7f16},
+                              {0.1f16, 0.7f16});
+#endif
+  // CHECK: ret void
+}
+
 // CHECK-LABEL: nvvm_min_max_sm86
 __device__ void nvvm_min_max_sm86() {
 #if __CUDA_ARCH__ >= 860
Index: clang/include/clang/Basic/BuiltinsNVPTX.def
===================================================================
--- clang/include/clang/Basic/BuiltinsNVPTX.def
+++ clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -293,6 +293,22 @@
 
 // Fma
 
+BUILTIN(__nvvm_fma_rn_f16, "hhhh", "")
+BUILTIN(__nvvm_fma_rn_ftz_f16, "hhhh", "")
+BUILTIN(__nvvm_fma_rn_sat_f16, "hhhh", "")
+BUILTIN(__nvvm_fma_rn_ftz_sat_f16, "hhhh", "")
+TARGET_BUILTIN(__nvvm_fma_rn_relu_f16, "hhhh", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16, "hhhh", "", AND(SM_80, PTX70))
+BUILTIN(__nvvm_fma_rn_f16x2, "V2hV2hV2hV2h", "")
+BUILTIN(__nvvm_fma_rn_ftz_f16x2, "V2hV2hV2hV2h", "")
+BUILTIN(__nvvm_fma_rn_sat_f16x2, "V2hV2hV2hV2h", "")
+BUILTIN(__nvvm_fma_rn_ftz_sat_f16x2, "V2hV2hV2hV2h", "")
+TARGET_BUILTIN(__nvvm_fma_rn_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
 BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")
 BUILTIN(__nvvm_fma_rn_f, "ffff", "")
 BUILTIN(__nvvm_fma_rz_ftz_f, "ffff", "")
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to