[clang] [llvm] [HLSL] Implement WaveActiveAnyTrue intrinsic (PR #115902)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 08:48:32 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-directx

Author: Ashley Coleman (V-FEXrt)

<details>
<summary>Changes</summary>

Resolves https://github.com/llvm/llvm-project/issues/99160

- [x]  Implement `WaveActiveAnyTrue` clang builtin,
- [x]  Link `WaveActiveAnyTrue` clang builtin with `hlsl_intrinsics.h`
- [x]  Add sema checks for `WaveActiveAnyTrue` to `CheckHLSLBuiltinFunctionCall` in `SemaChecking.cpp`
- [x]  Add codegen for `WaveActiveAnyTrue` to `EmitHLSLBuiltinExpr` in `CGBuiltin.cpp`
- [x]  Add codegen tests to `clang/test/CodeGenHLSL/builtins/WaveActiveAnyTrue.hlsl`
- [x]  Add sema tests to `clang/test/SemaHLSL/BuiltIns/WaveActiveAnyTrue-errors.hlsl`
- [x]  Create the `int_dx_WaveActiveAnyTrue` intrinsic in `IntrinsicsDirectX.td`
- [x]  Create the `DXILOpMapping` of `int_dx_WaveActiveAnyTrue` to `113` in `DXIL.td`
- [x]  Create the `WaveActiveAnyTrue.ll` and `WaveActiveAnyTrue_errors.ll` tests in `llvm/test/CodeGen/DirectX/`
- [x]  Create the `int_spv_WaveActiveAnyTrue` intrinsic in `IntrinsicsSPIRV.td`
- [x]  In SPIRVInstructionSelector.cpp create the `WaveActiveAnyTrue` lowering and map it to `int_spv_WaveActiveAnyTrue` in `SPIRVInstructionSelector::selectIntrinsic`.
- [x]  Create SPIR-V backend test case in `llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAnyTrue.ll`

---
Full diff: https://github.com/llvm/llvm-project/pull/115902.diff


13 Files Affected:

- (modified) clang/include/clang/Basic/Builtins.td (+6) 
- (modified) clang/lib/CodeGen/CGBuiltin.cpp (+16) 
- (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1) 
- (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+9) 
- (modified) clang/lib/Sema/SemaHLSL.cpp (+6) 
- (added) clang/test/CodeGenHLSL/builtins/WaveActiveAnyTrue.hlsl (+17) 
- (added) clang/test/SemaHLSL/BuiltIns/WaveActiveAnyTrue-errors.hlsl (+21) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+1) 
- (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+10) 
- (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+23) 
- (added) llvm/test/CodeGen/DirectX/WaveActiveAnyTrue.ll (+10) 
- (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAnyTrue.ll (+17) 


``````````diff
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9bd67e0cefebc3..496b6739724295 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4744,6 +4744,12 @@ def HLSLAny : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "bool(...)";
 }
 
+def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_active_any_true"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "bool(bool)";
+}
+
 def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 65d7f5c54a1913..cc1995cd2382c8 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18960,6 +18960,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
   }
+  case Builtin::BI__builtin_hlsl_wave_active_any_true: {
+    IntegerType *Int1Ty =
+        llvm::Type::getInt1Ty(CGM.getTypes().getLLVMContext());
+    Value *Op = EmitScalarExpr(E->getArg(0));
+    assert(Op->getType() == Int1Ty &&
+           "wave_active_any_true operand must be a bool");
+
+    llvm::FunctionType *FT =
+        llvm::FunctionType::get(Int1Ty, {Int1Ty}, /*isVarArg=*/false);
+    llvm::StringRef Name = Intrinsic::getName(
+        CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic());
+    return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
+                                                     /*Local=*/false,
+                                                     /*AssumeConvergent=*/true),
+                           {Op}, "hlsl.wave.activeanytrue");
+  }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
     // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
     // defined in SPIRVBuiltins.td. So instead we manually get the matching name
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index cd533cad84e9fb..c3b689ea44fcb8 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -89,6 +89,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
   GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_activeanytrue)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 8ade4b27f360fb..90991e95d6565a 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -2096,6 +2096,15 @@ float4 trunc(float4);
 // Wave* builtins
 //===----------------------------------------------------------------------===//
 
+/// \brief Returns true if the expression is true in any active lane in the
+/// current wave.
+///
+/// \param Val The boolean expression to evaluate.
+/// \return True if the expression is true in any lane.
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_any_true)
+__attribute__((convergent)) bool WaveActiveAnyTrue(bool Val);
+
 /// \brief Counts the number of boolean variables which evaluate to true across
 /// all active lanes in the current wave.
 ///
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index a472538236e2d9..0377f46dcf1991 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2066,6 +2066,12 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_wave_active_any_true: {
+    if (SemaRef.checkArgCount(TheCall, 1))
+      return true;
+
+    break;
+  }
   case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
     if (SemaRef.checkArgCount(TheCall, 2))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveAnyTrue.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveAnyTrue.hlsl
new file mode 100644
index 00000000000000..d8233c45e0ff2d
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveAnyTrue.hlsl
@@ -0,0 +1,17 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -fnative-half-type -triple \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -fnative-half-type -triple \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call for int values.
+
+// CHECK-LABEL: test
+bool test(bool p1) {
+  // CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
+  // CHECK-SPIRV:  %[[RET:.*]] = call spir_func i1 @llvm.spv.wave.activeanytrue(i1 %{{[a-zA-Z0-9]+}}) [ "convergencectrl"(token %[[#entry_tok0]]) ]
+  // CHECK-DXIL:  %[[RET:.*]] = call i1 @llvm.dx.wave.activeanytrue(i1 %{{[a-zA-Z0-9]+}})
+  // CHECK:  ret i1 %[[RET]]
+  return WaveActiveAnyTrue(p1);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveAnyTrue-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveAnyTrue-errors.hlsl
new file mode 100644
index 00000000000000..48fda7b571fca7
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveAnyTrue-errors.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+bool test_too_few_arg() {
+  return __builtin_hlsl_wave_active_any_true();
+  // expected-error at -1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+bool test_too_many_arg(bool p0) {
+  return __builtin_hlsl_wave_active_any_true(p0, p0);
+  // expected-error at -1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+struct Foo
+{
+  int a;
+};
+
+bool test_type_check_2(Foo p0) {
+  return __builtin_hlsl_wave_active_any_true(p0);
+  // expected-error at -1 {{no viable conversion from 'Foo' to 'bool'}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index e30d37f69f781e..7c21dc8f41f169 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -86,6 +86,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
 def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
+def int_dx_wave_activeanytrue : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent]>;
 def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
 def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
 def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 6df2eb156a0774..c5e93aa4cfe01d 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -85,6 +85,7 @@ let TargetPrefix = "spv" in {
     [IntrNoMem, Commutative] >;
   def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
   def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
+  def int_spv_wave_activeanytrue : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent]>;
   def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
   def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
 
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 68ae5de06423c2..9657d249f9170a 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -202,6 +202,7 @@ defset list<DXILOpClass> OpClasses = {
   def unpack4x8 : DXILOpClass;
   def viewID : DXILOpClass;
   def waveActiveAllEqual : DXILOpClass;
+  def waveActiveAnyTrue : DXILOpClass;
   def waveActiveBallot : DXILOpClass;
   def waveActiveBit : DXILOpClass;
   def waveActiveOp : DXILOpClass;
@@ -803,6 +804,15 @@ def CreateHandleFromBinding : DXILOp<218, createHandleFromBinding> {
   let stages = [Stages<DXIL1_6, [all_stages]>];
 }
 
+def WaveActiveAnyTrue : DXILOp<113, waveActiveAnyTrue> {
+  let Doc = "returns true if the expression is true in any of the active lanes in the current wave";
+  let LLVMIntrinsic = int_dx_wave_activeanytrue;
+  let arguments = [Int1Ty];
+  let result = Int1Ty;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
+
 def WaveIsFirstLane :  DXILOp<110, waveIsFirstLane> {
   let Doc = "returns 1 for the first lane in the wave";
   let LLVMIntrinsic = int_dx_wave_is_first_lane;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index d9377fe4b91a1a..47538da068ec0a 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -230,6 +230,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
                          MachineInstr &I) const;
 
+  bool selectWaveActiveAnyTrue(Register ResVReg, const SPIRVType *ResType,
+                               MachineInstr &I) const;
+
   bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
                             MachineInstr &I) const;
 
@@ -1762,6 +1765,24 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
   return Result;
 }
 
+bool SPIRVInstructionSelector::selectWaveActiveAnyTrue(Register ResVReg,
+                                                       const SPIRVType *ResType,
+                                                       MachineInstr &I) const {
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+
+  // IntTy is used to define the execution scope, set to 3 to denote a
+  // cross-lane interaction equivalent to a SPIR-V subgroup.
+  MachineBasicBlock &BB = *I.getParent();
+  SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpGroupNonUniformAny))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII))
+      .addUse(I.getOperand(2).getReg());
+}
+
 bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
                                                     const SPIRVType *ResType,
                                                     MachineInstr &I) const {
@@ -2567,6 +2588,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
         .addUse(GR.getSPIRVTypeID(ResType))
         .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
   }
+  case Intrinsic::spv_wave_activeanytrue:
+    return selectWaveActiveAnyTrue(ResVReg, ResType, I);
   case Intrinsic::spv_wave_readlane:
     return selectWaveReadLaneAt(ResVReg, ResType, I);
   case Intrinsic::spv_step:
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveAnyTrue.ll b/llvm/test/CodeGen/DirectX/WaveActiveAnyTrue.ll
new file mode 100644
index 00000000000000..cc97eb2eedcd60
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveAnyTrue.ll
@@ -0,0 +1,10 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef i1 @wave_aat_simple(i1 noundef %p1) {
+entry:
+; CHECK: call i1 @dx.op.waveActiveAnyTrue(i32 113, i1 %p1)
+  %ret = call i1 @llvm.dx.wave.activeanytrue(i1 %p1)
+  ret i1 %ret
+}
+
+declare i1 @llvm.dx.wave.activeanytrue(i1)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAnyTrue.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAnyTrue.ll
new file mode 100644
index 00000000000000..8b56b60f1da4da
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveAnyTrue.ll
@@ -0,0 +1,17 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#bool:]] = OpTypeBool
+; CHECK: %[[#uint:]] = OpTypeInt 32 0
+; CHECK: %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_wave_aat
+define i1 @test_wave_aat(i1 %p1) {
+entry:
+; CHECK: %[[#param:]] = OpFunctionParameter %[[#bool]]
+; CHECK: %[[#ret:]] = OpGroupNonUniformAny %[[#bool]] %[[#scope]] %[[#param]]
+  %ret = call i1 @llvm.spv.wave.activeanytrue(i1 %p1)
+  ret i1 %ret
+}
+
+declare i1 @llvm.spv.wave.activeanytrue(i1)

``````````

</details>


https://github.com/llvm/llvm-project/pull/115902


More information about the llvm-commits mailing list