https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/183634
>From 8f3b1eca648517042b010d92c380995212afd59b Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Mon, 23 Feb 2026 17:04:58 -0800 Subject: [PATCH 1/6] first attempt --- clang/include/clang/Basic/Builtins.td | 6 + clang/lib/CodeGen/CGHLSLBuiltins.cpp | 7 + clang/lib/CodeGen/CGHLSLRuntime.h | 27 +++- .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 124 ++++++++++++++++++ clang/lib/Sema/SemaHLSL.cpp | 13 ++ .../builtins/WaveActiveAllEqual.hlsl | 45 +++++++ .../BuiltIns/WaveActiveAllEqual-errors.hlsl | 28 ++++ .../BuiltIns/WaveActiveAllTrue-errors.hlsl | 49 ++++--- llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 + llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + llvm/lib/Target/DirectX/DXIL.td | 10 ++ llvm/lib/Target/DirectX/DXILShaderFlags.cpp | 2 +- .../DirectX/DirectXTargetTransformInfo.cpp | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 3 + .../CodeGen/DirectX/ShaderFlags/wave-ops.ll | 7 + .../CodeGen/DirectX/WaveActiveAllEqual.ll | 87 ++++++++++++ .../hlsl-intrinsics/WaveActiveAllEqual.ll | 41 ++++++ 17 files changed, 426 insertions(+), 26 deletions(-) create mode 100644 clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 78dd26aa2c455..c66c029900453 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -5132,6 +5132,12 @@ def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } +def HLSLWaveActiveAllEqual : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_wave_active_all_equal"]; + let Attributes = [NoThrow, Const]; + let Prototype = "void(...)"; +} + def HLSLWaveActiveAllTrue : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_wave_active_all_true"]; let Attributes = [NoThrow, Const]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 70891eac39425..09dae2ab931ee 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -1088,6 +1088,13 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(), ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step"); } + case Builtin::BI__builtin_hlsl_wave_active_all_equal: { + Value *Op = EmitScalarExpr(E->getArg(0)); + + Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllEqualIntrinsic(); + return EmitRuntimeCall( + Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); + } case Builtin::BI__builtin_hlsl_wave_active_all_true: { Value *Op = EmitScalarExpr(E->getArg(0)); assert(Op->getType()->isIntegerTy(1) && diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index dbbc887353cec..940e3bbae8df2 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -34,16 +34,33 @@ // A function generator macro for picking the right intrinsic // for the target backend -#define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix) \ +#define _GEN_INTRIN_CHOOSER(_1, _2, _3, NAME, ...) NAME + +#define GENERATE_HLSL_INTRINSIC_FUNCTION(...) \ + _GEN_INTRIN_CHOOSER(__VA_ARGS__, GENERATE_HLSL_INTRINSIC_FUNCTION3, \ + GENERATE_HLSL_INTRINSIC_FUNCTION2, \ + /* dummy to solve pre-C++20 errors */ ignored)( \ + __VA_ARGS__) + +// 2-arg form: same postfix for both backends (uses the identity) +#define GENERATE_HLSL_INTRINSIC_FUNCTION2(FunctionName, IntrinsicPostfix) \ + llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \ + llvm::Triple::ArchType Arch = getArch(); \ + switch (Arch) {} \ + } + +// 3-arg form: explicit SPIR-V postfix override (perfect for wave->subgroup) +#define GENERATE_HLSL_INTRINSIC_FUNCTION3(FunctionName, DxilPostfix, \ + SpirvPostfix) \ llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \ llvm::Triple::ArchType Arch = getArch(); \ switch (Arch) { \ case llvm::Triple::dxil: \ - return llvm::Intrinsic::dx_##IntrinsicPostfix; \ + return llvm::Intrinsic::dx_##DxilPostfix; \ case llvm::Triple::spirv: \ - return llvm::Intrinsic::spv_##IntrinsicPostfix; \ + return llvm::Intrinsic::spv_##SpirvPostfix; \ default: \ - llvm_unreachable("Intrinsic " #IntrinsicPostfix \ + llvm_unreachable("Intrinsic " #DxilPostfix \ " not supported by target architecture"); \ } \ } @@ -144,6 +161,8 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot) GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed) GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed) + GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllEqual, wave_all_equal, + subgroup_all_equal) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any) GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveMax, wave_reduce_max) diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index 2543401bdfbf9..e4a9c5dc7b4a8 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2413,6 +2413,130 @@ float4 trunc(float4); // Wave* builtins //===----------------------------------------------------------------------===// +/// \brief Evaluates a value for all active invocations in the group. The +/// result is true if Value is equal for all active invocations in the +/// group. Otherwise, the result is false. +/// \param Value The value to compare with +/// \return True if all values across all lanes are equal, false otherwise +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) half WaveActiveAllEqual(half); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) half2 WaveActiveAllEqual(half2); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) half3 WaveActiveAllEqual(half3); +_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) half4 WaveActiveAllEqual(half4); + +#ifdef __HLSL_ENABLE_16_BIT +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int16_t WaveActiveAllEqual(int16_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int16_t2 WaveActiveAllEqual(int16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int16_t3 WaveActiveAllEqual(int16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int16_t4 WaveActiveAllEqual(int16_t4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint16_t WaveActiveAllEqual(uint16_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint16_t2 WaveActiveAllEqual(uint16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint16_t3 WaveActiveAllEqual(uint16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint16_t4 WaveActiveAllEqual(uint16_t4); +#endif + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int WaveActiveAllEqual(int); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int2 WaveActiveAllEqual(int2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int3 WaveActiveAllEqual(int3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int4 WaveActiveAllEqual(int4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint WaveActiveAllEqual(uint); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint2 WaveActiveAllEqual(uint2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint3 WaveActiveAllEqual(uint3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint4 WaveActiveAllEqual(uint4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int64_t WaveActiveAllEqual(int64_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int64_t2 WaveActiveAllEqual(int64_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int64_t3 WaveActiveAllEqual(int64_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) int64_t4 WaveActiveAllEqual(int64_t4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint64_t WaveActiveAllEqual(uint64_t); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint64_t2 WaveActiveAllEqual(uint64_t2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint64_t3 WaveActiveAllEqual(uint64_t3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) uint64_t4 WaveActiveAllEqual(uint64_t4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) float WaveActiveAllEqual(float); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) float2 WaveActiveAllEqual(float2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) float3 WaveActiveAllEqual(float3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) float4 WaveActiveAllEqual(float4); + +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) double WaveActiveAllEqual(double); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) double2 WaveActiveAllEqual(double2); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) double3 WaveActiveAllEqual(double3); +_HLSL_AVAILABILITY(shadermodel, 6.0) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) +__attribute__((convergent)) double4 WaveActiveAllEqual(double4); + /// \brief Returns true if the expression is true in all active lanes in the /// current wave. /// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 802a1bdbccfdd..249d8dc58b866 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3809,6 +3809,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { TheCall->setType(ArgTyA); break; } + case Builtin::BI__builtin_hlsl_wave_active_all_true: { + if (SemaRef.checkArgCount(TheCall, 1)) + return true; + + // Ensure input expr type is a scalar/vector + if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) + return true; + + // set return type to bool + TheCall->setType(getASTContext().BoolTy); + + break; + } case Builtin::BI__builtin_hlsl_wave_active_max: case Builtin::BI__builtin_hlsl_wave_active_min: case Builtin::BI__builtin_hlsl_wave_active_sum: { diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl new file mode 100644 index 0000000000000..4b4149d05eb3f --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl @@ -0,0 +1,45 @@ +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +// CHECK-LABEL: test_int +bool test_int(int expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i32([[TY]] %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32([[TY]] %[[#]]) + // CHECK: ret i1 %[[RET]] + return WaveActiveAllEqual(expr); +} + +// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32([[TY]]) #[[#attr:]] +// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i32([[TY]]) #[[#attr:]] + +// CHECK-LABEL: test_uint64_t +bool test_uint64_t(uint64_t expr) { + // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i64(i64 %[[#]]) + // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.uproduct.i64(i64 %[[#]]) + // CHECK: ret i1 %[[RET]] + return WaveActiveAllEqual(expr); +} + +// CHECK-DXIL: declare i1 @llvm.dx.wave.uproduct.i64(i64 #[[#attr:]] +// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i64(i64) #[[#attr:]] + +// Test basic lowering to runtime function call with array and float value. + +// CHECK-LABEL: test_floatv4 +bool test_floatv4(float4 expr) { + // CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func i1 @llvm.spv.wave.all.equal.v4f32(i32 %[[#]] + // CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn i1 @llvm.dx.wave.all.equal.v4f32(i32 %[[#]]) + // CHECK: ret [[TY1]] %[[RET1]] + return WaveActiveAllEqual(expr); +} + +// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.v4f32(i32) #[[#attr]] +// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.v4f32(i32) #[[#attr]] + +// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl new file mode 100644 index 0000000000000..2c838cb51dd78 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl @@ -0,0 +1,28 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +int test_too_few_arg() { + return __builtin_hlsl_wave_active_all_equal(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +float2 test_too_many_arg(float2 p0) { + return __builtin_hlsl_wave_active_all_equal(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +bool test_expr_bool_type_check(bool p0) { + return __builtin_hlsl_wave_active_all_equal(p0); + // expected-error@-1 {{invalid operand of type 'bool'}} +} + +bool2 test_expr_bool_vec_type_check(bool2 p0) { + return __builtin_hlsl_wave_active_all_equal(p0); + // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}} +} + +struct S { float f; }; + +S test_expr_struct_type_check(S p0) { + return __builtin_hlsl_wave_active_all_equal(p0); + // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} +} diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl index b0d0fdfca5e18..af926d60624c6 100644 --- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl @@ -1,21 +1,28 @@ -// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify - -bool test_too_few_arg() { - return __builtin_hlsl_wave_active_all_true(); - // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} -} - -bool test_too_many_arg(bool p0) { - return __builtin_hlsl_wave_active_all_true(p0, p0); - // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} -} - -struct Foo -{ - int a; -}; - -bool test_type_check(Foo p0) { - return __builtin_hlsl_wave_active_all_true(p0); - // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}} -} +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +int test_too_few_arg() { + return __builtin_hlsl_wave_active_product(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +float2 test_too_many_arg(float2 p0) { + return __builtin_hlsl_wave_active_product(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +bool test_expr_bool_type_check(bool p0) { + return __builtin_hlsl_wave_active_product(p0); + // expected-error@-1 {{invalid operand of type 'bool'}} +} + +bool2 test_expr_bool_vec_type_check(bool2 p0) { + return __builtin_hlsl_wave_active_product(p0); + // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}} +} + +struct S { float f; }; + +S test_expr_struct_type_check(S p0) { + return __builtin_hlsl_wave_active_product(p0); + // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 909482d72aa88..a688da131ce75 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -213,6 +213,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; +def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 77f49ae721ad5..59a9612d1ff50 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; + def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">, diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 59a5b7fe4d508..a378f0d665d44 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -217,6 +217,7 @@ defset list<DXILOpClass> OpClasses = { def waveActiveOp : DXILOpClass; def waveAllOp : DXILOpClass; def waveAllTrue : DXILOpClass; + def waveAllEqual : DXILOpClass; def waveAnyTrue : DXILOpClass; def waveActiveBallot : DXILOpClass; def waveGetLaneCount : DXILOpClass; @@ -1062,6 +1063,15 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> { let stages = [Stages<DXIL1_0, [all_stages]>]; } +def WaveActiveAllEqual : DXILOp<115, waveAllEqual> { + let Doc = "returns true if the expression is equal in all of the active lanes " + "in the current wave"; + let intrinsics = [IntrinSelect<int_dx_wave_all_equal>]; + let arguments = [OverloadTy]; + let result = Int1Ty; + let stages = [Stages<DXIL1_0, [all_stages]>]; +} + def WaveActiveBallot : DXILOp<116, waveActiveBallot> { let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave."; let intrinsics = [IntrinSelect<int_dx_wave_ballot>]; diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 52993ee1c1220..1d14079407cbe 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -64,7 +64,6 @@ static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM, static bool checkWaveOps(Intrinsic::ID IID) { // Currently unsupported intrinsics // case Intrinsic::dx_wave_getlanecount: - // case Intrinsic::dx_wave_allequal: // case Intrinsic::dx_wave_readfirst: // case Intrinsic::dx_wave_reduce.and: // case Intrinsic::dx_wave_reduce.or: @@ -85,6 +84,7 @@ static bool checkWaveOps(Intrinsic::ID IID) { case Intrinsic::dx_wave_is_first_lane: case Intrinsic::dx_wave_getlaneindex: case Intrinsic::dx_wave_any: + case Intrinsic::dx_wave_all_equal: case Intrinsic::dx_wave_all: case Intrinsic::dx_wave_readlane: case Intrinsic::dx_wave_active_countbits: diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index 8018b09c9f248..eca2343227577 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -57,6 +57,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( case Intrinsic::dx_rsqrt: case Intrinsic::dx_saturate: case Intrinsic::dx_splitdouble: + case Intrinsic::dx_wave_all_equal: case Intrinsic::dx_wave_readlane: case Intrinsic::dx_wave_reduce_max: case Intrinsic::dx_wave_reduce_min: diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3d3e311eeedb7..b9c6cb1e67595 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -4084,6 +4084,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, return selectWavePrefixBitCount(ResVReg, ResType, I); case Intrinsic::spv_wave_active_countbits: return selectWaveActiveCountBits(ResVReg, ResType, I); + case Intrinsic::spv_subgroup_all_equal: + return selectWaveOpInst(ResVReg, ResType, I, + SPIRV::OpGroupNonUniformAllEqual); case Intrinsic::spv_wave_all: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll); case Intrinsic::spv_wave_any: diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll index be53d19aca8f2..6c29ac73719e6 100644 --- a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll @@ -42,6 +42,13 @@ entry: ret i1 %ret } +define noundef i1 @wave_all_equal(i1 %x) { +entry: + ; CHECK: Function wave_all_equal : [[WAVE_FLAG]] + %ret = call i1 @llvm.dx.wave.all.equal(i1 %x) + ret i1 %ret +} + define noundef i1 @wave_readlane(i1 %x, i32 %idx) { entry: ; CHECK: Function wave_readlane : [[WAVE_FLAG]] diff --git a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll new file mode 100644 index 0000000000000..702f2ad1dde5f --- /dev/null +++ b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll @@ -0,0 +1,87 @@ +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s + +; Test that for scalar values, WaveAcitveProduct maps down to the DirectX op + +define noundef half @wave_active_all_equal_half(half noundef %expr) { +entry: +; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr, i8 1, i8 0) + %ret = call half @llvm.dx.wave.all.equal.f16(half %expr) + ret half %ret +} + +define noundef float @wave_active_all_equal_float(float noundef %expr) { +entry: +; CHECK: call float @dx.op.waveActiveAllEqual.f32(i32 119, float %expr, i8 1, i8 0) + %ret = call float @llvm.dx.wave.all.equal.f32(float %expr) + ret float %ret +} + +define noundef double @wave_active_all_equal_double(double noundef %expr) { +entry: +; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr, i8 1, i8 0) + %ret = call double @llvm.dx.wave.all.equal.f64(double %expr) + ret double %ret +} + +define noundef i16 @wave_active_all_equal_i16(i16 noundef %expr) { +entry: +; CHECK: call i16 @dx.op.waveActiveAllEqual.i16(i32 119, i16 %expr, i8 1, i8 0) + %ret = call i16 @llvm.dx.wave.all.equal.i16(i16 %expr) + ret i16 %ret +} + +define noundef i32 @wave_active_all_equal_i32(i32 noundef %expr) { +entry: +; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr, i8 1, i8 0) + %ret = call i32 @llvm.dx.wave.all.equal.i32(i32 %expr) + ret i32 %ret +} + +define noundef i64 @wave_active_all_equal_i64(i64 noundef %expr) { +entry: +; CHECK: call i64 @dx.op.waveActiveAllEqual.i64(i32 119, i64 %expr, i8 1, i8 0) + %ret = call i64 @llvm.dx.wave.all.equal.i64(i64 %expr) + ret i64 %ret +} + +declare half @llvm.dx.wave.all.equal.f16(half) +declare float @llvm.dx.wave.all.equal.f32(float) +declare double @llvm.dx.wave.all.equal.f64(double) + +declare i16 @llvm.dx.wave.all.equal.i16(i16) +declare i32 @llvm.dx.wave.all.equal.i32(i32) +declare i64 @llvm.dx.wave.all.equal.i64(i64) + +; Test that for vector values, WaveAcitveProduct scalarizes and maps down to the +; DirectX op + +define noundef <2 x half> @wave_active_all_equal_v2half(<2 x half> noundef %expr) { +entry: +; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i0, i8 1, i8 0) +; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i1, i8 1, i8 0) + %ret = call <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr) + ret <2 x half> %ret +} + +define noundef <3 x i32> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) { +entry: +; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i0, i8 1, i8 0) +; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i1, i8 1, i8 0) +; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i2, i8 1, i8 0) + %ret = call <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr) + ret <3 x i32> %ret +} + +define noundef <4 x double> @wave_active_all_equal_v4f64(<4 x double> noundef %expr) { +entry: +; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i0, i8 1, i8 0) +; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i1, i8 1, i8 0) +; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i2, i8 1, i8 0) +; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i3, i8 1, i8 0) + %ret = call <4 x double> @llvm.dx.wave.all.equal.v464(<4 x double> %expr) + ret <4 x double> %ret +} + +declare <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half>) +declare <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32>) +declare <4 x double> @llvm.dx.wave.all.equal.v4f64(<4 x double>) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll new file mode 100644 index 0000000000000..e871dc9a7aa28 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll @@ -0,0 +1,41 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %} + +; Test lowering to spir-v backend for various types and scalar/vector + +; CHECK-DAG: %[[#f16:]] = OpTypeFloat 16 +; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v4_half:]] = OpTypeVector %[[#f16]] 4 +; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 + +; CHECK-LABEL: Begin function test_float +; CHECK: %[[#fexpr:]] = OpFunctionParameter %[[#f32]] +define i1 @test_float(float %fexpr) { +entry: +; CHECK: %[[#fret:]] = OpGroupNonUniformAllEqual %[[#f32]] %[[#scope]] Reduce %[[#fexpr]] + %0 = call i1 @llvm.spv.wave.all.equal.f32(float %fexpr) + ret i1 %0 +} + +; CHECK-LABEL: Begin function test_int +; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]] +define i1 @test_int(i32 %iexpr) { +entry: +; CHECK: %[[#iret:]] = OpGroupNonUniformAllEqual %[[#uint]] %[[#scope]] Reduce %[[#iexpr]] + %0 = call i1 @llvm.spv.wave.all.equal.i32(i32 %iexpr) + ret i1 %0 +} + +; CHECK-LABEL: Begin function test_vhalf +; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]] +define i1 @test_vhalf(<4 x half> %vbexpr) { +entry: +; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#v4_half]] %[[#scope]] Reduce %[[#vbexpr]] + %0 = call i1 @llvm.spv.wave.all.equal.v4half(<4 x half> %vbexpr) + ret i1 %0 +} + +declare i1 @llvm.spv.wave.all.equal.f32(float) +declare i1 @llvm.spv.wave.all.equal.i32(i32) +declare i1 @llvm.spv.wave.all.equal.v4half(<4 x half>) >From c523c88934ae38d62d4b65729f4374e6f9eb617e Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Fri, 27 Feb 2026 13:39:08 -0800 Subject: [PATCH 2/6] fix return type, self review --- clang/lib/CodeGen/CGHLSLBuiltins.cpp | 5 +- clang/lib/CodeGen/CGHLSLRuntime.h | 10 +- .../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 72 +++++----- clang/lib/Sema/SemaHLSL.cpp | 18 ++- .../builtins/WaveActiveAllEqual.hlsl | 30 ++--- .../BuiltIns/WaveActiveAllEqual-errors.hlsl | 16 +-- .../BuiltIns/WaveActiveAllTrue-errors.hlsl | 31 ++--- llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +- llvm/include/llvm/IR/IntrinsicsSPIRV.td | 2 +- llvm/lib/Target/DirectX/DXIL.td | 7 +- .../DirectX/DirectXTargetTransformInfo.cpp | 3 + .../CodeGen/DirectX/WaveActiveAllEqual.ll | 124 +++++++++++------- .../hlsl-intrinsics/WaveActiveAllEqual.ll | 24 ++-- 13 files changed, 187 insertions(+), 157 deletions(-) diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 09dae2ab931ee..47b7e2b18d942 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -1092,8 +1092,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, Value *Op = EmitScalarExpr(E->getArg(0)); Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllEqualIntrinsic(); - return EmitRuntimeCall( - Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); + return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( + &CGM.getModule(), ID, {Op->getType()}), + {Op}); } case Builtin::BI__builtin_hlsl_wave_active_all_true: { Value *Op = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index 940e3bbae8df2..d6055d89e3c84 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -46,7 +46,15 @@ #define GENERATE_HLSL_INTRINSIC_FUNCTION2(FunctionName, IntrinsicPostfix) \ llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \ llvm::Triple::ArchType Arch = getArch(); \ - switch (Arch) {} \ + switch (Arch) { \ + case llvm::Triple::dxil: \ + return llvm::Intrinsic::dx_##IntrinsicPostfix; \ + case llvm::Triple::spirv: \ + return llvm::Intrinsic::spv_##IntrinsicPostfix; \ + default: \ + llvm_unreachable("Intrinsic " #IntrinsicPostfix \ + " not supported by target architecture"); \ + } \ } // 3-arg form: explicit SPIR-V postfix override (perfect for wave->subgroup) diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index e4a9c5dc7b4a8..5ca4713c2d520 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -2420,122 +2420,122 @@ float4 trunc(float4); /// \return True if all values across all lanes are equal, false otherwise _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) half WaveActiveAllEqual(half); +__attribute__((convergent)) bool WaveActiveAllEqual(half); _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) half2 WaveActiveAllEqual(half2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(half2); _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) half3 WaveActiveAllEqual(half3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(half3); _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) half4 WaveActiveAllEqual(half4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(half4); #ifdef __HLSL_ENABLE_16_BIT _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int16_t WaveActiveAllEqual(int16_t); +__attribute__((convergent)) bool WaveActiveAllEqual(int16_t); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int16_t2 WaveActiveAllEqual(int16_t2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(int16_t2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int16_t3 WaveActiveAllEqual(int16_t3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(int16_t3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int16_t4 WaveActiveAllEqual(int16_t4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(int16_t4); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint16_t WaveActiveAllEqual(uint16_t); +__attribute__((convergent)) bool WaveActiveAllEqual(uint16_t); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint16_t2 WaveActiveAllEqual(uint16_t2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(uint16_t2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint16_t3 WaveActiveAllEqual(uint16_t3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(uint16_t3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint16_t4 WaveActiveAllEqual(uint16_t4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(uint16_t4); #endif _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int WaveActiveAllEqual(int); +__attribute__((convergent)) bool WaveActiveAllEqual(int); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int2 WaveActiveAllEqual(int2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(int2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int3 WaveActiveAllEqual(int3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(int3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int4 WaveActiveAllEqual(int4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(int4); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint WaveActiveAllEqual(uint); +__attribute__((convergent)) bool WaveActiveAllEqual(uint); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint2 WaveActiveAllEqual(uint2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(uint2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint3 WaveActiveAllEqual(uint3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(uint3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint4 WaveActiveAllEqual(uint4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(uint4); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int64_t WaveActiveAllEqual(int64_t); +__attribute__((convergent)) bool WaveActiveAllEqual(int64_t); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int64_t2 WaveActiveAllEqual(int64_t2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(int64_t2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int64_t3 WaveActiveAllEqual(int64_t3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(int64_t3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) int64_t4 WaveActiveAllEqual(int64_t4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(int64_t4); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint64_t WaveActiveAllEqual(uint64_t); +__attribute__((convergent)) bool WaveActiveAllEqual(uint64_t); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint64_t2 WaveActiveAllEqual(uint64_t2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(uint64_t2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint64_t3 WaveActiveAllEqual(uint64_t3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(uint64_t3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) uint64_t4 WaveActiveAllEqual(uint64_t4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(uint64_t4); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) float WaveActiveAllEqual(float); +__attribute__((convergent)) bool WaveActiveAllEqual(float); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) float2 WaveActiveAllEqual(float2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(float2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) float3 WaveActiveAllEqual(float3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(float3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) float4 WaveActiveAllEqual(float4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(float4); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) double WaveActiveAllEqual(double); +__attribute__((convergent)) bool WaveActiveAllEqual(double); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) double2 WaveActiveAllEqual(double2); +__attribute__((convergent)) bool2 WaveActiveAllEqual(double2); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) double3 WaveActiveAllEqual(double3); +__attribute__((convergent)) bool3 WaveActiveAllEqual(double3); _HLSL_AVAILABILITY(shadermodel, 6.0) _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal) -__attribute__((convergent)) double4 WaveActiveAllEqual(double4); +__attribute__((convergent)) bool4 WaveActiveAllEqual(double4); /// \brief Returns true if the expression is true in all active lanes in the /// current wave. diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 249d8dc58b866..46cc3835c85a8 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3809,7 +3809,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { TheCall->setType(ArgTyA); break; } - case Builtin::BI__builtin_hlsl_wave_active_all_true: { + case Builtin::BI__builtin_hlsl_wave_active_all_equal: { if (SemaRef.checkArgCount(TheCall, 1)) return true; @@ -3817,9 +3817,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) return true; - // set return type to bool - TheCall->setType(getASTContext().BoolTy); + QualType InputTy = TheCall->getArg(0)->getType(); + ASTContext &Ctx = getASTContext(); + QualType RetTy; + + // If vector, construct bool vector of same size + if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) { + unsigned NumElts = VecTy->getNumElements(); + RetTy = Ctx.getExtVectorType(Ctx.BoolTy, NumElts); + } else { + // Scalar case + RetTy = Ctx.BoolTy; + } + + TheCall->setType(RetTy); break; } case Builtin::BI__builtin_hlsl_wave_active_max: diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl index 4b4149d05eb3f..65d15633eb6cf 100644 --- a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl +++ b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl @@ -9,37 +9,37 @@ // CHECK-LABEL: test_int bool test_int(int expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i32([[TY]] %[[#]]) - // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32([[TY]] %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.subgroup.all.equal.i32(i32 + // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32(i32 // CHECK: ret i1 %[[RET]] return WaveActiveAllEqual(expr); } -// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32([[TY]]) #[[#attr:]] -// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i32([[TY]]) #[[#attr:]] +// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32(i32) #[[attr:.*]] +// CHECK-SPIRV: declare i1 @llvm.spv.subgroup.all.equal.i32(i32) #[[attr:.*]] // CHECK-LABEL: test_uint64_t bool test_uint64_t(uint64_t expr) { - // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.all.equal.i64(i64 %[[#]]) - // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.uproduct.i64(i64 %[[#]]) + // CHECK-SPIRV: %[[RET:.*]] = call spir_func i1 @llvm.spv.subgroup.all.equal.i64(i64 + // CHECK-DXIL: %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i64(i64 // CHECK: ret i1 %[[RET]] return WaveActiveAllEqual(expr); } -// CHECK-DXIL: declare i1 @llvm.dx.wave.uproduct.i64(i64 #[[#attr:]] -// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i64(i64) #[[#attr:]] +// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i64(i64) #[[attr]] +// CHECK-SPIRV: declare i1 @llvm.spv.subgroup.all.equal.i64(i64) #[[attr]] // Test basic lowering to runtime function call with array and float value. // CHECK-LABEL: test_floatv4 -bool test_floatv4(float4 expr) { - // CHECK-SPIRV: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn spir_func i1 @llvm.spv.wave.all.equal.v4f32(i32 %[[#]] - // CHECK-DXIL: %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn i1 @llvm.dx.wave.all.equal.v4f32(i32 %[[#]]) - // CHECK: ret [[TY1]] %[[RET1]] +bool4 test_floatv4(float4 expr) { + // CHECK-SPIRV: %[[RET1:.*]] = call spir_func <4 x i1> @llvm.spv.subgroup.all.equal.v4f32(<4 x float> + // CHECK-DXIL: %[[RET1:.*]] = call <4 x i1> @llvm.dx.wave.all.equal.v4f32(<4 x float> + // CHECK: ret <4 x i1> %[[RET1]] return WaveActiveAllEqual(expr); } -// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.v4f32(i32) #[[#attr]] -// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.v4f32(i32) #[[#attr]] +// CHECK-DXIL: declare <4 x i1> @llvm.dx.wave.all.equal.v4f32(<4 x float>) #[[attr]] +// CHECK-SPIRV: declare <4 x i1> @llvm.spv.subgroup.all.equal.v4f32(<4 x float>) #[[attr]] -// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}} +// CHECK: attributes #[[attr]] = {{{.*}} convergent {{.*}}} diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl index 2c838cb51dd78..1b5d7955baffc 100644 --- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl @@ -1,28 +1,18 @@ // RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify -int test_too_few_arg() { +bool test_too_few_arg() { return __builtin_hlsl_wave_active_all_equal(); // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} } -float2 test_too_many_arg(float2 p0) { +bool test_too_many_arg(float2 p0) { return __builtin_hlsl_wave_active_all_equal(p0, p0); // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} } -bool test_expr_bool_type_check(bool p0) { - return __builtin_hlsl_wave_active_all_equal(p0); - // expected-error@-1 {{invalid operand of type 'bool'}} -} - -bool2 test_expr_bool_vec_type_check(bool2 p0) { - return __builtin_hlsl_wave_active_all_equal(p0); - // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}} -} - struct S { float f; }; -S test_expr_struct_type_check(S p0) { +bool test_expr_struct_type_check(S p0) { return __builtin_hlsl_wave_active_all_equal(p0); // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} } diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl index af926d60624c6..0975ad649e714 100644 --- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl @@ -1,28 +1,21 @@ // RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify -int test_too_few_arg() { - return __builtin_hlsl_wave_active_product(); +bool test_too_few_arg() { + return __builtin_hlsl_wave_active_all_true(); // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} } -float2 test_too_many_arg(float2 p0) { - return __builtin_hlsl_wave_active_product(p0, p0); +bool test_too_many_arg(bool p0) { + return __builtin_hlsl_wave_active_all_true(p0, p0); // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} } -bool test_expr_bool_type_check(bool p0) { - return __builtin_hlsl_wave_active_product(p0); - // expected-error@-1 {{invalid operand of type 'bool'}} -} - -bool2 test_expr_bool_vec_type_check(bool2 p0) { - return __builtin_hlsl_wave_active_product(p0); - // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}} -} +struct Foo +{ + int a; +}; -struct S { float f; }; - -S test_expr_struct_type_check(S p0) { - return __builtin_hlsl_wave_active_product(p0); - // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}} -} +bool test_type_check(Foo p0) { + return __builtin_hlsl_wave_active_all_true(p0); + // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}} +} \ No newline at end of file diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index a688da131ce75..6774a33556c09 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -213,7 +213,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>; def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; -def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>; +def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], [llvm_any_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 59a9612d1ff50..b91905f350506 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -120,7 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty] def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; - def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrConvergent, IntrNoMem]>; + def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], [llvm_any_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>; def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">, diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index a378f0d665d44..e64909b059d29 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -217,7 +217,6 @@ defset list<DXILOpClass> OpClasses = { def waveActiveOp : DXILOpClass; def waveAllOp : DXILOpClass; def waveAllTrue : DXILOpClass; - def waveAllEqual : DXILOpClass; def waveAnyTrue : DXILOpClass; def waveActiveBallot : DXILOpClass; def waveGetLaneCount : DXILOpClass; @@ -1063,9 +1062,9 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> { let stages = [Stages<DXIL1_0, [all_stages]>]; } -def WaveActiveAllEqual : DXILOp<115, waveAllEqual> { - let Doc = "returns true if the expression is equal in all of the active lanes " - "in the current wave"; +def WaveActiveAllEqual : DXILOp<115, waveActiveAllEqual> { + let Doc = "returns true for each scalar element of the expression if the " + "expression is equal in all of the active lanes in the current wave"; let intrinsics = [IntrinSelect<int_dx_wave_all_equal>]; let arguments = [OverloadTy]; let result = Int1Ty; diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp index eca2343227577..a2d7ffefbb5a2 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -36,8 +36,11 @@ bool DirectXTTIImpl::isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, case Intrinsic::dx_isnan: case Intrinsic::dx_legacyf16tof32: case Intrinsic::dx_legacyf32tof16: + case Intrinsic::dx_wave_all_equal: return OpdIdx == 0; default: + // All DX intrinsics are overloaded on return type unless specified + // otherwise return OpdIdx == -1; } } diff --git a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll index 702f2ad1dde5f..f6dcd59c33958 100644 --- a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll +++ b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll @@ -2,86 +2,108 @@ ; Test that for scalar values, WaveAcitveProduct maps down to the DirectX op -define noundef half @wave_active_all_equal_half(half noundef %expr) { +define noundef i1 @wave_active_all_equal_half(half noundef %expr) { entry: -; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr, i8 1, i8 0) - %ret = call half @llvm.dx.wave.all.equal.f16(half %expr) - ret half %ret +; CHECK: call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half %expr) + %ret = call i1 @llvm.dx.wave.all.equal.f16(half %expr) + ret i1 %ret } -define noundef float @wave_active_all_equal_float(float noundef %expr) { +define noundef i1 @wave_active_all_equal_float(float noundef %expr) { entry: -; CHECK: call float @dx.op.waveActiveAllEqual.f32(i32 119, float %expr, i8 1, i8 0) - %ret = call float @llvm.dx.wave.all.equal.f32(float %expr) - ret float %ret +; CHECK: call i1 @dx.op.waveActiveAllEqual.f32(i32 115, float %expr) + %ret = call i1 @llvm.dx.wave.all.equal.f32(float %expr) + ret i1 %ret } -define noundef double @wave_active_all_equal_double(double noundef %expr) { +define noundef i1 @wave_active_all_equal_double(double noundef %expr) { entry: -; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr, i8 1, i8 0) - %ret = call double @llvm.dx.wave.all.equal.f64(double %expr) - ret double %ret +; CHECK: call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %expr) + %ret = call i1 @llvm.dx.wave.all.equal.f64(double %expr) + ret i1 %ret } -define noundef i16 @wave_active_all_equal_i16(i16 noundef %expr) { +define noundef i1 @wave_active_all_equal_i16(i16 noundef %expr) { entry: -; CHECK: call i16 @dx.op.waveActiveAllEqual.i16(i32 119, i16 %expr, i8 1, i8 0) - %ret = call i16 @llvm.dx.wave.all.equal.i16(i16 %expr) - ret i16 %ret +; CHECK: call i1 @dx.op.waveActiveAllEqual.i16(i32 115, i16 %expr) + %ret = call i1 @llvm.dx.wave.all.equal.i16(i16 %expr) + ret i1 %ret } -define noundef i32 @wave_active_all_equal_i32(i32 noundef %expr) { +define noundef i1 @wave_active_all_equal_i32(i32 noundef %expr) { entry: -; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr, i8 1, i8 0) - %ret = call i32 @llvm.dx.wave.all.equal.i32(i32 %expr) - ret i32 %ret +; CHECK: call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %expr) + %ret = call i1 @llvm.dx.wave.all.equal.i32(i32 %expr) + ret i1 %ret } -define noundef i64 @wave_active_all_equal_i64(i64 noundef %expr) { +define noundef i1 @wave_active_all_equal_i64(i64 noundef %expr) { entry: -; CHECK: call i64 @dx.op.waveActiveAllEqual.i64(i32 119, i64 %expr, i8 1, i8 0) - %ret = call i64 @llvm.dx.wave.all.equal.i64(i64 %expr) - ret i64 %ret +; CHECK: call i1 @dx.op.waveActiveAllEqual.i64(i32 115, i64 %expr) + %ret = call i1 @llvm.dx.wave.all.equal.i64(i64 %expr) + ret i1 %ret } -declare half @llvm.dx.wave.all.equal.f16(half) -declare float @llvm.dx.wave.all.equal.f32(float) -declare double @llvm.dx.wave.all.equal.f64(double) +declare i1 @llvm.dx.wave.all.equal.f16(half) +declare i1 @llvm.dx.wave.all.equal.f32(float) +declare i1 @llvm.dx.wave.all.equal.f64(double) -declare i16 @llvm.dx.wave.all.equal.i16(i16) -declare i32 @llvm.dx.wave.all.equal.i32(i32) -declare i64 @llvm.dx.wave.all.equal.i64(i64) +declare i1 @llvm.dx.wave.all.equal.i16(i16) +declare i1 @llvm.dx.wave.all.equal.i32(i32) +declare i1 @llvm.dx.wave.all.equal.i64(i64) ; Test that for vector values, WaveAcitveProduct scalarizes and maps down to the ; DirectX op -define noundef <2 x half> @wave_active_all_equal_v2half(<2 x half> noundef %expr) { +define noundef <2 x i1> @wave_active_all_equal_v2half(<2 x half> noundef %expr) { entry: -; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i0, i8 1, i8 0) -; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i1, i8 1, i8 0) - %ret = call <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr) - ret <2 x half> %ret +; CHECK: %[[EXPR0:.*]] = extractelement <2 x half> %expr, i64 0 +; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half %[[EXPR0]]) +; CHECK: %[[EXPR1:.*]] = extractelement <2 x half> %expr, i64 1 +; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half %[[EXPR1]]) +; CHECK: %[[RETUPTO0:.*]] = insertelement <2 x i1> poison, i1 %[[RET0]], i64 0 +; CHECK: %ret = insertelement <2 x i1> %[[RETUPTO0]], i1 %[[RET1]], i64 1 +; CHECK: ret <2 x i1> %ret + + %ret = call <2 x i1> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr) + ret <2 x i1> %ret } -define noundef <3 x i32> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) { +define noundef <3 x i1> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) { entry: -; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i0, i8 1, i8 0) -; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i1, i8 1, i8 0) -; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i2, i8 1, i8 0) - %ret = call <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr) - ret <3 x i32> %ret +; CHECK: %[[EXPR0:.*]] = extractelement <3 x i32> %expr, i64 0 +; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %[[EXPR0]]) +; CHECK: %[[EXPR1:.*]] = extractelement <3 x i32> %expr, i64 1 +; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %[[EXPR1]]) +; CHECK: %[[EXPR2:.*]] = extractelement <3 x i32> %expr, i64 2 +; CHECK: %[[RET2:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %[[EXPR2]]) +; CHECK: %[[RETUPTO0:.*]] = insertelement <3 x i1> poison, i1 %[[RET0]], i64 0 +; CHECK: %[[RETUPTO1:.*]] = insertelement <3 x i1> %[[RETUPTO0]], i1 %[[RET1]], i64 1 +; CHECK: %ret = insertelement <3 x i1> %[[RETUPTO1]], i1 %[[RET2]], i64 2 + + %ret = call <3 x i1> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr) + ret <3 x i1> %ret } -define noundef <4 x double> @wave_active_all_equal_v4f64(<4 x double> noundef %expr) { +define noundef <4 x i1> @wave_active_all_equal_v4f64(<4 x double> noundef %expr) { entry: -; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i0, i8 1, i8 0) -; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i1, i8 1, i8 0) -; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i2, i8 1, i8 0) -; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i3, i8 1, i8 0) - %ret = call <4 x double> @llvm.dx.wave.all.equal.v464(<4 x double> %expr) - ret <4 x double> %ret +; CHECK: %[[EXPR0:.*]] = extractelement <4 x double> %expr, i64 0 +; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR0]]) +; CHECK: %[[EXPR1:.*]] = extractelement <4 x double> %expr, i64 1 +; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR1]]) +; CHECK: %[[EXPR2:.*]] = extractelement <4 x double> %expr, i64 2 +; CHECK: %[[RET2:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR2]]) +; CHECK: %[[EXPR3:.*]] = extractelement <4 x double> %expr, i64 3 +; CHECK: %[[RET3:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %[[EXPR3]]) +; CHECK: %[[RETUPTO0:.*]] = insertelement <4 x i1> poison, i1 %[[RET0]], i64 0 +; CHECK: %[[RETUPTO1:.*]] = insertelement <4 x i1> %[[RETUPTO0]], i1 %[[RET1]], i64 1 +; CHECK: %[[RETUPTO2:.*]] = insertelement <4 x i1> %[[RETUPTO1]], i1 %[[RET2]], i64 2 +; CHECK: %ret = insertelement <4 x i1> %[[RETUPTO2]], i1 %[[RET3]], i64 3 + + %ret = call <4 x i1> @llvm.dx.wave.all.equal.v464(<4 x double> %expr) + ret <4 x i1> %ret } -declare <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half>) -declare <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32>) -declare <4 x double> @llvm.dx.wave.all.equal.v4f64(<4 x double>) +declare <2 x i1> @llvm.dx.wave.all.equal.v2f16(<2 x half>) +declare <3 x i1> @llvm.dx.wave.all.equal.v3i32(<3 x i32>) +declare <4 x i1> @llvm.dx.wave.all.equal.v4f64(<4 x double>) diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll index e871dc9a7aa28..c64e5770b2d6d 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll @@ -5,6 +5,8 @@ ; CHECK-DAG: %[[#f16:]] = OpTypeFloat 16 ; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32 +; CHECK-DAG: %[[#bool:]] = OpTypeBool +; CHECK-DAG: %[[#bool4:]] = OpTypeVector %[[#bool]] 4 ; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 ; CHECK-DAG: %[[#v4_half:]] = OpTypeVector %[[#f16]] 4 ; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3 @@ -13,8 +15,8 @@ ; CHECK: %[[#fexpr:]] = OpFunctionParameter %[[#f32]] define i1 @test_float(float %fexpr) { entry: -; CHECK: %[[#fret:]] = OpGroupNonUniformAllEqual %[[#f32]] %[[#scope]] Reduce %[[#fexpr]] - %0 = call i1 @llvm.spv.wave.all.equal.f32(float %fexpr) +; CHECK: %[[#fret:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#fexpr]] + %0 = call i1 @llvm.spv.subgroup.all.equal.f32(float %fexpr) ret i1 %0 } @@ -22,20 +24,20 @@ entry: ; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]] define i1 @test_int(i32 %iexpr) { entry: -; CHECK: %[[#iret:]] = OpGroupNonUniformAllEqual %[[#uint]] %[[#scope]] Reduce %[[#iexpr]] - %0 = call i1 @llvm.spv.wave.all.equal.i32(i32 %iexpr) +; CHECK: %[[#iret:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#iexpr]] + %0 = call i1 @llvm.spv.subgroup.all.equal.i32(i32 %iexpr) ret i1 %0 } ; CHECK-LABEL: Begin function test_vhalf ; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]] -define i1 @test_vhalf(<4 x half> %vbexpr) { +define <4 x i1> @test_vhalf(<4 x half> %vbexpr) { entry: -; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#v4_half]] %[[#scope]] Reduce %[[#vbexpr]] - %0 = call i1 @llvm.spv.wave.all.equal.v4half(<4 x half> %vbexpr) - ret i1 %0 +; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#bool4]] %[[#scope]] %[[#vbexpr]] + %0 = call <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half> %vbexpr) + ret <4 x i1> %0 } -declare i1 @llvm.spv.wave.all.equal.f32(float) -declare i1 @llvm.spv.wave.all.equal.i32(i32) -declare i1 @llvm.spv.wave.all.equal.v4half(<4 x half>) +declare i1 @llvm.spv.subgroup.all.equal.f32(float) +declare i1 @llvm.spv.subgroup.all.equal.i32(i32) +declare <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half>) >From c78d96eecf4435cae10014da20025c8ceef4d384 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Fri, 27 Feb 2026 13:41:26 -0800 Subject: [PATCH 3/6] revert file changes --- .../BuiltIns/WaveActiveAllTrue-errors.hlsl | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl index 0975ad649e714..b0d0fdfca5e18 100644 --- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl @@ -1,21 +1,21 @@ -// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify - -bool test_too_few_arg() { - return __builtin_hlsl_wave_active_all_true(); - // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} -} - -bool test_too_many_arg(bool p0) { - return __builtin_hlsl_wave_active_all_true(p0, p0); - // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} -} - -struct Foo -{ - int a; -}; - -bool test_type_check(Foo p0) { - return __builtin_hlsl_wave_active_all_true(p0); - // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}} -} \ No newline at end of file +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +bool test_too_few_arg() { + return __builtin_hlsl_wave_active_all_true(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} +} + +bool test_too_many_arg(bool p0) { + return __builtin_hlsl_wave_active_all_true(p0, p0); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} +} + +struct Foo +{ + int a; +}; + +bool test_type_check(Foo p0) { + return __builtin_hlsl_wave_active_all_true(p0); + // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}} +} >From df45116b4e350d1bfa36f7f3d737cca9fb0b8879 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Mon, 2 Mar 2026 12:51:59 -0800 Subject: [PATCH 4/6] perform manual scalarization --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 96 ++++++++++++++++++- .../hlsl-intrinsics/WaveActiveAllEqual.ll | 15 ++- 2 files changed, 107 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index b9c6cb1e67595..3794092703470 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -337,6 +337,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectWaveActiveCountBits(Register ResVReg, SPIRVTypeInst ResType, MachineInstr &I) const; + bool selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType, + MachineInstr &I) const; + bool selectUnmergeValues(MachineInstr &I) const; bool selectHandleFromBinding(Register &ResVReg, SPIRVTypeInst ResType, @@ -2830,6 +2833,96 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits( return true; } +unsigned getVectorSizeOrOne(SPIRVTypeInst Type) { + + if (Type->getOpcode() != SPIRV::OpTypeVector) + return 1; + + // Operand(2) is the vector size + return Type->getOperand(2).getImm(); +} + +bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg, + SPIRVTypeInst ResType, + MachineInstr &I) const { + + MachineBasicBlock &BB = *I.getParent(); + const DebugLoc &DL = I.getDebugLoc(); + + SPIRVTypeInst SpvTy = GR.getSPIRVTypeForVReg(ResVReg); + unsigned NumElems = getVectorSizeOrOne(SpvTy); + bool IsVector = NumElems > 1; + + // Subgroup scope constant + SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); + + Register ScopeConst = GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, + TII, !STI.isShader()); + + Register InputReg = I.getOperand(2).getReg(); + + SmallVector<Register, 4> ElementResults; + + // If vector, determine element type once + SPIRVTypeInst ElemInputType = SpvTy; + SPIRVTypeInst ElemBoolType = ResType; + + if (IsVector) { + Register ElemTypeReg = SpvTy->getOperand(1).getReg(); + ElemInputType = GR.getSPIRVTypeForVReg(ElemTypeReg); + + Register BoolElemReg = ResType->getOperand(1).getReg(); + ElemBoolType = GR.getSPIRVTypeForVReg(BoolElemReg); + } + + for (unsigned Idx = 0; Idx < NumElems; ++Idx) { + + Register ElemInput = InputReg; + + if (IsVector) { + Register Extracted = + MRI->createVirtualRegister(GR.getRegClass(ElemInputType)); + + BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract)) + .addDef(Extracted) + .addUse(GR.getSPIRVTypeID(ElemInputType)) + .addUse(InputReg) + .addImm(Idx) + .constrainAllUses(TII, TRI, RBI); + + ElemInput = Extracted; + } + + Register ElemResult = + IsVector ? MRI->createVirtualRegister(GR.getRegClass(ElemBoolType)) + : ResVReg; + + BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual)) + .addDef(ElemResult) + .addUse(GR.getSPIRVTypeID(ElemBoolType)) + .addUse(ScopeConst) + .addUse(ElemInput) + .constrainAllUses(TII, TRI, RBI); + + ElementResults.push_back(ElemResult); + } + + if (!IsVector) + return true; + + auto MIB = BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeConstruct)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ResType)); + + for (Register R : ElementResults) + MIB.addUse(R); + + MIB.constrainAllUses(TII, TRI, RBI); + + return true; +} + + bool SPIRVInstructionSelector::selectWavePrefixBitCount(Register ResVReg, SPIRVTypeInst ResType, MachineInstr &I) const { @@ -4085,8 +4178,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, case Intrinsic::spv_wave_active_countbits: return selectWaveActiveCountBits(ResVReg, ResType, I); case Intrinsic::spv_subgroup_all_equal: - return selectWaveOpInst(ResVReg, ResType, I, - SPIRV::OpGroupNonUniformAllEqual); + return selectWaveActiveAllEqual(ResVReg, ResType, I); case Intrinsic::spv_wave_all: return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll); case Intrinsic::spv_wave_any: diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll index c64e5770b2d6d..9c63539e4b9e4 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll @@ -30,10 +30,21 @@ entry: } ; CHECK-LABEL: Begin function test_vhalf -; CHECK: %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]] +; Here there's a vector, so we scalarize and then recombine the +; result back into one vector define <4 x i1> @test_vhalf(<4 x half> %vbexpr) { entry: -; CHECK: %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#bool4]] %[[#scope]] %[[#vbexpr]] +; CHECK: %[[#param:]] = OpFunctionParameter %[[#v4float:]] +; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#bool]] %[[#param]] 0 +; CHECK-NEXT: %[[#res1:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext1]] +; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#bool]] %[[#param]] 1 +; CHECK-NEXT: %[[#res2:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext2]] +; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#bool]] %[[#param]] 2 +; CHECK-NEXT: %[[#res3:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext3]] +; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#bool]] %[[#param]] 3 +; CHECK-NEXT: %[[#res4:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext4]] +; CHECK-NEXT: %[[#ret:]] = OpCompositeConstruct %[[#bool4]] %[[#res1:]] %[[#res2:]] %[[#res3:]] %[[#res4:]] +; CHECK-NEXT: OpReturnValue %[[#ret]] %0 = call <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half> %vbexpr) ret <4 x i1> %0 } >From 8ce17b4829f1d8c32c4e027c607ef5511460d695 Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Mon, 2 Mar 2026 13:27:59 -0800 Subject: [PATCH 5/6] clang format --- llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3794092703470..fe16ceaf28c14 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -338,7 +338,7 @@ class SPIRVInstructionSelector : public InstructionSelector { MachineInstr &I) const; bool selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType, - MachineInstr &I) const; + MachineInstr &I) const; bool selectUnmergeValues(MachineInstr &I) const; @@ -2922,7 +2922,6 @@ bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg, return true; } - bool SPIRVInstructionSelector::selectWavePrefixBitCount(Register ResVReg, SPIRVTypeInst ResType, MachineInstr &I) const { >From 9a58dc845943b20c414c238dee6440a2f9e96d1d Mon Sep 17 00:00:00 2001 From: Joshua Batista <[email protected]> Date: Mon, 2 Mar 2026 14:49:19 -0800 Subject: [PATCH 6/6] more repairs to pass spirv-val --- .../Target/SPIRV/SPIRVInstructionSelector.cpp | 77 ++++++++++--------- .../hlsl-intrinsics/WaveActiveAllEqual.ll | 8 +- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index fe16ceaf28c14..1a9e07eb54e8c 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -2845,57 +2845,63 @@ unsigned getVectorSizeOrOne(SPIRVTypeInst Type) { bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType, MachineInstr &I) const { - MachineBasicBlock &BB = *I.getParent(); const DebugLoc &DL = I.getDebugLoc(); - SPIRVTypeInst SpvTy = GR.getSPIRVTypeForVReg(ResVReg); - unsigned NumElems = getVectorSizeOrOne(SpvTy); + // Input to the intrinsic + Register InputReg = I.getOperand(2).getReg(); + SPIRVTypeInst InputType = GR.getSPIRVTypeForVReg(InputReg); + + // Determine if input is vector + unsigned NumElems = getVectorSizeOrOne(InputType); bool IsVector = NumElems > 1; + // Determine element types + SPIRVTypeInst ElemInputType = InputType; + SPIRVTypeInst ElemBoolType = ResType; + if (IsVector) { + ElemInputType = GR.getSPIRVTypeForVReg(InputType->getOperand(1).getReg()); + ElemBoolType = GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg()); + } + // Subgroup scope constant SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII); - Register ScopeConst = GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, !STI.isShader()); - Register InputReg = I.getOperand(2).getReg(); + // === Scalar case === + if (!IsVector) { + BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual)) + .addDef(ResVReg) + .addUse(GR.getSPIRVTypeID(ElemBoolType)) + .addUse(ScopeConst) + .addUse(InputReg) + .constrainAllUses(TII, TRI, RBI); + return true; + } + // === Vector case === SmallVector<Register, 4> ElementResults; - - // If vector, determine element type once - SPIRVTypeInst ElemInputType = SpvTy; - SPIRVTypeInst ElemBoolType = ResType; - - if (IsVector) { - Register ElemTypeReg = SpvTy->getOperand(1).getReg(); - ElemInputType = GR.getSPIRVTypeForVReg(ElemTypeReg); - - Register BoolElemReg = ResType->getOperand(1).getReg(); - ElemBoolType = GR.getSPIRVTypeForVReg(BoolElemReg); - } + ElementResults.reserve(NumElems); for (unsigned Idx = 0; Idx < NumElems; ++Idx) { - + // Extract element Register ElemInput = InputReg; + Register Extracted = + MRI->createVirtualRegister(GR.getRegClass(ElemInputType)); + + BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract)) + .addDef(Extracted) + .addUse(GR.getSPIRVTypeID(ElemInputType)) + .addUse(InputReg) + .addImm(Idx) + .constrainAllUses(TII, TRI, RBI); - if (IsVector) { - Register Extracted = - MRI->createVirtualRegister(GR.getRegClass(ElemInputType)); - - BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract)) - .addDef(Extracted) - .addUse(GR.getSPIRVTypeID(ElemInputType)) - .addUse(InputReg) - .addImm(Idx) - .constrainAllUses(TII, TRI, RBI); - - ElemInput = Extracted; - } + ElemInput = Extracted; + // Emit per-element AllEqual Register ElemResult = - IsVector ? MRI->createVirtualRegister(GR.getRegClass(ElemBoolType)) - : ResVReg; + MRI->createVirtualRegister(GR.getRegClass(ElemBoolType)); BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual)) .addDef(ElemResult) @@ -2907,13 +2913,10 @@ bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg, ElementResults.push_back(ElemResult); } - if (!IsVector) - return true; - + // Reconstruct vector<bool> auto MIB = BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeConstruct)) .addDef(ResVReg) .addUse(GR.getSPIRVTypeID(ResType)); - for (Register R : ElementResults) MIB.addUse(R); diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll index 9c63539e4b9e4..8733505942c4c 100644 --- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll @@ -35,13 +35,13 @@ entry: define <4 x i1> @test_vhalf(<4 x half> %vbexpr) { entry: ; CHECK: %[[#param:]] = OpFunctionParameter %[[#v4float:]] -; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#bool]] %[[#param]] 0 +; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#f16]] %[[#param]] 0 ; CHECK-NEXT: %[[#res1:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext1]] -; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#bool]] %[[#param]] 1 +; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#f16]] %[[#param]] 1 ; CHECK-NEXT: %[[#res2:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext2]] -; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#bool]] %[[#param]] 2 +; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#f16]] %[[#param]] 2 ; CHECK-NEXT: %[[#res3:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext3]] -; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#bool]] %[[#param]] 3 +; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#f16]] %[[#param]] 3 ; CHECK-NEXT: %[[#res4:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] %[[#ext4]] ; CHECK-NEXT: %[[#ret:]] = OpCompositeConstruct %[[#bool4]] %[[#res1:]] %[[#res2:]] %[[#res3:]] %[[#res4:]] ; CHECK-NEXT: OpReturnValue %[[#ret]] _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
