liutianle created this revision.
liutianle added reviewers: craig.topper, smaslov, LuoYuanke, wxiao3, 
annita.zhang.
Herald added subscribers: cfe-commits, mgorny.
Herald added a project: clang.

1. Enable infrastructure of AVX512_BF16, which is supported for BFLOAT16 in 
Cooper Lake;
2. Enable intrinsics for VCVTNE2PS2BF16, VCVTNEPS2BF16 and DPBF16PS 
instructions, which are Vector Neural Network Instructions supporting BFLOAT16 
inputs and conversion instructions from IEEE single precision.

For more details about BF16 intrinsic, please refer to the latest ISE document: 
https://software.intel.com/en-us/download/intel-architecture-instruction-set-extensions-programming-reference


Repository:
  rC Clang

https://reviews.llvm.org/D60552

Files:
  docs/ClangCommandLineReference.rst
  include/clang/Basic/BuiltinsX86.def
  include/clang/Driver/Options.td
  lib/Basic/Targets/X86.cpp
  lib/Basic/Targets/X86.h
  lib/CodeGen/CGBuiltin.cpp
  lib/Headers/CMakeLists.txt
  lib/Headers/avx512bf16intrin.h
  lib/Headers/avx512vlbf16intrin.h
  lib/Headers/cpuid.h
  lib/Headers/immintrin.h
  test/CodeGen/attr-target-x86.c
  test/CodeGen/avx512bf16-builtins.c
  test/CodeGen/avx512vlbf16-builtins.c

Index: test/CodeGen/avx512vlbf16-builtins.c
===================================================================
--- /dev/null
+++ test/CodeGen/avx512vlbf16-builtins.c
@@ -0,0 +1,142 @@
+//  RUN: %clang_cc1 -ffreestanding %s -triple=x86_64-apple-darwin \
+//  RUN:            -target-feature +avx512bf16 -target-feature \
+//  RUN:            +avx512vl -emit-llvm -o - -Wall -Werror | FileCheck %s
+
+#include <immintrin.h>
+
+__m128bh test_mm_cvtne2ps2bf16(__m128 A, __m128 B) {
+  // CHECK-LABEL: @test_mm_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.128
+  return _mm_cvtne2ps_pbh(A, B);
+}
+
+__m128bh test_mm_maskz_cvtne2ps2bf16(__m128 A, __m128 B, __mmask8 U) {
+  // CHECK-LABEL: @test_mm_maskz_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.128
+  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
+  return _mm_maskz_cvtne2ps_pbh(U, A, B);
+}
+
+__m128bh test_mm_mask_cvtne2ps2bf16(__m128bh C, __mmask8 U, __m128 A, __m128 B) {
+  // CHECK-LABEL: @test_mm_mask_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.128
+  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
+  return _mm_mask_cvtne2ps_pbh(C, U, A, B);
+}
+
+__m256bh test_mm256_cvtne2ps2bf16(__m256 A, __m256 B) {
+  // CHECK-LABEL: @test_mm256_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.256
+  return _mm256_cvtne2ps_pbh(A, B);
+}
+
+__m256bh test_mm256_maskz_cvtne2ps2bf16(__m256 A, __m256 B, __mmask16 U) {
+  // CHECK-LABEL: @test_mm256_maskz_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.256
+  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
+  return _mm256_maskz_cvtne2ps_pbh(U, A, B);
+}
+
+__m256bh test_mm256_mask_cvtne2ps2bf16(__m256bh C, __mmask16 U, __m256 A, __m256 B) {
+  // CHECK-LABEL: @test_mm256_mask_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.256
+  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
+  return _mm256_mask_cvtne2ps_pbh(C, U, A, B);
+}
+
+__m512bh test_mm512_cvtne2ps2bf16(__m512 A, __m512 B) {
+  // CHECK-LABEL: @test_mm512_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
+  return _mm512_cvtne2ps_pbh(A, B);
+}
+
+__m512bh test_mm512_maskz_cvtne2ps2bf16(__m512 A, __m512 B, __mmask32 U) {
+  // CHECK-LABEL: @test_mm512_maskz_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
+  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
+  return _mm512_maskz_cvtne2ps_pbh(U, A, B);
+}
+
+__m512bh test_mm512_mask_cvtne2ps2bf16(__m512bh C, __mmask32 U, __m512 A, __m512 B) {
+  // CHECK-LABEL: @test_mm512_mask_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
+  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
+  return _mm512_mask_cvtne2ps_pbh(C, U, A, B);
+}
+
+__m128bh test_mm_cvtneps2bf16(__m128 A) {
+  // CHECK-LABEL: @test_mm_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.mask.cvtneps2bf16.128
+  return _mm_cvtneps_pbh(A);
+}
+
+__m128bh test_mm_mask_cvtneps2bf16(__m128bh C, __mmask8 U, __m128 A) {
+  // CHECK-LABEL: @test_mm_mask_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.mask.cvtneps2bf16.
+  return _mm_mask_cvtneps_pbh(C, U, A);
+}
+
+__m128bh test_mm_maskz_cvtneps2bf16(__m128 A, __mmask8 U) {
+  // CHECK-LABEL: @test_mm_maskz_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.mask.cvtneps2bf16.128
+  return _mm_maskz_cvtneps_pbh(U, A);
+}
+
+__m128bh test_mm256_cvtneps2bf16(__m256 A) {
+  // CHECK-LABEL: @test_mm256_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.256
+  return _mm256_cvtneps_pbh(A);
+}
+
+__m128bh test_mm256_mask_cvtneps2bf16(__m128bh C, __mmask8 U, __m256 A) {
+  // CHECK-LABEL: @test_mm256_mask_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.256
+  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
+  return _mm256_mask_cvtneps_pbh(C, U, A);
+}
+
+__m128bh test_mm256_maskz_cvtneps2bf16(__m256 A, __mmask8 U) {
+  // CHECK-LABEL: @test_mm256_maskz_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.256
+  // CHECK: select <8 x i1> %{{.*}}, <8 x i16> %{{.*}}, <8 x i16> %{{.*}}
+  return _mm256_maskz_cvtneps_pbh(U, A);
+}
+
+__m128 test_mm_dpbf16_ps(__m128 D, __m128bh A, __m128bh B) {
+  // CHECK-LABEL: @test_mm_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.128
+  return _mm_dpbf16_ps(D, A, B);
+}
+
+__m128 test_mm_maskz_dpbf16_ps(__m128 D, __m128bh A, __m128bh B, __mmask8 U) {
+  // CHECK-LABEL: @test_mm_maskz_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.128
+  // CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
+  return _mm_maskz_dpbf16_ps(U, D, A, B);
+}
+
+__m128 test_mm_mask_dpbf16_ps(__m128 D, __m128bh A, __m128bh B, __mmask8 U) {
+  // CHECK-LABEL: @test_mm_mask_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.128
+  // CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
+  return _mm_mask_dpbf16_ps(D, U, A, B);
+}
+__m256 test_mm256_dpbf16_ps(__m256 D, __m256bh A, __m256bh B) {
+  // CHECK-LABEL: @test_mm256_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.256
+  return _mm256_dpbf16_ps(D, A, B);
+}
+
+__m256 test_mm256_maskz_dpbf16_ps(__m256 D, __m256bh A, __m256bh B, __mmask8 U) {
+  // CHECK-LABEL: @test_mm256_maskz_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.256
+  // CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
+  return _mm256_maskz_dpbf16_ps(U, D, A, B);
+}
+
+__m256 test_mm256_mask_dpbf16_ps(__m256 D, __m256bh A, __m256bh B, __mmask8 U) {
+  // CHECK-LABEL: @test_mm256_mask_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.256
+  // CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
+  return _mm256_mask_dpbf16_ps(D, U, A, B);
+}
Index: test/CodeGen/avx512bf16-builtins.c
===================================================================
--- /dev/null
+++ test/CodeGen/avx512bf16-builtins.c
@@ -0,0 +1,65 @@
+//  RUN: %clang_cc1 -ffreestanding %s -triple=x86_64-apple-darwin \
+//  RUN:            -target-feature +avx512bf16 -emit-llvm -o - -Wall -Werror \
+//  RUN:            | FileCheck %s
+
+#include <immintrin.h>
+
+__m512bh test_mm512_cvtne2ps2bf16(__m512 A, __m512 B) {
+  // CHECK-LABEL: @test_mm512_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
+  return _mm512_cvtne2ps_pbh(A, B);
+}
+
+__m512bh test_mm512_maskz_cvtne2ps2bf16(__m512 A, __m512 B, __mmask32 U) {
+  // CHECK-LABEL: @test_mm512_maskz_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
+  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
+  return _mm512_maskz_cvtne2ps_pbh(U, A, B);
+}
+
+__m512bh test_mm512_mask_cvtne2ps2bf16(__m512bh C, __mmask32 U, __m512 A, __m512 B) {
+  // CHECK-LABEL: @test_mm512_mask_cvtne2ps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512
+  // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}}
+  return _mm512_mask_cvtne2ps_pbh(C, U, A, B);
+}
+
+__m256bh test_mm512_cvtneps2bf16(__m512 A) {
+  // CHECK-LABEL: @test_mm512_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512
+  return _mm512_cvtneps_pbh(A);
+}
+
+__m256bh test_mm512_mask_cvtneps2bf16(__m256bh C, __mmask16 U, __m512 A) {
+  // CHECK-LABEL: @test_mm512_mask_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512
+  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
+  return _mm512_mask_cvtneps_pbh(C, U, A);
+}
+
+__m256bh test_mm512_maskz_cvtneps2bf16(__m512 A, __mmask16 U) {
+  // CHECK-LABEL: @test_mm512_maskz_cvtneps2bf16
+  // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512
+  // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}
+  return _mm512_maskz_cvtneps_pbh(U, A);
+}
+
+__m512 test_mm512_dpbf16_ps(__m512 D, __m512bh A, __m512bh B) {
+  // CHECK-LABEL: @test_mm512_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.512
+  return _mm512_dpbf16_ps(D, A, B);
+}
+
+__m512 test_mm512_maskz_dpbf16_ps(__m512 D, __m512bh A, __m512bh B, __mmask16 U) {
+  // CHECK-LABEL: @test_mm512_maskz_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.512
+  // CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
+  return _mm512_maskz_dpbf16_ps(U, D, A, B);
+}
+
+__m512 test_mm512_mask_dpbf16_ps(__m512 D, __m512bh A, __m512bh B, __mmask16 U) {
+  // CHECK-LABEL: @test_mm512_mask_dpbf16_ps
+  // CHECK: @llvm.x86.avx512bf16.dpbf16ps.512
+  // CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
+  return _mm512_mask_dpbf16_ps(D, U, A, B);
+}
Index: test/CodeGen/attr-target-x86.c
===================================================================
--- test/CodeGen/attr-target-x86.c
+++ test/CodeGen/attr-target-x86.c
@@ -50,9 +50,9 @@
 // CHECK: use_before_def{{.*}} #7
 // CHECK: #0 = {{.*}}"target-cpu"="i686" "target-features"="+cx8,+x87"
 // CHECK: #1 = {{.*}}"target-cpu"="ivybridge" "target-features"="+avx,+cx16,+cx8,+f16c,+fsgsbase,+fxsr,+mmx,+pclmul,+popcnt,+rdrnd,+sahf,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt"
-// CHECK: #2 = {{.*}}"target-cpu"="i686" "target-features"="+cx8,+x87,-aes,-avx,-avx2,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,-avx512vnni,-avx512vpopcntdq,-f16c,-fma,-fma4,-gfni,-pclmul,-sha,-sse2,-sse3,-sse4.1,-sse4.2,-sse4a,-ssse3,-vaes,-vpclmulqdq,-xop,-xsave,-xsaveopt"
+// CHECK: #2 = {{.*}}"target-cpu"="i686" "target-features"="+cx8,+x87,-aes,-avx,-avx2,-avx512bf16,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,-avx512vnni,-avx512vpopcntdq,-f16c,-fma,-fma4,-gfni,-pclmul,-sha,-sse2,-sse3,-sse4.1,-sse4.2,-sse4a,-ssse3,-vaes,-vpclmulqdq,-xop,-xsave,-xsaveopt"
 // CHECK: #3 = {{.*}}"target-cpu"="i686" "target-features"="+cx8,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87"
-// CHECK: #4 = {{.*}}"target-cpu"="i686" "target-features"="+cx8,+x87,-avx,-avx2,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,-avx512vnni,-avx512vpopcntdq,-f16c,-fma,-fma4,-sse4.1,-sse4.2,-vaes,-vpclmulqdq,-xop,-xsave,-xsaveopt"
+// CHECK: #4 = {{.*}}"target-cpu"="i686" "target-features"="+cx8,+x87,-avx,-avx2,-avx512bf16,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,-avx512vnni,-avx512vpopcntdq,-f16c,-fma,-fma4,-sse4.1,-sse4.2,-vaes,-vpclmulqdq,-xop,-xsave,-xsaveopt"
 // CHECK: #5 = {{.*}}"target-cpu"="ivybridge" "target-features"="+avx,+cx16,+cx8,+f16c,+fsgsbase,+fxsr,+mmx,+pclmul,+popcnt,+rdrnd,+sahf,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave,+xsaveopt,-aes,-vaes"
 // CHECK: #6 = {{.*}}"target-cpu"="i686" "target-features"="+cx8,+x87,-3dnow,-3dnowa,-mmx"
 // CHECK: #7 = {{.*}}"target-cpu"="lakemont" "target-features"="+cx8,+mmx"
Index: lib/Headers/immintrin.h
===================================================================
--- lib/Headers/immintrin.h
+++ lib/Headers/immintrin.h
@@ -181,6 +181,15 @@
 #include <avx512pfintrin.h>
 #endif
 
+#if !defined(_MSC_VER) || __has_feature(modules) || defined(__AVX512BF16__)
+#include <avx512bf16intrin.h>
+#endif
+
+#if !defined(_MSC_VER) || __has_feature(modules) || \
+    (defined(__AVX512VL__) && defined(__AVX512BF16__))
+#include <avx512vlbf16intrin.h>
+#endif
+
 #if !defined(_MSC_VER) || __has_feature(modules) || defined(__PKU__)
 #include <pkuintrin.h>
 #endif
Index: lib/Headers/cpuid.h
===================================================================
--- lib/Headers/cpuid.h
+++ lib/Headers/cpuid.h
@@ -184,6 +184,9 @@
 #define bit_PCONFIG       0x00040000
 #define bit_IBT           0x00100000
 
+/* Features in %eax for leaf 7 sub-leaf 1 */
+#define bit_AVX512BF16    0x00000020
+
 /* Features in %eax for leaf 13 sub-leaf 1 */
 #define bit_XSAVEOPT    0x00000001
 #define bit_XSAVEC      0x00000002
Index: lib/Headers/avx512vlbf16intrin.h
===================================================================
--- /dev/null
+++ lib/Headers/avx512vlbf16intrin.h
@@ -0,0 +1,149 @@
+/*===--------- avx512vlbf16intrin.h - AVX512_BF16 intrinsics ---------------===
+ *
+ * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+ * See https://llvm.org/LICENSE.txt for license information.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ *
+ *===-----------------------------------------------------------------------===
+ */
+#ifndef __IMMINTRIN_H
+#error "Never use <avx512vlbf16intrin.h> directly; include <immintrin.h> instead."
+#endif
+
+#ifndef __AVX512VLBF16INTRIN_H
+#define __AVX512VLBF16INTRIN_H
+
+typedef short __m128bh __attribute__((__vector_size__(16), __aligned__(16)));
+
+#define __DEFAULT_FN_ATTRS128 \
+  __attribute__((__always_inline__, __nodebug__, \
+                 __target__("avx512vl, avx512bf16"), __min_vector_width__(128)))
+#define __DEFAULT_FN_ATTRS256 \
+  __attribute__((__always_inline__, __nodebug__, \
+                 __target__("avx512vl, avx512bf16"), __min_vector_width__(256)))
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS128
+_mm_cvtne2ps_pbh(__m128 __A, __m128 __B) {
+  return (__m128bh)__builtin_ia32_cvtne2ps2bf16_128((__v4sf) __A,
+                                                    (__v4sf) __B);
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS128
+_mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
+  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
+                                             (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
+                                             (__v8hi)__W);
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS128
+_mm_maskz_cvtne2ps_pbh(__mmask8 __U, __m128 __A, __m128 __B) {
+  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
+                                             (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
+                                             (__v8hi)_mm_setzero_si128());
+}
+
+static __inline__ __m256bh __DEFAULT_FN_ATTRS256
+_mm256_cvtne2ps_pbh(__m256 __A, __m256 __B) {
+  return (__m256bh)__builtin_ia32_cvtne2ps2bf16_256((__v8sf) __A,
+                                                    (__v8sf) __B);
+}
+
+static __inline__ __m256bh __DEFAULT_FN_ATTRS256
+_mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
+  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
+                                         (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
+                                         (__v16hi)__W);
+}
+
+static __inline__ __m256bh __DEFAULT_FN_ATTRS256
+_mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
+  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
+                                         (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
+                                         (__v16hi)_mm256_setzero_si256());
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS128
+_mm_cvtneps_pbh(__m128 __A) {
+  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
+                                                  (__v8hi)_mm_undefined_si128(),
+                                                  (__mmask8)-1);
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS128
+_mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
+  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
+                                                        (__v8hi)__W,
+                                                        (__mmask8)__U);
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS128
+_mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
+  return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
+                                                    (__v8hi)_mm_setzero_si128(),
+                                                    (__mmask8)__U);
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS256
+_mm256_cvtneps_pbh(__m256 __A) {
+  return (__m128bh)__builtin_ia32_cvtneps2bf16_256((__v8sf)__A);
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS256
+_mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
+  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
+                                              (__v8hi)_mm256_cvtneps_pbh(__A),
+                                              (__v8hi)__W);
+}
+
+static __inline__ __m128bh __DEFAULT_FN_ATTRS256
+_mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
+  return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
+                                              (__v8hi)_mm256_cvtneps_pbh(__A),
+                                              (__v8hi)_mm_setzero_si128());
+}
+
+static __inline__ __m128 __DEFAULT_FN_ATTRS128
+_mm_dpbf16_ps(__m128 __D, __m128bh __A, __m128bh __B) {
+  return (__m128)__builtin_ia32_dpbf16ps_128((__v4sf)__D,
+                                             (__v4si)__A,
+                                             (__v4si)__B);
+}
+
+static __inline__ __m128 __DEFAULT_FN_ATTRS128
+_mm_mask_dpbf16_ps(__m128 __D, __mmask8 __U, __m128bh __A, __m128bh __B) {
+  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
+                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
+                                           (__v4sf)__D);
+}
+
+static __inline__ __m128 __DEFAULT_FN_ATTRS128
+_mm_maskz_dpbf16_ps(__mmask8 __U, __m128 __D, __m128bh __A, __m128bh __B) {
+  return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
+                                           (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
+                                           (__v4sf)_mm_setzero_si128());
+}
+
+static __inline__ __m256 __DEFAULT_FN_ATTRS256
+_mm256_dpbf16_ps(__m256 __D, __m256bh __A, __m256bh __B) {
+  return (__m256)__builtin_ia32_dpbf16ps_256((__v8sf)__D,
+                                             (__v8si)__A,
+                                             (__v8si)__B);
+}
+
+static __inline__ __m256 __DEFAULT_FN_ATTRS256
+_mm256_mask_dpbf16_ps(__m256 __D, __mmask8 __U, __m256bh __A, __m256bh __B) {
+  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
+                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
+                                        (__v8sf)__D);
+}
+
+static __inline__ __m256 __DEFAULT_FN_ATTRS256
+_mm256_maskz_dpbf16_ps(__mmask8 __U, __m256 __D, __m256bh __A, __m256bh __B) {
+  return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
+                                        (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
+                                        (__v8sf)_mm256_setzero_si256());
+}
+#undef __DEFAULT_FN_ATTRS128
+#undef __DEFAULT_FN_ATTRS256
+
+#endif
Index: lib/Headers/avx512bf16intrin.h
===================================================================
--- /dev/null
+++ lib/Headers/avx512bf16intrin.h
@@ -0,0 +1,83 @@
+/*===------------ avx512bf16intrin.h - AVX512_BF16 intrinsics --------------===
+ *
+ * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+ * See https://llvm.org/LICENSE.txt for license information.
+ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+ *
+ *===-----------------------------------------------------------------------===
+ */
+#ifndef __IMMINTRIN_H
+#error "Never use <avx512bf16intrin.h> directly; include <immintrin.h> instead."
+#endif
+
+#ifndef __AVX512BF16INTRIN_H
+#define __AVX512BF16INTRIN_H
+
+typedef short __m512bh __attribute__((__vector_size__(64), __aligned__(64)));
+typedef short __m256bh __attribute__((__vector_size__(32), __aligned__(32)));
+
+#define __DEFAULT_FN_ATTRS512 \
+  __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \
+                 __min_vector_width__(512)))
+
+static __inline__ __m512bh __DEFAULT_FN_ATTRS512
+_mm512_cvtne2ps_pbh(__m512 __A, __m512 __B) {
+  return (__m512bh)__builtin_ia32_cvtne2ps2bf16_512((__v16sf) __A,
+                                                    (__v16sf) __B);
+}
+
+static __inline__ __m512bh __DEFAULT_FN_ATTRS512
+_mm512_mask_cvtne2ps_pbh(__m512bh __W, __mmask32 __U, __m512 __A, __m512 __B) {
+  return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
+                                        (__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
+                                        (__v32hi)__W);
+}
+
+static __inline__ __m512bh __DEFAULT_FN_ATTRS512
+_mm512_maskz_cvtne2ps_pbh(__mmask32 __U, __m512 __A, __m512 __B) {
+  return (__m512bh)__builtin_ia32_selectw_512((__mmask32)__U,
+                                        (__v32hi)_mm512_cvtne2ps_pbh(__A, __B),
+                                        (__v32hi)_mm512_setzero_si512());
+}
+
+static __inline__ __m256bh __DEFAULT_FN_ATTRS512
+_mm512_cvtneps_pbh(__m512 __A) {
+  return (__m256bh)__builtin_ia32_cvtneps2bf16_512((__v16sf) __A);
+}
+static __inline__ __m256bh __DEFAULT_FN_ATTRS512
+_mm512_mask_cvtneps_pbh(__m256bh __W, __mmask16 __U, __m512 __A) {
+  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
+                                              (__v16hi)_mm512_cvtneps_pbh(__A),
+                                              (__v16hi)__W);
+}
+static __inline__ __m256bh __DEFAULT_FN_ATTRS512
+_mm512_maskz_cvtneps_pbh(__mmask16 __U, __m512 __A) {
+  return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
+                                              (__v16hi)_mm512_cvtneps_pbh(__A),
+                                              (__v16hi)_mm256_setzero_si256());
+}
+
+static __inline__ __m512 __DEFAULT_FN_ATTRS512
+_mm512_dpbf16_ps(__m512 __D, __m512bh __A, __m512bh __B) {
+  return (__m512)__builtin_ia32_dpbf16ps_512((__v16sf) __D,
+                                             (__v16si) __A,
+                                             (__v16si) __B);
+}
+
+static __inline__ __m512 __DEFAULT_FN_ATTRS512
+_mm512_mask_dpbf16_ps(__m512 __D, __mmask16 __U, __m512bh __A, __m512bh __B) {
+  return (__m512)__builtin_ia32_selectps_512((__mmask16)__U,
+                                       (__v16sf)_mm512_dpbf16_ps(__D, __A, __B),
+                                       (__v16sf)__D);
+}
+
+static __inline__ __m512 __DEFAULT_FN_ATTRS512
+_mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) {
+  return (__m512)__builtin_ia32_selectps_512((__mmask16)__U,
+                                       (__v16sf)_mm512_dpbf16_ps(__D, __A, __B),
+                                       (__v16sf)_mm512_setzero_si512());
+}
+
+#undef __DEFAULT_FN_ATTRS512
+
+#endif
Index: lib/Headers/CMakeLists.txt
===================================================================
--- lib/Headers/CMakeLists.txt
+++ lib/Headers/CMakeLists.txt
@@ -6,6 +6,7 @@
   armintr.h
   arm64intr.h
   avx2intrin.h
+  avx512bf16intrin.h
   avx512bwintrin.h
   avx512bitalgintrin.h
   avx512vlbitalgintrin.h
@@ -21,6 +22,7 @@
   avx512vbmivlintrin.h
   avx512vbmi2intrin.h
   avx512vlvbmi2intrin.h
+  avx512vlbf16intrin.h
   avx512vlbwintrin.h
   avx512vlcdintrin.h
   avx512vldqintrin.h
Index: lib/CodeGen/CGBuiltin.cpp
===================================================================
--- lib/CodeGen/CGBuiltin.cpp
+++ lib/CodeGen/CGBuiltin.cpp
@@ -11719,6 +11719,14 @@
   case X86::BI__builtin_ia32_cmpordsd:
     return getCmpIntrinsicCall(Intrinsic::x86_sse2_cmp_sd, 7);
 
+// AVX512 bf16 intrinsics
+  case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: {
+    Ops[2] = getMaskVecValue(*this, Ops[2],
+                             Ops[0]->getType()->getVectorNumElements());
+    Intrinsic::ID IID = Intrinsic::x86_avx512bf16_mask_cvtneps2bf16_128;
+    return Builder.CreateCall(CGM.getIntrinsic(IID), Ops);
+  }
+
   case X86::BI__emul:
   case X86::BI__emulu: {
     llvm::Type *Int64Ty = llvm::IntegerType::get(getLLVMContext(), 64);
Index: lib/Basic/Targets/X86.h
===================================================================
--- lib/Basic/Targets/X86.h
+++ lib/Basic/Targets/X86.h
@@ -68,6 +68,7 @@
   bool HasAVX512CD = false;
   bool HasAVX512VPOPCNTDQ = false;
   bool HasAVX512VNNI = false;
+  bool HasAVX512BF16 = false;
   bool HasAVX512ER = false;
   bool HasAVX512PF = false;
   bool HasAVX512DQ = false;
Index: lib/Basic/Targets/X86.cpp
===================================================================
--- lib/Basic/Targets/X86.cpp
+++ lib/Basic/Targets/X86.cpp
@@ -521,6 +521,7 @@
                 Features["avx512ifma"] = Features["avx512vpopcntdq"] =
                     Features["avx512bitalg"] = Features["avx512vnni"] =
                         Features["avx512vbmi2"] = false;
+                        Features["avx512bf16"] = false;
     break;
   }
 }
@@ -652,12 +653,15 @@
              Name == "avx512dq" || Name == "avx512bw" || Name == "avx512vl" ||
              Name == "avx512vbmi" || Name == "avx512ifma" ||
              Name == "avx512vpopcntdq" || Name == "avx512bitalg" ||
+             Name == "avx512bf16" ||
              Name == "avx512vnni" || Name == "avx512vbmi2") {
     if (Enabled)
       setSSELevel(Features, AVX512F, Enabled);
     // Enable BWI instruction if VBMI/VBMI2/BITALG is being enabled.
     if ((Name.startswith("avx512vbmi") || Name == "avx512bitalg") && Enabled)
       Features["avx512bw"] = true;
+    if (Name == "avx512bf16" && Enabled)
+      Features["avx512bw"] = Features["avx512vl"] = true;
     // Also disable VBMI/VBMI2/BITALG if BWI is being disabled.
     if (Name == "avx512bw" && !Enabled)
       Features["avx512vbmi"] = Features["avx512vbmi2"] =
@@ -751,6 +755,8 @@
       HasAVX512VPOPCNTDQ = true;
     } else if (Feature == "+avx512vnni") {
       HasAVX512VNNI = true;
+    } else if (Feature == "+avx512bf16") {
+      HasAVX512BF16 = true;
     } else if (Feature == "+avx512er") {
       HasAVX512ER = true;
     } else if (Feature == "+avx512pf") {
@@ -1141,6 +1147,8 @@
     Builder.defineMacro("__AVX512VPOPCNTDQ__");
   if (HasAVX512VNNI)
     Builder.defineMacro("__AVX512VNNI__");
+  if (HasAVX512BF16)
+    Builder.defineMacro("__AVX512BF16__");
   if (HasAVX512ER)
     Builder.defineMacro("__AVX512ER__");
   if (HasAVX512PF)
@@ -1305,6 +1313,7 @@
       .Case("avx512cd", true)
       .Case("avx512vpopcntdq", true)
       .Case("avx512vnni", true)
+      .Case("avx512bf16", true)
       .Case("avx512er", true)
       .Case("avx512pf", true)
       .Case("avx512dq", true)
@@ -1383,6 +1392,7 @@
       .Case("avx512cd", HasAVX512CD)
       .Case("avx512vpopcntdq", HasAVX512VPOPCNTDQ)
       .Case("avx512vnni", HasAVX512VNNI)
+      .Case("avx512bf16", HasAVX512BF16)
       .Case("avx512er", HasAVX512ER)
       .Case("avx512pf", HasAVX512PF)
       .Case("avx512dq", HasAVX512DQ)
Index: include/clang/Driver/Options.td
===================================================================
--- include/clang/Driver/Options.td
+++ include/clang/Driver/Options.td
@@ -2854,6 +2854,8 @@
 def mno_avx2 : Flag<["-"], "mno-avx2">, Group<m_x86_Features_Group>;
 def mavx512f : Flag<["-"], "mavx512f">, Group<m_x86_Features_Group>;
 def mno_avx512f : Flag<["-"], "mno-avx512f">, Group<m_x86_Features_Group>;
+def mavx512bf16 : Flag<["-"], "mavx512bf16">, Group<m_x86_Features_Group>;
+def mno_avx512bf16 : Flag<["-"], "mno-avx512bf16">, Group<m_x86_Features_Group>;
 def mavx512bitalg : Flag<["-"], "mavx512bitalg">, Group<m_x86_Features_Group>;
 def mno_avx512bitalg : Flag<["-"], "mno-avx512bitalg">, Group<m_x86_Features_Group>;
 def mavx512bw : Flag<["-"], "mavx512bw">, Group<m_x86_Features_Group>;
Index: include/clang/Basic/BuiltinsX86.def
===================================================================
--- include/clang/Basic/BuiltinsX86.def
+++ include/clang/Basic/BuiltinsX86.def
@@ -1825,6 +1825,24 @@
 TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb512, "V64cV64cV64c", "ncV:512:", "avx512vbmi")
 TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb128, "V16cV16cV16c", "ncV:128:", "avx512vbmi,avx512vl")
 TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb256, "V32cV32cV32c", "ncV:256:", "avx512vbmi,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_128, "V8sV4fV4f", "ncV:128:",
+                                                 "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_256, "V16sV8fV8f", "ncV:256:",
+                                                 "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_512, "V32sV16fV16f", "ncV:512:",
+                                                 "avx512bf16")
+TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_128_mask, "V8sV4fV8sUc", "ncV:128:",
+                                                "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_256, "V8sV8f", "ncV:256:",
+                                                "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_cvtneps2bf16_512, "V16sV16f", "ncV:512:",
+                                                "avx512bf16")
+TARGET_BUILTIN(__builtin_ia32_dpbf16ps_128, "V4fV4fV4iV4i", "ncV:128:",
+                                            "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_dpbf16ps_256, "V8fV8fV8iV8i", "ncV:256:",
+                                            "avx512bf16,avx512vl")
+TARGET_BUILTIN(__builtin_ia32_dpbf16ps_512, "V16fV16fV16iV16i", "ncV:512:",
+                                            "avx512bf16")
 
 // generic select intrinsics
 TARGET_BUILTIN(__builtin_ia32_selectb_128, "V16cUsV16cV16c", "ncV:128:", "avx512bw,avx512vl")
Index: docs/ClangCommandLineReference.rst
===================================================================
--- docs/ClangCommandLineReference.rst
+++ docs/ClangCommandLineReference.rst
@@ -2610,6 +2610,8 @@
 
 .. option:: -mavx512bitalg, -mno-avx512bitalg
 
+.. option:: -mavx512bf16, -mno-avx512bf16
+
 .. option:: -mavx512bw, -mno-avx512bw
 
 .. option:: -mavx512cd, -mno-avx512cd
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to