[llvm] [NVPTX] Allow the ctor/dtor lowering pass to emit kernels (PR #71549)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 9 11:49:46 PST 2023


================
@@ -42,11 +48,152 @@ static std::string getHash(StringRef Str) {
   return llvm::utohexstr(Hash.low(), /*LowerCase=*/true);
 }
 
-static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
-                                   bool IsCtor) {
-  GlobalVariable *GV = M.getGlobalVariable(GlobalName);
-  if (!GV || !GV->hasInitializer())
-    return false;
+static void addKernelMetadata(Module &M, GlobalValue *GV) {
+  llvm::LLVMContext &Ctx = M.getContext();
+
+  // Get "nvvm.annotations" metadata node.
+  llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
+
+  llvm::Metadata *KernelMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "kernel"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+  // This kernel is only to be called single-threaded.
+  llvm::Metadata *ThreadXMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidx"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+  llvm::Metadata *ThreadYMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidy"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+  llvm::Metadata *ThreadZMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidz"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+  llvm::Metadata *BlockMDVals[] = {
+      llvm::ConstantAsMetadata::get(GV),
+      llvm::MDString::get(Ctx, "maxclusterrank"),
+      llvm::ConstantAsMetadata::get(
+          llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+  // Append metadata to nvvm.annotations.
+  MD->addOperand(llvm::MDNode::get(Ctx, KernelMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, ThreadXMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, ThreadYMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, ThreadZMDVals));
+  MD->addOperand(llvm::MDNode::get(Ctx, BlockMDVals));
+}
+
+static Function *createInitOrFiniKernelFunction(Module &M, bool IsCtor) {
----------------
Artem-B wrote:

Got it. SGTM.

https://github.com/llvm/llvm-project/pull/71549


More information about the llvm-commits mailing list