[llvm-branch-commits] [llvm] [mlir] [OpenMP][OMPIRBuilder] Support parallel in Generic kernels (PR #150926)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Aug 14 05:50:26 PDT 2025
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/150926
>From 632223908c734c16b4f01e0a2a44257a1cc153d7 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 4 Jul 2025 16:32:03 +0100
Subject: [PATCH 1/2] [OpenMP][OMPIRBuilder] Support parallel in Generic
kernels
This patch introduces codegen logic to produce a wrapper function argument for
the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed
using device shared memory in Generic mode.
---
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 100 ++++++++++++++++--
.../LLVMIR/omptarget-parallel-llvm.mlir | 25 ++++-
2 files changed, 116 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index c573c782d3958..1e1c5dab8a40c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1334,6 +1334,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
return Error::success();
}
+// Create wrapper function used to gather the outlined function's argument
+// structure from a shared buffer and to forward them to it when running in
+// Generic mode.
+//
+// The outlined function is expected to receive 2 integer arguments followed by
+// an optional pointer argument to an argument structure holding the rest.
+static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
+ Function &OutlinedFn) {
+ size_t NumArgs = OutlinedFn.arg_size();
+ assert((NumArgs == 2 || NumArgs == 3) &&
+ "expected a 2-3 argument parallel outlined function");
+ bool UseArgStruct = NumArgs == 3;
+
+ IRBuilder<> &Builder = OMPIRBuilder->Builder;
+ IRBuilder<>::InsertPointGuard IPG(Builder);
+ auto *FnTy = FunctionType::get(Builder.getVoidTy(),
+ {Builder.getInt16Ty(), Builder.getInt32Ty()},
+ /*isVarArg=*/false);
+ auto *WrapperFn =
+ Function::Create(FnTy, GlobalValue::InternalLinkage,
+ OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M);
+
+ WrapperFn->addParamAttr(0, Attribute::NoUndef);
+ WrapperFn->addParamAttr(0, Attribute::ZExt);
+ WrapperFn->addParamAttr(1, Attribute::NoUndef);
+
+ BasicBlock *EntryBB =
+ BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn);
+ Builder.SetInsertPoint(EntryBB);
+
+ // Allocation.
+ Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+ /*ArraySize=*/nullptr, "addr");
+ AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+ AddrAlloca->getName() + ".ascast");
+
+ Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(),
+ /*ArraySize=*/nullptr, "zero");
+ ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+ ZeroAlloca->getName() + ".ascast");
+
+ Value *ArgsAlloca = nullptr;
+ if (UseArgStruct) {
+ ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(),
+ /*ArraySize=*/nullptr, "global_args");
+ ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast(
+ ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0),
+ ArgsAlloca->getName() + ".ascast");
+ }
+
+ // Initialization.
+ Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca);
+ Builder.CreateStore(Builder.getInt32(0), ZeroAlloca);
+ if (UseArgStruct) {
+ Builder.CreateCall(
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr(
+ llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables),
+ {ArgsAlloca});
+ }
+
+ SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca};
+
+ // Load structArg from global_args.
+ if (UseArgStruct) {
+ Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca);
+ StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg,
+ {Builder.getInt64(0)});
+ StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg");
+ Args.push_back(StructArg);
+ }
+
+ // Call the outlined function holding the parallel body.
+ Builder.CreateCall(&OutlinedFn, Args);
+ Builder.CreateRetVoid();
+
+ return WrapperFn;
+}
+
// Callback used to create OpenMP runtime calls to support
// omp parallel clause for the device.
// We need to use this callback to replace call to the OutlinedFn in OuterFn
@@ -1343,6 +1423,10 @@ static void targetParallelCallback(
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
+ assert(OutlinedFn.arg_size() >= 2 &&
+ "Expected at least tid and bounded tid as arguments");
+ unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
// Add some known attributes.
IRBuilder<> &Builder = OMPIRBuilder->Builder;
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
@@ -1351,17 +1435,12 @@ static void targetParallelCallback(
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
OutlinedFn.addFnAttr(Attribute::NoUnwind);
- assert(OutlinedFn.arg_size() >= 2 &&
- "Expected at least tid and bounded tid as arguments");
- unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
-
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
assert(CI && "Expected call instruction to outlined function");
CI->getParent()->setName("omp_parallel");
Builder.SetInsertPoint(CI);
Type *PtrTy = OMPIRBuilder->VoidPtr;
- Value *NullPtrValue = Constant::getNullValue(PtrTy);
// Add alloca for kernel args
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
@@ -1387,6 +1466,15 @@ static void targetParallelCallback(
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
: Builder.getInt32(1);
+ // If this is not a Generic kernel, we can skip generating the wrapper.
+ std::optional<omp::OMPTgtExecModeFlags> ExecMode =
+ getTargetKernelExecMode(*OuterFn);
+ Value *WrapperFn;
+ if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
+ WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
+ else
+ WrapperFn = Constant::getNullValue(PtrTy);
+
// Build kmpc_parallel_51 call
Value *Parallel51CallArgs[] = {
/* identifier*/ Ident,
@@ -1395,7 +1483,7 @@ static void targetParallelCallback(
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
/* Proc bind */ Builder.getInt32(-1),
/* outlined function */ &OutlinedFn,
- /* wrapper function */ NullPtrValue,
+ /* wrapper function */ WrapperFn,
/* arguments of the outlined funciton*/ Args,
/* number of arguments */ Builder.getInt64(NumCapturedVars)};
diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
index 504e39c96f008..ca998b4672ba0 100644
--- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir
@@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8
// CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0
// CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8
-// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1)
+// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1)
// CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8)
// CHECK: call void @__kmpc_target_deinit()
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
// CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr),
// CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156,
-// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
+// CHECK-SAME: i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1)
// One of the arguments of kmpc_parallel_51 function is responsible for handling if clause
// of omp parallel construct for target region. If this argument is nonzero,
@@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (
// CHECK-SAME: ptr addrspace(1) {{.*}} to ptr),
// CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1,
-// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1)
+// CHECK-SAME: i32 -1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1)
+
+// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]])
+// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr
+// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5)
+// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr
+// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5)
+// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr
+// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]]
+// CHECK: store i32 0, ptr %[[ZERO_ASCAST]]
+// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]])
+// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8
+// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0
+// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8
+// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]])
+
+// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}})
+// CHECK-NOT: define
+// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}})
>From 4f751f7a6caebb36d4698f186907c489c7a4b82f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Thu, 14 Aug 2025 13:41:26 +0100
Subject: [PATCH 2/2] Address review comments
---
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 1e1c5dab8a40c..b3760a79a5a2c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -1334,12 +1334,12 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl(
return Error::success();
}
-// Create wrapper function used to gather the outlined function's argument
-// structure from a shared buffer and to forward them to it when running in
-// Generic mode.
-//
-// The outlined function is expected to receive 2 integer arguments followed by
-// an optional pointer argument to an argument structure holding the rest.
+/// Create wrapper function used to gather the outlined function's argument
+/// structure from a shared buffer and to forward them to it when running in
+/// Generic mode.
+///
+/// The outlined function is expected to receive 2 integer arguments followed by
+/// an optional pointer argument to an argument structure holding the rest.
static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder,
Function &OutlinedFn) {
size_t NumArgs = OutlinedFn.arg_size();
@@ -1470,7 +1470,7 @@ static void targetParallelCallback(
std::optional<omp::OMPTgtExecModeFlags> ExecMode =
getTargetKernelExecMode(*OuterFn);
Value *WrapperFn;
- if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC)
+ if (ExecMode && (*ExecMode & OMP_TGT_EXEC_MODE_GENERIC))
WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn);
else
WrapperFn = Constant::getNullValue(PtrTy);
More information about the llvm-branch-commits
mailing list