[llvm] [OpenMP][OMPIRBuilder] Use OMPKinds.def to specify callback metadata (PR #142753)

Ivan R. Ivanov via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 4 02:40:24 PDT 2025


https://github.com/ivanradanov created https://github.com/llvm/llvm-project/pull/142753

None

>From 852ca95c185fdeffeffcaed1ca1931a05330f52a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 4 Jun 2025 18:33:40 +0900
Subject: [PATCH] [OpenMP][OMPIRBuilder] Use OMPKinds.def to specify callback
 metadata

---
 .../include/llvm/Frontend/OpenMP/OMPKinds.def | 23 +++++++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 36 +++++++------
 .../Frontend/OpenMPIRBuilderTest.cpp          | 51 +++++++++++++++++++
 3 files changed, 95 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index f974cfc78c8dd..e01f30912f891 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -1418,3 +1418,26 @@ __OMP_ASSUME_CLAUSE(llvm::StringLiteral("no_parallelism"), false, false, false)
 #undef __OMP_ASSUME_CLAUSE
 #undef OMP_ASSUME_CLAUSE
 ///}
+
+
+/// Callback specification
+///
+///{
+
+#ifndef OMP_CALLBACK
+#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...)
+#endif
+
+#define __OMP_CALLBACK(Name, VarArgsArePassed, CallbackArgNo, ...)             \
+  OMP_CALLBACK(OMPRTL_##Name, VarArgsArePassed, CallbackArgNo, __VA_ARGS__)
+
+__OMP_CALLBACK(__kmpc_fork_call, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_fork_call_if, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_fork_teams, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_omp_task_alloc, true, 5, -1, -1)
+
+#undef __OMP_PTR_TYPE
+
+#undef __OMP_TYPE
+#undef OMP_CALLBACK
+///}
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ca3d8438654dc..ac42cd9ab3297 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -614,21 +614,27 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
 #include "llvm/Frontend/OpenMP/OMPKinds.def"
     }
 
-    // Add information if the runtime function takes a callback function
-    if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
-      if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
-        LLVMContext &Ctx = Fn->getContext();
-        MDBuilder MDB(Ctx);
-        // Annotate the callback behavior of the runtime function:
-        //  - The callback callee is argument number 2 (microtask).
-        //  - The first two arguments of the callback callee are unknown (-1).
-        //  - All variadic arguments to the runtime function are passed to the
-        //    callback callee.
-        Fn->addMetadata(
-            LLVMContext::MD_callback,
-            *MDNode::get(Ctx, {MDB.createCallbackEncoding(
-                                  2, {-1, -1}, /* VarArgsArePassed */ true)}));
-      }
+    // Annotate the callback behavior of the runtime function:
+    //  - First the callback callee argument number
+    //  - Then the arguments passed on to the callback (-1 for unknown),
+    //    variadic
+    //  - Finally, whether variadic args are passed on to the callback.
+    LLVMContext &Ctx = Fn->getContext();
+    MDBuilder MDB(Ctx);
+    switch (FnID) {
+#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...)               \
+  case Enum: {                                                                 \
+    if (!Fn->hasMetadata(LLVMContext::MD_callback)) {                          \
+      Fn->addMetadata(LLVMContext::MD_callback,                                \
+                      *MDNode::get(Ctx, {MDB.createCallbackEncoding(           \
+                                            CallbackArgNo, {__VA_ARGS__},      \
+                                            VarArgsArePassed)}));              \
+    }                                                                          \
+    break;                                                                     \
+  }
+#include "llvm/Frontend/OpenMP/OMPKinds.def"
+    default:
+      break;
     }
 
     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index be98be260c9dc..6cb7b8664c64f 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7702,4 +7702,55 @@ TEST_F(OpenMPIRBuilderTest, splitBB) {
     EXPECT_TRUE(DL == AllocaBB->getTerminator()->getStableDebugLoc());
 }
 
+TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+
+  FunctionCallee ForkCall = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call);
+  FunctionCallee ForkCallIf = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call_if);
+  FunctionCallee ForkTeam = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_teams);
+  FunctionCallee TaskAlloc = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_omp_task_alloc);
+
+  M->dump();
+  for (auto [FC, ArgNo] : zip(SmallVector<FunctionCallee>(
+                                  {ForkCall, ForkCallIf, ForkTeam, TaskAlloc}),
+                              SmallVector<unsigned>({2, 2, 2, 5}))) {
+    MDNode *CallbackMD =
+        cast<Function>(FC.getCallee())->getMetadata(LLVMContext::MD_callback);
+    EXPECT_NE(CallbackMD, nullptr);
+    unsigned Num = 0;
+    CallbackMD->dump();
+    M->dump();
+    for (const MDOperand &Op : CallbackMD->operands()) {
+      Num++;
+      MDNode *OpMD = cast<MDNode>(Op.get());
+      auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
+      uint64_t CBCalleeIdx =
+          cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
+      EXPECT_EQ(CBCalleeIdx, ArgNo);
+
+      uint64_t Arg0 =
+          cast<ConstantInt>(
+              cast<ConstantAsMetadata>(OpMD->getOperand(1))->getValue())
+              ->getZExtValue();
+      uint64_t Arg1 =
+          cast<ConstantInt>(
+              cast<ConstantAsMetadata>(OpMD->getOperand(2))->getValue())
+              ->getZExtValue();
+      uint64_t VarArg =
+          cast<ConstantInt>(
+              cast<ConstantAsMetadata>(OpMD->getOperand(3))->getValue())
+              ->getZExtValue();
+      EXPECT_EQ(Arg0, -1);
+      EXPECT_EQ(Arg1, -1);
+      EXPECT_EQ(VarArg, true);
+    }
+    EXPECT_EQ(Num, 1);
+  }
+}
+
 } // namespace



More information about the llvm-commits mailing list