[clang] [llvm] [clang][HLSL][SPRI-V] Add convergence intrinsics (PR #80680)

Nathan Gauër via cfe-commits cfe-commits at lists.llvm.org
Tue Mar 12 07:15:35 PDT 2024


https://github.com/Keenuts updated https://github.com/llvm/llvm-project/pull/80680

>From 818ccfd0258602fdd0630823bb2b8af0507749d5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Fri, 2 Feb 2024 16:38:46 +0100
Subject: [PATCH 1/4] [clang][HLSL][SPRI-V] Add convergence intrinsics
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

HLSL has wave operations and other kind of function which required the
control flow to either be converged, or respect certain constraints as
where and how to re-converge.

At the HLSL level, the convergence are mostly obvious: the control flow
is expected to re-converge at the end of a scope.
Once translated to IR, HLSL scopes disapear. This means we need a way to
communicate convergence restrictions down to the backend.

For this, the SPIR-V backend uses convergence intrinsics. So this commit
adds some code to generate convergence intrinsics when required.

This commit is not to be submitted as-is (lacks testing), but
should serve as a basis for an upcoming RFC.

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CGBuiltin.cpp      | 102 +++++++++++++++++++++++++++
 clang/lib/CodeGen/CGCall.cpp         |   4 ++
 clang/lib/CodeGen/CGLoopInfo.h       |   8 ++-
 clang/lib/CodeGen/CodeGenFunction.h  |  19 +++++
 llvm/include/llvm/IR/IntrinsicInst.h |  13 ++++
 5 files changed, 145 insertions(+), 1 deletion(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 20c35757939152..ba5e27a5d4668c 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1130,8 +1130,97 @@ struct BitTest {
 
   static BitTest decodeBitTestBuiltin(unsigned BuiltinID);
 };
+
+// Returns the first convergence entry/loop/anchor instruction found in |BB|.
+// std::nullopt otherwise.
+std::optional<llvm::IntrinsicInst *> getConvergenceToken(llvm::BasicBlock *BB) {
+  for (auto &I : *BB) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
+    if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
+      return II;
+  }
+  return std::nullopt;
+}
+
 } // namespace
 
+llvm::CallBase *
+CodeGenFunction::AddConvergenceControlAttr(llvm::CallBase *Input,
+                                           llvm::Value *ParentToken) {
+  llvm::Value *bundleArgs[] = {ParentToken};
+  llvm::OperandBundleDef OB("convergencectrl", bundleArgs);
+  auto Output = llvm::CallBase::addOperandBundle(
+      Input, llvm::LLVMContext::OB_convergencectrl, OB, Input);
+  Input->replaceAllUsesWith(Output);
+  Input->eraseFromParent();
+  return Output;
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::EmitConvergenceLoop(llvm::BasicBlock *BB,
+                                     llvm::Value *ParentToken) {
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+  Builder.SetInsertPoint(&BB->front());
+  auto CB = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_loop, {}, {});
+  Builder.restoreIP(IP);
+
+  auto I = AddConvergenceControlAttr(CB, ParentToken);
+  // Controlled convergence is incompatible with uncontrolled convergence.
+  // Removing any old attributes.
+  I->setNotConvergent();
+
+  assert(isa<llvm::IntrinsicInst>(I));
+  return dyn_cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
+  auto *BB = &F->getEntryBlock();
+  auto token = getConvergenceToken(BB);
+  if (token.has_value())
+    return token.value();
+
+  // Adding a convergence token requires the function to be marked as
+  // convergent.
+  F->setConvergent();
+
+  CGBuilderTy::InsertPoint IP = Builder.saveIP();
+  Builder.SetInsertPoint(&BB->front());
+  auto I = Builder.CreateIntrinsic(
+      llvm::Intrinsic::experimental_convergence_entry, {}, {});
+  assert(isa<llvm::IntrinsicInst>(I));
+  Builder.restoreIP(IP);
+
+  return dyn_cast<llvm::IntrinsicInst>(I);
+}
+
+llvm::IntrinsicInst *
+CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
+  assert(LI != nullptr);
+
+  auto token = getConvergenceToken(LI->getHeader());
+  if (token.has_value())
+    return *token;
+
+  llvm::IntrinsicInst *PII =
+      LI->getParent()
+          ? EmitConvergenceLoop(LI->getHeader(),
+                                getOrEmitConvergenceLoopToken(LI->getParent()))
+          : getOrEmitConvergenceEntryToken(LI->getHeader()->getParent());
+
+  return EmitConvergenceLoop(LI->getHeader(), PII);
+}
+
+llvm::CallBase *
+CodeGenFunction::AddControlledConvergenceAttr(llvm::CallBase *Input) {
+  llvm::Value *ParentToken =
+      LoopStack.hasInfo()
+          ? getOrEmitConvergenceLoopToken(&LoopStack.getInfo())
+          : getOrEmitConvergenceEntryToken(Input->getFunction());
+  return AddConvergenceControlAttr(Input, ParentToken);
+}
+
 BitTest BitTest::decodeBitTestBuiltin(unsigned BuiltinID) {
   switch (BuiltinID) {
     // Main portable variants.
@@ -5698,6 +5787,19 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         {NDRange, Kernel, Block}));
   }
 
+  case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
+    llvm::Type *BoolTy = llvm::IntegerType::get(getLLVMContext(), 1);
+    llvm::Value *Src0 = EmitScalarExpr(E->getArg(0));
+    auto *CI =
+        EmitRuntimeCall(CGM.CreateRuntimeFunction(
+                            llvm::FunctionType::get(IntTy, {BoolTy}, false),
+                            "__hlsl_wave_active_count_bits", {}),
+                        {Src0});
+    if (getTarget().getTriple().isSPIRVLogical())
+      CI = dyn_cast<CallInst>(AddControlledConvergenceAttr(CI));
+    return RValue::get(CI);
+  }
+
   case Builtin::BI__builtin_store_half:
   case Builtin::BI__builtin_store_halff: {
     Value *Val = EmitScalarExpr(E->getArg(0));
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index a28d7888715d85..4b24367a8e19d9 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -5686,6 +5686,10 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
   if (!CI->getType()->isVoidTy())
     CI->setName("call");
 
+  if (getTarget().getTriple().isSPIRVLogical() &&
+      CI->getCalledFunction()->isConvergent())
+    CI = AddControlledConvergenceAttr(CI);
+
   // Update largest vector width from the return type.
   LargestVectorWidth =
       std::max(LargestVectorWidth, getMaxVectorWidth(CI->getType()));
diff --git a/clang/lib/CodeGen/CGLoopInfo.h b/clang/lib/CodeGen/CGLoopInfo.h
index a1c8c7e5307fd9..7c2f7443bd3c99 100644
--- a/clang/lib/CodeGen/CGLoopInfo.h
+++ b/clang/lib/CodeGen/CGLoopInfo.h
@@ -110,6 +110,10 @@ class LoopInfo {
   /// been processed.
   void finish();
 
+  /// Returns the first outer loop containing this loop if any, nullptr
+  /// otherwise.
+  const LoopInfo *getParent() const { return Parent; }
+
 private:
   /// Loop ID metadata.
   llvm::TempMDTuple TempLoopID;
@@ -291,12 +295,14 @@ class LoopInfoStack {
   /// Set no progress for the next loop pushed.
   void setMustProgress(bool P) { StagedAttrs.MustProgress = P; }
 
-private:
   /// Returns true if there is LoopInfo on the stack.
   bool hasInfo() const { return !Active.empty(); }
+
   /// Return the LoopInfo for the current loop. HasInfo should be called
   /// first to ensure LoopInfo is present.
   const LoopInfo &getInfo() const { return *Active.back(); }
+
+private:
   /// The set of attributes that will be applied to the next pushed loop.
   LoopAttributes StagedAttrs;
   /// Stack of active loops.
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 6c825a302913df..2763f1d3e7e327 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4173,6 +4173,25 @@ class CodeGenFunction : public CodeGenTypeCache {
   void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl);
   void checkTargetFeatures(SourceLocation Loc, const FunctionDecl *TargetDecl);
 
+  // Adds a convergence_ctrl attribute to |Input| and emits the required parent
+  // convergence instructions.
+  llvm::CallBase *AddControlledConvergenceAttr(llvm::CallBase *Input);
+
+  // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
+  // as it's parent convergence instr.
+  llvm::IntrinsicInst *EmitConvergenceLoop(llvm::BasicBlock *BB,
+                                           llvm::Value *ParentToken);
+  // Adds a convergence_ctrl attribute with |ParentToken| as parent convergence
+  // instr to the call |Input|.
+  llvm::CallBase *AddConvergenceControlAttr(llvm::CallBase *Input,
+                                            llvm::Value *ParentToken);
+  // Find the convergence_entry instruction |F|, or emits ones if none exists.
+  // Returns the convergence instruction.
+  llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
+  // Find the convergence_loop instruction for the loop defined by |LI|, or
+  // emits one if none exists. Returns the convergence instruction.
+  llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
+
   llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,
                                   const Twine &name = "");
   llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,
diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index c07b83a81a63e1..4f22720f1c558d 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -1782,6 +1782,19 @@ class ConvergenceControlInst : public IntrinsicInst {
   static bool classof(const Value *V) {
     return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
   }
+
+  // Returns the convergence intrinsic referenced by |I|'s convergencectrl
+  // attribute if any.
+  static IntrinsicInst *getParentConvergenceToken(Instruction *I) {
+    auto *CI = dyn_cast<llvm::CallInst>(I);
+    if (!CI)
+      return nullptr;
+
+    auto Bundle = CI->getOperandBundle(llvm::LLVMContext::OB_convergencectrl);
+    assert(Bundle->Inputs.size() == 1 &&
+           Bundle->Inputs[0]->getType()->isTokenTy());
+    return dyn_cast<llvm::IntrinsicInst>(Bundle->Inputs[0].get());
+  }
 };
 
 } // end namespace llvm

>From 010928a42d018812afd1ed9a552cb7ad823333e1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Thu, 22 Feb 2024 20:39:40 +0100
Subject: [PATCH 2/4] change implemented builtin
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/include/clang/Basic/Builtins.td    |  7 +++++++
 clang/lib/CodeGen/CGBuiltin.cpp          | 17 ++++++-----------
 clang/lib/Headers/hlsl/hlsl_intrinsics.h |  5 +++++
 3 files changed, 18 insertions(+), 11 deletions(-)

diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 9c703377ca8d3e..11c857cfa3f374 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4554,6 +4554,13 @@ def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "unsigned int(bool)";
 }
 
+// HLSL
+def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_wave_get_lane_index"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "unsigned int()";
+}
+
 def HLSLCreateHandle : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_create_handle"];
   let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index ba5e27a5d4668c..12fc855fb92bb8 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1170,8 +1170,7 @@ CodeGenFunction::EmitConvergenceLoop(llvm::BasicBlock *BB,
   // Removing any old attributes.
   I->setNotConvergent();
 
-  assert(isa<llvm::IntrinsicInst>(I));
-  return dyn_cast<llvm::IntrinsicInst>(I);
+  return cast<llvm::IntrinsicInst>(I);
 }
 
 llvm::IntrinsicInst *
@@ -1192,7 +1191,7 @@ CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
   assert(isa<llvm::IntrinsicInst>(I));
   Builder.restoreIP(IP);
 
-  return dyn_cast<llvm::IntrinsicInst>(I);
+  return cast<llvm::IntrinsicInst>(I);
 }
 
 llvm::IntrinsicInst *
@@ -5787,14 +5786,10 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
         {NDRange, Kernel, Block}));
   }
 
-  case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
-    llvm::Type *BoolTy = llvm::IntegerType::get(getLLVMContext(), 1);
-    llvm::Value *Src0 = EmitScalarExpr(E->getArg(0));
-    auto *CI =
-        EmitRuntimeCall(CGM.CreateRuntimeFunction(
-                            llvm::FunctionType::get(IntTy, {BoolTy}, false),
-                            "__hlsl_wave_active_count_bits", {}),
-                        {Src0});
+  case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
+    auto *CI = EmitRuntimeCall(CGM.CreateRuntimeFunction(
+        llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",
+        {}, false, true));
     if (getTarget().getTriple().isSPIRVLogical())
       CI = dyn_cast<CallInst>(AddControlledConvergenceAttr(CI));
     return RValue::get(CI);
diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
index 45f8544392584e..108588e5e0af60 100644
--- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h
@@ -1297,5 +1297,10 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
 _HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_count_bits)
 uint WaveActiveCountBits(bool Val);
 
+/// \brief Returns the index of the current lane within the current wave.
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_get_lane_index)
+uint WaveGetLaneIndex();
+
 } // namespace hlsl
 #endif //_HLSL_HLSL_INTRINSICS_H_

>From 5fbfecb2f436a500c90a1abbb5363c7b85e21a24 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Mon, 11 Mar 2024 16:48:18 +0100
Subject: [PATCH 3/4] add tests
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Nathan Gauër <brioche at google.com>
---
 clang/lib/CodeGen/CodeGenFunction.h           | 38 +++++++++---------
 .../wave_get_lane_index_do_while.hlsl         | 40 +++++++++++++++++++
 .../builtins/wave_get_lane_index_simple.hlsl  | 14 +++++++
 .../builtins/wave_get_lane_index_subcall.hlsl | 21 ++++++++++
 4 files changed, 94 insertions(+), 19 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/builtins/wave_get_lane_index_do_while.hlsl
 create mode 100644 clang/test/CodeGenHLSL/builtins/wave_get_lane_index_simple.hlsl
 create mode 100644 clang/test/CodeGenHLSL/builtins/wave_get_lane_index_subcall.hlsl

diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 2763f1d3e7e327..c475b80db0fc41 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4173,25 +4173,6 @@ class CodeGenFunction : public CodeGenTypeCache {
   void checkTargetFeatures(const CallExpr *E, const FunctionDecl *TargetDecl);
   void checkTargetFeatures(SourceLocation Loc, const FunctionDecl *TargetDecl);
 
-  // Adds a convergence_ctrl attribute to |Input| and emits the required parent
-  // convergence instructions.
-  llvm::CallBase *AddControlledConvergenceAttr(llvm::CallBase *Input);
-
-  // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
-  // as it's parent convergence instr.
-  llvm::IntrinsicInst *EmitConvergenceLoop(llvm::BasicBlock *BB,
-                                           llvm::Value *ParentToken);
-  // Adds a convergence_ctrl attribute with |ParentToken| as parent convergence
-  // instr to the call |Input|.
-  llvm::CallBase *AddConvergenceControlAttr(llvm::CallBase *Input,
-                                            llvm::Value *ParentToken);
-  // Find the convergence_entry instruction |F|, or emits ones if none exists.
-  // Returns the convergence instruction.
-  llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
-  // Find the convergence_loop instruction for the loop defined by |LI|, or
-  // emits one if none exists. Returns the convergence instruction.
-  llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
-
   llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,
                                   const Twine &name = "");
   llvm::CallInst *EmitRuntimeCall(llvm::FunctionCallee callee,
@@ -4887,6 +4868,25 @@ class CodeGenFunction : public CodeGenTypeCache {
   llvm::Value *emitBoolVecConversion(llvm::Value *SrcVec,
                                      unsigned NumElementsDst,
                                      const llvm::Twine &Name = "");
+  // Adds a convergence_ctrl attribute to |Input| and emits the required parent
+  // convergence instructions.
+  llvm::CallBase *AddControlledConvergenceAttr(llvm::CallBase *Input);
+
+private:
+  // Emits a convergence_loop instruction for the given |BB|, with |ParentToken|
+  // as it's parent convergence instr.
+  llvm::IntrinsicInst *EmitConvergenceLoop(llvm::BasicBlock *BB,
+                                           llvm::Value *ParentToken);
+  // Adds a convergence_ctrl attribute with |ParentToken| as parent convergence
+  // instr to the call |Input|.
+  llvm::CallBase *AddConvergenceControlAttr(llvm::CallBase *Input,
+                                            llvm::Value *ParentToken);
+  // Find the convergence_entry instruction |F|, or emits ones if none exists.
+  // Returns the convergence instruction.
+  llvm::IntrinsicInst *getOrEmitConvergenceEntryToken(llvm::Function *F);
+  // Find the convergence_loop instruction for the loop defined by |LI|, or
+  // emits one if none exists. Returns the convergence instruction.
+  llvm::IntrinsicInst *getOrEmitConvergenceLoopToken(const LoopInfo *LI);
 
 private:
   llvm::MDNode *getRangeForLoadFromType(QualType Ty);
diff --git a/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_do_while.hlsl b/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_do_while.hlsl
new file mode 100644
index 00000000000000..9481b0d60a2723
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_do_while.hlsl
@@ -0,0 +1,40 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+// CHECK: define spir_func void @main() [[A0:#[0-9]+]] {
+void main() {
+// CHECK: entry:
+// CHECK:   %[[CT_ENTRY:[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK:   br label %[[LABEL_WHILE_COND:.+]]
+  int cond = 0;
+
+// CHECK: [[LABEL_WHILE_COND]]:
+// CHECK:   %[[CT_LOOP:[0-9]+]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %[[CT_ENTRY]]) ]
+// CHECK:   br label %[[LABEL_WHILE_BODY:.+]]
+  while (true) {
+
+// CHECK: [[LABEL_WHILE_BODY]]:
+// CHECK:   br i1 {{%.+}}, label %[[LABEL_IF_THEN:.+]], label %[[LABEL_IF_END:.+]]
+
+// CHECK: [[LABEL_IF_THEN]]:
+// CHECK:   call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[CT_LOOP]]) ]
+// CHECK:   br label %[[LABEL_WHILE_END:.+]]
+    if (cond == 2) {
+      uint index = WaveGetLaneIndex();
+      break;
+    }
+
+// CHECK: [[LABEL_IF_END]]:
+// CHECK:   br label %[[LABEL_WHILE_COND]]
+    cond++;
+  }
+
+// CHECK: [[LABEL_WHILE_END]]:
+// CHECK:   ret void
+}
+
+// CHECK-DAG: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]
+
+// CHECK-DAG: attributes [[A0]] = {{{.*}}convergent{{.*}}}
+// CHECK-DAG: attributes [[A1]] = {{{.*}}convergent{{.*}}}
+
diff --git a/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_simple.hlsl b/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_simple.hlsl
new file mode 100644
index 00000000000000..8f52d81091c180
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_simple.hlsl
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+// CHECK: define spir_func noundef i32 @_Z6test_1v() [[A0:#[0-9]+]] {
+// CHECK: %[[CI:[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[CI]]) ]
+uint test_1() {
+  return WaveGetLaneIndex();
+}
+
+// CHECK: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]
+
+// CHECK-DAG: attributes [[A0]] = { {{.*}}convergent{{.*}} }
+// CHECK-DAG: attributes [[A1]] = { {{.*}}convergent{{.*}} }
diff --git a/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_subcall.hlsl b/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_subcall.hlsl
new file mode 100644
index 00000000000000..379c8f118f52f3
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/wave_get_lane_index_subcall.hlsl
@@ -0,0 +1,21 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
+// RUN:   spirv-pc-vulkan-library %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+// CHECK: define spir_func noundef i32 @_Z6test_1v() [[A0:#[0-9]+]] {
+// CHECK: %[[C1:[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: call i32 @__hlsl_wave_get_lane_index() [ "convergencectrl"(token %[[C1]]) ]
+uint test_1() {
+  return WaveGetLaneIndex();
+}
+
+// CHECK-DAG: declare i32 @__hlsl_wave_get_lane_index() [[A1:#[0-9]+]]
+
+// CHECK: define spir_func noundef i32 @_Z6test_2v() [[A0]] {
+// CHECK: %[[C2:[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK: call spir_func noundef i32 @_Z6test_1v() [ "convergencectrl"(token %[[C2]]) ]
+uint test_2() {
+  return test_1();
+}
+
+// CHECK-DAG: attributes [[A0]] = {{{.*}}convergent{{.*}}}
+// CHECK-DAG: attributes [[A1]] = {{{.*}}convergent{{.*}}}

>From e399de32d884a1c7cf79eb756b5ec513ffd62bad Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= <brioche at google.com>
Date: Tue, 12 Mar 2024 15:14:55 +0100
Subject: [PATCH 4/4] review feedback

---
 clang/lib/CodeGen/CGBuiltin.cpp | 22 +++++++++-------------
 clang/lib/CodeGen/CGCall.cpp    |  3 +--
 clang/lib/CodeGen/CGLoopInfo.h  |  1 -
 3 files changed, 10 insertions(+), 16 deletions(-)

diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 12fc855fb92bb8..17d7975acfdcb5 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -1132,14 +1132,14 @@ struct BitTest {
 };
 
 // Returns the first convergence entry/loop/anchor instruction found in |BB|.
-// std::nullopt otherwise.
-std::optional<llvm::IntrinsicInst *> getConvergenceToken(llvm::BasicBlock *BB) {
+// std::nullptr otherwise.
+llvm::IntrinsicInst *getConvergenceToken(llvm::BasicBlock *BB) {
   for (auto &I : *BB) {
     auto *II = dyn_cast<llvm::IntrinsicInst>(&I);
     if (II && isConvergenceControlIntrinsic(II->getIntrinsicID()))
       return II;
   }
-  return std::nullopt;
+  return nullptr;
 }
 
 } // namespace
@@ -1166,19 +1166,15 @@ CodeGenFunction::EmitConvergenceLoop(llvm::BasicBlock *BB,
   Builder.restoreIP(IP);
 
   auto I = AddConvergenceControlAttr(CB, ParentToken);
-  // Controlled convergence is incompatible with uncontrolled convergence.
-  // Removing any old attributes.
-  I->setNotConvergent();
-
   return cast<llvm::IntrinsicInst>(I);
 }
 
 llvm::IntrinsicInst *
 CodeGenFunction::getOrEmitConvergenceEntryToken(llvm::Function *F) {
   auto *BB = &F->getEntryBlock();
-  auto token = getConvergenceToken(BB);
-  if (token.has_value())
-    return token.value();
+  auto *token = getConvergenceToken(BB);
+  if (token)
+    return token;
 
   // Adding a convergence token requires the function to be marked as
   // convergent.
@@ -1198,9 +1194,9 @@ llvm::IntrinsicInst *
 CodeGenFunction::getOrEmitConvergenceLoopToken(const LoopInfo *LI) {
   assert(LI != nullptr);
 
-  auto token = getConvergenceToken(LI->getHeader());
-  if (token.has_value())
-    return *token;
+  auto *token = getConvergenceToken(LI->getHeader());
+  if (token)
+    return token;
 
   llvm::IntrinsicInst *PII =
       LI->getParent()
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 4b24367a8e19d9..32ee976a370707 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -5686,8 +5686,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
   if (!CI->getType()->isVoidTy())
     CI->setName("call");
 
-  if (getTarget().getTriple().isSPIRVLogical() &&
-      CI->getCalledFunction()->isConvergent())
+  if (getTarget().getTriple().isSPIRVLogical() && CI->isConvergent())
     CI = AddControlledConvergenceAttr(CI);
 
   // Update largest vector width from the return type.
diff --git a/clang/lib/CodeGen/CGLoopInfo.h b/clang/lib/CodeGen/CGLoopInfo.h
index 7c2f7443bd3c99..0fe33b28913063 100644
--- a/clang/lib/CodeGen/CGLoopInfo.h
+++ b/clang/lib/CodeGen/CGLoopInfo.h
@@ -297,7 +297,6 @@ class LoopInfoStack {
 
   /// Returns true if there is LoopInfo on the stack.
   bool hasInfo() const { return !Active.empty(); }
-
   /// Return the LoopInfo for the current loop. HasInfo should be called
   /// first to ensure LoopInfo is present.
   const LoopInfo &getInfo() const { return *Active.back(); }



More information about the cfe-commits mailing list