[clang] [llvm] [DXIL][SPIRV] Lower WaveActiveCountBits intrinsic (PR #113382)

Finn Plummer via cfe-commits cfe-commits at lists.llvm.org
Tue Oct 22 14:07:22 PDT 2024


https://github.com/inbelic updated https://github.com/llvm/llvm-project/pull/113382

>From 68c16dc2e21d3a78a388fdb883551f32b902db11 Mon Sep 17 00:00:00 2001
From: Finn Plummer <canadienfinn at gmail.com>
Date: Fri, 18 Oct 2024 15:48:29 -0700
Subject: [PATCH 1/2] [DXIL][SPIRV] Lower WaveActiveCountBits intrinsic

  - add codegen for llvm builtin to spirv/directx intrinsic
    in CGBuiltin.cpp
  - add lowering of spirv intrinsic to spirv backend in
    SPIRVInstructionSelector.cpp
  - add lowering of directx intrinsic to dxil op in DXIL.td

  - add test cases to illustrate passes
  - add test case for semantic analysis
---
 clang/lib/CodeGen/CGBuiltin.cpp               |  7 ++++
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 .../builtins/WaveActiveCountBits.hlsl         | 22 +++++++++++
 .../BuiltIns/WaveActiveCountBits-errors.hlsl  | 18 +++++++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  1 +
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  1 +
 llvm/lib/Target/DirectX/DXIL.td               |  9 +++++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 37 +++++++++++++++++++
 .../CodeGen/DirectX/WaveActiveCountBits.ll    | 10 +++++
 .../hlsl-intrinsics/WaveActiveCountBits.ll    | 19 ++++++++++
 10 files changed, 125 insertions(+)
 create mode 100755 clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
 create mode 100644 clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 28f28c70b5ae52..f178499156eb1f 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18879,6 +18879,13 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
   }
+  case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
+    Value *OpExpr = EmitScalarExpr(E->getArg(0));
+    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
+    return EmitRuntimeCall(
+        Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
+        ArrayRef{OpExpr});
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
     // defined in SPIRVBuiltins.td. So instead we manually get the matching name
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..ada57b070783c6 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -89,6 +89,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
new file mode 100755
index 00000000000000..3e1f8fcaace9c2
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveCountBits.hlsl
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_bool
+int test_bool(bool expr) {
+  // CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i32 @llvm.spv.wave.active.countbits(i1 %{{.*}}) [ "convergencectrl"(token %[[#entry_tok]]) ]
+  // CHECK-DXIL:  %[[RET:.*]] = call i32 @llvm.dx.wave.active.countbits(i1 %{{.*}})
+  // CHECK:  ret i32 %[[RET]]
+  return WaveActiveCountBits(expr);
+}
+
+// CHECK-DXIL: declare i32 @llvm.dx.wave.active.countbits(i1) #[[#attr:]]
+// CHECK-SPIRV: declare i32 @llvm.spv.wave.active.countbits(i1) #[[#attr:]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
new file mode 100644
index 00000000000000..02f45eb30b377a
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveCountBits-errors.hlsl
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+int test_too_few_arg() {
+  return __builtin_hlsl_wave_active_count_bits();
+  // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+int test_too_many_arg(bool x) {
+  return __builtin_hlsl_wave_active_count_bits(x, x);
+  // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+struct S { float f; };
+
+int test_bad_conversion(S x) {
+  return __builtin_hlsl_wave_active_count_bits(x);
+  // expected-error at -1 {{no viable conversion from 'S' to 'bool'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..306829f6d1375d 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -84,6 +84,7 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
+def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 6df2eb156a0774..28bbd72b122208 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -83,6 +83,7 @@ let TargetPrefix = "spv" in {
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
+  def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 147b32b1ca9903..7cb6b0a1af9a33 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -820,3 +820,12 @@ def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
   let stages = [Stages<DXIL1_0, [all_stages]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
+
+def WaveAllBitCount : DXILOp<135, waveAllOp> {
+  let Doc = "returns the count of bits set to 1 across the wave";
+  let LLVMIntrinsic = int_dx_wave_active_countbits;
+  let arguments = [Int1Ty];
+  let result = Int32Ty;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d9377fe4b91a1a..5be6d63eae1d6f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -230,6 +230,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
+  bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
+                                MachineInstr &I) const;
+
   bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
                             MachineInstr &I) const;
 
@@ -1762,6 +1765,38 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
+                                                         const SPIRVType *ResType,
+                                                         MachineInstr &I) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  Register BallotReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+  SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+  SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
+
+  bool Result =
+      BuildMI(BB, I, I.getDebugLoc(),
+              TII.get(SPIRV::OpGroupNonUniformBallot))
+          .addDef(BallotReg)
+          .addUse(GR.getSPIRVTypeID(BallotType))
+          .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+          .addUse(I.getOperand(2).getReg());
+
+  Result |=
+    BuildMI(BB, I, I.getDebugLoc(),
+            TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+        .addImm(0)
+        .addUse(BallotReg)
+        .constrainAllUses(TII, TRI, RBI);
+
+  return Result;
+}
+
 bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
                                                     const SPIRVType *ResType,
                                                     MachineInstr &I) const {
@@ -2559,6 +2594,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   } break;
   case Intrinsic::spv_saturate:
     return selectSaturate(ResVReg, ResType, I);
+  case Intrinsic::spv_wave_active_countbits:
+    return selectWaveActiveCountBits(ResVReg, ResType, I);
   case Intrinsic::spv_wave_is_first_lane: {
     SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
     return BuildMI(BB, I, I.getDebugLoc(),
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
new file mode 100644
index 00000000000000..5d321372433198
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveCountBits.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define void @main(i1 %expr) {
+entry:
+; CHECK: call i32 @dx.op.waveAllOp(i32 135, i1 %expr)
+  %0 = call i32 @llvm.dx.wave.active.countbits(i1 %expr)
+  ret void
+}
+
+declare i32 @llvm.dx.wave.active.countbits(i1)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
new file mode 100644
index 00000000000000..29944054111ac0
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveCountBits.ll
@@ -0,0 +1,19 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG:   %[[#ballot_type:]] = OpTypeVector %[[#uint]] 4
+; CHECK-DAG:   %[[#bool:]] = OpTypeBool
+; CHECK-DAG:   %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_fun
+; CHECK:   %[[#bexpr:]] = OpFunctionParameter %[[#bool]]
+define i32 @test_fun(i1 %expr) {
+entry:
+; CHECK:   %[[#ballot:]] = OpGroupNonUniformBallot %[[#ballot_type]] %[[#scope]] %[[#bexpr]]
+; CHECK:   %[[#ret:]] = OpGroupNonUniformBallotBitCount %[[#uint]] %[[#scope]] Reduce %[[#ballot]]
+  %0 = call i32 @llvm.spv.wave.active.countbits(i1 %expr)
+  ret i32 %0
+}
+
+declare i32 @llvm.dx.wave.active.countbits(i1)

>From 6e58d116fbe601256765ef948ad4738309809f73 Mon Sep 17 00:00:00 2001
From: Finn Plummer <finnplummer at microsoft.com>
Date: Tue, 22 Oct 2024 21:06:50 +0000
Subject: [PATCH 2/2] clang-format

---
 .../Target/SPIRV/SPIRVInstructionSelector.cpp | 26 +++++++++----------
 1 file changed, 12 insertions(+), 14 deletions(-)

diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5be6d63eae1d6f..9f670f793574ec 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -231,7 +231,7 @@ class SPIRVInstructionSelector : public InstructionSelector {
                          MachineInstr &I) const;
 
   bool selectWaveActiveCountBits(Register ResVReg, const SPIRVType *ResType,
-                                MachineInstr &I) const;
+                                 MachineInstr &I) const;
 
   bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
                             MachineInstr &I) const;
@@ -1765,9 +1765,8 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
   return Result;
 }
 
-bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
-                                                         const SPIRVType *ResType,
-                                                         MachineInstr &I) const {
+bool SPIRVInstructionSelector::selectWaveActiveCountBits(
+    Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
   assert(I.getNumOperands() == 3);
   assert(I.getOperand(2).isReg());
   MachineBasicBlock &BB = *I.getParent();
@@ -1777,22 +1776,21 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(Register ResVReg,
   SPIRVType *BallotType = GR.getOrCreateSPIRVVectorType(IntTy, 4, I, TII);
 
   bool Result =
-      BuildMI(BB, I, I.getDebugLoc(),
-              TII.get(SPIRV::OpGroupNonUniformBallot))
+      BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformBallot))
           .addDef(BallotReg)
           .addUse(GR.getSPIRVTypeID(BallotType))
           .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
           .addUse(I.getOperand(2).getReg());
 
   Result |=
-    BuildMI(BB, I, I.getDebugLoc(),
-            TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
-        .addDef(ResVReg)
-        .addUse(GR.getSPIRVTypeID(ResType))
-        .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
-        .addImm(0)
-        .addUse(BallotReg)
-        .constrainAllUses(TII, TRI, RBI);
+      BuildMI(BB, I, I.getDebugLoc(),
+              TII.get(SPIRV::OpGroupNonUniformBallotBitCount))
+          .addDef(ResVReg)
+          .addUse(GR.getSPIRVTypeID(ResType))
+          .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII))
+          .addImm(0)
+          .addUse(BallotReg)
+          .constrainAllUses(TII, TRI, RBI);
 
   return Result;
 }



More information about the cfe-commits mailing list