[llvm] [NVPTX] Allow the ctor/dtor lowering pass to emit kernels (PR #71549)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 9 11:49:46 PST 2023
================
@@ -42,11 +48,152 @@ static std::string getHash(StringRef Str) {
return llvm::utohexstr(Hash.low(), /*LowerCase=*/true);
}
-static bool createInitOrFiniGlobls(Module &M, StringRef GlobalName,
- bool IsCtor) {
- GlobalVariable *GV = M.getGlobalVariable(GlobalName);
- if (!GV || !GV->hasInitializer())
- return false;
+static void addKernelMetadata(Module &M, GlobalValue *GV) {
+ llvm::LLVMContext &Ctx = M.getContext();
+
+ // Get "nvvm.annotations" metadata node.
+ llvm::NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
+
+ llvm::Metadata *KernelMDVals[] = {
+ llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "kernel"),
+ llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+ // This kernel is only to be called single-threaded.
+ llvm::Metadata *ThreadXMDVals[] = {
+ llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidx"),
+ llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+ llvm::Metadata *ThreadYMDVals[] = {
+ llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidy"),
+ llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+ llvm::Metadata *ThreadZMDVals[] = {
+ llvm::ConstantAsMetadata::get(GV), llvm::MDString::get(Ctx, "maxntidz"),
+ llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+ llvm::Metadata *BlockMDVals[] = {
+ llvm::ConstantAsMetadata::get(GV),
+ llvm::MDString::get(Ctx, "maxclusterrank"),
+ llvm::ConstantAsMetadata::get(
+ llvm::ConstantInt::get(llvm::Type::getInt32Ty(Ctx), 1))};
+
+ // Append metadata to nvvm.annotations.
+ MD->addOperand(llvm::MDNode::get(Ctx, KernelMDVals));
+ MD->addOperand(llvm::MDNode::get(Ctx, ThreadXMDVals));
+ MD->addOperand(llvm::MDNode::get(Ctx, ThreadYMDVals));
+ MD->addOperand(llvm::MDNode::get(Ctx, ThreadZMDVals));
+ MD->addOperand(llvm::MDNode::get(Ctx, BlockMDVals));
+}
+
+static Function *createInitOrFiniKernelFunction(Module &M, bool IsCtor) {
----------------
Artem-B wrote:
Got it. SGTM.
https://github.com/llvm/llvm-project/pull/71549
More information about the llvm-commits
mailing list