https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/175105
>From c74bf28df75691dfdc3537462a8f1d735b51f865 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Thu, 8 Jan 2026 17:06:34 -0800 Subject: [PATCH 1/2] handle waveballot struct return type --- clang/include/clang/Basic/Builtins.td | 2 +- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 29 +++++++++++++++++-- clang/lib/Sema/SemaHLSL.cpp | 5 ++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +- llvm/lib/Target/DirectX/DXIL.td | 8 ++--- llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 17 +++++++++-- llvm/test/CodeGen/DirectX/WaveActiveBallot.ll | 12 ++++---- 7 files changed, 58 insertions(+), 17 deletions(-) diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 0ab50b06e11cf..ccbc0abe3f0b4 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5058,7 +5058,7 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> { def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_wave_active_ballot"]; let Attributes = [NoThrow, Const]; - let Prototype = "_ExtVector<4, unsigned int>(bool)"; + let Prototype = "void(bool)"; } def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> { diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 1b6c3714f7821..c5a072bfa3974 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -160,6 +160,31 @@ static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) { return LastInst; } +static Value *handleHlslWaveActiveBallot(const CallExpr *E, + CodeGenFunction *CGF) { + Value *Cond = CGF->EmitScalarExpr(E->getArg(0)); + llvm::Type *I32 = CGF->Int32Ty; + llvm::StructType *RetTy = llvm::StructType::get(I32, I32, I32, I32); + + if (CGF->CGM.getTarget().getTriple().isDXIL()) { + // dx.op.waveActiveBallot(opcode, i1) + return CGF->Builder.CreateIntrinsic(RetTy, Intrinsic::dx_wave_ballot, + {Cond}, nullptr, "wave.active.ballot"); + } + + if (CGF->CGM.getTarget().getTriple().isSPIRV()) { + // spv.wave.ballot(i1) -> <4 x i32>, then bitcast to struct + llvm::Type *VecTy = llvm::FixedVectorType::get(I32, 4); + return CGF->Builder.CreateIntrinsic(VecTy, Intrinsic::spv_wave_ballot, + {Cond}, nullptr, "spv.wave.ballot"); + } + + CGF->CGM.Error(E->getExprLoc(), + "waveActiveBallot is not supported for this target"); + + return llvm::UndefValue::get(RetTy); +} + static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF, const CallExpr *E) { Value *Op0 = CGF.EmitScalarExpr(E->getArg(0)); @@ -834,9 +859,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, assert(Op->getType()->isIntegerTy(1) && "Intrinsic WaveActiveBallot operand must be a bool"); - Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic(); - return EmitRuntimeCall( - Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); + return handleHlslWaveActiveBallot(E, this); } case Builtin::BI__builtin_hlsl_wave_active_count_bits: { Value *OpExpr = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index a6de1cd550212..51f74c10677a9 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3507,6 +3507,11 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } + case Builtin::BI__builtin_hlsl_wave_active_ballot: { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + break; + } case Builtin::BI__builtin_hlsl_elementwise_splitdouble: { if (SemaRef.checkArgCount(TheCall, 3)) return true; diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 6e6eb2d0ece9d..f79945785566c 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -153,7 +153,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>] def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; -def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; +def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>; def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 6d04732d92ecf..23701e2218e57 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -58,6 +58,7 @@ def ResPropsTy : DXILOpParamType; def SplitDoubleTy : DXILOpParamType; def BinaryWithCarryTy : DXILOpParamType; def DimensionsTy : DXILOpParamType; +def Fouri32s : DXILOpParamType; class DXILOpClass; @@ -212,13 +213,12 @@ defset list<DXILOpClass> OpClasses = { def unpack4x8 : DXILOpClass; def viewID : DXILOpClass; def waveActiveAllEqual : DXILOpClass; - def waveActiveBallot : DXILOpClass; def waveActiveBit : DXILOpClass; def waveActiveOp : DXILOpClass; def waveAllOp : DXILOpClass; def waveAllTrue : DXILOpClass; def waveAnyTrue : DXILOpClass; - def waveBallot : DXILOpClass; + def waveActiveBallot : DXILOpClass; def waveGetLaneCount : DXILOpClass; def waveGetLaneIndex : DXILOpClass; def waveIsFirstLane : DXILOpClass; @@ -1072,11 +1072,11 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> { let stages = [Stages<DXIL1_0, [all_stages]>]; } -def WaveActiveBallot : DXILOp<118, waveBallot> { +def WaveActiveBallot : DXILOp<116, waveActiveBallot> { let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave."; let intrinsics = [IntrinSelect<int_dx_wave_ballot>]; let arguments = [Int1Ty]; - let result = OverloadTy; + let result = Fouri32s; let stages = [Stages<DXIL1_0, [all_stages]>]; } diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 944b2e6433988..1f41d2457e5bc 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -261,10 +261,18 @@ static StructType *getBinaryWithCarryType(LLVMContext &Context) { return StructType::create({Int32Ty, Int1Ty}, "dx.types.i32c"); } -static StructType *getDimensionsType(LLVMContext &Ctx) { - Type *Int32Ty = Type::getInt32Ty(Ctx); +static StructType *getDimensionsType(LLVMContext &Context) { + Type *Int32Ty = Type::getInt32Ty(Context); return getOrCreateStructType("dx.types.Dimensions", - {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Ctx); + {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context); +} + +static StructType *getFouri32sType(LLVMContext &Context) { + if (auto *ST = StructType::getTypeByName(Context, "dx.types.fouri32")) + return ST; + Type *Int32Ty = Type::getInt32Ty(Context); + return getOrCreateStructType("dx.types.fouri32", + {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context); } static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, @@ -326,7 +334,10 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, return getBinaryWithCarryType(Ctx); case OpParamType::DimensionsTy: return getDimensionsType(Ctx); + case OpParamType::Fouri32s: + return getFouri32sType(Ctx); } + llvm_unreachable("Invalid parameter kind"); return nullptr; } diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll index cf6255de3a734..31a64cbcf061e 100644 --- a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll +++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll @@ -1,10 +1,12 @@ ; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s -define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) { +%dx.types.fouri32 = type { i32, i32, i32, i32 } + +define noundef %dx.types.fouri32 @wave_ballot_simple(i1 noundef %p1) { entry: -; CHECK: call <4 x i32> @dx.op.waveBallot.void(i32 118, i1 %p1) - %ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1) - ret <4 x i32> %ret +; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %p1) + %ret = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1) + ret %dx.types.fouri32 %ret } -declare <4 x i32> @llvm.dx.wave.ballot(i1) +declare %dx.types.fouri32 @llvm.dx.wave.ballot(i1) >From b54b8d5a525cb9817603488ba7c8a2220bc6baab Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Thu, 8 Jan 2026 19:53:00 -0800 Subject: [PATCH 2/2] update codegen to use emitruntimecall to force use of convergence token --- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 16 +++++++--------- .../CodeGenHLSL/builtins/WaveActiveBallot.hlsl | 11 ++++++++--- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index c5a072bfa3974..1e3f5611e69d1 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -164,25 +164,23 @@ static Value *handleHlslWaveActiveBallot(const CallExpr *E, CodeGenFunction *CGF) { Value *Cond = CGF->EmitScalarExpr(E->getArg(0)); llvm::Type *I32 = CGF->Int32Ty; - llvm::StructType *RetTy = llvm::StructType::get(I32, I32, I32, I32); if (CGF->CGM.getTarget().getTriple().isDXIL()) { - // dx.op.waveActiveBallot(opcode, i1) - return CGF->Builder.CreateIntrinsic(RetTy, Intrinsic::dx_wave_ballot, - {Cond}, nullptr, "wave.active.ballot"); + return CGF->EmitRuntimeCall( + CGF->CGM.getIntrinsic(Intrinsic::dx_wave_ballot, {I32}), Cond); } if (CGF->CGM.getTarget().getTriple().isSPIRV()) { - // spv.wave.ballot(i1) -> <4 x i32>, then bitcast to struct llvm::Type *VecTy = llvm::FixedVectorType::get(I32, 4); - return CGF->Builder.CreateIntrinsic(VecTy, Intrinsic::spv_wave_ballot, - {Cond}, nullptr, "spv.wave.ballot"); + + return CGF->EmitRuntimeCall( + CGF->CGM.getIntrinsic(Intrinsic::spv_wave_ballot), Cond); } CGF->CGM.Error(E->getExprLoc(), - "waveActiveBallot is not supported for this target"); + "WaveActiveBallot is not supported for this target"); - return llvm::UndefValue::get(RetTy); + return llvm::PoisonValue::get(llvm::FixedVectorType::get(I32, 4)); } static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF, diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl index 61b077eb1fead..ceee9eb015512 100644 --- a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl @@ -10,8 +10,13 @@ // CHECK-LABEL: define {{.*}}test uint4 test(bool p1) { // CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry() - // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ] - // CHECK-DXIL: %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) - // CHECK: ret <4 x i32> %[[RET]] + // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ] + // CHECK-DXIL: %[[RETVAL:.*]] = alloca <4 x i32>, align 16 + // CHECK-DXIL: %[[WAB:.*]] = call { i32, i32, i32, i32 } @llvm.dx.wave.ballot.i32(i1 %{{[a-zA-Z0-9]+}}) + // CHECK-DXIL: store { i32, i32, i32, i32 } %[[WAB]], ptr %[[RETVAL]], align 16 + // CHECK-DXIL: %[[LOAD:.*]] = load <4 x i32>, ptr %[[RETVAL]], align 16 + // CHECK-DXIL: ret <4 x i32> %[[LOAD]] + // CHECK-SPIRV: ret <4 x i32> %[[RET]] + return WaveActiveBallot(p1); } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
