[clang] [llvm] [HLSL] Add wave prefix count bits function (PR #178059)
Joshua Batista via cfe-commits
cfe-commits at lists.llvm.org
Tue Jan 27 12:59:22 PST 2026
https://github.com/bob80905 updated https://github.com/llvm/llvm-project/pull/178059
>From 6e20bc6d92b14abd20085589d63eb89136cbf4a6 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Mon, 26 Jan 2026 13:39:12 -0800
Subject: [PATCH 1/3] first attempt
---
clang/include/clang/Basic/Builtins.td | 6 ++
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 47 ++++++++++++++++
.../lib/Headers/hlsl/hlsl_alias_intrinsics.h | 10 ++++
clang/lib/Sema/SemaHLSL.cpp | 24 ++++++++
.../builtins/WavePrefixCountBits.hlsl | 27 +++++++++
.../BuiltIns/WavePrefixCountBits-errors.hlsl | 30 ++++++++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 +
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 +
llvm/lib/Target/DirectX/DXIL.td | 26 +++++++++
.../DirectX/DirectXTargetTransformInfo.cpp | 1 +
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 55 +++++++++++++++++++
.../CodeGen/DirectX/WavePrefixBitCount.ll | 10 ++++
.../hlsl-intrinsics/WavePrefixCountBits.ll | 17 ++++++
13 files changed, 255 insertions(+)
create mode 100644 clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl
create mode 100644 clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl
create mode 100644 llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index bc8f1474493b0..0ef28ae16c301 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -5127,6 +5127,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool()";
}
+def HLSLWavePrefixCountBits : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_prefix_count_bits"];
+ let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Prototype = "int(bool)";
+}
+
def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 36691c7b72efe..c6998a343f496 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -331,6 +331,40 @@ static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
}
}
+// select and return a specific wave prefix op intrinsic,
+// based on the provided op kind.
+// OpKinds:
+// CountBits = 136, count all bits set in previous threads
+// This is the only operation in DXIL so far under this class
+static Intrinsic::ID getPrefixOpIntrinsic(int OpKind,
+ llvm::Triple::ArchType Arch,
+ CGHLSLRuntime &RT, QualType QT) {
+ switch (Arch) {
+ case llvm::Triple::spirv:
+ switch (OpKind) {
+ case 136: {
+ return Intrinsic::spv_subgroup_prefix_bit_count;
+ }
+ default: {
+ llvm_unreachable("Unexpected SubOp ID");
+ }
+ }
+ case llvm::Triple::dxil: {
+ switch (OpKind) {
+ case 136: {
+ return Intrinsic::dx_wave_prefix_bit_count;
+ }
+ default: {
+ llvm_unreachable("Unexpected SubOp ID");
+ }
+ }
+ }
+ default:
+ llvm_unreachable(
+ "WavePrefixOp instruction not supported by target architecture");
+ }
+}
+
// Returns the mangled name for a builtin function that the SPIR-V backend
// will expand into a spec Constant.
static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
@@ -808,6 +842,19 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
nullptr, "hlsl.saturate");
}
+ case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
+ Value *Op = EmitScalarExpr(E->getArg(0));
+ assert(Op->getType()->isIntegerTy(1) &&
+ "WavePrefixBitCount operand must be a boolean type");
+
+ Intrinsic::ID IID = getPrefixOpIntrinsic(
+ /* OpKind */ 136, getTarget().getTriple().getArch(),
+ CGM.getHLSLRuntime(), E->getArg(0)->getType());
+
+ return EmitRuntimeCall(
+ Intrinsic::getOrInsertDeclaration(&CGM.getModule(), IID), ArrayRef{Op},
+ "hlsl.wave.prefix.bit.count");
+ }
case Builtin::BI__builtin_hlsl_select: {
Value *OpCond = EmitScalarExpr(E->getArg(0));
RValue RValTrue = EmitAnyExpr(E->getArg(1));
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index e9a41b94d6c03..656fa7c7dea82 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2445,6 +2445,16 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_get_lane_count)
__attribute__((convergent)) uint WaveGetLaneCount();
+//===----------------------------------------------------------------------===//
+// WavePrefixOp builtins
+//===----------------------------------------------------------------------===//
+/// \brief Returns the count of bits of Expr set to 1 on prior lanes.
+/// \param Expr The boolean expression to evaluate.
+/// \return the count of bits set to 1 on prior lanes.
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_prefix_count_bits)
+__attribute__((convergent)) int WavePrefixCountBits(bool Expr);
+
//===----------------------------------------------------------------------===//
// WaveReadLaneAt builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 4d31e26d56e6b..cd9e77f913800 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3593,6 +3593,30 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
+ case Builtin::BI__builtin_hlsl_wave_prefix_count_bits: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ // Ensure input expr type is a scalar/vector and then
+ // set the return type to the arg type
+ QualType ArgType = TheCall->getArg(0)->getType();
+ // not the scalar or vector<scalar>
+ if (!(ArgType->isScalarType())) {
+ SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+ diag::err_typecheck_expect_any_scalar_or_vector)
+ << ArgType << 0;
+ return true;
+ }
+
+ if (!(ArgType->isBooleanType())) {
+ SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+ diag::err_typecheck_expect_any_scalar_or_vector)
+ << ArgType << 0;
+ return true;
+ }
+
+ break;
+ }
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
diff --git a/clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl
new file mode 100644
index 0000000000000..135507b60c8fc
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WavePrefixCountBits.hlsl
@@ -0,0 +1,27 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm \
+// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm \
+// RUN: -disable-llvm-passes -o - | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+int test_int(bool expr) {
+// CHECK-DXIL: define hidden noundef i32 {{.*}}(i1 noundef %[[EXPR:.*]]) #[[CONVATTR:.*]] {
+// CHECK-SPIRV: define hidden spir_func noundef i32 {{.*}}(i1 noundef %[[EXPR:.*]]) #[[CONVATTR:.*]] {
+ // CHECK: entry:
+ // CHECK: %[[EXPRADDR:.*]] = alloca i32, align 4
+ // CHECK: %[[STOREDVAL:.*]] = zext i1 %[[EXPR]] to i32
+ // CHECK: store i32 %[[STOREDVAL]], ptr %[[EXPRADDR]], align 4
+ // CHECK: %[[LOADEDVAL:.*]] = load i32, ptr %[[EXPRADDR]], align 4
+ // CHECK: %[[TRUNCLOADEDVAL:.*]] = trunc i32 %[[LOADEDVAL]] to i1
+
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.subgroup.prefix.bit.count(i1 %[[TRUNCLOADEDVAL]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.prefix.bit.count(i1 %[[TRUNCLOADEDVAL]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WavePrefixCountBits(expr);
+}
+
+// CHECK: attributes #[[CONVATTR]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl
new file mode 100644
index 0000000000000..2b2bca82c6b17
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WavePrefixCountBits-errors.hlsl
@@ -0,0 +1,30 @@
+// RUN: %clang_cc1 -finclude-default-header -fnative-int16-type -fnative-half-type \
+// RUN: -fmath-errno -ffp-contract=on -fno-rounding-math -finclude-default-header \
+// RUN: -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+ return __builtin_hlsl_wave_prefix_count_bits();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+int test_too_many_arg(bool p0) {
+ return __builtin_hlsl_wave_prefix_count_bits(p0, p0);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+float test_expr_bool_type_check(float p0) {
+ return __builtin_hlsl_wave_prefix_count_bits(p0);
+ // expected-error at -1 {{invalid operand of type 'float'}}
+}
+
+float2 test_expr_bool_vec_type_check(float2 p0) {
+ return __builtin_hlsl_wave_prefix_count_bits(p0);
+ // expected-error at -1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+ return __builtin_hlsl_wave_prefix_count_bits(p0);
+ // expected-error at -1 {{invalid operand of type 'S'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 23627848b6214..2aa20ddd5d434 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -160,6 +160,7 @@ def int_dx_lerp : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, L
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
+def int_dx_wave_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index a93e8ad0ce964..1e54bf394c984 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -117,6 +117,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
[IntrNoMem, Commutative] >;
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
+ def int_spv_subgroup_prefix_bit_count : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrNoMem]>;
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 3a40d2c36139d..bbce6bc082ac8 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -235,6 +235,7 @@ defset list<DXILOpClass> OpClasses = {
def waveMatrix_StoreRawBuf : DXILOpClass;
def waveMultiPrefixBitCount : DXILOpClass;
def waveMultiPrefixOp : DXILOpClass;
+ def wavePrefixBitCount : DXILOpClass;
def wavePrefixOp : DXILOpClass;
def waveReadLaneAt : DXILOpClass;
def waveReadLaneFirst : DXILOpClass;
@@ -317,6 +318,8 @@ defvar WaveOpKind_Product = 1;
defvar WaveOpKind_Min = 2;
defvar WaveOpKind_Max = 3;
+defvar WavePrefixOpKind_BitCount = 136;
+
defvar SignedOpKind_Signed = 0;
defvar SignedOpKind_Unsigned = 1;
@@ -1124,6 +1127,29 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}
+def WavePrefixOp : DXILOp<121, wavePrefixOp> {
+ let Doc = "returns the result of the operation on prior lanes";
+
+ let intrinsics = [
+ IntrinSelect<int_dx_wave_prefix_bit_count,
+ [
+ IntrinArgI32<WavePrefixOpKind_BitCount>, IntrinArgIndex<0>,
+ IntrinArgI8<SignedOpKind_Unsigned>
+ ]>
+ ];
+
+ let arguments = [
+ Int32Ty, // prefix op kind
+ Int1Ty, // value
+ Int8Ty // signedness
+ ];
+
+ let result = Int32Ty;
+
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, []>];
+}
+
def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
let Doc = "returns the float16 stored in the low-half of the uint converted "
"to a float";
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index f54b48b91265e..b885b459b5d72 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -55,6 +55,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble:
+ case Intrinsic::dx_wave_prefix_bit_count:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_min:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 23cfb326bc8d9..a443acceca824 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -224,6 +224,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectWavePrefixBitCount(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, bool IsUnsigned) const;
@@ -2715,6 +2718,56 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}
+bool SPIRVInstructionSelector::selectWavePrefixBitCount(
+ Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+
+ auto Op = I.getOperand(2);
+ assert(Op.isReg());
+
+ MachineBasicBlock &BB = *I.getParent();
+ DebugLoc DL = I.getDebugLoc();
+
+ Register InputRegister = Op.getReg();
+ SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+
+ if (!InputType)
+ report_fatal_error("Input Type could not be determined.");
+
+ if (!GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool))
+ report_fatal_error("WavePrefixBitCount requires boolean input");
+
+ // Types
+ SPIRVType *Int32Ty = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+
+ // Ballot result type: vector<uint32>
+ // Match DXC: %v4uint for Subgroup size
+ SPIRVType *BallotTy = GR.getOrCreateSPIRVVectorType(Int32Ty, 4, I, TII);
+
+ // Create a vreg for the ballot result
+ Register BallotVReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+
+ // 1. OpGroupNonUniformBallot
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformBallot))
+ .addDef(BallotVReg)
+ .addUse(GR.getSPIRVTypeID(BallotTy))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, Int32Ty, TII))
+ .addUse(InputRegister)
+ .constrainAllUses(TII, TRI, RBI);
+
+ // 2. OpGroupNonUniformBallotBitCount
+ BuildMI(BB, I, DL, TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, Int32Ty, TII))
+ .addImm(SPIRV::GroupOperation::ExclusiveScan)
+ .addUse(BallotVReg)
+ .constrainAllUses(TII, TRI, RBI);
+
+ return true;
+}
+
bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
@@ -3859,6 +3912,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectExtInst(ResVReg, ResType, I, CL::u_clamp, GL::UClamp);
case Intrinsic::spv_sclamp:
return selectExtInst(ResVReg, ResType, I, CL::s_clamp, GL::SClamp);
+ case Intrinsic::spv_subgroup_prefix_bit_count:
+ return selectWavePrefixBitCount(ResVReg, ResType, I);
case Intrinsic::spv_wave_active_countbits:
return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_wave_all:
diff --git a/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll
new file mode 100644
index 0000000000000..67432aa8e3696
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library < %s | FileCheck %s
+
+; Test that WavePrefixCountBits maps down to the DirectX op
+
+define noundef i32 @wave_prefix_count_bits(i1 noundef %expr) {
+entry:
+; CHECK: call i32 @dx.op.wavePrefixOp(i32 121, i32 136, i1 %expr, i8 1)
+ %ret = call i32 @llvm.dx.wave.prefix.bit.count(i1 %expr)
+ ret i32 %ret
+}
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll
new file mode 100644
index 0000000000000..321123ab5a617
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WavePrefixCountBits.ll
@@ -0,0 +1,17 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend
+
+define noundef i32 @wave_prefix_count_bits(i1 noundef %expr) {
+entry:
+ ; CHECK: %[[UINT:.*]] = OpTypeInt 32 0
+ ; CHECK: %[[UINT4:.*]] = OpTypeVector %[[UINT]] 4
+ ; CHECK: %[[UINT3:.*]] = OpConstant %[[UINT]] 3
+ ; CHECK: %[[INPUTREG:.*]] = OpFunctionParameter
+ ; CHECK: %[[BALLOTRESULT:.*]] = OpGroupNonUniformBallot %[[UINT4]] %[[UINT3]] %[[INPUTREG]]
+ ; CHECK: %[[RET:.*]] = OpGroupNonUniformBallotBitCount %[[UINT]] %[[UINT3]] ExclusiveScan %[[BALLOTRESULT]]
+ %ret = call i32 @llvm.spv.subgroup.prefix.bit.count(i1 %expr)
+ ; CHECK: OpReturnValue %[[RET]]
+ ret i32 %ret
+}
>From a0085c2ed1817d3f20ac22bc33b2bb3dbb511151 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Tue, 27 Jan 2026 12:51:33 -0800
Subject: [PATCH 2/3] address Tex, remove waveprefixop intrinsic and make
countbits independent, but still named the same
---
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 33 ++++---------------
clang/lib/Sema/SemaHLSL.cpp | 2 +-
llvm/lib/Target/DirectX/DXIL.td | 20 +++--------
.../CodeGen/DirectX/WavePrefixBitCount.ll | 2 +-
4 files changed, 12 insertions(+), 45 deletions(-)
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 54259a3f6aba1..33196b66c576e 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -384,33 +384,13 @@ static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
}
}
-// select and return a specific wave prefix op intrinsic,
-// based on the provided op kind.
-// OpKinds:
-// CountBits = 136, count all bits set in previous threads
-// This is the only operation in DXIL so far under this class
-static Intrinsic::ID getPrefixOpIntrinsic(int OpKind,
- llvm::Triple::ArchType Arch,
- CGHLSLRuntime &RT, QualType QT) {
+static Intrinsic::ID getPrefixCountBitsIntrinsic(
+ llvm::Triple::ArchType Arch) {
switch (Arch) {
case llvm::Triple::spirv:
- switch (OpKind) {
- case 136: {
- return Intrinsic::spv_subgroup_prefix_bit_count;
- }
- default: {
- llvm_unreachable("Unexpected SubOp ID");
- }
- }
+ return Intrinsic::spv_subgroup_prefix_bit_count;
case llvm::Triple::dxil: {
- switch (OpKind) {
- case 136: {
- return Intrinsic::dx_wave_prefix_bit_count;
- }
- default: {
- llvm_unreachable("Unexpected SubOp ID");
- }
- }
+ return Intrinsic::dx_wave_prefix_bit_count;
}
default:
llvm_unreachable(
@@ -903,9 +883,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(Op->getType()->isIntegerTy(1) &&
"WavePrefixBitCount operand must be a boolean type");
- Intrinsic::ID IID = getPrefixOpIntrinsic(
- /* OpKind */ 136, getTarget().getTriple().getArch(),
- CGM.getHLSLRuntime(), E->getArg(0)->getType());
+ Intrinsic::ID IID = getPrefixCountBitsIntrinsic(
+ getTarget().getTriple().getArch());
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), IID), ArrayRef{Op},
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 5d65c19bab234..e6ba492549b12 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3614,7 +3614,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
// Ensure input expr type is a scalar/vector and then
// set the return type to the arg type
QualType ArgType = TheCall->getArg(0)->getType();
- // not the scalar or vector<scalar>
+
if (!(ArgType->isScalarType())) {
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_typecheck_expect_any_scalar_or_vector)
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 3843257ba26d3..d230e3daec55e 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -318,8 +318,6 @@ defvar WaveOpKind_Product = 1;
defvar WaveOpKind_Min = 2;
defvar WaveOpKind_Max = 3;
-defvar WavePrefixOpKind_BitCount = 136;
-
defvar SignedOpKind_Signed = 0;
defvar SignedOpKind_Unsigned = 1;
@@ -1127,21 +1125,11 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}
-def WavePrefixOp : DXILOp<121, wavePrefixOp> {
- let Doc = "returns the result of the operation on prior lanes";
-
- let intrinsics = [
- IntrinSelect<int_dx_wave_prefix_bit_count,
- [
- IntrinArgI32<WavePrefixOpKind_BitCount>, IntrinArgIndex<0>,
- IntrinArgI8<SignedOpKind_Unsigned>
- ]>
- ];
-
- let arguments = [ Int32Ty, Int1Ty, Int8Ty ];
-
+def WavePrefixCountBits : DXILOp<136, wavePrefixOp> {
+ let Doc = "returns the count of bits of Expr set to 1 on prior lanes";
+ let intrinsics = [IntrinSelect<int_dx_wave_prefix_bit_count>];
+ let arguments = [Int1Ty];
let result = Int32Ty;
-
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, []>];
}
diff --git a/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll
index 67432aa8e3696..406bfa44a3f47 100644
--- a/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll
+++ b/llvm/test/CodeGen/DirectX/WavePrefixBitCount.ll
@@ -4,7 +4,7 @@
define noundef i32 @wave_prefix_count_bits(i1 noundef %expr) {
entry:
-; CHECK: call i32 @dx.op.wavePrefixOp(i32 121, i32 136, i1 %expr, i8 1)
+; CHECK: call i32 @dx.op.wavePrefixOp(i32 136, i1 %expr)
%ret = call i32 @llvm.dx.wave.prefix.bit.count(i1 %expr)
ret i32 %ret
}
>From 7be2fafd1d4accb947a02f2d7101e3eb5e1dbf31 Mon Sep 17 00:00:00 2001
From: Joshua Batista <jbatista at microsoft.com>
Date: Tue, 27 Jan 2026 12:59:03 -0800
Subject: [PATCH 3/3] try manual clang-format
---
clang/lib/CodeGen/CGHLSLBuiltins.cpp | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 33196b66c576e..14cacb59e229a 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -384,8 +384,7 @@ static Intrinsic::ID getWaveActiveMinIntrinsic(llvm::Triple::ArchType Arch,
}
}
-static Intrinsic::ID getPrefixCountBitsIntrinsic(
- llvm::Triple::ArchType Arch) {
+static Intrinsic::ID getPrefixCountBitsIntrinsic(llvm::Triple::ArchType Arch) {
switch (Arch) {
case llvm::Triple::spirv:
return Intrinsic::spv_subgroup_prefix_bit_count;
@@ -883,8 +882,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
assert(Op->getType()->isIntegerTy(1) &&
"WavePrefixBitCount operand must be a boolean type");
- Intrinsic::ID IID = getPrefixCountBitsIntrinsic(
- getTarget().getTriple().getArch());
+ Intrinsic::ID IID =
+ getPrefixCountBitsIntrinsic(getTarget().getTriple().getArch());
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), IID), ArrayRef{Op},
More information about the cfe-commits
mailing list