[clang] 11b1836 - [HLSL] Handle WaveActiveBallot struct return type appropriately (#175105)
via cfe-commits
cfe-commits at lists.llvm.org
Tue Jan 20 10:08:31 PST 2026
Author: Joshua Batista
Date: 2026-01-20T10:08:26-08:00
New Revision: 11b18362822759ac1592cee5b857943fa2320f8c
URL: https://github.com/llvm/llvm-project/commit/11b18362822759ac1592cee5b857943fa2320f8c
DIFF: https://github.com/llvm/llvm-project/commit/11b18362822759ac1592cee5b857943fa2320f8c.diff
LOG: [HLSL] Handle WaveActiveBallot struct return type appropriately (#175105)
The previous WaveActiveBallot implementation did not account for the
fact that the DXC implementation of the intrinsic returns a struct type
with 4 uints, rather than a vector of 4 uints. This must be respected,
otherwise the validator will reject the uses of WaveActiveBallot that
return a vector of 4 uints.
This PR updates the return type and adds the DXC-specific return type
`fouri32` to use for the intrinsic.
Added:
llvm/test/tools/dxil-dis/waveactiveballot.ll
Modified:
clang/lib/CodeGen/CGHLSLBuiltins.cpp
clang/lib/CodeGen/CGHLSLRuntime.h
clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
clang/test/CodeGenSPIRV/Builtins/subgroup.c
clang/test/Headers/gpuintrin.c
llvm/include/llvm/IR/IntrinsicsDirectX.td
llvm/include/llvm/IR/IntrinsicsSPIRV.td
llvm/lib/Target/DirectX/DXIL.td
llvm/lib/Target/DirectX/DXILOpBuilder.cpp
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
Removed:
################################################################################
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 1b6c3714f7821..75995ff940bc4 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -160,6 +160,42 @@ static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
return LastInst;
}
+static Value *handleHlslWaveActiveBallot(CodeGenFunction &CGF,
+ const CallExpr *E) {
+ Value *Cond = CGF.EmitScalarExpr(E->getArg(0));
+ llvm::Type *I32 = CGF.Int32Ty;
+
+ llvm::Type *Vec4I32 = llvm::FixedVectorType::get(I32, 4);
+ llvm::StructType *Struct4I32 =
+ llvm::StructType::get(CGF.getLLVMContext(), {I32, I32, I32, I32});
+
+ if (CGF.CGM.getTarget().getTriple().isDXIL()) {
+ // Call DXIL intrinsic: returns { i32, i32, i32, i32 }
+ llvm::Function *Fn = CGF.CGM.getIntrinsic(Intrinsic::dx_wave_ballot, {I32});
+
+ Value *StructVal = CGF.EmitRuntimeCall(Fn, Cond);
+ assert(StructVal->getType() == Struct4I32 &&
+ "dx.wave.ballot must return {i32,i32,i32,i32}");
+
+ // Reassemble struct to <4 x i32>
+ llvm::Value *VecVal = llvm::PoisonValue::get(Vec4I32);
+ for (unsigned I = 0; I < 4; ++I) {
+ Value *Elt = CGF.Builder.CreateExtractValue(StructVal, I);
+ VecVal =
+ CGF.Builder.CreateInsertElement(VecVal, Elt, CGF.Builder.getInt32(I));
+ }
+
+ return VecVal;
+ }
+
+ if (CGF.CGM.getTarget().getTriple().isSPIRV())
+ return CGF.EmitRuntimeCall(
+ CGF.CGM.getIntrinsic(Intrinsic::spv_subgroup_ballot), Cond);
+
+ llvm_unreachable(
+ "WaveActiveBallot is only supported for DXIL and SPIRV targets");
+}
+
static Value *handleElementwiseF16ToF32(CodeGenFunction &CGF,
const CallExpr *E) {
Value *Op0 = CGF.EmitScalarExpr(E->getArg(0));
@@ -834,9 +870,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(Op->getType()->isIntegerTy(1) &&
"Intrinsic WaveActiveBallot operand must be a bool");
- Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBallotIntrinsic();
- return EmitRuntimeCall(
- Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+ return handleHlslWaveActiveBallot(*this, E);
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 7a5643052ed84..ba2ca2c358388 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -146,7 +146,6 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
- GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBallot, wave_ballot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
index 61b077eb1fead..df2d854a64247 100644
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBallot.hlsl
@@ -10,8 +10,18 @@
// CHECK-LABEL: define {{.*}}test
uint4 test(bool p1) {
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
- // CHECK-SPIRV: %[[RET:.*]] = call spir_func <4 x i32> @llvm.spv.wave.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
- // CHECK-DXIL: %[[RET:.*]] = call <4 x i32> @llvm.dx.wave.ballot(i1 %{{[a-zA-Z0-9]+}})
- // CHECK: ret <4 x i32> %[[RET]]
+ // CHECK-SPIRV: %[[SPIRVRET:.*]] = call spir_func <4 x i32> @llvm.spv.subgroup.ballot(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
+ // CHECK-DXIL: %[[WAB:.*]] = call { i32, i32, i32, i32 } @llvm.dx.wave.ballot.i32(i1 %{{[a-zA-Z0-9]+}})
+ // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 0
+ // CHECK-DXIL-NEXT: insertelement <4 x i32> poison, i32 {{.*}}, i32 0
+ // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 1
+ // CHECK-DXIL-NEXT: insertelement <4 x i32> {{.*}}, i32 {{.*}}, i32 1
+ // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 2
+ // CHECK-DXIL-NEXT: insertelement <4 x i32> {{.*}}, i32 {{.*}}, i32 2
+ // CHECK-DXIL-NEXT: extractvalue { i32, i32, i32, i32 } {{.*}} 3
+ // CHECK-DXIL-NEXT: %[[DXILRET:.*]] = insertelement <4 x i32> {{.*}}, i32 {{.*}}, i32 3
+ // CHECK-DXIL-NEXT: ret <4 x i32> %[[DXILRET]]
+ // CHECK-SPIRV: ret <4 x i32> %[[SPIRVRET]]
+
return WaveActiveBallot(p1);
}
diff --git a/clang/test/CodeGenSPIRV/Builtins/subgroup.c b/clang/test/CodeGenSPIRV/Builtins/subgroup.c
index 78d41b7933f1f..ba6b48e3f3848 100644
--- a/clang/test/CodeGenSPIRV/Builtins/subgroup.c
+++ b/clang/test/CodeGenSPIRV/Builtins/subgroup.c
@@ -9,7 +9,7 @@ typedef unsigned __attribute__((ext_vector_type(4))) int4;
// CHECK: @{{.*}}test_subgroup_ballot{{.*}}(
// CHECK-NEXT: [[ENTRY:.*:]]
-// CHECK-NEXT: tail call <4 x i32> @llvm.spv.wave.ballot(i1 %i)
+// CHECK-NEXT: tail call <4 x i32> @llvm.spv.subgroup.ballot(i1 %i)
[[clang::sycl_external]] int4 test_subgroup_ballot(_Bool i) {
return __builtin_spirv_subgroup_ballot(i);
}
diff --git a/clang/test/Headers/gpuintrin.c b/clang/test/Headers/gpuintrin.c
index c8fe721c8c37c..891a5abf7a72a 100644
--- a/clang/test/Headers/gpuintrin.c
+++ b/clang/test/Headers/gpuintrin.c
@@ -1267,7 +1267,7 @@ __gpu_kernel void foo() {
// SPIRV-NEXT: [[ENTRY:.*:]]
// SPIRV-NEXT: [[__MASK:%.*]] = alloca <4 x i32>, align 16
// SPIRV-NEXT: [[REF_TMP:%.*]] = alloca <2 x i32>, align 8
-// SPIRV-NEXT: [[TMP0:%.*]] = call <4 x i32> @llvm.spv.wave.ballot(i1 true)
+// SPIRV-NEXT: [[TMP0:%.*]] = call <4 x i32> @llvm.spv.subgroup.ballot(i1 true)
// SPIRV-NEXT: store <4 x i32> [[TMP0]], ptr [[__MASK]], align 16
// SPIRV-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[__MASK]], align 16
// SPIRV-NEXT: [[TMP2:%.*]] = load <4 x i32>, ptr [[__MASK]], align 16
@@ -1335,7 +1335,7 @@ __gpu_kernel void foo() {
// SPIRV-NEXT: store i8 [[STOREDV]], ptr [[__X_ADDR]], align 1
// SPIRV-NEXT: [[TMP0:%.*]] = load i8, ptr [[__X_ADDR]], align 1
// SPIRV-NEXT: [[LOADEDV:%.*]] = trunc i8 [[TMP0]] to i1
-// SPIRV-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.spv.wave.ballot(i1 [[LOADEDV]])
+// SPIRV-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.spv.subgroup.ballot(i1 [[LOADEDV]])
// SPIRV-NEXT: store <4 x i32> [[TMP1]], ptr [[__MASK]], align 16
// SPIRV-NEXT: [[TMP2:%.*]] = load i64, ptr [[__LANE_MASK_ADDR]], align 8
// SPIRV-NEXT: [[TMP3:%.*]] = load <4 x i32>, ptr [[__MASK]], align 16
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 6e6eb2d0ece9d..f79945785566c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -153,7 +153,7 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
-def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_ballot : DefaultAttrsIntrinsic<[llvm_anyint_ty, LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index d782d4f5fae0b..293cb750cea98 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,7 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
- def int_spv_wave_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
+ def int_spv_subgroup_ballot : ClangBuiltin<"__builtin_spirv_subgroup_ballot">,
DefaultAttrsIntrinsic<[llvm_v4i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 6d04732d92ecf..3a40d2c36139d 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -58,6 +58,7 @@ def ResPropsTy : DXILOpParamType;
def SplitDoubleTy : DXILOpParamType;
def BinaryWithCarryTy : DXILOpParamType;
def DimensionsTy : DXILOpParamType;
+def Fouri32s : DXILOpParamType;
class DXILOpClass;
@@ -212,13 +213,12 @@ defset list<DXILOpClass> OpClasses = {
def unpack4x8 : DXILOpClass;
def viewID : DXILOpClass;
def waveActiveAllEqual : DXILOpClass;
- def waveActiveBallot : DXILOpClass;
def waveActiveBit : DXILOpClass;
def waveActiveOp : DXILOpClass;
def waveAllOp : DXILOpClass;
def waveAllTrue : DXILOpClass;
def waveAnyTrue : DXILOpClass;
- def waveBallot : DXILOpClass;
+ def waveActiveBallot : DXILOpClass;
def waveGetLaneCount : DXILOpClass;
def waveGetLaneIndex : DXILOpClass;
def waveIsFirstLane : DXILOpClass;
@@ -1062,6 +1062,14 @@ def WaveActiveAllTrue : DXILOp<114, waveAllTrue> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
+def WaveActiveBallot : DXILOp<116, waveActiveBallot> {
+ let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave.";
+ let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
+ let arguments = [Int1Ty];
+ let result = Fouri32s;
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
let Doc = "returns the value from the specified lane";
let intrinsics = [IntrinSelect<int_dx_wave_readlane>];
@@ -1072,14 +1080,6 @@ def WaveReadLaneAt : DXILOp<117, waveReadLaneAt> {
let stages = [Stages<DXIL1_0, [all_stages]>];
}
-def WaveActiveBallot : DXILOp<118, waveBallot> {
- let Doc = "returns uint4 containing a bitmask of the evaluation of the boolean expression for all active lanes in the current wave.";
- let intrinsics = [IntrinSelect<int_dx_wave_ballot>];
- let arguments = [Int1Ty];
- let result = OverloadTy;
- let stages = [Stages<DXIL1_0, [all_stages]>];
-}
-
def WaveActiveOp : DXILOp<119, waveActiveOp> {
let Doc = "returns the result of the operation across waves";
let intrinsics = [
diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
index 944b2e6433988..1f41d2457e5bc 100644
--- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
@@ -261,10 +261,18 @@ static StructType *getBinaryWithCarryType(LLVMContext &Context) {
return StructType::create({Int32Ty, Int1Ty}, "dx.types.i32c");
}
-static StructType *getDimensionsType(LLVMContext &Ctx) {
- Type *Int32Ty = Type::getInt32Ty(Ctx);
+static StructType *getDimensionsType(LLVMContext &Context) {
+ Type *Int32Ty = Type::getInt32Ty(Context);
return getOrCreateStructType("dx.types.Dimensions",
- {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Ctx);
+ {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
+}
+
+static StructType *getFouri32sType(LLVMContext &Context) {
+ if (auto *ST = StructType::getTypeByName(Context, "dx.types.fouri32"))
+ return ST;
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ return getOrCreateStructType("dx.types.fouri32",
+ {Int32Ty, Int32Ty, Int32Ty, Int32Ty}, Context);
}
static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
@@ -326,7 +334,10 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
return getBinaryWithCarryType(Ctx);
case OpParamType::DimensionsTy:
return getDimensionsType(Ctx);
+ case OpParamType::Fouri32s:
+ return getFouri32sType(Ctx);
}
+
llvm_unreachable("Invalid parameter kind");
return nullptr;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 98b5bfd678135..626393d4ecb40 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -3815,7 +3815,7 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAll);
case Intrinsic::spv_wave_any:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
- case Intrinsic::spv_wave_ballot:
+ case Intrinsic::spv_subgroup_ballot:
return selectWaveOpInst(ResVReg, ResType, I,
SPIRV::OpGroupNonUniformBallot);
case Intrinsic::spv_wave_is_first_lane:
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
index cf6255de3a734..f0440cb4e6183 100644
--- a/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBallot.ll
@@ -1,10 +1,37 @@
-; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
-
-define noundef <4 x i32> @wave_ballot_simple(i1 noundef %p1) {
-entry:
-; CHECK: call <4 x i32> @dx.op.waveBallot.void(i32 118, i1 %p1)
- %ret = call <4 x i32> @llvm.dx.wave.ballot(i1 %p1)
- ret <4 x i32> %ret
-}
-
-declare <4 x i32> @llvm.dx.wave.ballot(i1)
+; RUN: opt -S -scalarizer -dxil-op-lower %s | FileCheck %s
+
+target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64-v48:16:16-v96:32:32-v192:64:64"
+target triple = "dxilv1.3-pc-shadermodel6.3-compute"
+
+; The definition of the custom type should be added
+; CHECK: %dx.types.fouri32 = type { i32, i32, i32, i32 }
+
+; Function Attrs: alwaysinline convergent mustprogress norecurse nounwind
+define hidden noundef <4 x i32> @_Z4testb(i1 noundef %p1) {
+entry:
+ %p1.addr = alloca i32, align 4
+ %storedv = zext i1 %p1 to i32
+ store i32 %storedv, ptr %p1.addr, align 4
+ %0 = load i32, ptr %p1.addr, align 4
+ %loadedv = trunc i32 %0 to i1
+ %1 = load i32, ptr %p1.addr, align 4
+ %loadedv1 = trunc i32 %1 to i1
+
+ ; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %loadedv1)
+
+ %2 = call { i32, i32, i32, i32 } @llvm.dx.wave.ballot.i32(i1 %loadedv1)
+ %3 = extractvalue { i32, i32, i32, i32 } %2, 0
+ %4 = insertelement <4 x i32> poison, i32 %3, i32 0
+ %5 = extractvalue { i32, i32, i32, i32 } %2, 1
+ %6 = insertelement <4 x i32> %4, i32 %5, i32 1
+ %7 = extractvalue { i32, i32, i32, i32 } %2, 2
+ %8 = insertelement <4 x i32> %6, i32 %7, i32 2
+ %9 = extractvalue { i32, i32, i32, i32 } %2, 3
+ %10 = insertelement <4 x i32> %8, i32 %9, i32 3
+
+ ; CHECK-NOT: ret %dx.types.fouri32
+ ; CHECK: ret <4 x i32>
+ ret <4 x i32> %10
+}
+
+declare { i32, i32, i32, i32 } @llvm.dx.wave.ballot.i32(i1)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
index 6831888f038fd..e38d77360631b 100644
--- a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBallot.ll
@@ -13,10 +13,10 @@ entry:
; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
; CHECK: %{{.+}} = OpGroupNonUniformBallot %[[#bitmask]] %[[#scope]] %[[#param]]
%0 = call token @llvm.experimental.convergence.entry()
- %ret = call <4 x i32> @llvm.spv.wave.ballot(i1 %p1) [ "convergencectrl"(token %0) ]
+ %ret = call <4 x i32> @llvm.spv.subgroup.ballot(i1 %p1) [ "convergencectrl"(token %0) ]
ret <4 x i32> %ret
}
-declare <4 x i32> @llvm.spv.wave.ballot(i1) #0
+declare <4 x i32> @llvm.spv.subgroup.ballot(i1) #0
attributes #0 = { convergent }
diff --git a/llvm/test/tools/dxil-dis/waveactiveballot.ll b/llvm/test/tools/dxil-dis/waveactiveballot.ll
new file mode 100644
index 0000000000000..2bdb4ec98a3db
--- /dev/null
+++ b/llvm/test/tools/dxil-dis/waveactiveballot.ll
@@ -0,0 +1,31 @@
+; RUN: llc %s --filetype=obj -o - | dxil-dis -o - | FileCheck %s
+
+; CHECK-NOT: llvm.dx.wave.ballot
+
+; CHECK: call %dx.types.fouri32 @dx.op.waveActiveBallot(i32 116, i1 %p1)
+; CHECK-NOT: ret %dx.types.fouri32
+; CHECK: ret <4 x i32>
+
+
+target triple = "dxil-unknown-shadermodel6.3-library"
+
+%dx.types.fouri32 = type { i32, i32, i32, i32 }
+
+define <4 x i32> @wave_ballot_simple(i1 %p1) {
+entry:
+ %s = call %dx.types.fouri32 @llvm.dx.wave.ballot(i1 %p1)
+
+ %v0 = extractvalue %dx.types.fouri32 %s, 0
+ %v1 = extractvalue %dx.types.fouri32 %s, 1
+ %v2 = extractvalue %dx.types.fouri32 %s, 2
+ %v3 = extractvalue %dx.types.fouri32 %s, 3
+
+ %vec0 = insertelement <4 x i32> poison, i32 %v0, i32 0
+ %vec1 = insertelement <4 x i32> %vec0, i32 %v1, i32 1
+ %vec2 = insertelement <4 x i32> %vec1, i32 %v2, i32 2
+ %vec3 = insertelement <4 x i32> %vec2, i32 %v3, i32 3
+
+ ret <4 x i32> %vec3
+}
+
+declare %dx.types.fouri32 @llvm.dx.wave.ballot(i1)
More information about the cfe-commits
mailing list