[clang] [llvm] [HLSL] Handle WaveActiveBallot struct return type appropriately (PR #175105)
Joshua Batista via cfe-commits
cfe-commits at lists.llvm.org
Thu Jan 8 19:53:15 PST 2026
https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/175105
>From c74bf28df75691dfdc3537462a8f1d735b51f865 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Thu, 8 Jan 2026 17:06:34 -0800
Subject: [PATCH 1/2] handle waveballot struct return type
---
clang/include/clang/Basic/Builtins.td | 2 +-
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 29 +++++++++++++++++--
clang/lib/Sema/SemaHLSL.cpp | 5 ++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +-
llvm/lib/Target/DirectX/DXIL.td | 8 ++---
llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 17 +++++++++--
llvm/test/CodeGen/DirectX/WaveActiveBallot.ll | 12 ++++----
7 files changed, 58 insertions(+), 17 deletions(-)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 0ab50b06e11cf..ccbc0abe3f0b4 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5058,7 +5058,7 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
def HLSLWaveActiveBallot : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_ballot"];
let Attributes = [NoThrow, Const];
- let Prototype = "_ExtVector<4, unsigned int>(bool)";
+ let Prototype = "void(bool)";
}
def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 1b6c3714f7821..c5a072bfa3974 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -160,6 +160,31 @@ static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
return LastInst;
}
+static Value *handleHlslWaveActiveBallot(const CallExpr *E,
+ CodeGenFunction *CGF) {
+ Value *Cond = CGF->EmitScalarExpr(E->getArg(0));
+ llvm::Type *I32 = CGF->Int32Ty;
+ llvm::StructType *RetTy = llvm::StructType::get(I32, I32, I32, I32);
+
+ if (CGF->CGM.getTarget().getTriple().isDXIL()) {
+ // dx.op.waveActiveBallot(opcode, i1)
+ return CGF->Builder.CreateIntrinsic(RetTy, Intrinsic::dx_wave_ballot,
+ {Cond}, nullptr, "wave.active.ballot");
+ }
+
+ if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
+ // spv.wave.ballot(i1) -> <4 x i32>, then bitcast to struct
+ llvm::Type *VecTy = llvm::FixedVectorType::get(I32, 4);
+ return CGF->Builder.CreateIntrinsic(VecTy, Intrinsic::spv_wave_ballot,
+ {Cond}, nullptr, "spv.wave.ballot");
+ }
+
+ CGF->CGM.Error(E->getExprLoc(),
+ "waveActiveBallot is not supported for this target");
+
+ return llvm::UndefValue::get(RetTy);
+}
+
static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF,
const CallExpr *E) {
Value *Op0 = CGF.EmitScalarExpr(E->getArg(0));
@@ -834,9 +859,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(Op->getType()->isIntegerTy(1) &&
"Intrinsic WaveActiveBallot operand must be a bool");
- Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic();
- return EmitRuntimeCall(
- Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+ return handleHlslWaveActiveBallot(E, this);
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a6de1cd550212..51f74c10677a9 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3507,6 +3507,11 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_wave_active_ballot: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+ break;
+ }
case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 6e6eb2d0ece9d..f79945785566c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -153,7 +153,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
-def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 6d04732d92ecf..23701e2218e57 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -58,6 +58,7 @@ def ResPropsTy : DXILOpParamType;
def SplitDoubleTy : DXILOpParamType;
def BinaryWithCarryTy : DXILOpParamType;
def DimensionsTy : DXILOpParamType;
+def Fouri32s : DXILOpParamType;
class DXILOpClass;
@@ -212,13 +213,12 @@ defset list<DXILOpClass> OpClasses = {
def unpack4x8 : DXILOpClass;
def viewID : DXILOpClass;
def waveActiveAllEqual : DXILOpClass;
- def waveActiveBallot : DXILOpClass;
def waveActiveBit : DXILOpClass;
def waveActiveOp : DXILOpClass;
def waveAllOp : DXILOpClass;
def waveAllTrue : DXILOpClass;
def waveAnyTrue : DXILOpClass;
- def waveBallot : DXILOpClass;
+ def waveActiveBallot : DXILOpClass;
def waveGetLaneCount : DXILOpClass;
def waveGetLaneIndex : DXILOpClass;
def waveIsFirstLane : DXILOpClass;
@@ -1072,11 +1072,11 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
-def WaveActiveBallot : DXILOp<118, waveBallot> {
+def WaveActiveBallot : DXILOp<116, waveActiveBallot> {
let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave.";
let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
let arguments = [Int1Ty];
- let result = OverloadTy;
+ let result = Fouri32s;
let stages = [Stages<DXIL1_0, [all_stages]>];
}
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 944b2e6433988..1f41d2457e5bc 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -261,10 +261,18 @@ static StructType *getBinaryWithCarryType(LLVMContext &Context) {
return StructType::create({Int32Ty, Int1Ty}, "dx.types.i32c");
}
-static StructType *getDimensionsType(LLVMContext &Ctx) {
- Type *Int32Ty = Type::getInt32Ty(Ctx);
+static StructType *getDimensionsType(LLVMContext &Context) {
+ Type *Int32Ty = Type::getInt32Ty(Context);
return getOrCreateStructType("dx.types.Dimensions",
- {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Ctx);
+ {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
+}
+
+static StructType *getFouri32sType(LLVMContext &Context) {
+ if (auto *ST = StructType::getTypeByName(Context, "dx.types.fouri32"))
+ return ST;
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ return getOrCreateStructType("dx.types.fouri32",
+ {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
}
static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
@@ -326,7 +334,10 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return getBinaryWithCarryType(Ctx);
case OpParamType::DimensionsTy:
return getDimensionsType(Ctx);
+ case OpParamType::Fouri32s:
+ return getFouri32sType(Ctx);
}
+
llvm_unreachable("Invalid parameter kind");
return nullptr;
}
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
index cf6255de3a734..31a64cbcf061e 100644
--- a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
@@ -1,10 +1,12 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
-define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) {
+%dx.types.fouri32 = type { i32, i32, i32, i32 }
+
+define noundef %dx.types.fouri32 @wave_ballot_simple(i1 noundef %p1) {
entry:
-; CHECK: call <4 x i32> @dx.op.waveBallot.void(i32 118, i1 %p1)
- %ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1)
- ret <4 x i32> %ret
+; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %p1)
+ %ret = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1)
+ ret %dx.types.fouri32 %ret
}
-declare <4 x i32> @llvm.dx.wave.ballot(i1)
+declare %dx.types.fouri32 @llvm.dx.wave.ballot(i1)
>From b54b8d5a525cb9817603488ba7c8a2220bc6baab Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Thu, 8 Jan 2026 19:53:00 -0800
Subject: [PATCH 2/2] update codegen to use emitruntimecall to force use of
convergence token
---
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 16 +++++++---------
.../CodeGenHLSL/builtins/WaveActiveBallot.hlsl | 11 ++++++++---
2 files changed, 15 insertions(+), 12 deletions(-)
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index c5a072bfa3974..1e3f5611e69d1 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -164,25 +164,23 @@ static Value *handleHlslWaveActiveBallot(const CallExpr *E,
CodeGenFunction *CGF) {
Value *Cond = CGF->EmitScalarExpr(E->getArg(0));
llvm::Type *I32 = CGF->Int32Ty;
- llvm::StructType *RetTy = llvm::StructType::get(I32, I32, I32, I32);
if (CGF->CGM.getTarget().getTriple().isDXIL()) {
- // dx.op.waveActiveBallot(opcode, i1)
- return CGF->Builder.CreateIntrinsic(RetTy, Intrinsic::dx_wave_ballot,
- {Cond}, nullptr, "wave.active.ballot");
+ return CGF->EmitRuntimeCall(
+ CGF->CGM.getIntrinsic(Intrinsic::dx_wave_ballot, {I32}), Cond);
}
if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
- // spv.wave.ballot(i1) -> <4 x i32>, then bitcast to struct
llvm::Type *VecTy = llvm::FixedVectorType::get(I32, 4);
- return CGF->Builder.CreateIntrinsic(VecTy, Intrinsic::spv_wave_ballot,
- {Cond}, nullptr, "spv.wave.ballot");
+
+ return CGF->EmitRuntimeCall(
+ CGF->CGM.getIntrinsic(Intrinsic::spv_wave_ballot), Cond);
}
CGF->CGM.Error(E->getExprLoc(),
- "waveActiveBallot is not supported for this target");
+ "WaveActiveBallot is not supported for this target");
- return llvm::UndefValue::get(RetTy);
+ return llvm::PoisonValue::get(llvm::FixedVectorType::get(I32, 4));
}
static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF,
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
index 61b077eb1fead..ceee9eb015512 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
@@ -10,8 +10,13 @@
// CHECK-LABEL: define {{.*}}test
uint4 test(bool p1) {
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
- // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
- // CHECK-DXIL: %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 %{{[a-zA-Z0-9]+}})
- // CHECK: ret <4 x i32> %[[RET]]
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
+ // CHECK-DXIL: %[[RETVAL:.*]] = alloca <4 x i32>, align 16
+ // CHECK-DXIL: %[[WAB:.*]] = call { i32, i32, i32, i32 } @llvm.dx.wave.ballot.i32(i1 %{{[a-zA-Z0-9]+}})
+ // CHECK-DXIL: store { i32, i32, i32, i32 } %[[WAB]], ptr %[[RETVAL]], align 16
+ // CHECK-DXIL: %[[LOAD:.*]] = load <4 x i32>, ptr %[[RETVAL]], align 16
+ // CHECK-DXIL: ret <4 x i32> %[[LOAD]]
+ // CHECK-SPIRV: ret <4 x i32> %[[RET]]
+
return WaveActiveBallot(p1);
}
More information about the cfe-commits
mailing list