[clang] [HLSL][SPIRV] Add convergence tokens to entry point wrapper (PR #112757)

Steven Perron via cfe-commits cfe-commits at lists.llvm.org
Mon Oct 28 07:54:13 PDT 2024


https://github.com/s-perron updated https://github.com/llvm/llvm-project/pull/112757

>From 24e84865fe8cd9684c6e5b2763970ba278038039 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Wed, 16 Oct 2024 13:20:29 -0400
Subject: [PATCH 1/2] [HLSL][SPIRV] Add convergence tokens to entry point
 wrapper

Inlining currently assumes that either all function use controled
convergence or none of them do. This is why we need to have the entry
point wrapper use controled convergence.

https://github.com/llvm/llvm-project/blob/c85611e8583e6392d56075ebdfa60893b6284813/llvm/lib/Transforms/Utils/InlineFunction.cpp#L2431-L2439
---
 clang/lib/CodeGen/CGHLSLRuntime.cpp           | 41 +++++++++++++++++--
 clang/lib/CodeGen/CGHLSLRuntime.h             |  1 +
 .../CodeGenHLSL/convergence/entry.point.hlsl  | 11 +++++
 3 files changed, 49 insertions(+), 4 deletions(-)
 create mode 100644 clang/test/CodeGenHLSL/convergence/entry.point.hlsl

diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index 3237d93ca31ceb..d006f5d8f3c1cd 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -399,6 +399,17 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
   IRBuilder<> B(BB);
   llvm::SmallVector<Value *> Args;
+
+  SmallVector<OperandBundleDef, 1> OB;
+  if (CGM.shouldEmitConvergenceTokens()) {
+    assert(EntryFn->isConvergent());
+    llvm::Value *
+        I = B.CreateIntrinsic(llvm::Intrinsic::experimental_convergence_entry, {},
+                              {});
+    llvm::Value *bundleArgs[] = {I};
+    OB.emplace_back("convergencectrl", bundleArgs);
+  }
+
   // FIXME: support struct parameters where semantics are on members.
   // See: https://github.com/llvm/llvm-project/issues/57874
   unsigned SRetOffset = 0;
@@ -414,7 +425,7 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
     Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
   }
 
-  CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
+  CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
   CI->setCallingConv(Fn->getCallingConv());
   // FIXME: Handle codegen for return type semantics.
   // See: https://github.com/llvm/llvm-project/issues/57875
@@ -469,14 +480,21 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
   for (auto &F : M.functions()) {
     if (!F.hasFnAttribute("hlsl.shader"))
       continue;
-    IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
+    auto* Token = getConvergenceToken(F.getEntryBlock());
+    Instruction* IP = Token ? Token : &*F.getEntryBlock().begin();
+    IRBuilder<> B(IP);
+    std::vector<OperandBundleDef> OB;
+    if (Token) {
+      llvm::Value *bundleArgs[] = {Token};
+      OB.emplace_back("convergencectrl", bundleArgs);
+    }
     for (auto *Fn : CtorFns)
-      B.CreateCall(FunctionCallee(Fn));
+      B.CreateCall(FunctionCallee(Fn), {}, OB);
 
     // Insert global dtors before the terminator of the last instruction
     B.SetInsertPoint(F.back().getTerminator());
     for (auto *Fn : DtorFns)
-      B.CreateCall(FunctionCallee(Fn));
+      B.CreateCall(FunctionCallee(Fn), {}, OB);
   }
 
   // No need to keep global ctors/dtors for non-lib profile after call to
@@ -489,3 +507,18 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
       GV->eraseFromParent();
   }
 }
+
+llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
+  if (!CGM.shouldEmitConvergenceTokens())
+    return nullptr;
+
+  auto E = BB.end();
+  for(auto I = BB.begin(); I != E; ++I) {
+    auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
+    if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
+      return II;
+    }
+  }
+  llvm_unreachable("Convergence token should have been emitted.");
+  return nullptr;
+}
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index f7621ee20b1243..3eb56cd5449704 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -137,6 +137,7 @@ class CGHLSLRuntime {
 
   void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn);
   void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn);
+  llvm::Instruction *getConvergenceToken(llvm::BasicBlock &BB);
 
 private:
   void addBufferResourceAnnotation(llvm::GlobalVariable *GV,
diff --git a/clang/test/CodeGenHLSL/convergence/entry.point.hlsl b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
new file mode 100644
index 00000000000000..a848c834da3535
--- /dev/null
+++ b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
@@ -0,0 +1,11 @@
+// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -fnative-half-type -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
+
+// CHECK-LABEL: define void @main()
+// CHECK-NEXT: entry:
+// CHECK-NEXT: [[token:%[0-9]+]] = call token @llvm.experimental.convergence.entry()
+// CHECK-NEXT: call spir_func void @_Z4mainv() [ "convergencectrl"(token [[token]]) ]
+
+[numthreads(1,1,1)]
+void main() {
+}
+

>From 90512c62a8e7f92857c3c74c54c1a2dd7d160602 Mon Sep 17 00:00:00 2001
From: Steven Perron <stevenperron at google.com>
Date: Mon, 28 Oct 2024 10:41:09 -0400
Subject: [PATCH 2/2] Fix nits from code review.

---
 clang/lib/CodeGen/CGHLSLRuntime.cpp              | 16 ++++++++--------
 .../CodeGenHLSL/convergence/entry.point.hlsl     |  2 +-
 2 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp
index d006f5d8f3c1cd..fcde0d20f2b543 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.cpp
+++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp
@@ -403,9 +403,8 @@ void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
   SmallVector<OperandBundleDef, 1> OB;
   if (CGM.shouldEmitConvergenceTokens()) {
     assert(EntryFn->isConvergent());
-    llvm::Value *
-        I = B.CreateIntrinsic(llvm::Intrinsic::experimental_convergence_entry, {},
-                              {});
+    llvm::Value *I = B.CreateIntrinsic(
+        llvm::Intrinsic::experimental_convergence_entry, {}, {});
     llvm::Value *bundleArgs[] = {I};
     OB.emplace_back("convergencectrl", bundleArgs);
   }
@@ -480,14 +479,15 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
   for (auto &F : M.functions()) {
     if (!F.hasFnAttribute("hlsl.shader"))
       continue;
-    auto* Token = getConvergenceToken(F.getEntryBlock());
-    Instruction* IP = Token ? Token : &*F.getEntryBlock().begin();
-    IRBuilder<> B(IP);
-    std::vector<OperandBundleDef> OB;
+    auto *Token = getConvergenceToken(F.getEntryBlock());
+    Instruction *IP = &*F.getEntryBlock().begin();
+    SmallVector<OperandBundleDef, 1> OB;
     if (Token) {
       llvm::Value *bundleArgs[] = {Token};
       OB.emplace_back("convergencectrl", bundleArgs);
+      IP = Token->getNextNode();
     }
+    IRBuilder<> B(IP);
     for (auto *Fn : CtorFns)
       B.CreateCall(FunctionCallee(Fn), {}, OB);
 
@@ -513,7 +513,7 @@ llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
     return nullptr;
 
   auto E = BB.end();
-  for(auto I = BB.begin(); I != E; ++I) {
+  for (auto I = BB.begin(); I != E; ++I) {
     auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
     if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
       return II;
diff --git a/clang/test/CodeGenHLSL/convergence/entry.point.hlsl b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
index a848c834da3535..337a9ad5026c16 100644
--- a/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
+++ b/clang/test/CodeGenHLSL/convergence/entry.point.hlsl
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -fnative-half-type -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
+// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute -finclude-default-header -disable-llvm-passes -emit-llvm -o - %s | FileCheck %s
 
 // CHECK-LABEL: define void @main()
 // CHECK-NEXT: entry:



More information about the cfe-commits mailing list