https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/174638
>From 72b0527f4c7183fa4c41a929ea8c006a3f5ee81a Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Tue, 6 Jan 2026 11:30:56 -0800 Subject: [PATCH] first attempt --- clang/include/clang/Basic/Builtins.td | 6 +++++ clang/lib/CodeGen/CGHLSLBuiltins.cpp | 9 ++++++++ clang/lib/CodeGen/CGHLSLRuntime.h | 1 + .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 12 ++++++++++ .../builtins/WaveActiveBallot.hlsl | 17 ++++++++++++++ .../BuiltIns/WaveActiveBallot-errors.hlsl | 21 ++++++++++++++++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + llvm/lib/Target/DirectX/DXIL.td | 8 +++++++ .../Target/SPIRV/SPIRVInstructionSelector.cpp | 3 +++ llvm/test/CodeGen/DirectX/WaveActiveBallot.ll | 10 +++++++++ .../SPIRV/hlsl-intrinsics/WaveActiveBallot.ll | 22 +++++++++++++++++++ 12 files changed, 111 insertions(+) create mode 100644 clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveBallot.ll create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index a5be656cdacf4..5199578e8f666 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5055,6 +5055,12 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> { let Prototype = "bool(bool)"; } +def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_wave_active_ballot"]; + let Attributes = [NoThrow, Const]; + let Prototype = "bool(bool)"; +} + def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_wave_active_count_bits"]; let Attributes = [NoThrow, Const]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 317e64d595243..1b6c3714f7821 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -829,6 +829,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, return EmitRuntimeCall( Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); } + case Builtin::BI__builtin_hlsl_wave_active_ballot: { + Value *Op = EmitScalarExpr(E->getArg(0)); + assert(Op->getType()->isIntegerTy(1) && + "Intrinsic WaveActiveBallot operand must be a bool"); + + Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic(); + return EmitRuntimeCall( + Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); + } case Builtin::BI__builtin_hlsl_wave_active_count_bits: { Value *OpExpr = EmitScalarExpr(E->getArg(0)); Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic(); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index ba2ca2c358388..7a5643052ed84 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -146,6 +146,7 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any) + GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBallot, wave_ballot) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count) diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index b065e5dd8447f..a1fafb40f2c6c 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2410,6 +2410,18 @@ _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_any_true) __attribute__((convergent)) bool WaveActiveAnyTrue(bool Val); +/// \brief Returns a uint4 containing a bitmask of the evaluation of the +/// boolean expression for all active lanes in the current wave. +/// The least-significant bit corresponds to the lane with index zero. +/// The bits corresponding to inactive lanes will be zero. The bits that +/// are greater than or equal to WaveGetLaneCount will be zero. +/// +/// \param Val The boolean expression to evaluate. +/// \return uint4 bitmask +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_ballot) +__attribute__((convergent)) bool WaveActiveBallot(bool Val); + /// \brief Counts the number of boolean variables which evaluate to true across /// all active lanes in the current wave. /// diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl new file mode 100644 index 0000000000000..61b077eb1fead --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call for int values. + +// CHECK-LABEL: define {{.*}}test +uint4 test(bool p1) { + // CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry() + // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ] + // CHECK-DXIL: %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) + // CHECK: ret <4 x i32> %[[RET]] + return WaveActiveBallot(p1); +} diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl new file mode 100644 index 0000000000000..ae39068494864 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveBallot-errors.hlsl @@ -0,0 +1,21 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +bool test_too_few_arg() { + return __builtin_hlsl_wave_active_ballot(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +bool test_too_many_arg(bool p0) { + return __builtin_hlsl_wave_active_ballot(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +struct Foo +{ + int a; +}; + +bool test_type_check(Foo p0) { + return __builtin_hlsl_wave_active_ballot(p0); + // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 8ca93731ffa04..6e6eb2d0ece9d 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -153,6 +153,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>] def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; +def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>; def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 402235ec7cd9c..0e23d19a01c20 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; + def int_spv_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; def int_spv_wave_reduce_min : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index b221fa2d7fe87..d84fa4a94f489 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1071,6 +1071,14 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> { let stages = [Stages<DXIL1_0, [all_stages]>]; } +def WaveActiveBallot : DXILOp<118, waveAnyTrue> { + let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave."; + let intrinsics = [IntrinSelect<int_dx_wave_ballot>]; + let arguments = [Int1Ty]; + let result = OverloadTy; + let stages = [Stages<DXIL1_0, [all_stages]>]; +} + def WaveActiveOp : DXILOp<119, waveActiveOp> { let Doc = "returns the result of the operation across waves"; let intrinsics = [ diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index f991938c14dfe..c403b297b57ee 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -3803,6 +3803,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll); case Intrinsic::spv_wave_any: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny); + case Intrinsic::spv_wave_ballot: + return selectWaveOpInst(ResVReg, ResType, I, + SPIRV::OpGroupNonUniformBallot); case Intrinsic::spv_wave_is_first_lane: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect); case Intrinsic::spv_wave_reduce_umax: diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll new file mode 100644 index 0000000000000..5c3f8fc2d7643 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll @@ -0,0 +1,10 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s + +define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) { +entry: +; CHECK: call <4 x i32> @dx.op.waveAnyTrue.void(i32 118, i1 %p1) + %ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1) + ret <4 x i32> %ret +} + +declare <4 x i32> @llvm.dx.wave.ballot(i1) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll new file mode 100644 index 0000000000000..6831888f038fd --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll @@ -0,0 +1,22 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#bool:]] = OpTypeBool +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 +; CHECK-DAG: %[[#bitmask:]] = OpTypeVector %[[#uint]] 4 +; CHECK-DAG: OpCapability GroupNonUniformBallot + +; CHECK-LABEL: Begin function test_wave_ballot +define <4 x i32> @test_wave_ballot(i1 %p1) #0 { +entry: +; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]] +; CHECK: %{{.+}} = OpGroupNonUniformBallot %[[#bitmask]] %[[#scope]] %[[#param]] + %0 = call token @llvm.experimental.convergence.entry() + %ret = call <4 x i32> @llvm.spv.wave.ballot(i1 %p1) [ "convergencectrl"(token %0) ] + ret <4 x i32> %ret +} + +declare <4 x i32> @llvm.spv.wave.ballot(i1) #0 + +attributes #0 = { convergent } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
