[llvm] [OpenMP][OMPIRBuilder] Use OMPKinds.def to specify callback metadata (PR #142753)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 4 02:41:00 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Ivan R. Ivanov (ivanradanov)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/142753.diff
3 Files Affected:
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPKinds.def (+23)
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+21-15)
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+51)
``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index f974cfc78c8dd..e01f30912f891 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -1418,3 +1418,26 @@ __OMP_ASSUME_CLAUSE(llvm::StringLiteral("no_parallelism"), false, false, false)
#undef __OMP_ASSUME_CLAUSE
#undef OMP_ASSUME_CLAUSE
///}
+
+
+/// Callback specification
+///
+///{
+
+#ifndef OMP_CALLBACK
+#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...)
+#endif
+
+#define __OMP_CALLBACK(Name, VarArgsArePassed, CallbackArgNo, ...) \
+ OMP_CALLBACK(OMPRTL_##Name, VarArgsArePassed, CallbackArgNo, __VA_ARGS__)
+
+__OMP_CALLBACK(__kmpc_fork_call, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_fork_call_if, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_fork_teams, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_omp_task_alloc, true, 5, -1, -1)
+
+#undef __OMP_PTR_TYPE
+
+#undef __OMP_TYPE
+#undef OMP_CALLBACK
+///}
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ca3d8438654dc..ac42cd9ab3297 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -614,21 +614,27 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
#include "llvm/Frontend/OpenMP/OMPKinds.def"
}
- // Add information if the runtime function takes a callback function
- if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
- if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
- LLVMContext &Ctx = Fn->getContext();
- MDBuilder MDB(Ctx);
- // Annotate the callback behavior of the runtime function:
- // - The callback callee is argument number 2 (microtask).
- // - The first two arguments of the callback callee are unknown (-1).
- // - All variadic arguments to the runtime function are passed to the
- // callback callee.
- Fn->addMetadata(
- LLVMContext::MD_callback,
- *MDNode::get(Ctx, {MDB.createCallbackEncoding(
- 2, {-1, -1}, /* VarArgsArePassed */ true)}));
- }
+ // Annotate the callback behavior of the runtime function:
+ // - First the callback callee argument number
+ // - Then the arguments passed on to the callback (-1 for unknown),
+ // variadic
+ // - Finally, whether variadic args are passed on to the callback.
+ LLVMContext &Ctx = Fn->getContext();
+ MDBuilder MDB(Ctx);
+ switch (FnID) {
+#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...) \
+ case Enum: { \
+ if (!Fn->hasMetadata(LLVMContext::MD_callback)) { \
+ Fn->addMetadata(LLVMContext::MD_callback, \
+ *MDNode::get(Ctx, {MDB.createCallbackEncoding( \
+ CallbackArgNo, {__VA_ARGS__}, \
+ VarArgsArePassed)})); \
+ } \
+ break; \
+ }
+#include "llvm/Frontend/OpenMP/OMPKinds.def"
+ default:
+ break;
}
LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index be98be260c9dc..6cb7b8664c64f 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7702,4 +7702,55 @@ TEST_F(OpenMPIRBuilderTest, splitBB) {
EXPECT_TRUE(DL == AllocaBB->getTerminator()->getStableDebugLoc());
}
+TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
+ OpenMPIRBuilder OMPBuilder(*M);
+ OMPBuilder.initialize();
+
+ FunctionCallee ForkCall = OMPBuilder.getOrCreateRuntimeFunction(
+ *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call);
+ FunctionCallee ForkCallIf = OMPBuilder.getOrCreateRuntimeFunction(
+ *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call_if);
+ FunctionCallee ForkTeam = OMPBuilder.getOrCreateRuntimeFunction(
+ *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_teams);
+ FunctionCallee TaskAlloc = OMPBuilder.getOrCreateRuntimeFunction(
+ *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_omp_task_alloc);
+
+ M->dump();
+ for (auto [FC, ArgNo] : zip(SmallVector<FunctionCallee>(
+ {ForkCall, ForkCallIf, ForkTeam, TaskAlloc}),
+ SmallVector<unsigned>({2, 2, 2, 5}))) {
+ MDNode *CallbackMD =
+ cast<Function>(FC.getCallee())->getMetadata(LLVMContext::MD_callback);
+ EXPECT_NE(CallbackMD, nullptr);
+ unsigned Num = 0;
+ CallbackMD->dump();
+ M->dump();
+ for (const MDOperand &Op : CallbackMD->operands()) {
+ Num++;
+ MDNode *OpMD = cast<MDNode>(Op.get());
+ auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
+ uint64_t CBCalleeIdx =
+ cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
+ EXPECT_EQ(CBCalleeIdx, ArgNo);
+
+ uint64_t Arg0 =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(OpMD->getOperand(1))->getValue())
+ ->getZExtValue();
+ uint64_t Arg1 =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(OpMD->getOperand(2))->getValue())
+ ->getZExtValue();
+ uint64_t VarArg =
+ cast<ConstantInt>(
+ cast<ConstantAsMetadata>(OpMD->getOperand(3))->getValue())
+ ->getZExtValue();
+ EXPECT_EQ(Arg0, -1);
+ EXPECT_EQ(Arg1, -1);
+ EXPECT_EQ(VarArg, true);
+ }
+ EXPECT_EQ(Num, 1);
+ }
+}
+
} // namespace
``````````
</details>
https://github.com/llvm/llvm-project/pull/142753
More information about the llvm-commits
mailing list