https://github.com/E00N777 updated https://github.com/llvm/llvm-project/pull/187460
>From aad7d00334f135c04783744f8f440a2dc278684c Mon Sep 17 00:00:00 2001 From: E0N777 <[email protected]> Date: Thu, 19 Mar 2026 16:45:45 +0800 Subject: [PATCH] [CIR][AArch64] Support BF16/FP16 NEON types and lower vdup lane builtins Made-with: Cursor --- .../lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp | 34 +++++- .../CodeGen/AArch64/bf16-getset-intrinsics.c | 76 ------------- clang/test/CodeGen/AArch64/neon/bf16-getset.c | 106 ++++++++++++++++-- 3 files changed, 124 insertions(+), 92 deletions(-) diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp index a3488bfcc3dec..f86abc6cfe1c7 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp @@ -134,10 +134,9 @@ static cir::VectorType getNeonType(CIRGenFunction *cgf, NeonTypeFlags typeFlags, v1Ty ? 1 : (4 << isQuad)); case NeonTypeFlags::BFloat16: if (allowBFloatArgsAndRet) - cgf->getCIRGenModule().errorNYI(loc, std::string("NEON type: BFloat16")); - else - cgf->getCIRGenModule().errorNYI(loc, std::string("NEON type: BFloat16")); - [[fallthrough]]; + return cir::VectorType::get(cgf->getCIRGenModule().bFloat16Ty, + v1Ty ? 1 : (4 << isQuad)); + return cir::VectorType::get(cgf->uInt16Ty, v1Ty ? 1 : (4 << isQuad)); case NeonTypeFlags::Float16: if (hasLegalHalfType) cgf->getCIRGenModule().errorNYI(loc, std::string("NEON type: Float16")); @@ -169,6 +168,20 @@ static cir::VectorType getNeonType(CIRGenFunction *cgf, NeonTypeFlags typeFlags, llvm_unreachable("Unknown vector element type!"); } +static int64_t getIntValueFromConstOp(mlir::Value val) { + return val.getDefiningOp<cir::ConstantOp>().getIntValue().getSExtValue(); +} + +static mlir::Value emitNeonSplat(CIRGenBuilderTy &builder, mlir::Location loc, + mlir::Value v, mlir::Value lane, + unsigned int resEltCnt) { + assert(isa<cir::ConstantOp>(lane.getDefiningOp()) && + "lane number is not a constant!"); + int64_t laneCst = getIntValueFromConstOp(lane); + llvm::SmallVector<int64_t, 4> shuffleMask(resEltCnt, laneCst); + return builder.createVecShuffle(loc, v, shuffleMask); +} + static mlir::Value emitCommonNeonBuiltinExpr( CIRGenFunction &cgf, unsigned builtinID, unsigned llvmIntrinsic, unsigned altLLVMIntrinsic, const char *nameHint, unsigned modifier, @@ -191,7 +204,8 @@ static mlir::Value emitCommonNeonBuiltinExpr( // The value of allowBFloatArgsAndRet is true for AArch64, but it should // come from ABI info. - const bool allowBFloatArgsAndRet = false; + // TODO(cir): Use ABInfo to extract this information + const bool allowBFloatArgsAndRet = cgf.getTarget().hasFastHalfType(); // FIXME // getTargetHooks().getABIInfo().allowBFloatArgsAndRet(); @@ -205,7 +219,15 @@ static mlir::Value emitCommonNeonBuiltinExpr( case NEON::BI__builtin_neon_splat_lane_v: case NEON::BI__builtin_neon_splat_laneq_v: case NEON::BI__builtin_neon_splatq_lane_v: - case NEON::BI__builtin_neon_splatq_laneq_v: + case NEON::BI__builtin_neon_splatq_laneq_v: { + uint64_t numElements = vTy.getSize(); + if (builtinID == NEON::BI__builtin_neon_splatq_lane_v) + numElements *= 2; + if (builtinID == NEON::BI__builtin_neon_splat_laneq_v) + numElements /= 2; + ops[0] = cgf.getBuilder().createBitcast(loc, ops[0], vTy); + return emitNeonSplat(cgf.getBuilder(), loc, ops[0], ops[1], numElements); + } case NEON::BI__builtin_neon_vpadd_v: case NEON::BI__builtin_neon_vpaddq_v: case NEON::BI__builtin_neon_vabs_v: diff --git a/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c b/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c index 55eb5210829d2..69171902c7e69 100644 --- a/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c +++ b/clang/test/CodeGen/AArch64/bf16-getset-intrinsics.c @@ -14,82 +14,6 @@ bfloat16x4_t test_vcreate_bf16(uint64_t a) { return vcreate_bf16(a); } -// CHECK-LABEL: @test_vdup_n_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[VECINIT_I:%.*]] = insertelement <4 x bfloat> poison, bfloat [[V:%.*]], i32 0 -// CHECK-NEXT: [[VECINIT1_I:%.*]] = insertelement <4 x bfloat> [[VECINIT_I]], bfloat [[V]], i32 1 -// CHECK-NEXT: [[VECINIT2_I:%.*]] = insertelement <4 x bfloat> [[VECINIT1_I]], bfloat [[V]], i32 2 -// CHECK-NEXT: [[VECINIT3_I:%.*]] = insertelement <4 x bfloat> [[VECINIT2_I]], bfloat [[V]], i32 3 -// CHECK-NEXT: ret <4 x bfloat> [[VECINIT3_I]] -// -bfloat16x4_t test_vdup_n_bf16(bfloat16_t v) { - return vdup_n_bf16(v); -} - -// CHECK-LABEL: @test_vdupq_n_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[VECINIT_I:%.*]] = insertelement <8 x bfloat> poison, bfloat [[V:%.*]], i32 0 -// CHECK-NEXT: [[VECINIT1_I:%.*]] = insertelement <8 x bfloat> [[VECINIT_I]], bfloat [[V]], i32 1 -// CHECK-NEXT: [[VECINIT2_I:%.*]] = insertelement <8 x bfloat> [[VECINIT1_I]], bfloat [[V]], i32 2 -// CHECK-NEXT: [[VECINIT3_I:%.*]] = insertelement <8 x bfloat> [[VECINIT2_I]], bfloat [[V]], i32 3 -// CHECK-NEXT: [[VECINIT4_I:%.*]] = insertelement <8 x bfloat> [[VECINIT3_I]], bfloat [[V]], i32 4 -// CHECK-NEXT: [[VECINIT5_I:%.*]] = insertelement <8 x bfloat> [[VECINIT4_I]], bfloat [[V]], i32 5 -// CHECK-NEXT: [[VECINIT6_I:%.*]] = insertelement <8 x bfloat> [[VECINIT5_I]], bfloat [[V]], i32 6 -// CHECK-NEXT: [[VECINIT7_I:%.*]] = insertelement <8 x bfloat> [[VECINIT6_I]], bfloat [[V]], i32 7 -// CHECK-NEXT: ret <8 x bfloat> [[VECINIT7_I]] -// -bfloat16x8_t test_vdupq_n_bf16(bfloat16_t v) { - return vdupq_n_bf16(v); -} - -// CHECK-LABEL: @test_vdup_lane_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x bfloat> [[V:%.*]] to <4 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to <8 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <4 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <4 x bfloat> [[TMP2]], <4 x bfloat> [[TMP2]], <4 x i32> <i32 1, i32 1, i32 1, i32 1> -// CHECK-NEXT: ret <4 x bfloat> [[LANE]] -// -bfloat16x4_t test_vdup_lane_bf16(bfloat16x4_t v) { - return vdup_lane_bf16(v, 1); -} - -// CHECK-LABEL: @test_vdupq_lane_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x bfloat> [[V:%.*]] to <4 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to <8 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <4 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <4 x bfloat> [[TMP2]], <4 x bfloat> [[TMP2]], <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1> -// CHECK-NEXT: ret <8 x bfloat> [[LANE]] -// -bfloat16x8_t test_vdupq_lane_bf16(bfloat16x4_t v) { - return vdupq_lane_bf16(v, 1); -} - -// CHECK-LABEL: @test_vdup_laneq_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x bfloat> [[V:%.*]] to <8 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i16> [[TMP0]] to <16 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <8 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <8 x bfloat> [[TMP2]], <8 x bfloat> [[TMP2]], <4 x i32> <i32 7, i32 7, i32 7, i32 7> -// CHECK-NEXT: ret <4 x bfloat> [[LANE]] -// -bfloat16x4_t test_vdup_laneq_bf16(bfloat16x8_t v) { - return vdup_laneq_bf16(v, 7); -} - -// CHECK-LABEL: @test_vdupq_laneq_bf16( -// CHECK-NEXT: entry: -// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x bfloat> [[V:%.*]] to <8 x i16> -// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i16> [[TMP0]] to <16 x i8> -// CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <8 x bfloat> -// CHECK-NEXT: [[LANE:%.*]] = shufflevector <8 x bfloat> [[TMP2]], <8 x bfloat> [[TMP2]], <8 x i32> <i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7> -// CHECK-NEXT: ret <8 x bfloat> [[LANE]] -// -bfloat16x8_t test_vdupq_laneq_bf16(bfloat16x8_t v) { - return vdupq_laneq_bf16(v, 7); -} - // CHECK-LABEL: @test_vcombine_bf16( // CHECK-NEXT: entry: // CHECK-NEXT: [[SHUFFLE_I:%.*]] = shufflevector <4 x bfloat> [[LOW:%.*]], <4 x bfloat> [[HIGH:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7> diff --git a/clang/test/CodeGen/AArch64/neon/bf16-getset.c b/clang/test/CodeGen/AArch64/neon/bf16-getset.c index faae31cb013dd..eeb9209fe97a5 100644 --- a/clang/test/CodeGen/AArch64/neon/bf16-getset.c +++ b/clang/test/CodeGen/AArch64/neon/bf16-getset.c @@ -15,22 +15,108 @@ //===------------------------------------------------------===// // 2.4.1.2. Set all lanes to the same value -// -// TODO: Add the remaining intrinsics from this group. //===------------------------------------------------------===// // ALL-LABEL: @test_vduph_lane_bf16( +// LLVM-SAME: <4 x bfloat> {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { bfloat16_t test_vduph_lane_bf16(bfloat16x4_t v) { - // CIR: cir.vec.extract %{{.*}}[%{{.*}} : !s32i] : !cir.vector<4 x !cir.bf16> - // LLVM: [[VGET_LANE:%.*]] = extractelement <4 x bfloat> %{{.*}}, i32 1 - // LLVM: ret bfloat [[VGET_LANE]] - return vduph_lane_bf16(v, 1); +// CIR: cir.vec.extract %{{.*}}[%{{.*}} : !s32i] : !cir.vector<4 x !cir.bf16> + +// LLVM: [[VGET_LANE:%.*]] = extractelement <4 x bfloat> %{{.*}}, i32 1 +// LLVM: ret bfloat [[VGET_LANE]] +return vduph_lane_bf16(v, 1); } // ALL-LABEL: @test_vduph_laneq_bf16( +// LLVM-SAME: <8 x bfloat> {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { bfloat16_t test_vduph_laneq_bf16(bfloat16x8_t v) { - // CIR: cir.vec.extract %{{.*}}[%{{.*}} : !s32i] : !cir.vector<8 x !cir.bf16> - // LLVM: [[VGETQ_LANE:%.*]] = extractelement <8 x bfloat> %{{.*}}, i32 7 - // LLVM: ret bfloat [[VGETQ_LANE]] - return vduph_laneq_bf16(v, 7); +// CIR: cir.vec.extract %{{.*}}[%{{.*}} : !s32i] : !cir.vector<8 x !cir.bf16> + +// LLVM: [[VGETQ_LANE:%.*]] = extractelement <8 x bfloat> %{{.*}}, i32 7 +// LLVM: ret bfloat [[VGETQ_LANE]] +return vduph_laneq_bf16(v, 7); +} + +// ALL-LABEL: @test_vdup_lane_bf16( +// LLVM-SAME: <4 x bfloat> {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { +bfloat16x4_t test_vdup_lane_bf16(bfloat16x4_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<4 x !cir.bf16>) [#cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i] : !cir.vector<4 x !cir.bf16> + + // LLVM: [[TMP0:%.*]] = bitcast <4 x bfloat> [[V]] to <4 x i16> + // LLVM: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to <8 x i8> + // LLVM: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <4 x bfloat> + // LLVM: [[SHUF:%.*]] = shufflevector <4 x bfloat> [[TMP2]], {{.*}}, <4 x i32> <i32 1, i32 1, i32 1, i32 1> + // LLVM: ret <4 x bfloat> {{%.*}} + return vdup_lane_bf16(v, 1); +} + +// ALL-LABEL: @test_vdupq_lane_bf16( +// LLVM-SAME: <4 x bfloat> {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { +bfloat16x8_t test_vdupq_lane_bf16(bfloat16x4_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<4 x !cir.bf16>) [#cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i, #cir.int<1> : !s32i] : !cir.vector<8 x !cir.bf16> + + // LLVM: [[TMP0:%.*]] = bitcast <4 x bfloat> [[V]] to <4 x i16> + // LLVM: [[TMP1:%.*]] = bitcast <4 x i16> [[TMP0]] to <8 x i8> + // LLVM: [[TMP2:%.*]] = bitcast <8 x i8> [[TMP1]] to <4 x bfloat> + // LLVM: [[SHUF:%.*]] = shufflevector <4 x bfloat> [[TMP2]], {{.*}}, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1> + // LLVM: ret <8 x bfloat> {{%.*}} + return vdupq_lane_bf16(v, 1); +} + +// ALL-LABEL: @test_vdup_laneq_bf16( +// LLVM-SAME: <8 x bfloat> {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { +bfloat16x4_t test_vdup_laneq_bf16(bfloat16x8_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<8 x !cir.bf16>) [#cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i] : !cir.vector<4 x !cir.bf16> + + // LLVM: [[TMP0:%.*]] = bitcast <8 x bfloat> [[V]] to <8 x i16> + // LLVM: [[TMP1:%.*]] = bitcast <8 x i16> [[TMP0]] to <16 x i8> + // LLVM: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <8 x bfloat> + // LLVM: [[SHUF:%.*]] = shufflevector <8 x bfloat> [[TMP2]], {{.*}}, <4 x i32> <i32 7, i32 7, i32 7, i32 7> + // LLVM: ret <4 x bfloat> {{%.*}} + return vdup_laneq_bf16(v, 7); } + +// ALL-LABEL: @test_vdupq_laneq_bf16( +// LLVM-SAME: <8 x bfloat> {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { +bfloat16x8_t test_vdupq_laneq_bf16(bfloat16x8_t v) { + // CIR: cir.vec.shuffle({{%.*}}, {{%.*}} : !cir.vector<8 x !cir.bf16>) [#cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i, #cir.int<7> : !s32i] : !cir.vector<8 x !cir.bf16> + + // LLVM: [[TMP0:%.*]] = bitcast <8 x bfloat> [[V]] to <8 x i16> + // LLVM: [[TMP1:%.*]] = bitcast <8 x i16> [[TMP0]] to <16 x i8> + // LLVM: [[TMP2:%.*]] = bitcast <16 x i8> [[TMP1]] to <8 x bfloat> + // LLVM: [[SHUF:%.*]] = shufflevector <8 x bfloat> [[TMP2]], {{.*}}, <8 x i32> <i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7, i32 7> + // LLVM: ret <8 x bfloat> {{%.*}} + return vdupq_laneq_bf16(v, 7); +} + +// LLVM-LABEL: @test_vdup_n_bf16( +// LLVM-SAME: bfloat {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { +// CIR-LABEL: @vdup_n_bf16( + bfloat16x4_t test_vdup_n_bf16(bfloat16_t v) { + // CIR: cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !cir.bf16, !cir.bf16, !cir.bf16, !cir.bf16) : !cir.vector<4 x !cir.bf16> + + // LLVM: [[VECINIT_I:%.*]] = insertelement <4 x bfloat> poison, bfloat [[V]], i{{32|64}} 0 + // LLVM-NEXT: [[VECINIT1_I:%.*]] = insertelement <4 x bfloat> [[VECINIT_I]], bfloat [[V]], i{{32|64}} 1 + // LLVM-NEXT: [[VECINIT2_I:%.*]] = insertelement <4 x bfloat> [[VECINIT1_I]], bfloat [[V]], i{{32|64}} 2 + // LLVM-NEXT: [[VECINIT3_I:%.*]] = insertelement <4 x bfloat> [[VECINIT2_I]], bfloat [[V]], i{{32|64}} 3 + // LLVM-NEXT: ret <4 x bfloat> [[VECINIT3_I]] + return vdup_n_bf16(v); + } + + // LLVM-LABEL: @test_vdupq_n_bf16( + // LLVM-SAME: bfloat {{.*}}[[V:%.*]]) #[[ATTR0:[0-9]+]] { + // CIR-LABEL: @vdupq_n_bf16( + bfloat16x8_t test_vdupq_n_bf16(bfloat16_t v) { + // CIR: cir.vec.create(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : !cir.bf16, !cir.bf16, !cir.bf16, !cir.bf16, !cir.bf16, !cir.bf16, !cir.bf16, !cir.bf16) : !cir.vector<8 x !cir.bf16> + + // LLVM: [[VECINIT_I:%.*]] = insertelement <8 x bfloat> poison, bfloat [[V]], i{{32|64}} 0 + // LLVM-NEXT: [[VECINIT1_I:%.*]] = insertelement <8 x bfloat> [[VECINIT_I]], bfloat [[V]], i{{32|64}} 1 + // LLVM-NEXT: [[VECINIT2_I:%.*]] = insertelement <8 x bfloat> [[VECINIT1_I]], bfloat [[V]], i{{32|64}} 2 + // LLVM-NEXT: [[VECINIT3_I:%.*]] = insertelement <8 x bfloat> [[VECINIT2_I]], bfloat [[V]], i{{32|64}} 3 + // LLVM-NEXT: [[VECINIT4_I:%.*]] = insertelement <8 x bfloat> [[VECINIT3_I]], bfloat [[V]], i{{32|64}} 4 + // LLVM-NEXT: [[VECINIT5_I:%.*]] = insertelement <8 x bfloat> [[VECINIT4_I]], bfloat [[V]], i{{32|64}} 5 + // LLVM-NEXT: [[VECINIT6_I:%.*]] = insertelement <8 x bfloat> [[VECINIT5_I]], bfloat [[V]], i{{32|64}} 6 + // LLVM-NEXT: [[VECINIT7_I:%.*]] = insertelement <8 x bfloat> [[VECINIT6_I]], bfloat [[V]], i{{32|64}} 7 + // LLVM-NEXT: ret <8 x bfloat> [[VECINIT7_I]] + return vdupq_n_bf16(v); + } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
