[clang] [llvm] [clang][HLSL] Add WaveIsFirstLane() intrinsic (PR #103299)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Tue Aug 13 08:25:13 PDT 2024


https://github.com/Keenuts created https://github.com/llvm/llvm-project/pull/103299

This commits add the WaveIsFirstLane() hlsl intrinsinc. This intrinsic uses the convergence intrinsincs for the SPIR-V backend. On the DXIL side, I'm not sure what the strategy is. (DXC didn't used convergence intrinsincs for DXIL).

>From 3c65e014ff038d20fe1fb8229157737306bb89e0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 13 Aug 2024 14:39:03 +0200
Subject: [PATCH] [clang][HLSL] Add WaveIsLaneFirst() intrinsic
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This commits add the WaveIsLaneFirst() hlsl intrinsinc.
This intrinsic uses the convergence intrinsincs for the SPIR-V backend.
On the DXIL side, I'm not sure what the strategy is. (DXC didn't used
convergence intrinsincs for DXIL).

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/include/clang/Basic/Builtins.td         |  6 +++
 clang/lib/CodeGen/CGBuiltin.cpp               |  4 ++
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 clang/lib/Headers/hlsl/hlsl_intrinsics.h      |  4 ++
 .../builtins/wave_is_first_lane.hlsl          | 34 ++++++++++++
 llvm/include/llvm/IR/IntrinsicsDirectX.td     |  2 +
 llvm/include/llvm/IR/IntrinsicsSPIRV.td       |  2 +
 llvm/lib/Target/DirectX/DXIL.td               |  9 ++++
 .../Target/SPIRV/SPIRVInstructionSelector.cpp |  8 +++
 .../SPIRV/SPIRVStripConvergentIntrinsics.cpp  | 53 +++++++++++--------
 .../CodeGen/DirectX/wave_is_first_lane.ll     | 13 +++++
 .../SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll  | 27 ++++++++++
 12 files changed, 141 insertions(+), 22 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl
 create mode 100644 llvm/test/CodeGen/DirectX/wave_is_first_lane.ll
 create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index b025a7681bfac3..b047669ff3c53f 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4677,6 +4677,12 @@ def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int()";
 }
 
+def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_is_first_lane"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "bool()";
+}
+
 def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_elementwise_clamp"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 7fe80b0cbdfbfa..0b96fe9d29b595 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18660,6 +18660,10 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
         {}, false, true));
   }
+  case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
+    Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
+    return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 527e73a0e21fc4..d856b03debc063 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -79,6 +79,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index e35a5262f92809..d7b5d8c40a0889 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1725,5 +1725,9 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_get_lane_index)
 __attribute__((convergent)) uint WaveGetLaneIndex();
 
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
+__attribute__((convergent)) bool WaveIsFirstLane();
+
 } // namespace hlsl
 #endif //_HLSL_HLSL_INTRINSICS_H_
diff --git a/clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl b/clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl
new file mode 100644
index 00000000000000..18860c321eb912
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/wave_is_first_lane.hlsl
@@ -0,0 +1,34 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple   \
+// RUN:   dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN:   FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+
+[numthreads(1, 1, 1)]
+void main() {
+// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
+
+// CHECK-SPIRV: %[[#loop_tok:]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[#entry_tok]]) ]
+  while (true) {
+
+// CHECK-DXIL:  %[[#]] = call i1 @llvm.dx.wave.is.first.lane()
+// CHECK-SPIRV: %[[#]] = call i1 @llvm.spv.wave.is.first.lane()
+// CHECK-SPIRV-SAME: [ "convergencectrl"(token %[[#loop_tok]]) ]
+    if (WaveIsFirstLane()) {
+      break;
+    }
+  }
+
+// CHECK-DXIL:  %[[#]] = call i1 @llvm.dx.wave.is.first.lane()
+// CHECK-SPIRV: %[[#]] = call i1 @llvm.spv.wave.is.first.lane()
+// CHECK-SPIRV-SAME: [ "convergencectrl"(token %[[#entry_tok]]) ]
+  if (WaveIsFirstLane()) {
+    return;
+  }
+}
+
+// CHECK-DXIL:  i1 @llvm.dx.wave.is.first.lane() #[[#attr:]]
+// CHECK-SPIRV: i1 @llvm.spv.wave.is.first.lane() #[[#attr:]]
+
+// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 312c3862f240d8..1eea4d25c0ac50 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -60,4 +60,6 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
 def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
 def int_dx_rcp  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
 def int_dx_rsqrt  : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
+
+def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 3f77ef6bfcdbe2..ea8b58caa6b193 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -65,4 +65,6 @@ let TargetPrefix = "spv" in {
     [IntrNoMem, IntrWillReturn] >;
   def int_spv_length : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty]>;
   def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
+
+  def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
 }
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 67015cff78a79a..a4e7aeae883fbc 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -703,3 +703,12 @@ def FlattenedThreadIdInGroup :  DXILOp<96, flattenedThreadIdInGroup> {
   let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>];
   let attributes = [Attributes<DXIL1_0, [ReadNone]>];
 }
+
+def WaveIsFirstLane :  DXILOp<110, waveIsFirstLane> {
+  let Doc = "returns 1 for the first lane in the wave";
+  let LLVMIntrinsic = int_dx_wave_is_first_lane;
+  let arguments = [];
+  let result = i1Ty;
+  let stages = [Stages<DXIL1_0, [all_stages]>];
+  let attributes = [Attributes<DXIL1_0, [ReadNone]>];
+}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c55235a04a607f..d014d90e31fb9d 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2132,6 +2132,14 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
       Size = 0;
     BuildMI(BB, I, I.getDebugLoc(), TII.get(Op)).addUse(PtrReg).addImm(Size);
   } break;
+  case Intrinsic::spv_wave_is_first_lane: {
+    SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+    return BuildMI(BB, I, I.getDebugLoc(),
+                   TII.get(SPIRV::OpGroupNonUniformElect))
+        .addDef(ResVReg)
+        .addUse(GR.getSPIRVTypeID(ResType))
+        .addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
+  }
   default: {
     std::string DiagMsg;
     raw_string_ostream OS(DiagMsg);
diff --git a/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp
index dca30535acfa1a..b632d784977678 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStripConvergentIntrinsics.cpp
@@ -41,31 +41,40 @@ class SPIRVStripConvergentIntrinsics : public FunctionPass {
   virtual bool runOnFunction(Function &F) override {
     DenseSet<Instruction *> ToRemove;
 
+    // Is the instruction is a convergent intrinsic, add it to kill-list and
+    // returns true. Returns false otherwise.
+    auto CleanupIntrinsic = [&](IntrinsicInst *II) {
+      if (II->getIntrinsicID() != Intrinsic::experimental_convergence_entry &&
+          II->getIntrinsicID() != Intrinsic::experimental_convergence_loop &&
+          II->getIntrinsicID() != Intrinsic::experimental_convergence_anchor)
+        return false;
+
+      II->replaceAllUsesWith(UndefValue::get(II->getType()));
+      ToRemove.insert(II);
+      return true;
+    };
+
+    // Replace the given CallInst by a similar CallInst with no convergencectrl
+    // attribute.
+    auto CleanupCall = [&](CallInst *CI) {
+      auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
+      if (!OB.has_value())
+        return;
+
+      auto *NewCall = CallBase::removeOperandBundle(
+          CI, LLVMContext::OB_convergencectrl, CI);
+      NewCall->copyMetadata(*CI);
+      CI->replaceAllUsesWith(NewCall);
+      ToRemove.insert(CI);
+    };
+
     for (BasicBlock &BB : F) {
       for (Instruction &I : BB) {
-        if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
-          if (II->getIntrinsicID() !=
-                  Intrinsic::experimental_convergence_entry &&
-              II->getIntrinsicID() !=
-                  Intrinsic::experimental_convergence_loop &&
-              II->getIntrinsicID() !=
-                  Intrinsic::experimental_convergence_anchor) {
+        if (auto *II = dyn_cast<IntrinsicInst>(&I))
+          if (CleanupIntrinsic(II))
             continue;
-          }
-
-          II->replaceAllUsesWith(UndefValue::get(II->getType()));
-          ToRemove.insert(II);
-        } else if (auto *CI = dyn_cast<CallInst>(&I)) {
-          auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
-          if (!OB.has_value())
-            continue;
-
-          auto *NewCall = CallBase::removeOperandBundle(
-              CI, LLVMContext::OB_convergencectrl, CI);
-          NewCall->copyMetadata(*CI);
-          CI->replaceAllUsesWith(NewCall);
-          ToRemove.insert(CI);
-        }
+        if (auto *CI = dyn_cast<CallInst>(&I))
+          CleanupCall(CI);
       }
     }
 
diff --git a/llvm/test/CodeGen/DirectX/wave_is_first_lane.ll b/llvm/test/CodeGen/DirectX/wave_is_first_lane.ll
new file mode 100644
index 00000000000000..b9a63bb0f14722
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/wave_is_first_lane.ll
@@ -0,0 +1,13 @@
+; RUN: opt -S  -dxil-op-lower  -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define void @main() #0 {
+entry:
+; CHECK: call i1 @dx.op.waveIsFirstLane.i1(i32 110)
+  %0 = call i1 @llvm.dx.wave.is.first.lane()
+  ret void
+}
+
+declare i1 @llvm.dx.wave.is.first.lane() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent nocallback nofree nosync nounwind willreturn }
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll
new file mode 100644
index 00000000000000..94597b37cc7eb1
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll
@@ -0,0 +1,27 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-linux %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
+target triple = "spirv-unknown-vulkan-compute"
+
+; CHECK-DAG:   %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#uint_3:]] = OpConstant %[[#uint]] 3
+; CHECK-DAG:   %[[#bool:]] = OpTypeBool
+
+define spir_func void @main() #0 {
+entry:
+  %0 = call token @llvm.experimental.convergence.entry()
+; CHECK:   %[[#]] = OpGroupNonUniformElect %[[#bool]] %[[#uint_3]]
+  %1 = call i1 @llvm.spv.wave.is.first.lane() [ "convergencectrl"(token %0) ]
+  ret void
+}
+
+declare i32 @__hlsl_wave_get_lane_index() #1
+
+attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
+attributes #1 = { convergent }
+
+!llvm.module.flags = !{!0, !1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{i32 4, !"dx.disable_optimizations", i32 1}



More information about the cfe-commits mailing list