[clang] [HLSL][SPIRV] Add convergence tokens to entry point wrapper (PR #112757)

Steven Perron via cfe-commits cfe-commits at lists.llvm.org
Mon Oct 28 08:05:02 PDT 2024


https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/112757

>From f113230de7412cc2440a800be6d3d3640742adbe Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 16 Oct 2024 13:20:29 -0400
Subject: [PATCH 1/2] [HLSL][SPIRV] Add convergence tokens to entry point
 wrapper

Inlining currently assumes that either all function use controled
convergence or none of them do. This is why we need to have the entry
point wrapper use controled convergence.

https://github.com/llvm/llvm-project/blob/c85611e8583e6392d56075ebdfa60893b6284813/llvm/lib/Transforms/Utils/InlineFunction.cpp#L2431-L2439
---
 clang/lib/CodeGen/CGHLSLRuntime.cpp           | 41 +++++++++++++++++--
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 .../CodeGenHLSL/convergence/entry.point.hlsl  | 11 +++++
 3 files changed, 49 insertions(+), 4 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/convergence/entry.point.hlsl

diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 2cce2936fe5aee..d786d804c7cb68 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -404,6 +404,17 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
   IRBuilder<> B(BB);
   llvm::SmallVector<Value *> Args;
+
+  SmallVector<OperandBundleDef, 1> OB;
+  if (CGM.shouldEmitConvergenceTokens()) {
+    assert(EntryFn->isConvergent());
+    llvm::Value *
+        I = B.CreateIntrinsic(llvm::Intrinsic::experimental_convergence_entry, {},
+                              {});
+    llvm::Value *bundleArgs[] = {I};
+    OB.emplace_back("convergencectrl", bundleArgs);
+  }
+
   // FIXME: support struct parameters where semantics are on members.
   // See: https://github.com/llvm/llvm-project/issues/57874
   unsigned SRetOffset = 0;
@@ -419,7 +430,7 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
     Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
   }
 
-  CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
+  CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
   CI->setCallingConv(Fn->getCallingConv());
   // FIXME: Handle codegen for return type semantics.
   // See: https://github.com/llvm/llvm-project/issues/57875
@@ -474,14 +485,21 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
   for (auto &F : M.functions()) {
     if (!F.hasFnAttribute("hlsl.shader"))
       continue;
-    IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
+    auto* Token = getConvergenceToken(F.getEntryBlock());
+    Instruction* IP = Token ? Token : &*F.getEntryBlock().begin();
+    IRBuilder<> B(IP);
+    std::vector<OperandBundleDef> OB;
+    if (Token) {
+      llvm::Value *bundleArgs[] = {Token};
+      OB.emplace_back("convergencectrl", bundleArgs);
+    }
     for (auto *Fn : CtorFns)
-      B.CreateCall(FunctionCallee(Fn));
+      B.CreateCall(FunctionCallee(Fn), {}, OB);
 
     // Insert global dtors before the terminator of the last instruction
     B.SetInsertPoint(F.back().getTerminator());
     for (auto *Fn : DtorFns)
-      B.CreateCall(FunctionCallee(Fn));
+      B.CreateCall(FunctionCallee(Fn), {}, OB);
   }
 
   // No need to keep global ctors/dtors for non-lib profile after call to
@@ -579,3 +597,18 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
   Builder.CreateRetVoid();
   return InitResBindingsFunc;
 }
+
+llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
+  if (!CGM.shouldEmitConvergenceTokens())
+    return nullptr;
+
+  auto E = BB.end();
+  for(auto I = BB.begin(); I != E; ++I) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
+    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
+      return II;
+    }
+  }
+  llvm_unreachable("Convergence token should have been emitted.");
+  return nullptr;
+}
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..cd533cad84e9fb 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -143,6 +143,7 @@ class CGHLSLRuntime {
 
   bool needsResourceBindingInitFn();
   llvm::Function *createResourceBindingInitFn();
+  llvm::Instruction *getConvergenceToken(llvm::BasicBlock &BB);
 
 private:
   void addBufferResourceAnnotation(llvm::GlobalVariable *GV,
diff --git a/clang/test/CodeGenHLSL/convergence/entry.point.hlsl b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
new file mode 100644
index 00000000000000..a848c834da3535
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -fnative-half-type -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
+
+// CHECK-LABEL: define void @main()
+// CHECK-NEXT: entry:
+// CHECK-NEXT: [[token:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK-NEXT: call spir_func void @_Z4mainv() [ "convergencectrl"(token [[token]]) ]
+
+[numthreads(1,1,1)]
+void main() {
+}
+

>From ee1eec91196eccc87248b16a8cc41dcb6f1e032a Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 28 Oct 2024 10:41:09 -0400
Subject: [PATCH 2/2] Fix nits from code review.

---
 clang/lib/CodeGen/CGHLSLRuntime.cpp              | 16 ++++++++--------
 .../CodeGenHLSL/convergence/entry.point.hlsl     |  2 +-
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index d786d804c7cb68..06558ce796f2e4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -408,9 +408,8 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
   SmallVector<OperandBundleDef, 1> OB;
   if (CGM.shouldEmitConvergenceTokens()) {
     assert(EntryFn->isConvergent());
-    llvm::Value *
-        I = B.CreateIntrinsic(llvm::Intrinsic::experimental_convergence_entry, {},
-                              {});
+    llvm::Value *I = B.CreateIntrinsic(
+        llvm::Intrinsic::experimental_convergence_entry, {}, {});
     llvm::Value *bundleArgs[] = {I};
     OB.emplace_back("convergencectrl", bundleArgs);
   }
@@ -485,14 +484,15 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
   for (auto &F : M.functions()) {
     if (!F.hasFnAttribute("hlsl.shader"))
       continue;
-    auto* Token = getConvergenceToken(F.getEntryBlock());
-    Instruction* IP = Token ? Token : &*F.getEntryBlock().begin();
-    IRBuilder<> B(IP);
-    std::vector<OperandBundleDef> OB;
+    auto *Token = getConvergenceToken(F.getEntryBlock());
+    Instruction *IP = &*F.getEntryBlock().begin();
+    SmallVector<OperandBundleDef, 1> OB;
     if (Token) {
       llvm::Value *bundleArgs[] = {Token};
       OB.emplace_back("convergencectrl", bundleArgs);
+      IP = Token->getNextNode();
     }
+    IRBuilder<> B(IP);
     for (auto *Fn : CtorFns)
       B.CreateCall(FunctionCallee(Fn), {}, OB);
 
@@ -603,7 +603,7 @@ llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
     return nullptr;
 
   auto E = BB.end();
-  for(auto I = BB.begin(); I != E; ++I) {
+  for (auto I = BB.begin(); I != E; ++I) {
     auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
     if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
       return II;
diff --git a/clang/test/CodeGenHLSL/convergence/entry.point.hlsl b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
index a848c834da3535..337a9ad5026c16 100644
--- a/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
+++ b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -fnative-half-type -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
 
 // CHECK-LABEL: define void @main()
 // CHECK-NEXT: entry:



More information about the cfe-commits mailing list