[Mlir-commits] [mlir] d463a27 - [OpenMP][OMPIRBuilder] Support parallel in Generic kernels (#150926)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 27 05:10:20 PDT 2026


Author: Sergio Afonso
Date: 2026-04-27T13:10:15+01:00
New Revision: d463a276fcd68b3f0a070e69600fbfcb3aa2fa6e

URL: https://github.com/llvm/llvm-project/commit/d463a276fcd68b3f0a070e69600fbfcb3aa2fa6e
DIFF: https://github.com/llvm/llvm-project/commit/d463a276fcd68b3f0a070e69600fbfcb3aa2fa6e.diff

LOG: [OpenMP][OMPIRBuilder] Support parallel in Generic kernels (#150926)

This patch introduces codegen logic to produce a wrapper function
argument for the `__kmpc_parallel_51` DeviceRTL function needed to
handle arguments passed using device shared memory in Generic mode.

Added: 
    

Modified: 
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ade1e2c8c3dc6..5dbb8aba2403e 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1511,6 +1511,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
   return Error::success();
 }
 
+/// Create wrapper function used to gather the outlined function's argument
+/// structure from a shared buffer and to forward them to it when running in
+/// Generic mode.
+///
+/// The outlined function is expected to receive 2 integer arguments followed by
+/// an optional pointer argument to an argument structure holding the rest.
+static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
+                                             Function &OutlinedFn) {
+  size_t NumArgs = OutlinedFn.arg_size();
+  assert((NumArgs == 2 || NumArgs == 3) &&
+         "expected a 2-3 argument parallel outlined function");
+  bool UseArgStruct = NumArgs == 3;
+
+  IRBuilder<> &Builder = OMPIRBuilder->Builder;
+  IRBuilder<>::InsertPointGuard IPG(Builder);
+  auto *FnTy = FunctionType::get(Builder.getVoidTy(),
+                                 {Builder.getInt16Ty(), Builder.getInt32Ty()},
+                                 /*isVarArg=*/false);
+  auto *WrapperFn =
+      Function::Create(FnTy, GlobalValue::InternalLinkage,
+                       OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
+
+  WrapperFn->addParamAttr(0, Attribute::NoUndef);
+  WrapperFn->addParamAttr(0, Attribute::ZExt);
+  WrapperFn->addParamAttr(1, Attribute::NoUndef);
+
+  BasicBlock *EntryBB =
+      BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
+  Builder.SetInsertPoint(EntryBB);
+
+  // Allocation.
+  Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+                                           /*ArraySize=*/nullptr, "addr");
+  AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+      AddrAlloca->getName() + ".ascast");
+
+  Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+                                           /*ArraySize=*/nullptr, "zero");
+  ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+      ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+      ZeroAlloca->getName() + ".ascast");
+
+  Value *ArgsAlloca = nullptr;
+  if (UseArgStruct) {
+    ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
+                                      /*ArraySize=*/nullptr, "global_args");
+    ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+        ArgsAlloca->getName() + ".ascast");
+  }
+
+  // Initialization.
+  Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
+  Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
+  if (UseArgStruct) {
+    Builder.CreateCall(
+        OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
+            llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
+        {ArgsAlloca});
+  }
+
+  SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
+
+  // Load structArg from global_args.
+  if (UseArgStruct) {
+    Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
+    StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
+                                          {Builder.getInt64(0)});
+    StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
+    Args.push_back(StructArg);
+  }
+
+  // Call the outlined function holding the parallel body.
+  Builder.CreateCall(&OutlinedFn, Args);
+  Builder.CreateRetVoid();
+
+  return WrapperFn;
+}
+
 // Callback used to create OpenMP runtime calls to support
 // omp parallel clause for the device.
 // We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1520,6 +1600,10 @@ static void targetParallelCallback(
     BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
     Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
     Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
+  assert(OutlinedFn.arg_size() >= 2 &&
+         "Expected at least tid and bounded tid as arguments");
+  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
   // Add some known attributes.
   IRBuilder<> &Builder = OMPIRBuilder->Builder;
   OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1528,17 +1612,12 @@ static void targetParallelCallback(
   OutlinedFn.addParamAttr(1, Attribute::NoUndef);
   OutlinedFn.addFnAttr(Attribute::NoUnwind);
 
-  assert(OutlinedFn.arg_size() >= 2 &&
-         "Expected at least tid and bounded tid as arguments");
-  unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
-
   CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
   assert(CI && "Expected call instruction to outlined function");
   CI->getParent()->setName("omp_parallel");
 
   Builder.SetInsertPoint(CI);
   Type *PtrTy = OMPIRBuilder->VoidPtr;
-  Value *NullPtrValue = Constant::getNullValue(PtrTy);
 
   // Add alloca for kernel args
   OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1564,6 +1643,15 @@ static void targetParallelCallback(
       IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
                   : Builder.getInt32(1);
 
+  // If this is not a Generic kernel, we can skip generating the wrapper.
+  std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+      getTargetKernelExecMode(*OuterFn);
+  Value *WrapperFn;
+  if (ExecMode && (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC))
+    WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
+  else
+    WrapperFn = Constant::getNullValue(PtrTy);
+
   // Build kmpc_parallel_60 call
   Value *Parallel60CallArgs[] = {
       /* identifier*/ Ident,
@@ -1572,7 +1660,7 @@ static void targetParallelCallback(
       /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
       /* Proc bind */ Builder.getInt32(-1),
       /* outlined function */ &OutlinedFn,
-      /* wrapper function */ NullPtrValue,
+      /* wrapper function */ WrapperFn,
       /* arguments of the outlined funciton*/ Args,
       /* number of arguments */ Builder.getInt64(NumCapturedVars),
       /* strict for number of threads */ Builder.getInt32(0)};

diff  --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 72652bc14cde4..8b4b2120006e3 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         store ptr %[[TMP6]], ptr %[[GEP_]], align 8
 // CHECK:         %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
 // CHECK:         store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
-// CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1, i32 0)
+// CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1, i32 0)
 // CHECK:         call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
 // CHECK:         call void @__kmpc_target_deinit()
 
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (
 // CHECK-SAME:  ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
 // CHECK-SAME:  i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
-// CHECK-SAME:  i32 -1,  ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1, i32 0)
+// CHECK-SAME:  i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1, i32 0)
 
 // One of the arguments of  kmpc_parallel_60 function is responsible for handling if clause
 // of omp parallel construct for target region. If this  argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
 // CHECK:         call void @__kmpc_parallel_60(ptr addrspacecast (
 // CHECK-SAME:  ptr addrspace(1) {{.*}} to ptr),
 // CHECK-SAME:  i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
-// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr null, ptr {{.*}}, i64 1, i32 0)
+// CHECK-SAME:  i32 -1,  ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1, i32 0)
+
+// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
+// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
+// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
+// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
+// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
+// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
+// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
+// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
+// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
+// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
+// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
+// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])
+
+// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
+// CHECK-NOT: define
+// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})


        


More information about the Mlir-commits mailing list