[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder][MLIR] Add support for target 'if' clause (PR #122478)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 10 07:50:20 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-openmp
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature.
---
Full diff: https://github.com/llvm/llvm-project/pull/122478.diff
6 Files Affected:
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15-11)
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+77-53)
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+6-5)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-3)
- (added) mlir/test/Target/LLVMIR/omptarget-if.mlir (+68)
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-11)
``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..b1a23996c7bdd2 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2965,21 +2965,25 @@ class OpenMPIRBuilder {
/// \param NumThreads Number of teams specified in the thread_limit clause.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
+ /// \param IfCond value of the `if` clause.
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
/// instructions for passed in target arguments where neccessary
/// \param Dependencies A vector of DependData objects that carry
- // dependency information as passed in the depend clause
- // \param HasNowait Whether the target construct has a `nowait` clause or not.
- InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
- OpenMPIRBuilder::InsertPointTy AllocaIP,
- OpenMPIRBuilder::InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
- GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
- TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
- SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+ /// dependency information as passed in the depend clause
+ /// \param HasNowait Whether the target construct has a `nowait` clause or
+ /// not.
+ InsertPointOrErrorTy
+ createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+ ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+ Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB,
+ TargetBodyGenCallbackTy BodyGenCB,
+ TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+ SmallVector<DependData> Dependencies = {},
+ bool HasNowait = false);
/// Returns __kmpc_for_static_init_* runtime function for the specified
/// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index db77c6a5869764..0e190f4c64a8b3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+ Value *IfCond,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
bool HasNoWait = false) {
@@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
return Error::success();
};
- // If we don't have an ID for the target region, it means an offload entry
- // wasn't created. In this case we just run the host fallback directly.
- if (!OutlinedFnID) {
+ auto &&EmitTargetCallElse =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}());
Builder.restoreIP(AfterIP);
- return;
- }
+ return Error::success();
+ };
+
+ auto &&EmitTargetCallThen =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+ OpenMPIRBuilder::TargetDataInfo Info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
+
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+ OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+ RTArgs, MapInfo,
+ /*IsNonContiguous=*/true,
+ /*ForEndCall=*/false);
+
+ SmallVector<Value *, 3> NumTeamsC;
+ SmallVector<Value *, 3> NumThreadsC;
+ for (auto V : NumTeams)
+ NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ for (auto V : NumThreads)
+ NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+
+ unsigned NumTargetItems = Info.NumberOfPtrs;
+ // TODO: Use correct device ID
+ Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+ Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+ llvm::omp::IdentFlag(0), 0);
+ // TODO: Use correct NumIterations
+ Value *NumIterations = Builder.getInt64(0);
+ // TODO: Use correct DynCGGroupMem
+ Value *DynCGGroupMem = Builder.getInt32(0);
+
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
+
+ // Assume no error was returned because TaskBodyCB and
+ // EmitTargetCallFallbackCB don't produce any.
+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+ // The presence of certain clauses on the target directive require the
+ // explicit generation of the target task.
+ if (RequiresOuterTargetTask)
+ return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+ Dependencies, HasNoWait);
+
+ return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+ EmitTargetCallFallbackCB, KArgs,
+ DeviceID, RTLoc, AllocaIP);
+ }());
- OpenMPIRBuilder::TargetDataInfo Info(
- /*RequiresDevicePointerInfo=*/false,
- /*SeparateBeginEndCalls=*/true);
+ Builder.restoreIP(AfterIP);
+ return Error::success();
+ };
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
- OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
- RTArgs, MapInfo,
- /*IsNonContiguous=*/true,
- /*ForEndCall=*/false);
+ // If we don't have an ID for the target region, it means an offload entry
+ // wasn't created. In this case we just run the host fallback directly.
+ if (!OutlinedFnID) {
+ cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+ return;
+ }
- SmallVector<Value *, 3> NumTeamsC;
- SmallVector<Value *, 3> NumThreadsC;
- for (auto V : NumTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : NumThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ // If there's no IF clause, only generate the kernel launch code path.
+ if (!IfCond) {
+ cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+ return;
+ }
- unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
- uint32_t SrcLocStrSize;
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
- Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
- llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
- // TODO: Use correct DynCGGroupMem
- Value *DynCGGroupMem = Builder.getInt32(0);
-
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
-
- // Assume no error was returned because TaskBodyCB and
- // EmitTargetCallFallbackCB don't produce any.
- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
- // The presence of certain clauses on the target directive require the
- // explicit generation of the target task.
- if (RequiresOuterTargetTask)
- return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
- Dependencies, HasNoWait);
-
- return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
- EmitTargetCallFallbackCB, KArgs,
- DeviceID, RTLoc, AllocaIP);
- }());
-
- Builder.restoreIP(AfterIP);
+ cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+ EmitTargetCallElse, AllocaIP));
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
- SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+ SmallVectorImpl<Value *> &Args, Value *IfCond,
+ GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
@@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+ NumThreads, Args, IfCond, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index cdca725b147436..94dce5243d7004 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, -1, 0, Inputs,
+ Builder.saveIP(), EntryInfo, /*NumTeams=*/-1,
+ /*NumThreads=*/0, Inputs, /*IfCond=*/nullptr,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
OMPBuilder.finalize();
@@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+ EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+ CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
@@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+ EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+ CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a364098e0bd8a6..0c637bd32ab3f3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
- checkIf(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
// Privatization clauses are supported, except on some situations, so we
@@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::Value *ifCond = nullptr;
+ if (Value targetIfCond = targetOp.getIfExpr())
+ ifCond = moduleTranslation.lookupValue(targetIfCond);
+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
- defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, dds, targetOp.getNowait());
+ defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB,
+ bodyCB, argAccessorCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @target_if_variable(%x : i1) {
+ omp.target if(%x) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_variable(
+ // CHECK-SAME: i1 %[[IF_COND:.*]])
+ // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+ // CHECK: [[THEN_LABEL]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ // CHECK: [[OFFLOAD_CONT_LABEL]]:
+ // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+ // CHECK: [[ELSE_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+ // CHECK-NEXT: br label %[[END_LABEL]]
+
+ llvm.func @target_if_true() {
+ %0 = llvm.mlir.constant(true) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_true()
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ llvm.func @target_if_false() {
+ %0 = llvm.mlir.constant(false) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_false()
+ // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 83a0990d631620..4e0925c833c3b7 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
// -----
-llvm.func @target_if(%x : i1) {
- // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.target}}
- omp.target if(%x) {
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
``````````
</details>
https://github.com/llvm/llvm-project/pull/122478
More information about the llvm-branch-commits
mailing list