[clang] [llvm] [DXIL][SPIRV] Lower `WaveActiveCountBits` intrinsic (PR #113382)
Finn Plummer via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 7 14:51:08 PST 2024
https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/113382
>From 35731658c1769453f86dde6063b137a2c5aeca32 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 18 Oct 2024 15:48:29 -0700
Subject: [PATCH 1/4] [DXIL][SPIRV] Lower WaveActiveCountBits intrinsic
- add codegen for llvm builtin to spirv/directx intrinsic
in CGBuiltin.cpp
- add lowering of spirv intrinsic to spirv backend in
SPIRVInstructionSelector.cpp
- add lowering of directx intrinsic to dxil op in DXIL.td
- add test cases to illustrate passes
- add test case for semantic analysis
---
clang/lib/CodeGen/CGBuiltin.cpp | 7 ++++
clang/lib/CodeGen/CGHLSLRuntime.h | 1 +
.../builtins/WaveActiveCountBits.hlsl | 22 +++++++++++
.../BuiltIns/WaveActiveCountBits-errors.hlsl | 18 +++++++++
llvm/include/llvm/IR/IntrinsicsDirectX.td | 1 +
llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 +
llvm/lib/Target/DirectX/DXIL.td | 9 +++++
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 37 +++++++++++++++++++
.../CodeGen/DirectX/WaveActiveCountBits.ll | 10 +++++
.../hlsl-intrinsics/WaveActiveCountBits.ll | 19 ++++++++++
10 files changed, 125 insertions(+)
create mode 100755 clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 0ef9058640db6a..db6b8f80195691 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -19056,6 +19056,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
}
+ case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
+ Value *OpExpr = EmitScalarExpr(E->getArg(0));
+ Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
+ return EmitRuntimeCall(
+ Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
+ ArrayRef{OpExpr});
+ }
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
// We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
// defined in SPIRVBuiltins.td. So instead we manually get the matching name
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index caf8777fd95a9f..167cc04baf159f 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -91,6 +91,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddI8Packed, dot4add_i8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
GENERATE_HLSL_INTRINSIC_FUNCTION(FirstBitUHigh, firstbituhigh)
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
new file mode 100755
index 00000000000000..3e1f8fcaace9c2
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_bool
+int test_bool(bool expr) {
+ // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ]
+ // CHECK-DXIL: %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}})
+ // CHECK: ret i32 %[[RET]]
+ return WaveActiveCountBits(expr);
+}
+
+// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]]
+// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
new file mode 100644
index 00000000000000..02f45eb30b377a
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+ return __builtin_hlsl_wave_active_count_bits();
+ // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+int test_too_many_arg(bool x) {
+ return __builtin_hlsl_wave_active_count_bits(x, x);
+ // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+struct S { float f; };
+
+int test_bad_conversion(S x) {
+ return __builtin_hlsl_wave_active_count_bits(x);
+ // expected-error at -1 {{no viable conversion from 'S' to 'bool'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 43267033f024a7..191dc8ad208f93 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -86,6 +86,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index e93d6fa83de61b..b9b1e6ab89ddcc 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -85,6 +85,7 @@ let TargetPrefix = "spv" in {
[IntrNoMem, Commutative] >;
def int_spv_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
def int_spv_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
+ def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 1aabff90e5ec6e..b8de926f4be017 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -873,3 +873,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
+
+def WaveAllBitCount : DXILOp<135, waveAllOp> {
+ let Doc = "returns the count of bits set to 1 across the wave";
+ let LLVMIntrinsic = int_dx_wave_active_countbits;
+ let arguments = [Int1Ty];
+ let result = Int32Ty;
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 414583aea91e64..e5cbc82fc5ab95 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -256,6 +256,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1917,6 +1920,38 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}
+bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+ SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
+
+ bool Result =
+ BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformBallot))
+ .addDef(BallotReg)
+ .addUse(GR.getSPIRVTypeID(BallotType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+ .addUse(I.getOperand(2).getReg());
+
+ Result |=
+ BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+ .addImm(0)
+ .addUse(BallotReg)
+ .constrainAllUses(TII, TRI, RBI);
+
+ return Result;
+}
+
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -2739,6 +2774,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
} break;
case Intrinsic::spv_saturate:
return selectSaturate(ResVReg, ResType, I);
+ case Intrinsic::spv_wave_active_countbits:
+ return selectWaveActiveCountBits(ResVReg, ResType, I);
case Intrinsic::spv_wave_is_first_lane: {
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
new file mode 100644
index 00000000000000..5d321372433198
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define void @main(i1 %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
+ %0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
+ ret void
+}
+
+declare i32 @llvm.dx.wave.active.countbits(i1)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
new file mode 100644
index 00000000000000..29944054111ac0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
@@ -0,0 +1,19 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG: %[[#bool:]] = OpTypeBool
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_fun
+; CHECK: %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
+define i32 @test_fun(i1 %expr) {
+entry:
+; CHECK: %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
+; CHECK: %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
+ %0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
+ ret i32 %0
+}
+
+declare i32 @llvm.dx.wave.active.countbits(i1)
>From e3683d55c8dc64bb57e45a8f772332310fc37b5d Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Tue, 22 Oct 2024 21:06:50 +0000
Subject: [PATCH 2/4] clang-format
---
.../Target/SPIRV/SPIRVInstructionSelector.cpp | 26 +++++++++----------
1 file changed, 12 insertions(+), 14 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index e5cbc82fc5ab95..85a425aa4ae025 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -257,7 +257,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
MachineInstr &I) const;
bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
- MachineInstr &I) const;
+ MachineInstr &I) const;
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
@@ -1920,9 +1920,8 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
return Result;
}
-bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
- const SPIRVType *ResType,
- MachineInstr &I) const {
+bool SPIRVInstructionSelector::selectWaveActiveCountBits(
+ Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();
@@ -1932,22 +1931,21 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
bool Result =
- BuildMI(BB, I, I.getDebugLoc(),
- TII.get(SPIRV::OpGroupNonUniformBallot))
+ BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
.addDef(BallotReg)
.addUse(GR.getSPIRVTypeID(BallotType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
.addUse(I.getOperand(2).getReg());
Result |=
- BuildMI(BB, I, I.getDebugLoc(),
- TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
- .addDef(ResVReg)
- .addUse(GR.getSPIRVTypeID(ResType))
- .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
- .addImm(0)
- .addUse(BallotReg)
- .constrainAllUses(TII, TRI, RBI);
+ BuildMI(BB, I, I.getDebugLoc(),
+ TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+ .addImm(0)
+ .addUse(BallotReg)
+ .constrainAllUses(TII, TRI, RBI);
return Result;
}
>From 3180f0fcd9afa5951c8f49078a8bf70dac318652 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Thu, 24 Oct 2024 22:07:50 +0000
Subject: [PATCH 3/4] review comments:
- add constrainAllUses to first spirv op
- update testcase for ease of reading
- use enum instead of int equivalent for documentation
---
.../CodeGenHLSL/builtins/WaveActiveCountBits.hlsl | 12 ++++--------
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 5 +++--
2 files changed, 7 insertions(+), 10 deletions(-)
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
index 3e1f8fcaace9c2..086dd295ba938d 100755
--- a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
@@ -1,22 +1,18 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
-// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: FileCheck %s -DTARGET=dx
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
-// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+// RUN: FileCheck %s -DTARGET=spv
// Test basic lowering to runtime function call.
// CHECK-LABEL: test_bool
int test_bool(bool expr) {
- // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
- // CHECK-SPIRV: %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ]
- // CHECK-DXIL: %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}})
- // CHECK: ret i32 %[[RET]]
+ // CHECK: call {{.*}} @llvm.[[TARGET]].wave.active.countbits
return WaveActiveCountBits(expr);
}
-// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]]
-// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]]
+// CHECK: declare i32 @llvm.[[TARGET]].wave.active.countbits(i1) #[[#attr:]]
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 85a425aa4ae025..04dff2a5a08b6b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1935,7 +1935,8 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
.addDef(BallotReg)
.addUse(GR.getSPIRVTypeID(BallotType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
- .addUse(I.getOperand(2).getReg());
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
Result |=
BuildMI(BB, I, I.getDebugLoc(),
@@ -1943,7 +1944,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
- .addImm(0)
+ .addImm(SPIRV::GroupOperation::Reduce)
.addUse(BallotReg)
.constrainAllUses(TII, TRI, RBI);
>From cff8387169d07ae082af71e92511e43f5d092144 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Thu, 7 Nov 2024 22:26:03 +0000
Subject: [PATCH 4/4] review comments:
- get the proper register class
- use result and instead of or
---
llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 04dff2a5a08b6b..c17bbfc60b3954 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1926,9 +1926,9 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();
- Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
+ Register BallotReg = MRI->createVirtualRegister(GR.getRegClass(BallotType));
bool Result =
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
@@ -1938,7 +1938,7 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
- Result |=
+ Result &=
BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
.addDef(ResVReg)
More information about the llvm-commits
mailing list