[llvm] [NVPTX] Allow the ctor/dtor lowering pass to emit kernels (PR #71549)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 9 11:47:31 PST 2023


================
@@ -42,11 +48,152 @@ static std::string getHash(StringRef Str) {
   return llvm::utohexstr(Hash.low(), /*LowerCase=*/true);
 }
 
-static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
-                                   bool IsCtor) {
-  GlobalVariable *GV = M.getGlobalVariable(GlobalName);
-  if (!GV || !GV->hasInitializer())
-    return false;
+static void addKernelMetadata(Module &M, GlobalValue *GV) {
+  llvm::LLVMContext &Ctx = M.getContext();
+
+  // Get "nvvm.annotations" metadata node.
+  llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
+
+  llvm::Metadata *KernelMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "kernel"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+  // This kernel is only to be called single-threaded.
+  llvm::Metadata *ThreadXMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidx"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+  llvm::Metadata *ThreadYMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidy"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+  llvm::Metadata *ThreadZMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidz"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+  llvm::Metadata *BlockMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV),
+      llvm::MDString::get(Ctx, "maxclusterrank"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+  // Append metadata to nvvm.annotations.
+  MD->addOperand(llvm::MDNode::get(Ctx, KernelMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, ThreadXMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, ThreadYMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, ThreadZMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, BlockMDVals));
+}
+
+static Function *createInitOrFiniKernelFunction(Module &M, bool IsCtor) {
----------------
jhuber6 wrote:

So, with a normal linker it would find anything in the `.init_array.N` section, sort it by priority order, then gives us a symbol for the beginning and the end of that sorted array.

Here, because PTX has no concept of sections or proper linking, we emit every global constructor as a symbol with some prefixed name appended with a semi-unique hash. We, the runtime, then pull out all these symbols and basically do what the linker would've done for us and make the sorted array.

The generated kernel then reads two symbols `__init_array_start` and `__init_array_end` to traverse this list. These symbols are `weak`, so they will all be merged in the case of multiple TUs. Similarly, the `nvptx$device$init` kernel is `weak_odr` because for every TU it will just be traversing this list.

So, there's a single init kernel and init global that will be copied N times per TU if they contained global constructors. These definitions are weak and should all be identical, so any of them should do. If the user did not initialize the `__init_array_start` and `__init_array_stop` variables, they are both nullptr so the kernel will exit w/o doing anything.

https://github.com/llvm/llvm-project/pull/71549


More information about the llvm-commits mailing list