https://github.com/bob80905 updated 
https://github.com/llvm/llvm-project/pull/183634

>From 8f3b1eca648517042b010d92c380995212afd59b Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Mon, 23 Feb 2026 17:04:58 -0800
Subject: [PATCH 1/6] first attempt

---
 clang/include/clang/Basic/Builtins.td         |   6 +
 clang/lib/CodeGen/CGHLSLBuiltins.cpp          |   7 +
 clang/lib/CodeGen/CGHLSLRuntime.h             |  27 +++-
 .../lib/Headers/hlsl/hlsl_alias_intrinsics.h  | 124 ++++++++++++++++++
 clang/lib/Sema/SemaHLSL.cpp                   |  13 ++
 .../builtins/WaveActiveAllEqual.hlsl          |  45 +++++++
 .../BuiltIns/WaveActiveAllEqual-errors.hlsl   |  28 ++++
 .../BuiltIns/WaveActiveAllTrue-errors.hlsl    |  49 ++++---
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   1 +
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |   1 +
 llvm/lib/Target/DirectX/DXIL.td               |  10 ++
 llvm/lib/Target/DirectX/DXILShaderFlags.cpp   |   2 +-
 .../DirectX/DirectXTargetTransformInfo.cpp    |   1 +
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |   3 +
 .../CodeGen/DirectX/ShaderFlags/wave-ops.ll   |   7 +
 .../CodeGen/DirectX/WaveActiveAllEqual.ll     |  87 ++++++++++++
 .../hlsl-intrinsics/WaveActiveAllEqual.ll     |  41 ++++++
 17 files changed, 426 insertions(+), 26 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
 create mode 100644 
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll

diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index 78dd26aa2c455..c66c029900453 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5132,6 +5132,12 @@ def HLSLAsDouble : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLWaveActiveAllEqual : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_active_all_equal"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "void(...)";
+}
+
 def HLSLWaveActiveAllTrue : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_all_true"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 70891eac39425..09dae2ab931ee 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -1088,6 +1088,13 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
   }
+  case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
+    Value *Op = EmitScalarExpr(E->getArg(0));
+
+    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllEqualIntrinsic();
+    return EmitRuntimeCall(
+        Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+  }
   case Builtin::BI__builtin_hlsl_wave_active_all_true: {
     Value *Op = EmitScalarExpr(E->getArg(0));
     assert(Op->getType()->isIntegerTy(1) &&
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h 
b/clang/lib/CodeGen/CGHLSLRuntime.h
index dbbc887353cec..940e3bbae8df2 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -34,16 +34,33 @@
 
 // A function generator macro for picking the right intrinsic
 // for the target backend
-#define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix)       
\
+#define _GEN_INTRIN_CHOOSER(_1, _2, _3, NAME, ...) NAME
+
+#define GENERATE_HLSL_INTRINSIC_FUNCTION(...)                                  
\
+  _GEN_INTRIN_CHOOSER(__VA_ARGS__, GENERATE_HLSL_INTRINSIC_FUNCTION3,          
\
+                      GENERATE_HLSL_INTRINSIC_FUNCTION2,                       
\
+                      /* dummy to solve pre-C++20 errors */ ignored)(          
\
+      __VA_ARGS__)
+
+// 2-arg form: same postfix for both backends (uses the identity)
+#define GENERATE_HLSL_INTRINSIC_FUNCTION2(FunctionName, IntrinsicPostfix)      
\
+  llvm::Intrinsic::ID get##FunctionName##Intrinsic() {                         
\
+    llvm::Triple::ArchType Arch = getArch();                                   
\
+    switch (Arch) {}                                                           
\
+  }
+
+// 3-arg form: explicit SPIR-V postfix override (perfect for wave->subgroup)
+#define GENERATE_HLSL_INTRINSIC_FUNCTION3(FunctionName, DxilPostfix,           
\
+                                          SpirvPostfix)                        
\
   llvm::Intrinsic::ID get##FunctionName##Intrinsic() {                         
\
     llvm::Triple::ArchType Arch = getArch();                                   
\
     switch (Arch) {                                                            
\
     case llvm::Triple::dxil:                                                   
\
-      return llvm::Intrinsic::dx_##IntrinsicPostfix;                           
\
+      return llvm::Intrinsic::dx_##DxilPostfix;                                
\
     case llvm::Triple::spirv:                                                  
\
-      return llvm::Intrinsic::spv_##IntrinsicPostfix;                          
\
+      return llvm::Intrinsic::spv_##SpirvPostfix;                              
\
     default:                                                                   
\
-      llvm_unreachable("Intrinsic " #IntrinsicPostfix                          
\
+      llvm_unreachable("Intrinsic " #DxilPostfix                               
\
                        " not supported by target architecture");               
\
     }                                                                          
\
   }
@@ -144,6 +161,8 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllEqual, wave_all_equal,
+                                   subgroup_all_equal)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveMax, wave_reduce_max)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h 
b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index 2543401bdfbf9..e4a9c5dc7b4a8 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2413,6 +2413,130 @@ float4 trunc(float4);
 // Wave* builtins
 
//===----------------------------------------------------------------------===//
 
+/// \brief Evaluates a value for all active invocations in the group. The
+/// result is true if Value is equal for all active invocations in the
+/// group. Otherwise, the result is false.
+/// \param Value The value to compare with
+/// \return True if all values across all lanes are equal, false otherwise
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half WaveActiveAllEqual(half);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half2 WaveActiveAllEqual(half2);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half3 WaveActiveAllEqual(half3);
+_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) half4 WaveActiveAllEqual(half4);
+
+#ifdef __HLSL_ENABLE_16_BIT
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t WaveActiveAllEqual(int16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t2 WaveActiveAllEqual(int16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t3 WaveActiveAllEqual(int16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int16_t4 WaveActiveAllEqual(int16_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t WaveActiveAllEqual(uint16_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t2 WaveActiveAllEqual(uint16_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t3 WaveActiveAllEqual(uint16_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint16_t4 WaveActiveAllEqual(uint16_t4);
+#endif
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int WaveActiveAllEqual(int);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int2 WaveActiveAllEqual(int2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int3 WaveActiveAllEqual(int3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int4 WaveActiveAllEqual(int4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint WaveActiveAllEqual(uint);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint2 WaveActiveAllEqual(uint2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint3 WaveActiveAllEqual(uint3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint4 WaveActiveAllEqual(uint4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t WaveActiveAllEqual(int64_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t2 WaveActiveAllEqual(int64_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t3 WaveActiveAllEqual(int64_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) int64_t4 WaveActiveAllEqual(int64_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t WaveActiveAllEqual(uint64_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t2 WaveActiveAllEqual(uint64_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t3 WaveActiveAllEqual(uint64_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) uint64_t4 WaveActiveAllEqual(uint64_t4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float WaveActiveAllEqual(float);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float2 WaveActiveAllEqual(float2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float3 WaveActiveAllEqual(float3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) float4 WaveActiveAllEqual(float4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double WaveActiveAllEqual(double);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double2 WaveActiveAllEqual(double2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double3 WaveActiveAllEqual(double3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
+__attribute__((convergent)) double4 WaveActiveAllEqual(double4);
+
 /// \brief Returns true if the expression is true in all active lanes in the
 /// current wave.
 ///
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 802a1bdbccfdd..249d8dc58b866 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3809,6 +3809,19 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
     TheCall->setType(ArgTyA);
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_active_all_true: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    // Ensure input expr type is a scalar/vector
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+      return true;
+
+    // set return type to bool
+    TheCall->setType(getASTContext().BoolTy);
+
+    break;
+  }
   case Builtin::BI__builtin_hlsl_wave_active_max:
   case Builtin::BI__builtin_hlsl_wave_active_min:
   case Builtin::BI__builtin_hlsl_wave_active_sum: {
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl 
b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
new file mode 100644
index 0000000000000..4b4149d05eb3f
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
@@ -0,0 +1,45 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o 
- | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_int
+bool test_int(int expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i1 
@llvm.spv.wave.all.equal.i32([[TY]] %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32([[TY]] 
%[[#]])
+  // CHECK:  ret i1 %[[RET]]
+  return WaveActiveAllEqual(expr);
+}
+
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_uint64_t
+bool test_uint64_t(uint64_t expr) {
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i1 
@llvm.spv.wave.all.equal.i64(i64 %[[#]])
+  // CHECK-DXIL:  %[[RET:.*]] = call i1 @llvm.dx.wave.uproduct.i64(i64 %[[#]])
+  // CHECK:  ret i1 %[[RET]]
+  return WaveActiveAllEqual(expr);
+}
+
+// CHECK-DXIL: declare i1 @llvm.dx.wave.uproduct.i64(i64 #[[#attr:]]
+// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i64(i64) #[[#attr:]]
+
+// Test basic lowering to runtime function call with array and float value.
+
+// CHECK-LABEL: test_floatv4
+bool test_floatv4(float4 expr) {
+  // CHECK-SPIRV:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn 
spir_func i1 @llvm.spv.wave.all.equal.v4f32(i32 %[[#]]
+  // CHECK-DXIL:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn i1 
@llvm.dx.wave.all.equal.v4f32(i32 %[[#]])
+  // CHECK:  ret [[TY1]] %[[RET1]]
+  return WaveActiveAllEqual(expr);
+}
+
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.v4f32(i32) #[[#attr]]
+// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.v4f32(i32) #[[#attr]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
new file mode 100644
index 0000000000000..2c838cb51dd78
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
@@ -0,0 +1,28 @@
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+  return __builtin_hlsl_wave_active_all_equal();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 
0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_active_all_equal(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 
2}}
+}
+
+bool test_expr_bool_type_check(bool p0) {
+  return __builtin_hlsl_wave_active_all_equal(p0);
+  // expected-error@-1 {{invalid operand of type 'bool'}}
+}
+
+bool2 test_expr_bool_vec_type_check(bool2 p0) {
+  return __builtin_hlsl_wave_active_all_equal(p0);
+  // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 
2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+  return __builtin_hlsl_wave_active_all_equal(p0);
+  // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector 
is required}}
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
index b0d0fdfca5e18..af926d60624c6 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
@@ -1,21 +1,28 @@
-// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
-
-bool test_too_few_arg() {
-  return __builtin_hlsl_wave_active_all_true();
-  // expected-error@-1 {{too few arguments to function call, expected 1, have 
0}}
-}
-
-bool test_too_many_arg(bool p0) {
-  return __builtin_hlsl_wave_active_all_true(p0, p0);
-  // expected-error@-1 {{too many arguments to function call, expected 1, have 
2}}
-}
-
-struct Foo
-{
-  int a;
-};
-
-bool test_type_check(Foo p0) {
-  return __builtin_hlsl_wave_active_all_true(p0);
-  // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
-}
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+  return __builtin_hlsl_wave_active_product();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 
0}}
+}
+
+float2 test_too_many_arg(float2 p0) {
+  return __builtin_hlsl_wave_active_product(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 
2}}
+}
+
+bool test_expr_bool_type_check(bool p0) {
+  return __builtin_hlsl_wave_active_product(p0);
+  // expected-error@-1 {{invalid operand of type 'bool'}}
+}
+
+bool2 test_expr_bool_vec_type_check(bool2 p0) {
+  return __builtin_hlsl_wave_active_product(p0);
+  // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 
2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+  return __builtin_hlsl_wave_active_product(p0);
+  // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector 
is required}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td 
b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 909482d72aa88..a688da131ce75 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -213,6 +213,7 @@ def int_dx_normalize : 
DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
 def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], 
[LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, 
LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td 
b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 77f49ae721ad5..59a9612d1ff50 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,6 +120,7 @@ def int_spv_rsqrt : 
DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
   def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
   def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+  def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], 
[llvm_any_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_subgroup_ballot : 
ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 59a5b7fe4d508..a378f0d665d44 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -217,6 +217,7 @@ defset list<DXILOpClass> OpClasses = {
   def waveActiveOp : DXILOpClass;
   def waveAllOp : DXILOpClass;
   def waveAllTrue : DXILOpClass;
+  def waveAllEqual : DXILOpClass;
   def waveAnyTrue : DXILOpClass;
   def waveActiveBallot : DXILOpClass;
   def waveGetLaneCount : DXILOpClass;
@@ -1062,6 +1063,15 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
+def WaveActiveAllEqual : DXILOp<115, waveAllEqual> {
+  let Doc = "returns true if the expression is equal in all of the active 
lanes "
+            "in the current wave";
+  let intrinsics = [IntrinSelect<int_dx_wave_all_equal>];
+  let arguments = [OverloadTy];
+  let result = Int1Ty;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
 def WaveActiveBallot : DXILOp<116, waveActiveBallot> {
   let Doc = "returns uint4 containing a bitmask of the evaluation of the 
boolean expression for all active lanes in the current wave.";
   let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp 
b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index 52993ee1c1220..1d14079407cbe 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -64,7 +64,6 @@ static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM,
 static bool checkWaveOps(Intrinsic::ID IID) {
   // Currently unsupported intrinsics
   // case Intrinsic::dx_wave_getlanecount:
-  // case Intrinsic::dx_wave_allequal:
   // case Intrinsic::dx_wave_readfirst:
   // case Intrinsic::dx_wave_reduce.and:
   // case Intrinsic::dx_wave_reduce.or:
@@ -85,6 +84,7 @@ static bool checkWaveOps(Intrinsic::ID IID) {
   case Intrinsic::dx_wave_is_first_lane:
   case Intrinsic::dx_wave_getlaneindex:
   case Intrinsic::dx_wave_any:
+  case Intrinsic::dx_wave_all_equal:
   case Intrinsic::dx_wave_all:
   case Intrinsic::dx_wave_readlane:
   case Intrinsic::dx_wave_active_countbits:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp 
b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 8018b09c9f248..eca2343227577 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -57,6 +57,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
   case Intrinsic::dx_rsqrt:
   case Intrinsic::dx_saturate:
   case Intrinsic::dx_splitdouble:
+  case Intrinsic::dx_wave_all_equal:
   case Intrinsic::dx_wave_readlane:
   case Intrinsic::dx_wave_reduce_max:
   case Intrinsic::dx_wave_reduce_min:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp 
b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3d3e311eeedb7..b9c6cb1e67595 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -4084,6 +4084,9 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register 
ResVReg,
     return selectWavePrefixBitCount(ResVReg, ResType, I);
   case Intrinsic::spv_wave_active_countbits:
     return selectWaveActiveCountBits(ResVReg, ResType, I);
+  case Intrinsic::spv_subgroup_all_equal:
+    return selectWaveOpInst(ResVReg, ResType, I,
+                            SPIRV::OpGroupNonUniformAllEqual);
   case Intrinsic::spv_wave_all:
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
   case Intrinsic::spv_wave_any:
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll 
b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
index be53d19aca8f2..6c29ac73719e6 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
@@ -42,6 +42,13 @@ entry:
   ret i1 %ret
 }
 
+define noundef i1 @wave_all_equal(i1 %x) {
+entry:
+  ; CHECK: Function wave_all_equal : [[WAVE_FLAG]]
+  %ret = call i1 @llvm.dx.wave.all.equal(i1 %x)
+  ret i1 %ret
+}
+
 define noundef i1 @wave_readlane(i1 %x, i32 %idx) {
 entry:
   ; CHECK: Function wave_readlane : [[WAVE_FLAG]]
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll 
b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
new file mode 100644
index 0000000000000..702f2ad1dde5f
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
@@ -0,0 +1,87 @@
+; RUN: opt -S -scalarizer -dxil-op-lower 
-mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s
+
+; Test that for scalar values, WaveAcitveProduct maps down to the DirectX op
+
+define noundef half @wave_active_all_equal_half(half noundef %expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr, i8 1, i8 
0)
+  %ret = call half @llvm.dx.wave.all.equal.f16(half %expr)
+  ret half %ret
+}
+
+define noundef float @wave_active_all_equal_float(float noundef %expr) {
+entry:
+; CHECK: call float @dx.op.waveActiveAllEqual.f32(i32 119, float %expr, i8 1, 
i8 0)
+  %ret = call float @llvm.dx.wave.all.equal.f32(float %expr)
+  ret float %ret
+}
+
+define noundef double @wave_active_all_equal_double(double noundef %expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr, i8 
1, i8 0)
+  %ret = call double @llvm.dx.wave.all.equal.f64(double %expr)
+  ret double %ret
+}
+
+define noundef i16 @wave_active_all_equal_i16(i16 noundef %expr) {
+entry:
+; CHECK: call i16 @dx.op.waveActiveAllEqual.i16(i32 119, i16 %expr, i8 1, i8 0)
+  %ret = call i16 @llvm.dx.wave.all.equal.i16(i16 %expr)
+  ret i16 %ret
+}
+
+define noundef i32 @wave_active_all_equal_i32(i32 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr, i8 1, i8 0)
+  %ret = call i32 @llvm.dx.wave.all.equal.i32(i32 %expr)
+  ret i32 %ret
+}
+
+define noundef i64 @wave_active_all_equal_i64(i64 noundef %expr) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveAllEqual.i64(i32 119, i64 %expr, i8 1, i8 0)
+  %ret = call i64 @llvm.dx.wave.all.equal.i64(i64 %expr)
+  ret i64 %ret
+}
+
+declare half @llvm.dx.wave.all.equal.f16(half)
+declare float @llvm.dx.wave.all.equal.f32(float)
+declare double @llvm.dx.wave.all.equal.f64(double)
+
+declare i16 @llvm.dx.wave.all.equal.i16(i16)
+declare i32 @llvm.dx.wave.all.equal.i32(i32)
+declare i64 @llvm.dx.wave.all.equal.i64(i64)
+
+; Test that for vector values, WaveAcitveProduct scalarizes and maps down to 
the
+; DirectX op
+
+define noundef <2 x half> @wave_active_all_equal_v2half(<2 x half> noundef 
%expr) {
+entry:
+; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i0, i8 1, 
i8 0)
+; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i1, i8 1, 
i8 0)
+  %ret = call <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr)
+  ret <2 x half> %ret
+}
+
+define noundef <3 x i32> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) 
{
+entry:
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i0, i8 1, 
i8 0)
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i1, i8 1, 
i8 0)
+; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i2, i8 1, 
i8 0)
+  %ret = call <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr)
+  ret <3 x i32> %ret
+}
+
+define noundef <4 x double> @wave_active_all_equal_v4f64(<4 x double> noundef 
%expr) {
+entry:
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i0, 
i8 1, i8 0)
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i1, 
i8 1, i8 0)
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i2, 
i8 1, i8 0)
+; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i3, 
i8 1, i8 0)
+  %ret = call <4 x double> @llvm.dx.wave.all.equal.v464(<4 x double> %expr)
+  ret <4 x double> %ret
+}
+
+declare <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half>)
+declare <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32>)
+declare <4 x double> @llvm.dx.wave.all.equal.v4f64(<4 x double>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll 
b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
new file mode 100644
index 0000000000000..e871dc9a7aa28
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -0,0 +1,41 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | 
FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - 
-filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend for various types and scalar/vector
+
+; CHECK-DAG:   %[[#f16:]] = OpTypeFloat 16
+; CHECK-DAG:   %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:   %[[#v4_half:]] = OpTypeVector %[[#f16]] 4
+; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_float
+; CHECK:   %[[#fexpr:]] = OpFunctionParameter %[[#f32]]
+define i1 @test_float(float %fexpr) {
+entry:
+; CHECK:   %[[#fret:]] = OpGroupNonUniformAllEqual %[[#f32]] %[[#scope]] 
Reduce %[[#fexpr]]
+  %0 = call i1 @llvm.spv.wave.all.equal.f32(float %fexpr)
+  ret i1 %0
+}
+
+; CHECK-LABEL: Begin function test_int
+; CHECK:   %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
+define i1 @test_int(i32 %iexpr) {
+entry:
+; CHECK:   %[[#iret:]] = OpGroupNonUniformAllEqual %[[#uint]] %[[#scope]] 
Reduce %[[#iexpr]]
+  %0 = call i1 @llvm.spv.wave.all.equal.i32(i32 %iexpr)
+  ret i1 %0
+}
+
+; CHECK-LABEL: Begin function test_vhalf
+; CHECK:   %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]]
+define i1 @test_vhalf(<4 x half> %vbexpr) {
+entry:
+; CHECK:   %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#v4_half]] 
%[[#scope]] Reduce %[[#vbexpr]]
+  %0 = call i1 @llvm.spv.wave.all.equal.v4half(<4 x half> %vbexpr)
+  ret i1 %0
+}
+
+declare i1 @llvm.spv.wave.all.equal.f32(float)
+declare i1 @llvm.spv.wave.all.equal.i32(i32)
+declare i1 @llvm.spv.wave.all.equal.v4half(<4 x half>)

>From c523c88934ae38d62d4b65729f4374e6f9eb617e Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Fri, 27 Feb 2026 13:39:08 -0800
Subject: [PATCH 2/6] fix return type, self review

---
 clang/lib/CodeGen/CGHLSLBuiltins.cpp          |   5 +-
 clang/lib/CodeGen/CGHLSLRuntime.h             |  10 +-
 .../lib/Headers/hlsl/hlsl_alias_intrinsics.h  |  72 +++++-----
 clang/lib/Sema/SemaHLSL.cpp                   |  18 ++-
 .../builtins/WaveActiveAllEqual.hlsl          |  30 ++---
 .../BuiltIns/WaveActiveAllEqual-errors.hlsl   |  16 +--
 .../BuiltIns/WaveActiveAllTrue-errors.hlsl    |  31 ++---
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |   2 +-
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |   2 +-
 llvm/lib/Target/DirectX/DXIL.td               |   7 +-
 .../DirectX/DirectXTargetTransformInfo.cpp    |   3 +
 .../CodeGen/DirectX/WaveActiveAllEqual.ll     | 124 +++++++++++-------
 .../hlsl-intrinsics/WaveActiveAllEqual.ll     |  24 ++--
 13 files changed, 187 insertions(+), 157 deletions(-)

diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 09dae2ab931ee..47b7e2b18d942 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -1092,8 +1092,9 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     Value *Op = EmitScalarExpr(E->getArg(0));
 
     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllEqualIntrinsic();
-    return EmitRuntimeCall(
-        Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+    return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
+                               &CGM.getModule(), ID, {Op->getType()}),
+                           {Op});
   }
   case Builtin::BI__builtin_hlsl_wave_active_all_true: {
     Value *Op = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h 
b/clang/lib/CodeGen/CGHLSLRuntime.h
index 940e3bbae8df2..d6055d89e3c84 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -46,7 +46,15 @@
 #define GENERATE_HLSL_INTRINSIC_FUNCTION2(FunctionName, IntrinsicPostfix)      
\
   llvm::Intrinsic::ID get##FunctionName##Intrinsic() {                         
\
     llvm::Triple::ArchType Arch = getArch();                                   
\
-    switch (Arch) {}                                                           
\
+    switch (Arch) {                                                            
\
+    case llvm::Triple::dxil:                                                   
\
+      return llvm::Intrinsic::dx_##IntrinsicPostfix;                           
\
+    case llvm::Triple::spirv:                                                  
\
+      return llvm::Intrinsic::spv_##IntrinsicPostfix;                          
\
+    default:                                                                   
\
+      llvm_unreachable("Intrinsic " #IntrinsicPostfix                          
\
+                       " not supported by target architecture");               
\
+    }                                                                          
\
   }
 
 // 3-arg form: explicit SPIR-V postfix override (perfect for wave->subgroup)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h 
b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index e4a9c5dc7b4a8..5ca4713c2d520 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2420,122 +2420,122 @@ float4 trunc(float4);
 /// \return True if all values across all lanes are equal, false otherwise
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half WaveActiveAllEqual(half);
+__attribute__((convergent)) bool WaveActiveAllEqual(half);
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half2 WaveActiveAllEqual(half2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(half2);
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half3 WaveActiveAllEqual(half3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(half3);
 _HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) half4 WaveActiveAllEqual(half4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(half4);
 
 #ifdef __HLSL_ENABLE_16_BIT
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t WaveActiveAllEqual(int16_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(int16_t);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t2 WaveActiveAllEqual(int16_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(int16_t2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t3 WaveActiveAllEqual(int16_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(int16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int16_t4 WaveActiveAllEqual(int16_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(int16_t4);
 
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t WaveActiveAllEqual(uint16_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(uint16_t);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t2 WaveActiveAllEqual(uint16_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(uint16_t2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t3 WaveActiveAllEqual(uint16_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(uint16_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint16_t4 WaveActiveAllEqual(uint16_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(uint16_t4);
 #endif
 
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int WaveActiveAllEqual(int);
+__attribute__((convergent)) bool WaveActiveAllEqual(int);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int2 WaveActiveAllEqual(int2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(int2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int3 WaveActiveAllEqual(int3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(int3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int4 WaveActiveAllEqual(int4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(int4);
 
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint WaveActiveAllEqual(uint);
+__attribute__((convergent)) bool WaveActiveAllEqual(uint);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint2 WaveActiveAllEqual(uint2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(uint2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint3 WaveActiveAllEqual(uint3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(uint3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint4 WaveActiveAllEqual(uint4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(uint4);
 
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t WaveActiveAllEqual(int64_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(int64_t);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t2 WaveActiveAllEqual(int64_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(int64_t2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t3 WaveActiveAllEqual(int64_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(int64_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) int64_t4 WaveActiveAllEqual(int64_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(int64_t4);
 
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t WaveActiveAllEqual(uint64_t);
+__attribute__((convergent)) bool WaveActiveAllEqual(uint64_t);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t2 WaveActiveAllEqual(uint64_t2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(uint64_t2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t3 WaveActiveAllEqual(uint64_t3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(uint64_t3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) uint64_t4 WaveActiveAllEqual(uint64_t4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(uint64_t4);
 
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float WaveActiveAllEqual(float);
+__attribute__((convergent)) bool WaveActiveAllEqual(float);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float2 WaveActiveAllEqual(float2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(float2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float3 WaveActiveAllEqual(float3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(float3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) float4 WaveActiveAllEqual(float4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(float4);
 
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double WaveActiveAllEqual(double);
+__attribute__((convergent)) bool WaveActiveAllEqual(double);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double2 WaveActiveAllEqual(double2);
+__attribute__((convergent)) bool2 WaveActiveAllEqual(double2);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double3 WaveActiveAllEqual(double3);
+__attribute__((convergent)) bool3 WaveActiveAllEqual(double3);
 _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_all_equal)
-__attribute__((convergent)) double4 WaveActiveAllEqual(double4);
+__attribute__((convergent)) bool4 WaveActiveAllEqual(double4);
 
 /// \brief Returns true if the expression is true in all active lanes in the
 /// current wave.
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 249d8dc58b866..46cc3835c85a8 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3809,7 +3809,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
     TheCall->setType(ArgTyA);
     break;
   }
-  case Builtin::BI__builtin_hlsl_wave_active_all_true: {
+  case Builtin::BI__builtin_hlsl_wave_active_all_equal: {
     if (SemaRef.checkArgCount(TheCall, 1))
       return true;
 
@@ -3817,9 +3817,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
     if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
       return true;
 
-    // set return type to bool
-    TheCall->setType(getASTContext().BoolTy);
+    QualType InputTy = TheCall->getArg(0)->getType();
+    ASTContext &Ctx = getASTContext();
 
+    QualType RetTy;
+
+    // If vector, construct bool vector of same size
+    if (const auto *VecTy = InputTy->getAs<ExtVectorType>()) {
+      unsigned NumElts = VecTy->getNumElements();
+      RetTy = Ctx.getExtVectorType(Ctx.BoolTy, NumElts);
+    } else {
+      // Scalar case
+      RetTy = Ctx.BoolTy;
+    }
+
+    TheCall->setType(RetTy);
     break;
   }
   case Builtin::BI__builtin_hlsl_wave_active_max:
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl 
b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
index 4b4149d05eb3f..65d15633eb6cf 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveAllEqual.hlsl
@@ -9,37 +9,37 @@
 
 // CHECK-LABEL: test_int
 bool test_int(int expr) {
-  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i1 
@llvm.spv.wave.all.equal.i32([[TY]] %[[#]])
-  // CHECK-DXIL:  %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32([[TY]] 
%[[#]])
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i1 
@llvm.spv.subgroup.all.equal.i32(i32
+  // CHECK-DXIL:  %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i32(i32
   // CHECK:  ret i1 %[[RET]]
   return WaveActiveAllEqual(expr);
 }
 
-// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32([[TY]]) #[[#attr:]]
-// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i32([[TY]]) #[[#attr:]]
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i32(i32) #[[attr:.*]]
+// CHECK-SPIRV: declare i1 @llvm.spv.subgroup.all.equal.i32(i32) #[[attr:.*]]
 
 // CHECK-LABEL: test_uint64_t
 bool test_uint64_t(uint64_t expr) {
-  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i1 
@llvm.spv.wave.all.equal.i64(i64 %[[#]])
-  // CHECK-DXIL:  %[[RET:.*]] = call i1 @llvm.dx.wave.uproduct.i64(i64 %[[#]])
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i1 
@llvm.spv.subgroup.all.equal.i64(i64 
+  // CHECK-DXIL:  %[[RET:.*]] = call i1 @llvm.dx.wave.all.equal.i64(i64
   // CHECK:  ret i1 %[[RET]]
   return WaveActiveAllEqual(expr);
 }
 
-// CHECK-DXIL: declare i1 @llvm.dx.wave.uproduct.i64(i64 #[[#attr:]]
-// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.i64(i64) #[[#attr:]]
+// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.i64(i64) #[[attr]]
+// CHECK-SPIRV: declare i1 @llvm.spv.subgroup.all.equal.i64(i64) #[[attr]]
 
 // Test basic lowering to runtime function call with array and float value.
 
 // CHECK-LABEL: test_floatv4
-bool test_floatv4(float4 expr) {
-  // CHECK-SPIRV:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn 
spir_func i1 @llvm.spv.wave.all.equal.v4f32(i32 %[[#]]
-  // CHECK-DXIL:  %[[RET1:.*]] = call reassoc nnan ninf nsz arcp afn i1 
@llvm.dx.wave.all.equal.v4f32(i32 %[[#]])
-  // CHECK:  ret [[TY1]] %[[RET1]]
+bool4 test_floatv4(float4 expr) {
+  // CHECK-SPIRV:  %[[RET1:.*]] = call spir_func <4 x i1> 
@llvm.spv.subgroup.all.equal.v4f32(<4 x float> 
+  // CHECK-DXIL:  %[[RET1:.*]] = call <4 x i1> 
@llvm.dx.wave.all.equal.v4f32(<4 x float> 
+  // CHECK:  ret <4 x i1> %[[RET1]]
   return WaveActiveAllEqual(expr);
 }
 
-// CHECK-DXIL: declare i1 @llvm.dx.wave.all.equal.v4f32(i32) #[[#attr]]
-// CHECK-SPIRV: declare i1 @llvm.spv.wave.all.equal.v4f32(i32) #[[#attr]]
+// CHECK-DXIL: declare <4 x i1> @llvm.dx.wave.all.equal.v4f32(<4 x float>) 
#[[attr]]
+// CHECK-SPIRV: declare <4 x i1> @llvm.spv.subgroup.all.equal.v4f32(<4 x 
float>) #[[attr]]
 
-// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
+// CHECK: attributes #[[attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
index 2c838cb51dd78..1b5d7955baffc 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllEqual-errors.hlsl
@@ -1,28 +1,18 @@
 // RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
 
-int test_too_few_arg() {
+bool test_too_few_arg() {
   return __builtin_hlsl_wave_active_all_equal();
   // expected-error@-1 {{too few arguments to function call, expected 1, have 
0}}
 }
 
-float2 test_too_many_arg(float2 p0) {
+bool test_too_many_arg(float2 p0) {
   return __builtin_hlsl_wave_active_all_equal(p0, p0);
   // expected-error@-1 {{too many arguments to function call, expected 1, have 
2}}
 }
 
-bool test_expr_bool_type_check(bool p0) {
-  return __builtin_hlsl_wave_active_all_equal(p0);
-  // expected-error@-1 {{invalid operand of type 'bool'}}
-}
-
-bool2 test_expr_bool_vec_type_check(bool2 p0) {
-  return __builtin_hlsl_wave_active_all_equal(p0);
-  // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 
2>')}}
-}
-
 struct S { float f; };
 
-S test_expr_struct_type_check(S p0) {
+bool test_expr_struct_type_check(S p0) {
   return __builtin_hlsl_wave_active_all_equal(p0);
   // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector 
is required}}
 }
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
index af926d60624c6..0975ad649e714 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
@@ -1,28 +1,21 @@
 // RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
 
-int test_too_few_arg() {
-  return __builtin_hlsl_wave_active_product();
+bool test_too_few_arg() {
+  return __builtin_hlsl_wave_active_all_true();
   // expected-error@-1 {{too few arguments to function call, expected 1, have 
0}}
 }
 
-float2 test_too_many_arg(float2 p0) {
-  return __builtin_hlsl_wave_active_product(p0, p0);
+bool test_too_many_arg(bool p0) {
+  return __builtin_hlsl_wave_active_all_true(p0, p0);
   // expected-error@-1 {{too many arguments to function call, expected 1, have 
2}}
 }
 
-bool test_expr_bool_type_check(bool p0) {
-  return __builtin_hlsl_wave_active_product(p0);
-  // expected-error@-1 {{invalid operand of type 'bool'}}
-}
-
-bool2 test_expr_bool_vec_type_check(bool2 p0) {
-  return __builtin_hlsl_wave_active_product(p0);
-  // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 
2>')}}
-}
+struct Foo
+{
+  int a;
+};
 
-struct S { float f; };
-
-S test_expr_struct_type_check(S p0) {
-  return __builtin_hlsl_wave_active_product(p0);
-  // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector 
is required}}
-}
+bool test_type_check(Foo p0) {
+  return __builtin_hlsl_wave_active_all_true(p0);
+  // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
+}
\ No newline at end of file
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td 
b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index a688da131ce75..6774a33556c09 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -213,7 +213,7 @@ def int_dx_normalize : 
DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
 def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], 
[LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
-def int_dx_wave_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], 
[IntrConvergent, IntrNoMem]>;
+def int_dx_wave_all_equal : 
DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], 
[llvm_any_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, 
LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td 
b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 59a9612d1ff50..b91905f350506 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,7 +120,7 @@ def int_spv_rsqrt : 
DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
   def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
   def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
-  def int_spv_subgroup_all_equal : DefaultAttrsIntrinsic<[llvm_i1_ty], 
[llvm_any_ty], [IntrConvergent, IntrNoMem]>;
+  def int_spv_subgroup_all_equal : 
DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>], 
[llvm_any_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_subgroup_ballot : 
ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index a378f0d665d44..e64909b059d29 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -217,7 +217,6 @@ defset list<DXILOpClass> OpClasses = {
   def waveActiveOp : DXILOpClass;
   def waveAllOp : DXILOpClass;
   def waveAllTrue : DXILOpClass;
-  def waveAllEqual : DXILOpClass;
   def waveAnyTrue : DXILOpClass;
   def waveActiveBallot : DXILOpClass;
   def waveGetLaneCount : DXILOpClass;
@@ -1063,9 +1062,9 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
-def WaveActiveAllEqual : DXILOp<115, waveAllEqual> {
-  let Doc = "returns true if the expression is equal in all of the active 
lanes "
-            "in the current wave";
+def WaveActiveAllEqual : DXILOp<115, waveActiveAllEqual> {
+  let Doc = "returns true for each scalar element of the expression if the "
+            "expression is equal in all of the active lanes in the current 
wave";
   let intrinsics = [IntrinSelect<int_dx_wave_all_equal>];
   let arguments = [OverloadTy];
   let result = Int1Ty;
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp 
b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index eca2343227577..a2d7ffefbb5a2 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -36,8 +36,11 @@ bool 
DirectXTTIImpl::isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
   case Intrinsic::dx_isnan:
   case Intrinsic::dx_legacyf16tof32:
   case Intrinsic::dx_legacyf32tof16:
+  case Intrinsic::dx_wave_all_equal:
     return OpdIdx == 0;
   default:
+    // All DX intrinsics are overloaded on return type unless specified
+    // otherwise
     return OpdIdx == -1;
   }
 }
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll 
b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
index 702f2ad1dde5f..f6dcd59c33958 100644
--- a/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/DirectX/WaveActiveAllEqual.ll
@@ -2,86 +2,108 @@
 
 ; Test that for scalar values, WaveAcitveProduct maps down to the DirectX op
 
-define noundef half @wave_active_all_equal_half(half noundef %expr) {
+define noundef i1 @wave_active_all_equal_half(half noundef %expr) {
 entry:
-; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr, i8 1, i8 
0)
-  %ret = call half @llvm.dx.wave.all.equal.f16(half %expr)
-  ret half %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half %expr)
+  %ret = call i1 @llvm.dx.wave.all.equal.f16(half %expr)
+  ret i1 %ret
 }
 
-define noundef float @wave_active_all_equal_float(float noundef %expr) {
+define noundef i1 @wave_active_all_equal_float(float noundef %expr) {
 entry:
-; CHECK: call float @dx.op.waveActiveAllEqual.f32(i32 119, float %expr, i8 1, 
i8 0)
-  %ret = call float @llvm.dx.wave.all.equal.f32(float %expr)
-  ret float %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.f32(i32 115, float %expr)
+  %ret = call i1 @llvm.dx.wave.all.equal.f32(float %expr)
+  ret i1 %ret
 }
 
-define noundef double @wave_active_all_equal_double(double noundef %expr) {
+define noundef i1 @wave_active_all_equal_double(double noundef %expr) {
 entry:
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr, i8 
1, i8 0)
-  %ret = call double @llvm.dx.wave.all.equal.f64(double %expr)
-  ret double %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double %expr)
+  %ret = call i1 @llvm.dx.wave.all.equal.f64(double %expr)
+  ret i1 %ret
 }
 
-define noundef i16 @wave_active_all_equal_i16(i16 noundef %expr) {
+define noundef i1 @wave_active_all_equal_i16(i16 noundef %expr) {
 entry:
-; CHECK: call i16 @dx.op.waveActiveAllEqual.i16(i32 119, i16 %expr, i8 1, i8 0)
-  %ret = call i16 @llvm.dx.wave.all.equal.i16(i16 %expr)
-  ret i16 %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.i16(i32 115, i16 %expr)
+  %ret = call i1 @llvm.dx.wave.all.equal.i16(i16 %expr)
+  ret i1 %ret
 }
 
-define noundef i32 @wave_active_all_equal_i32(i32 noundef %expr) {
+define noundef i1 @wave_active_all_equal_i32(i32 noundef %expr) {
 entry:
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr, i8 1, i8 0)
-  %ret = call i32 @llvm.dx.wave.all.equal.i32(i32 %expr)
-  ret i32 %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 %expr)
+  %ret = call i1 @llvm.dx.wave.all.equal.i32(i32 %expr)
+  ret i1 %ret
 }
 
-define noundef i64 @wave_active_all_equal_i64(i64 noundef %expr) {
+define noundef i1 @wave_active_all_equal_i64(i64 noundef %expr) {
 entry:
-; CHECK: call i64 @dx.op.waveActiveAllEqual.i64(i32 119, i64 %expr, i8 1, i8 0)
-  %ret = call i64 @llvm.dx.wave.all.equal.i64(i64 %expr)
-  ret i64 %ret
+; CHECK: call i1 @dx.op.waveActiveAllEqual.i64(i32 115, i64 %expr)
+  %ret = call i1 @llvm.dx.wave.all.equal.i64(i64 %expr)
+  ret i1 %ret
 }
 
-declare half @llvm.dx.wave.all.equal.f16(half)
-declare float @llvm.dx.wave.all.equal.f32(float)
-declare double @llvm.dx.wave.all.equal.f64(double)
+declare i1 @llvm.dx.wave.all.equal.f16(half)
+declare i1 @llvm.dx.wave.all.equal.f32(float)
+declare i1 @llvm.dx.wave.all.equal.f64(double)
 
-declare i16 @llvm.dx.wave.all.equal.i16(i16)
-declare i32 @llvm.dx.wave.all.equal.i32(i32)
-declare i64 @llvm.dx.wave.all.equal.i64(i64)
+declare i1 @llvm.dx.wave.all.equal.i16(i16)
+declare i1 @llvm.dx.wave.all.equal.i32(i32)
+declare i1 @llvm.dx.wave.all.equal.i64(i64)
 
 ; Test that for vector values, WaveAcitveProduct scalarizes and maps down to 
the
 ; DirectX op
 
-define noundef <2 x half> @wave_active_all_equal_v2half(<2 x half> noundef 
%expr) {
+define noundef <2 x i1> @wave_active_all_equal_v2half(<2 x half> noundef 
%expr) {
 entry:
-; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i0, i8 1, 
i8 0)
-; CHECK: call half @dx.op.waveActiveAllEqual.f16(i32 119, half %expr.i1, i8 1, 
i8 0)
-  %ret = call <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr)
-  ret <2 x half> %ret
+; CHECK: %[[EXPR0:.*]] = extractelement <2 x half> %expr, i64 0
+; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half 
%[[EXPR0]])
+; CHECK: %[[EXPR1:.*]] = extractelement <2 x half> %expr, i64 1
+; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.f16(i32 115, half 
%[[EXPR1]])
+; CHECK: %[[RETUPTO0:.*]] = insertelement <2 x i1> poison, i1 %[[RET0]], i64 0
+; CHECK: %ret = insertelement <2 x i1> %[[RETUPTO0]], i1 %[[RET1]], i64 1
+; CHECK: ret <2 x i1> %ret
+
+  %ret = call <2 x i1> @llvm.dx.wave.all.equal.v2f16(<2 x half> %expr)
+  ret <2 x i1> %ret
 }
 
-define noundef <3 x i32> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) 
{
+define noundef <3 x i1> @wave_active_all_equal_v3i32(<3 x i32> noundef %expr) {
 entry:
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i0, i8 1, 
i8 0)
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i1, i8 1, 
i8 0)
-; CHECK: call i32 @dx.op.waveActiveAllEqual.i32(i32 119, i32 %expr.i2, i8 1, 
i8 0)
-  %ret = call <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr)
-  ret <3 x i32> %ret
+; CHECK: %[[EXPR0:.*]] = extractelement <3 x i32> %expr, i64 0
+; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 
%[[EXPR0]])
+; CHECK: %[[EXPR1:.*]] = extractelement <3 x i32> %expr, i64 1
+; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 
%[[EXPR1]])
+; CHECK: %[[EXPR2:.*]] = extractelement <3 x i32> %expr, i64 2
+; CHECK: %[[RET2:.*]] = call i1 @dx.op.waveActiveAllEqual.i32(i32 115, i32 
%[[EXPR2]])
+; CHECK: %[[RETUPTO0:.*]] = insertelement <3 x i1> poison, i1 %[[RET0]], i64 0
+; CHECK: %[[RETUPTO1:.*]] = insertelement <3 x i1> %[[RETUPTO0]], i1 
%[[RET1]], i64 1
+; CHECK: %ret = insertelement <3 x i1> %[[RETUPTO1]], i1 %[[RET2]], i64 2
+
+  %ret = call <3 x i1> @llvm.dx.wave.all.equal.v3i32(<3 x i32> %expr)
+  ret <3 x i1> %ret
 }
 
-define noundef <4 x double> @wave_active_all_equal_v4f64(<4 x double> noundef 
%expr) {
+define noundef <4 x i1> @wave_active_all_equal_v4f64(<4 x double> noundef 
%expr) {
 entry:
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i0, 
i8 1, i8 0)
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i1, 
i8 1, i8 0)
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i2, 
i8 1, i8 0)
-; CHECK: call double @dx.op.waveActiveAllEqual.f64(i32 119, double %expr.i3, 
i8 1, i8 0)
-  %ret = call <4 x double> @llvm.dx.wave.all.equal.v464(<4 x double> %expr)
-  ret <4 x double> %ret
+; CHECK: %[[EXPR0:.*]] = extractelement <4 x double> %expr, i64 0
+; CHECK: %[[RET0:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double 
%[[EXPR0]])
+; CHECK: %[[EXPR1:.*]] = extractelement <4 x double> %expr, i64 1
+; CHECK: %[[RET1:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double 
%[[EXPR1]])
+; CHECK: %[[EXPR2:.*]] = extractelement <4 x double> %expr, i64 2
+; CHECK: %[[RET2:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double 
%[[EXPR2]])
+; CHECK: %[[EXPR3:.*]] = extractelement <4 x double> %expr, i64 3
+; CHECK: %[[RET3:.*]] = call i1 @dx.op.waveActiveAllEqual.f64(i32 115, double 
%[[EXPR3]])
+; CHECK: %[[RETUPTO0:.*]] = insertelement <4 x i1> poison, i1 %[[RET0]], i64 0
+; CHECK: %[[RETUPTO1:.*]] = insertelement <4 x i1> %[[RETUPTO0]], i1 
%[[RET1]], i64 1
+; CHECK: %[[RETUPTO2:.*]] = insertelement <4 x i1> %[[RETUPTO1]], i1 
%[[RET2]], i64 2
+; CHECK: %ret = insertelement <4 x i1> %[[RETUPTO2]], i1 %[[RET3]], i64 3
+
+  %ret = call <4 x i1> @llvm.dx.wave.all.equal.v464(<4 x double> %expr)
+  ret <4 x i1> %ret
 }
 
-declare <2 x half> @llvm.dx.wave.all.equal.v2f16(<2 x half>)
-declare <3 x i32> @llvm.dx.wave.all.equal.v3i32(<3 x i32>)
-declare <4 x double> @llvm.dx.wave.all.equal.v4f64(<4 x double>)
+declare <2 x i1> @llvm.dx.wave.all.equal.v2f16(<2 x half>)
+declare <3 x i1> @llvm.dx.wave.all.equal.v3i32(<3 x i32>)
+declare <4 x i1> @llvm.dx.wave.all.equal.v4f64(<4 x double>)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll 
b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
index e871dc9a7aa28..c64e5770b2d6d 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -5,6 +5,8 @@
 
 ; CHECK-DAG:   %[[#f16:]] = OpTypeFloat 16
 ; CHECK-DAG:   %[[#f32:]] = OpTypeFloat 32
+; CHECK-DAG:   %[[#bool:]] = OpTypeBool
+; CHECK-DAG:   %[[#bool4:]] = OpTypeVector %[[#bool]] 4
 ; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
 ; CHECK-DAG:   %[[#v4_half:]] = OpTypeVector %[[#f16]] 4
 ; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 3
@@ -13,8 +15,8 @@
 ; CHECK:   %[[#fexpr:]] = OpFunctionParameter %[[#f32]]
 define i1 @test_float(float %fexpr) {
 entry:
-; CHECK:   %[[#fret:]] = OpGroupNonUniformAllEqual %[[#f32]] %[[#scope]] 
Reduce %[[#fexpr]]
-  %0 = call i1 @llvm.spv.wave.all.equal.f32(float %fexpr)
+; CHECK:   %[[#fret:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#fexpr]]
+  %0 = call i1 @llvm.spv.subgroup.all.equal.f32(float %fexpr)
   ret i1 %0
 }
 
@@ -22,20 +24,20 @@ entry:
 ; CHECK:   %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
 define i1 @test_int(i32 %iexpr) {
 entry:
-; CHECK:   %[[#iret:]] = OpGroupNonUniformAllEqual %[[#uint]] %[[#scope]] 
Reduce %[[#iexpr]]
-  %0 = call i1 @llvm.spv.wave.all.equal.i32(i32 %iexpr)
+; CHECK:   %[[#iret:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#iexpr]]
+  %0 = call i1 @llvm.spv.subgroup.all.equal.i32(i32 %iexpr)
   ret i1 %0
 }
 
 ; CHECK-LABEL: Begin function test_vhalf
 ; CHECK:   %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]]
-define i1 @test_vhalf(<4 x half> %vbexpr) {
+define <4 x i1> @test_vhalf(<4 x half> %vbexpr) {
 entry:
-; CHECK:   %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#v4_half]] 
%[[#scope]] Reduce %[[#vbexpr]]
-  %0 = call i1 @llvm.spv.wave.all.equal.v4half(<4 x half> %vbexpr)
-  ret i1 %0
+; CHECK:   %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#bool4]] %[[#scope]] 
%[[#vbexpr]]
+  %0 = call <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half> %vbexpr)
+  ret <4 x i1> %0
 }
 
-declare i1 @llvm.spv.wave.all.equal.f32(float)
-declare i1 @llvm.spv.wave.all.equal.i32(i32)
-declare i1 @llvm.spv.wave.all.equal.v4half(<4 x half>)
+declare i1 @llvm.spv.subgroup.all.equal.f32(float)
+declare i1 @llvm.spv.subgroup.all.equal.i32(i32)
+declare <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half>)

>From c78d96eecf4435cae10014da20025c8ceef4d384 Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Fri, 27 Feb 2026 13:41:26 -0800
Subject: [PATCH 3/6] revert file changes

---
 .../BuiltIns/WaveActiveAllTrue-errors.hlsl    | 42 +++++++++----------
 1 file changed, 21 insertions(+), 21 deletions(-)

diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
index 0975ad649e714..b0d0fdfca5e18 100644
--- a/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAllTrue-errors.hlsl
@@ -1,21 +1,21 @@
-// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
-
-bool test_too_few_arg() {
-  return __builtin_hlsl_wave_active_all_true();
-  // expected-error@-1 {{too few arguments to function call, expected 1, have 
0}}
-}
-
-bool test_too_many_arg(bool p0) {
-  return __builtin_hlsl_wave_active_all_true(p0, p0);
-  // expected-error@-1 {{too many arguments to function call, expected 1, have 
2}}
-}
-
-struct Foo
-{
-  int a;
-};
-
-bool test_type_check(Foo p0) {
-  return __builtin_hlsl_wave_active_all_true(p0);
-  // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
-}
\ No newline at end of file
+// RUN: %clang_cc1 -finclude-default-header -triple 
dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+  return __builtin_hlsl_wave_active_all_true();
+  // expected-error@-1 {{too few arguments to function call, expected 1, have 
0}}
+}
+
+bool test_too_many_arg(bool p0) {
+  return __builtin_hlsl_wave_active_all_true(p0, p0);
+  // expected-error@-1 {{too many arguments to function call, expected 1, have 
2}}
+}
+
+struct Foo
+{
+  int a;
+};
+
+bool test_type_check(Foo p0) {
+  return __builtin_hlsl_wave_active_all_true(p0);
+  // expected-error@-1 {{no viable conversion from 'Foo' to 'bool'}}
+}

>From df45116b4e350d1bfa36f7f3d737cca9fb0b8879 Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Mon, 2 Mar 2026 12:51:59 -0800
Subject: [PATCH 4/6] perform manual scalarization

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 96 ++++++++++++++++++-
 .../hlsl-intrinsics/WaveActiveAllEqual.ll     | 15 ++-
 2 files changed, 107 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp 
b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b9c6cb1e67595..3794092703470 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -337,6 +337,9 @@ class SPIRVInstructionSelector : public InstructionSelector 
{
   bool selectWaveActiveCountBits(Register ResVReg, SPIRVTypeInst ResType,
                                  MachineInstr &I) const;
 
+  bool selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType,
+                                 MachineInstr &I) const;
+
   bool selectUnmergeValues(MachineInstr &I) const;
 
   bool selectHandleFromBinding(Register &ResVReg, SPIRVTypeInst ResType,
@@ -2830,6 +2833,96 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
   return true;
 }
 
+unsigned getVectorSizeOrOne(SPIRVTypeInst Type) {
+
+  if (Type->getOpcode() != SPIRV::OpTypeVector)
+    return 1;
+
+  // Operand(2) is the vector size
+  return Type->getOperand(2).getImm();
+}
+
+bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
+                                                        SPIRVTypeInst ResType,
+                                                        MachineInstr &I) const 
{
+
+  MachineBasicBlock &BB = *I.getParent();
+  const DebugLoc &DL = I.getDebugLoc();
+
+  SPIRVTypeInst SpvTy = GR.getSPIRVTypeForVReg(ResVReg);
+  unsigned NumElems = getVectorSizeOrOne(SpvTy);
+  bool IsVector = NumElems > 1;
+
+  // Subgroup scope constant
+  SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+
+  Register ScopeConst = GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, 
IntTy,
+                                               TII, !STI.isShader());
+
+  Register InputReg = I.getOperand(2).getReg();
+
+  SmallVector<Register, 4> ElementResults;
+
+  // If vector, determine element type once
+  SPIRVTypeInst ElemInputType = SpvTy;
+  SPIRVTypeInst ElemBoolType = ResType;
+
+  if (IsVector) {
+    Register ElemTypeReg = SpvTy->getOperand(1).getReg();
+    ElemInputType = GR.getSPIRVTypeForVReg(ElemTypeReg);
+
+    Register BoolElemReg = ResType->getOperand(1).getReg();
+    ElemBoolType = GR.getSPIRVTypeForVReg(BoolElemReg);
+  }
+
+  for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
+
+    Register ElemInput = InputReg;
+
+    if (IsVector) {
+      Register Extracted =
+          MRI->createVirtualRegister(GR.getRegClass(ElemInputType));
+
+      BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract))
+          .addDef(Extracted)
+          .addUse(GR.getSPIRVTypeID(ElemInputType))
+          .addUse(InputReg)
+          .addImm(Idx)
+          .constrainAllUses(TII, TRI, RBI);
+
+      ElemInput = Extracted;
+    }
+
+    Register ElemResult =
+        IsVector ? MRI->createVirtualRegister(GR.getRegClass(ElemBoolType))
+                 : ResVReg;
+
+    BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
+        .addDef(ElemResult)
+        .addUse(GR.getSPIRVTypeID(ElemBoolType))
+        .addUse(ScopeConst)
+        .addUse(ElemInput)
+        .constrainAllUses(TII, TRI, RBI);
+
+    ElementResults.push_back(ElemResult);
+  }
+
+  if (!IsVector)
+    return true;
+
+  auto MIB = BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeConstruct))
+                 .addDef(ResVReg)
+                 .addUse(GR.getSPIRVTypeID(ResType));
+
+  for (Register R : ElementResults)
+    MIB.addUse(R);
+
+  MIB.constrainAllUses(TII, TRI, RBI);
+
+  return true;
+}
+
+
 bool SPIRVInstructionSelector::selectWavePrefixBitCount(Register ResVReg,
                                                         SPIRVTypeInst ResType,
                                                         MachineInstr &I) const 
{
@@ -4085,8 +4178,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register 
ResVReg,
   case Intrinsic::spv_wave_active_countbits:
     return selectWaveActiveCountBits(ResVReg, ResType, I);
   case Intrinsic::spv_subgroup_all_equal:
-    return selectWaveOpInst(ResVReg, ResType, I,
-                            SPIRV::OpGroupNonUniformAllEqual);
+    return selectWaveActiveAllEqual(ResVReg, ResType, I);
   case Intrinsic::spv_wave_all:
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
   case Intrinsic::spv_wave_any:
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll 
b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
index c64e5770b2d6d..9c63539e4b9e4 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -30,10 +30,21 @@ entry:
 }
 
 ; CHECK-LABEL: Begin function test_vhalf
-; CHECK:   %[[#vbexpr:]] = OpFunctionParameter %[[#v4_half]]
+; Here there's a vector, so we scalarize and then recombine the
+; result back into one vector
 define <4 x i1> @test_vhalf(<4 x half> %vbexpr) {
 entry:
-; CHECK:   %[[#vhalfret:]] = OpGroupNonUniformAllEqual %[[#bool4]] %[[#scope]] 
%[[#vbexpr]]
+; CHECK: %[[#param:]] = OpFunctionParameter %[[#v4float:]]
+; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#bool]] %[[#param]] 0
+; CHECK-NEXT: %[[#res1:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext1]]
+; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#bool]] %[[#param]] 1
+; CHECK-NEXT: %[[#res2:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext2]]
+; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#bool]] %[[#param]] 2
+; CHECK-NEXT: %[[#res3:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext3]]
+; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#bool]] %[[#param]] 3
+; CHECK-NEXT: %[[#res4:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext4]]
+; CHECK-NEXT: %[[#ret:]] = OpCompositeConstruct %[[#bool4]] %[[#res1:]] 
%[[#res2:]] %[[#res3:]] %[[#res4:]]
+; CHECK-NEXT: OpReturnValue %[[#ret]]
   %0 = call <4 x i1> @llvm.spv.subgroup.all.equal.v4half(<4 x half> %vbexpr)
   ret <4 x i1> %0
 }

>From 8ce17b4829f1d8c32c4e027c607ef5511460d695 Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Mon, 2 Mar 2026 13:27:59 -0800
Subject: [PATCH 5/6] clang format

---
 llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp 
b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 3794092703470..fe16ceaf28c14 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -338,7 +338,7 @@ class SPIRVInstructionSelector : public InstructionSelector 
{
                                  MachineInstr &I) const;
 
   bool selectWaveActiveAllEqual(Register ResVReg, SPIRVTypeInst ResType,
-                                 MachineInstr &I) const;
+                                MachineInstr &I) const;
 
   bool selectUnmergeValues(MachineInstr &I) const;
 
@@ -2922,7 +2922,6 @@ bool 
SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
   return true;
 }
 
-
 bool SPIRVInstructionSelector::selectWavePrefixBitCount(Register ResVReg,
                                                         SPIRVTypeInst ResType,
                                                         MachineInstr &I) const 
{

>From 9a58dc845943b20c414c238dee6440a2f9e96d1d Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Mon, 2 Mar 2026 14:49:19 -0800
Subject: [PATCH 6/6] more repairs to pass spirv-val

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 77 ++++++++++---------
 .../hlsl-intrinsics/WaveActiveAllEqual.ll     |  8 +-
 2 files changed, 44 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp 
b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index fe16ceaf28c14..1a9e07eb54e8c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2845,57 +2845,63 @@ unsigned getVectorSizeOrOne(SPIRVTypeInst Type) {
 bool SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
                                                         SPIRVTypeInst ResType,
                                                         MachineInstr &I) const 
{
-
   MachineBasicBlock &BB = *I.getParent();
   const DebugLoc &DL = I.getDebugLoc();
 
-  SPIRVTypeInst SpvTy = GR.getSPIRVTypeForVReg(ResVReg);
-  unsigned NumElems = getVectorSizeOrOne(SpvTy);
+  // Input to the intrinsic
+  Register InputReg = I.getOperand(2).getReg();
+  SPIRVTypeInst InputType = GR.getSPIRVTypeForVReg(InputReg);
+
+  // Determine if input is vector
+  unsigned NumElems = getVectorSizeOrOne(InputType);
   bool IsVector = NumElems > 1;
 
+  // Determine element types
+  SPIRVTypeInst ElemInputType = InputType;
+  SPIRVTypeInst ElemBoolType = ResType;
+  if (IsVector) {
+    ElemInputType = GR.getSPIRVTypeForVReg(InputType->getOperand(1).getReg());
+    ElemBoolType = GR.getSPIRVTypeForVReg(ResType->getOperand(1).getReg());
+  }
+
   // Subgroup scope constant
   SPIRVTypeInst IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
-
   Register ScopeConst = GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, 
IntTy,
                                                TII, !STI.isShader());
 
-  Register InputReg = I.getOperand(2).getReg();
+  // === Scalar case ===
+  if (!IsVector) {
+    BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ElemBoolType))
+        .addUse(ScopeConst)
+        .addUse(InputReg)
+        .constrainAllUses(TII, TRI, RBI);
+    return true;
+  }
 
+  // === Vector case ===
   SmallVector<Register, 4> ElementResults;
-
-  // If vector, determine element type once
-  SPIRVTypeInst ElemInputType = SpvTy;
-  SPIRVTypeInst ElemBoolType = ResType;
-
-  if (IsVector) {
-    Register ElemTypeReg = SpvTy->getOperand(1).getReg();
-    ElemInputType = GR.getSPIRVTypeForVReg(ElemTypeReg);
-
-    Register BoolElemReg = ResType->getOperand(1).getReg();
-    ElemBoolType = GR.getSPIRVTypeForVReg(BoolElemReg);
-  }
+  ElementResults.reserve(NumElems);
 
   for (unsigned Idx = 0; Idx < NumElems; ++Idx) {
-
+    // Extract element
     Register ElemInput = InputReg;
+    Register Extracted =
+        MRI->createVirtualRegister(GR.getRegClass(ElemInputType));
+
+    BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract))
+        .addDef(Extracted)
+        .addUse(GR.getSPIRVTypeID(ElemInputType))
+        .addUse(InputReg)
+        .addImm(Idx)
+        .constrainAllUses(TII, TRI, RBI);
 
-    if (IsVector) {
-      Register Extracted =
-          MRI->createVirtualRegister(GR.getRegClass(ElemInputType));
-
-      BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeExtract))
-          .addDef(Extracted)
-          .addUse(GR.getSPIRVTypeID(ElemInputType))
-          .addUse(InputReg)
-          .addImm(Idx)
-          .constrainAllUses(TII, TRI, RBI);
-
-      ElemInput = Extracted;
-    }
+    ElemInput = Extracted;
 
+    // Emit per-element AllEqual
     Register ElemResult =
-        IsVector ? MRI->createVirtualRegister(GR.getRegClass(ElemBoolType))
-                 : ResVReg;
+        MRI->createVirtualRegister(GR.getRegClass(ElemBoolType));
 
     BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformAllEqual))
         .addDef(ElemResult)
@@ -2907,13 +2913,10 @@ bool 
SPIRVInstructionSelector::selectWaveActiveAllEqual(Register ResVReg,
     ElementResults.push_back(ElemResult);
   }
 
-  if (!IsVector)
-    return true;
-
+  // Reconstruct vector<bool>
   auto MIB = BuildMI(BB, I, DL, TII.get(SPIRV::OpCompositeConstruct))
                  .addDef(ResVReg)
                  .addUse(GR.getSPIRVTypeID(ResType));
-
   for (Register R : ElementResults)
     MIB.addUse(R);
 
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll 
b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
index 9c63539e4b9e4..8733505942c4c 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAllEqual.ll
@@ -35,13 +35,13 @@ entry:
 define <4 x i1> @test_vhalf(<4 x half> %vbexpr) {
 entry:
 ; CHECK: %[[#param:]] = OpFunctionParameter %[[#v4float:]]
-; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#bool]] %[[#param]] 0
+; CHECK: %[[#ext1:]] = OpCompositeExtract %[[#f16]] %[[#param]] 0
 ; CHECK-NEXT: %[[#res1:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext1]]
-; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#bool]] %[[#param]] 1
+; CHECK-NEXT: %[[#ext2:]] = OpCompositeExtract %[[#f16]] %[[#param]] 1
 ; CHECK-NEXT: %[[#res2:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext2]]
-; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#bool]] %[[#param]] 2
+; CHECK-NEXT: %[[#ext3:]] = OpCompositeExtract %[[#f16]] %[[#param]] 2
 ; CHECK-NEXT: %[[#res3:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext3]]
-; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#bool]] %[[#param]] 3
+; CHECK-NEXT: %[[#ext4:]] = OpCompositeExtract %[[#f16]] %[[#param]] 3
 ; CHECK-NEXT: %[[#res4:]] = OpGroupNonUniformAllEqual %[[#bool]] %[[#scope]] 
%[[#ext4]]
 ; CHECK-NEXT: %[[#ret:]] = OpCompositeConstruct %[[#bool4]] %[[#res1:]] 
%[[#res2:]] %[[#res3:]] %[[#res4:]]
 ; CHECK-NEXT: OpReturnValue %[[#ret]]

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to