https://github.com/bob80905 updated 
https://github.com/llvm/llvm-project/pull/175105

>From c74bf28df75691dfdc3537462a8f1d735b51f865 Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Thu, 8 Jan 2026 17:06:34 -0800
Subject: [PATCH 1/6] handle waveballot struct return type

---
 clang/include/clang/Basic/Builtins.td         |  2 +-
 clang/lib/CodeGen/CGHLSLBuiltins.cpp          | 29 +++++++++++++++++--
 clang/lib/Sema/SemaHLSL.cpp                   |  5 ++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  2 +-
 llvm/lib/Target/DirectX/DXIL.td               |  8 ++---
 llvm/lib/Target/DirectX/DXILOpBuilder.cpp     | 17 +++++++++--
 llvm/test/CodeGen/DirectX/WaveActiveBallot.ll | 12 ++++----
 7 files changed, 58 insertions(+), 17 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index 0ab50b06e11cf..ccbc0abe3f0b4 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5058,7 +5058,7 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
 def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_ballot"];
   let Attributes = [NoThrow, Const];
-  let Prototype = "_ExtVector<4, unsigned int>(bool)";
+  let Prototype = "void(bool)";
 }
 
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 1b6c3714f7821..c5a072bfa3974 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -160,6 +160,31 @@ static Value *handleHlslSplitdouble(const CallExpr *E, 
CodeGenFunction *CGF) {
   return LastInst;
 }
 
+static Value *handleHlslWaveActiveBallot(const CallExpr *E,
+                                         CodeGenFunction *CGF) {
+  Value *Cond = CGF->EmitScalarExpr(E->getArg(0));
+  llvm::Type *I32 = CGF->Int32Ty;
+  llvm::StructType *RetTy = llvm::StructType::get(I32, I32, I32, I32);
+
+  if (CGF->CGM.getTarget().getTriple().isDXIL()) {
+    // dx.op.waveActiveBallot(opcode, i1)
+    return CGF->Builder.CreateIntrinsic(RetTy, Intrinsic::dx_wave_ballot,
+                                        {Cond}, nullptr, "wave.active.ballot");
+  }
+
+  if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
+    // spv.wave.ballot(i1) -> <4 x i32>, then bitcast to struct
+    llvm::Type *VecTy = llvm::FixedVectorType::get(I32, 4);
+    return CGF->Builder.CreateIntrinsic(VecTy, Intrinsic::spv_wave_ballot,
+                                        {Cond}, nullptr, "spv.wave.ballot");
+  }
+
+  CGF->CGM.Error(E->getExprLoc(),
+                 "waveActiveBallot is not supported for this target");
+
+  return llvm::UndefValue::get(RetTy);
+}
+
 static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF,
                                         const CallExpr *E) {
   Value *Op0 = CGF.EmitScalarExpr(E->getArg(0));
@@ -834,9 +859,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     assert(Op->getType()->isIntegerTy(1) &&
            "Intrinsic WaveActiveBallot operand must be a bool");
 
-    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic();
-    return EmitRuntimeCall(
-        Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+    return handleHlslWaveActiveBallot(E, this);
   }
   case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a6de1cd550212..51f74c10677a9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3507,6 +3507,11 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_active_ballot: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+    break;
+  }
   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td 
b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 6e6eb2d0ece9d..f79945785566c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -153,7 +153,7 @@ def int_dx_rsqrt  : 
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
 def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
-def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
+def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, 
LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 6d04732d92ecf..23701e2218e57 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -58,6 +58,7 @@ def ResPropsTy : DXILOpParamType;
 def SplitDoubleTy : DXILOpParamType;
 def BinaryWithCarryTy : DXILOpParamType;
 def DimensionsTy : DXILOpParamType;
+def Fouri32s : DXILOpParamType;
 
 class DXILOpClass;
 
@@ -212,13 +213,12 @@ defset list<DXILOpClass> OpClasses = {
   def unpack4x8 : DXILOpClass;
   def viewID : DXILOpClass;
   def waveActiveAllEqual : DXILOpClass;
-  def waveActiveBallot : DXILOpClass;
   def waveActiveBit : DXILOpClass;
   def waveActiveOp : DXILOpClass;
   def waveAllOp : DXILOpClass;
   def waveAllTrue : DXILOpClass;
   def waveAnyTrue : DXILOpClass;
-  def waveBallot : DXILOpClass;
+  def waveActiveBallot : DXILOpClass;
   def waveGetLaneCount : DXILOpClass;
   def waveGetLaneIndex : DXILOpClass;
   def waveIsFirstLane : DXILOpClass;
@@ -1072,11 +1072,11 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
-def WaveActiveBallot : DXILOp<118, waveBallot> {
+def WaveActiveBallot : DXILOp<116, waveActiveBallot> {
   let Doc = "returns uint4 containing a bitmask of the evaluation of the 
boolean expression for all active lanes in the current wave.";
   let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
   let arguments = [Int1Ty];
-  let result = OverloadTy;
+  let result = Fouri32s;
   let stages = [Stages<DXIL1_0, [all_stages]>];
 }
 
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp 
b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 944b2e6433988..1f41d2457e5bc 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -261,10 +261,18 @@ static StructType *getBinaryWithCarryType(LLVMContext 
&Context) {
   return StructType::create({Int32Ty, Int1Ty}, "dx.types.i32c");
 }
 
-static StructType *getDimensionsType(LLVMContext &Ctx) {
-  Type *Int32Ty = Type::getInt32Ty(Ctx);
+static StructType *getDimensionsType(LLVMContext &Context) {
+  Type *Int32Ty = Type::getInt32Ty(Context);
   return getOrCreateStructType("dx.types.Dimensions",
-                               {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Ctx);
+                               {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
+}
+
+static StructType *getFouri32sType(LLVMContext &Context) {
+  if (auto *ST = StructType::getTypeByName(Context, "dx.types.fouri32"))
+    return ST;
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  return getOrCreateStructType("dx.types.fouri32",
+                               {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
 }
 
 static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
@@ -326,7 +334,10 @@ static Type *getTypeFromOpParamType(OpParamType Kind, 
LLVMContext &Ctx,
     return getBinaryWithCarryType(Ctx);
   case OpParamType::DimensionsTy:
     return getDimensionsType(Ctx);
+  case OpParamType::Fouri32s:
+    return getFouri32sType(Ctx);
   }
+
   llvm_unreachable("Invalid parameter kind");
   return nullptr;
 }
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll 
b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
index cf6255de3a734..31a64cbcf061e 100644
--- a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
@@ -1,10 +1,12 @@
 ; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | 
FileCheck %s
 
-define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) {
+%dx.types.fouri32 = type { i32, i32, i32, i32 }
+
+define noundef %dx.types.fouri32 @wave_ballot_simple(i1 noundef %p1) {
 entry:
-; CHECK: call <4 x i32> @dx.op.waveBallot.void(i32 118, i1 %p1)
-  %ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1)
-  ret <4 x i32> %ret
+; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %p1)
+  %ret = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1)
+  ret %dx.types.fouri32 %ret
 }
 
-declare <4 x i32> @llvm.dx.wave.ballot(i1)
+declare %dx.types.fouri32 @llvm.dx.wave.ballot(i1)

>From b54b8d5a525cb9817603488ba7c8a2220bc6baab Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Thu, 8 Jan 2026 19:53:00 -0800
Subject: [PATCH 2/6] update codegen to use emitruntimecall to force use of
 convergence token

---
 clang/lib/CodeGen/CGHLSLBuiltins.cpp             | 16 +++++++---------
 .../CodeGenHLSL/builtins/WaveActiveBallot.hlsl   | 11 ++++++++---
 2 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index c5a072bfa3974..1e3f5611e69d1 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -164,25 +164,23 @@ static Value *handleHlslWaveActiveBallot(const CallExpr 
*E,
                                          CodeGenFunction *CGF) {
   Value *Cond = CGF->EmitScalarExpr(E->getArg(0));
   llvm::Type *I32 = CGF->Int32Ty;
-  llvm::StructType *RetTy = llvm::StructType::get(I32, I32, I32, I32);
 
   if (CGF->CGM.getTarget().getTriple().isDXIL()) {
-    // dx.op.waveActiveBallot(opcode, i1)
-    return CGF->Builder.CreateIntrinsic(RetTy, Intrinsic::dx_wave_ballot,
-                                        {Cond}, nullptr, "wave.active.ballot");
+    return CGF->EmitRuntimeCall(
+        CGF->CGM.getIntrinsic(Intrinsic::dx_wave_ballot, {I32}), Cond);
   }
 
   if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
-    // spv.wave.ballot(i1) -> <4 x i32>, then bitcast to struct
     llvm::Type *VecTy = llvm::FixedVectorType::get(I32, 4);
-    return CGF->Builder.CreateIntrinsic(VecTy, Intrinsic::spv_wave_ballot,
-                                        {Cond}, nullptr, "spv.wave.ballot");
+
+    return CGF->EmitRuntimeCall(
+        CGF->CGM.getIntrinsic(Intrinsic::spv_wave_ballot), Cond);
   }
 
   CGF->CGM.Error(E->getExprLoc(),
-                 "waveActiveBallot is not supported for this target");
+                 "WaveActiveBallot is not supported for this target");
 
-  return llvm::UndefValue::get(RetTy);
+  return llvm::PoisonValue::get(llvm::FixedVectorType::get(I32, 4));
 }
 
 static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF,
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl 
b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
index 61b077eb1fead..ceee9eb015512 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
@@ -10,8 +10,13 @@
 // CHECK-LABEL: define {{.*}}test
 uint4 test(bool p1) {
   // CHECK-SPIRV: %[[#entry_tok0:]] = call token 
@llvm.experimental.convergence.entry()
-  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func <4 x i32> 
@llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token 
%[[#entry_tok0]]) ]
-  // CHECK-DXIL:  %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 
%{{[a-zA-Z0-9]+}})
-  // CHECK:  ret <4 x i32> %[[RET]]
+  // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> 
@llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token 
%[[#entry_tok0]]) ]
+  // CHECK-DXIL: %[[RETVAL:.*]] = alloca <4 x i32>, align 16
+  // CHECK-DXIL: %[[WAB:.*]] = call { i32, i32, i32, i32 } 
@llvm.dx.wave.ballot.i32(i1 %{{[a-zA-Z0-9]+}})
+  // CHECK-DXIL: store { i32, i32, i32, i32 } %[[WAB]], ptr %[[RETVAL]], align 
16
+  // CHECK-DXIL: %[[LOAD:.*]] = load <4 x i32>, ptr %[[RETVAL]], align 16
+  // CHECK-DXIL: ret <4 x i32> %[[LOAD]]
+  // CHECK-SPIRV: ret <4 x i32> %[[RET]]
+
   return WaveActiveBallot(p1);
 }

>From 4b7b4a552c6d83e385c16eee0275878e73dc36c1 Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Fri, 9 Jan 2026 00:25:27 -0800
Subject: [PATCH 3/6] remove unused var

---
 clang/lib/CodeGen/CGHLSLBuiltins.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 1e3f5611e69d1..13c6dadf86ca7 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -171,8 +171,6 @@ static Value *handleHlslWaveActiveBallot(const CallExpr *E,
   }
 
   if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
-    llvm::Type *VecTy = llvm::FixedVectorType::get(I32, 4);
-
     return CGF->EmitRuntimeCall(
         CGF->CGM.getIntrinsic(Intrinsic::spv_wave_ballot), Cond);
   }

>From 1cc1eaced861ac697b6836b2b98a42c97fbc3147 Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Fri, 9 Jan 2026 00:27:10 -0800
Subject: [PATCH 4/6] clangformat

---
 clang/lib/CodeGen/CGHLSLBuiltins.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 13c6dadf86ca7..f5ca49749dc6d 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -170,10 +170,9 @@ static Value *handleHlslWaveActiveBallot(const CallExpr *E,
         CGF->CGM.getIntrinsic(Intrinsic::dx_wave_ballot, {I32}), Cond);
   }
 
-  if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
+  if (CGF->CGM.getTarget().getTriple().isSPIRV())
     return CGF->EmitRuntimeCall(
         CGF->CGM.getIntrinsic(Intrinsic::spv_wave_ballot), Cond);
-  }
 
   CGF->CGM.Error(E->getExprLoc(),
                  "WaveActiveBallot is not supported for this target");

>From 7b72cf11dd6321f3186c312acd97886ebdbc01ad Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Fri, 9 Jan 2026 14:00:08 -0800
Subject: [PATCH 5/6] address Finn

---
 clang/include/clang/Basic/Builtins.td | 2 +-
 clang/lib/CodeGen/CGHLSLBuiltins.cpp  | 6 ++----
 2 files changed, 3 insertions(+), 5 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td 
b/clang/include/clang/Basic/Builtins.td
index ccbc0abe3f0b4..0ab50b06e11cf 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5058,7 +5058,7 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
 def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_ballot"];
   let Attributes = [NoThrow, Const];
-  let Prototype = "void(bool)";
+  let Prototype = "_ExtVector<4, unsigned int>(bool)";
 }
 
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index f5ca49749dc6d..d75b5c0c6b17f 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -174,10 +174,8 @@ static Value *handleHlslWaveActiveBallot(const CallExpr *E,
     return CGF->EmitRuntimeCall(
         CGF->CGM.getIntrinsic(Intrinsic::spv_wave_ballot), Cond);
 
-  CGF->CGM.Error(E->getExprLoc(),
-                 "WaveActiveBallot is not supported for this target");
-
-  return llvm::PoisonValue::get(llvm::FixedVectorType::get(I32, 4));
+  llvm_unreachable(
+      "WaveActiveBallot is only supported for DXIL and SPIRV targets");
 }
 
 static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF,

>From b64387b47995d072dc0574f8736a45f15a55d23d Mon Sep 17 00:00:00 2001
From: Joshua Batista <[email protected]>
Date: Mon, 12 Jan 2026 15:47:54 -0800
Subject: [PATCH 6/6] address Farzon

---
 clang/lib/CodeGen/CGHLSLBuiltins.cpp          | 40 ++++++++++++++-----
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 -
 clang/lib/Sema/SemaHLSL.cpp                   |  5 ---
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  2 +-
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  2 +-
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  2 +-
 llvm/test/CodeGen/DirectX/WaveActiveBallot.ll | 16 ++++++--
 llvm/test/tools/dxil-dis/waveactiveballot.ll  | 31 ++++++++++++++
 8 files changed, 76 insertions(+), 23 deletions(-)
 create mode 100644 llvm/test/tools/dxil-dis/waveactiveballot.ll

diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index d75b5c0c6b17f..8e491ee318bb9 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -160,19 +160,37 @@ static Value *handleHlslSplitdouble(const CallExpr *E, 
CodeGenFunction *CGF) {
   return LastInst;
 }
 
-static Value *handleHlslWaveActiveBallot(const CallExpr *E,
-                                         CodeGenFunction *CGF) {
-  Value *Cond = CGF->EmitScalarExpr(E->getArg(0));
-  llvm::Type *I32 = CGF->Int32Ty;
+static Value *handleHlslWaveActiveBallot(CodeGenFunction &CGF,
+                                         const CallExpr *E) {
+  Value *Cond = CGF.EmitScalarExpr(E->getArg(0));
+  llvm::Type *I32 = CGF.Int32Ty;
+
+  llvm::Type *Vec4I32 = llvm::FixedVectorType::get(I32, 4);
+  llvm::StructType *Struct4I32 =
+      llvm::StructType::get(CGF.getLLVMContext(), {I32, I32, I32, I32});
+
+  if (CGF.CGM.getTarget().getTriple().isDXIL()) {
+    // Call DXIL intrinsic: returns { i32, i32, i32, i32 }
+    llvm::Function *Fn = CGF.CGM.getIntrinsic(Intrinsic::dx_wave_ballot, 
{I32});
+
+    Value *StructVal = CGF.EmitRuntimeCall(Fn, Cond);
+    assert(StructVal->getType() == Struct4I32 &&
+           "dx.wave.ballot must return {i32,i32,i32,i32}");
+
+    // Reassemble struct to <4 x i32>
+    llvm::Value *VecVal = llvm::PoisonValue::get(Vec4I32);
+    for (unsigned i = 0; i < 4; ++i) {
+      Value *Elt = CGF.Builder.CreateExtractValue(StructVal, i);
+      VecVal =
+          CGF.Builder.CreateInsertElement(VecVal, Elt, 
CGF.Builder.getInt32(i));
+    }
 
-  if (CGF->CGM.getTarget().getTriple().isDXIL()) {
-    return CGF->EmitRuntimeCall(
-        CGF->CGM.getIntrinsic(Intrinsic::dx_wave_ballot, {I32}), Cond);
+    return VecVal;
   }
 
-  if (CGF->CGM.getTarget().getTriple().isSPIRV())
-    return CGF->EmitRuntimeCall(
-        CGF->CGM.getIntrinsic(Intrinsic::spv_wave_ballot), Cond);
+  if (CGF.CGM.getTarget().getTriple().isSPIRV())
+    return CGF.EmitRuntimeCall(
+        CGF.CGM.getIntrinsic(Intrinsic::spv_subgroup_ballot), Cond);
 
   llvm_unreachable(
       "WaveActiveBallot is only supported for DXIL and SPIRV targets");
@@ -852,7 +870,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
     assert(Op->getType()->isIntegerTy(1) &&
            "Intrinsic WaveActiveBallot operand must be a bool");
 
-    return handleHlslWaveActiveBallot(E, this);
+    return handleHlslWaveActiveBallot(*this, E);
   }
   case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
     Value *OpExpr = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h 
b/clang/lib/CodeGen/CGHLSLRuntime.h
index 7a5643052ed84..ba2ca2c358388 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -146,7 +146,6 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
-  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBallot, wave_ballot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 51f74c10677a9..a6de1cd550212 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3507,11 +3507,6 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
-  case Builtin::BI__builtin_hlsl_wave_active_ballot: {
-    if (SemaRef.checkArgCount(TheCall, 1))
-      return true;
-    break;
-  }
   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td 
b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index f79945785566c..3c2c7477d8c7b 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -153,7 +153,7 @@ def int_dx_rsqrt  : 
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
 def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
-def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, 
LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
+def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_i32_ty, llvm_i32_ty, 
llvm_i32_ty, llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], 
[IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td 
b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index bcb533780b58c..da4031fadd8e9 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,7 +120,7 @@ def int_spv_rsqrt : 
DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
   def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
   def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], 
[IntrConvergent, IntrNoMem]>;
-  def int_spv_wave_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
+  def int_spv_subgroup_ballot : 
ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
     DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, 
IntrNoMem]>;
   def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], 
[LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp 
b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 100057f6d1a39..fcce327fa8b76 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3815,7 +3815,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register 
ResVReg,
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
   case Intrinsic::spv_wave_any:
     return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
-  case Intrinsic::spv_wave_ballot:
+  case Intrinsic::spv_subgroup_ballot:
     return selectWaveOpInst(ResVReg, ResType, I,
                             SPIRV::OpGroupNonUniformBallot);
   case Intrinsic::spv_wave_is_first_lane:
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll 
b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
index 31a64cbcf061e..a7186427c6d09 100644
--- a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
@@ -2,11 +2,21 @@
 
 %dx.types.fouri32 = type { i32, i32, i32, i32 }
 
-define noundef %dx.types.fouri32 @wave_ballot_simple(i1 noundef %p1) {
+define <4 x i32> @wave_ballot_simple(i1 noundef %p1) {
 entry:
 ; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %p1)
-  %ret = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1)
-  ret %dx.types.fouri32 %ret
+; CHECK-NOT: ret %dx.types.fouri32
+; CHECK: ret <4 x i32>
+  %s = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1)
+  %v0 = extractvalue %dx.types.fouri32 %s, 0
+  %v1 = extractvalue %dx.types.fouri32 %s, 1
+  %v2 = extractvalue %dx.types.fouri32 %s, 2
+  %v3 = extractvalue %dx.types.fouri32 %s, 3
+  %vec = insertelement <4 x i32> poison, i32 %v0, i32 0
+  %vec1 = insertelement <4 x i32> %vec, i32 %v1, i32 1
+  %vec2 = insertelement <4 x i32> %vec1, i32 %v2, i32 2
+  %vec3 = insertelement <4 x i32> %vec2, i32 %v3, i32 3
+  ret <4 x i32> %vec3
 }
 
 declare %dx.types.fouri32 @llvm.dx.wave.ballot(i1)
diff --git a/llvm/test/tools/dxil-dis/waveactiveballot.ll 
b/llvm/test/tools/dxil-dis/waveactiveballot.ll
new file mode 100644
index 0000000000000..2bdb4ec98a3db
--- /dev/null
+++ b/llvm/test/tools/dxil-dis/waveactiveballot.ll
@@ -0,0 +1,31 @@
+; RUN: llc %s --filetype=obj -o - | dxil-dis -o - | FileCheck %s
+
+; CHECK-NOT: llvm.dx.wave.ballot
+
+; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %p1)
+; CHECK-NOT: ret %dx.types.fouri32
+; CHECK: ret <4 x i32>
+
+
+target triple = "dxil-unknown-shadermodel6.3-library"
+
+%dx.types.fouri32 = type { i32, i32, i32, i32 }
+
+define <4 x i32> @wave_ballot_simple(i1 %p1) {
+entry:
+  %s = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1)
+
+  %v0 = extractvalue %dx.types.fouri32 %s, 0
+  %v1 = extractvalue %dx.types.fouri32 %s, 1
+  %v2 = extractvalue %dx.types.fouri32 %s, 2
+  %v3 = extractvalue %dx.types.fouri32 %s, 3
+
+  %vec0 = insertelement <4 x i32> poison, i32 %v0, i32 0
+  %vec1 = insertelement <4 x i32> %vec0, i32 %v1, i32 1
+  %vec2 = insertelement <4 x i32> %vec1, i32 %v2, i32 2
+  %vec3 = insertelement <4 x i32> %vec2, i32 %v3, i32 3
+
+  ret <4 x i32> %vec3
+}
+
+declare %dx.types.fouri32 @llvm.dx.wave.ballot(i1)

_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to