[llvm] [OpenMP][OMPIRBuilder] Use OMPKinds.def to specify callback metadata (PR #142753)

Ivan R. Ivanov via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 13 07:23:32 PDT 2025


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/142753

>From 852ca95c185fdeffeffcaed1ca1931a05330f52a Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 4 Jun 2025 18:33:40 +0900
Subject: [PATCH 1/3] [OpenMP][OMPIRBuilder] Use OMPKinds.def to specify
 callback metadata

---
 .../include/llvm/Frontend/OpenMP/OMPKinds.def | 23 +++++++++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 36 +++++++------
 .../Frontend/OpenMPIRBuilderTest.cpp          | 51 +++++++++++++++++++
 3 files changed, 95 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index f974cfc78c8dd..e01f30912f891 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -1418,3 +1418,26 @@ __OMP_ASSUME_CLAUSE(llvm::StringLiteral("no_parallelism"), false, false, false)
 #undef __OMP_ASSUME_CLAUSE
 #undef OMP_ASSUME_CLAUSE
 ///}
+
+
+/// Callback specification
+///
+///{
+
+#ifndef OMP_CALLBACK
+#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...)
+#endif
+
+#define __OMP_CALLBACK(Name, VarArgsArePassed, CallbackArgNo, ...)             \
+  OMP_CALLBACK(OMPRTL_##Name, VarArgsArePassed, CallbackArgNo, __VA_ARGS__)
+
+__OMP_CALLBACK(__kmpc_fork_call, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_fork_call_if, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_fork_teams, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_omp_task_alloc, true, 5, -1, -1)
+
+#undef __OMP_PTR_TYPE
+
+#undef __OMP_TYPE
+#undef OMP_CALLBACK
+///}
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ca3d8438654dc..ac42cd9ab3297 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -614,21 +614,27 @@ OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
 #include "llvm/Frontend/OpenMP/OMPKinds.def"
     }
 
-    // Add information if the runtime function takes a callback function
-    if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
-      if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
-        LLVMContext &Ctx = Fn->getContext();
-        MDBuilder MDB(Ctx);
-        // Annotate the callback behavior of the runtime function:
-        //  - The callback callee is argument number 2 (microtask).
-        //  - The first two arguments of the callback callee are unknown (-1).
-        //  - All variadic arguments to the runtime function are passed to the
-        //    callback callee.
-        Fn->addMetadata(
-            LLVMContext::MD_callback,
-            *MDNode::get(Ctx, {MDB.createCallbackEncoding(
-                                  2, {-1, -1}, /* VarArgsArePassed */ true)}));
-      }
+    // Annotate the callback behavior of the runtime function:
+    //  - First the callback callee argument number
+    //  - Then the arguments passed on to the callback (-1 for unknown),
+    //    variadic
+    //  - Finally, whether variadic args are passed on to the callback.
+    LLVMContext &Ctx = Fn->getContext();
+    MDBuilder MDB(Ctx);
+    switch (FnID) {
+#define OMP_CALLBACK(Enum, VarArgsArePassed, CallbackArgNo, ...)               \
+  case Enum: {                                                                 \
+    if (!Fn->hasMetadata(LLVMContext::MD_callback)) {                          \
+      Fn->addMetadata(LLVMContext::MD_callback,                                \
+                      *MDNode::get(Ctx, {MDB.createCallbackEncoding(           \
+                                            CallbackArgNo, {__VA_ARGS__},      \
+                                            VarArgsArePassed)}));              \
+    }                                                                          \
+    break;                                                                     \
+  }
+#include "llvm/Frontend/OpenMP/OMPKinds.def"
+    default:
+      break;
     }
 
     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index be98be260c9dc..6cb7b8664c64f 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7702,4 +7702,55 @@ TEST_F(OpenMPIRBuilderTest, splitBB) {
     EXPECT_TRUE(DL == AllocaBB->getTerminator()->getStableDebugLoc());
 }
 
+TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+
+  FunctionCallee ForkCall = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call);
+  FunctionCallee ForkCallIf = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_call_if);
+  FunctionCallee ForkTeam = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_fork_teams);
+  FunctionCallee TaskAlloc = OMPBuilder.getOrCreateRuntimeFunction(
+      *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_omp_task_alloc);
+
+  M->dump();
+  for (auto [FC, ArgNo] : zip(SmallVector<FunctionCallee>(
+                                  {ForkCall, ForkCallIf, ForkTeam, TaskAlloc}),
+                              SmallVector<unsigned>({2, 2, 2, 5}))) {
+    MDNode *CallbackMD =
+        cast<Function>(FC.getCallee())->getMetadata(LLVMContext::MD_callback);
+    EXPECT_NE(CallbackMD, nullptr);
+    unsigned Num = 0;
+    CallbackMD->dump();
+    M->dump();
+    for (const MDOperand &Op : CallbackMD->operands()) {
+      Num++;
+      MDNode *OpMD = cast<MDNode>(Op.get());
+      auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
+      uint64_t CBCalleeIdx =
+          cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
+      EXPECT_EQ(CBCalleeIdx, ArgNo);
+
+      uint64_t Arg0 =
+          cast<ConstantInt>(
+              cast<ConstantAsMetadata>(OpMD->getOperand(1))->getValue())
+              ->getZExtValue();
+      uint64_t Arg1 =
+          cast<ConstantInt>(
+              cast<ConstantAsMetadata>(OpMD->getOperand(2))->getValue())
+              ->getZExtValue();
+      uint64_t VarArg =
+          cast<ConstantInt>(
+              cast<ConstantAsMetadata>(OpMD->getOperand(3))->getValue())
+              ->getZExtValue();
+      EXPECT_EQ(Arg0, -1);
+      EXPECT_EQ(Arg1, -1);
+      EXPECT_EQ(VarArg, true);
+    }
+    EXPECT_EQ(Num, 1);
+  }
+}
+
 } // namespace

>From 3e704c34346b5f5ab78751aa8a8590b4cbadf316 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 4 Jun 2025 19:01:01 +0900
Subject: [PATCH 2/3] Fix var args

---
 llvm/include/llvm/Frontend/OpenMP/OMPKinds.def  |  4 ++--
 llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 12 +++++++-----
 2 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
index e01f30912f891..1d5ae099c76b4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -1432,9 +1432,9 @@ __OMP_ASSUME_CLAUSE(llvm::StringLiteral("no_parallelism"), false, false, false)
   OMP_CALLBACK(OMPRTL_##Name, VarArgsArePassed, CallbackArgNo, __VA_ARGS__)
 
 __OMP_CALLBACK(__kmpc_fork_call, true, 2, -1, -1)
-__OMP_CALLBACK(__kmpc_fork_call_if, true, 2, -1, -1)
+__OMP_CALLBACK(__kmpc_fork_call_if, false, 2, -1, -1)
 __OMP_CALLBACK(__kmpc_fork_teams, true, 2, -1, -1)
-__OMP_CALLBACK(__kmpc_omp_task_alloc, true, 5, -1, -1)
+__OMP_CALLBACK(__kmpc_omp_task_alloc, false, 5, -1, -1)
 
 #undef __OMP_PTR_TYPE
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 6cb7b8664c64f..3281c4cd32e16 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7716,9 +7716,11 @@ TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
       *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_omp_task_alloc);
 
   M->dump();
-  for (auto [FC, ArgNo] : zip(SmallVector<FunctionCallee>(
-                                  {ForkCall, ForkCallIf, ForkTeam, TaskAlloc}),
-                              SmallVector<unsigned>({2, 2, 2, 5}))) {
+  for (auto [FC, VarArg, ArgNo] :
+       zip(SmallVector<FunctionCallee>(
+               {ForkCall, ForkCallIf, ForkTeam, TaskAlloc}),
+           SmallVector<bool>({true, false, true, false}),
+           SmallVector<unsigned>({2, 2, 2, 5}))) {
     MDNode *CallbackMD =
         cast<Function>(FC.getCallee())->getMetadata(LLVMContext::MD_callback);
     EXPECT_NE(CallbackMD, nullptr);
@@ -7741,13 +7743,13 @@ TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
           cast<ConstantInt>(
               cast<ConstantAsMetadata>(OpMD->getOperand(2))->getValue())
               ->getZExtValue();
-      uint64_t VarArg =
+      uint64_t _VarArg =
           cast<ConstantInt>(
               cast<ConstantAsMetadata>(OpMD->getOperand(3))->getValue())
               ->getZExtValue();
       EXPECT_EQ(Arg0, -1);
       EXPECT_EQ(Arg1, -1);
-      EXPECT_EQ(VarArg, true);
+      EXPECT_EQ(_VarArg, VarArg);
     }
     EXPECT_EQ(Num, 1);
   }

>From 9fb46efeef35fd7b7d197cb7655ea17decb715e8 Mon Sep 17 00:00:00 2001
From: "Ivan R. Ivanov" <ivanov.i.aa at m.titech.ac.jp>
Date: Fri, 13 Jun 2025 23:23:23 +0900
Subject: [PATCH 3/3] Remove stray debug dumps

---
 llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 3281c4cd32e16..c04e2e399f52c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7715,7 +7715,6 @@ TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
   FunctionCallee TaskAlloc = OMPBuilder.getOrCreateRuntimeFunction(
       *M, llvm::omp::RuntimeFunction::OMPRTL___kmpc_omp_task_alloc);
 
-  M->dump();
   for (auto [FC, VarArg, ArgNo] :
        zip(SmallVector<FunctionCallee>(
                {ForkCall, ForkCallIf, ForkTeam, TaskAlloc}),
@@ -7725,8 +7724,6 @@ TEST_F(OpenMPIRBuilderTest, createCallbackMetadata) {
         cast<Function>(FC.getCallee())->getMetadata(LLVMContext::MD_callback);
     EXPECT_NE(CallbackMD, nullptr);
     unsigned Num = 0;
-    CallbackMD->dump();
-    M->dump();
     for (const MDOperand &Op : CallbackMD->operands()) {
       Num++;
       MDNode *OpMD = cast<MDNode>(Op.get());



More information about the llvm-commits mailing list