[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder][MLIR] Add support for target 'if' clause (PR #122478)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 10 07:49:49 PST 2025
https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/122478
This patch implements support for handling the 'if' clause of OpenMP 'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the `omp.target` MLIR operation to make use of this new feature.
>From 3aea11b2d7857784d442d97925561ea54a5a8095 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 10 Jan 2025 15:40:05 +0000
Subject: [PATCH] [OMPIRBuilder][MLIR] Add support for target 'if' clause
This patch implements support for handling the 'if' clause of OpenMP 'target'
constructs in the OMPIRBuilder and updates MLIR to LLVM IR translation of the
`omp.target` MLIR operation to make use of this new feature.
---
.../llvm/Frontend/OpenMP/OMPIRBuilder.h | 26 ++--
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 130 +++++++++++-------
.../Frontend/OpenMPIRBuilderTest.cpp | 11 +-
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 +-
mlir/test/Target/LLVMIR/omptarget-if.mlir | 68 +++++++++
mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 --
6 files changed, 172 insertions(+), 83 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/omptarget-if.mlir
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..b1a23996c7bdd2 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2965,21 +2965,25 @@ class OpenMPIRBuilder {
/// \param NumThreads Number of teams specified in the thread_limit clause.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
+ /// \param IfCond value of the `if` clause.
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
/// instructions for passed in target arguments where neccessary
/// \param Dependencies A vector of DependData objects that carry
- // dependency information as passed in the depend clause
- // \param HasNowait Whether the target construct has a `nowait` clause or not.
- InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
- OpenMPIRBuilder::InsertPointTy AllocaIP,
- OpenMPIRBuilder::InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
- GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
- TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
- SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
+ /// dependency information as passed in the depend clause
+ /// \param HasNowait Whether the target construct has a `nowait` clause or
+ /// not.
+ InsertPointOrErrorTy
+ createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+ ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+ Value *IfCond, GenMapInfoCallbackTy GenMapInfoCB,
+ TargetBodyGenCallbackTy BodyGenCB,
+ TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+ SmallVector<DependData> Dependencies = {},
+ bool HasNowait = false);
/// Returns __kmpc_for_static_init_* runtime function for the specified
/// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index db77c6a5869764..0e190f4c64a8b3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7310,6 +7310,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+ Value *IfCond,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
bool HasNoWait = false) {
@@ -7354,9 +7355,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
return Error::success();
};
- // If we don't have an ID for the target region, it means an offload entry
- // wasn't created. In this case we just run the host fallback directly.
- if (!OutlinedFnID) {
+ auto &&EmitTargetCallElse =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7372,65 +7373,87 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}());
Builder.restoreIP(AfterIP);
- return;
- }
+ return Error::success();
+ };
+
+ auto &&EmitTargetCallThen =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+ OpenMPIRBuilder::TargetDataInfo Info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
+
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+ OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+ RTArgs, MapInfo,
+ /*IsNonContiguous=*/true,
+ /*ForEndCall=*/false);
+
+ SmallVector<Value *, 3> NumTeamsC;
+ SmallVector<Value *, 3> NumThreadsC;
+ for (auto V : NumTeams)
+ NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ for (auto V : NumThreads)
+ NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+
+ unsigned NumTargetItems = Info.NumberOfPtrs;
+ // TODO: Use correct device ID
+ Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+ Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+ llvm::omp::IdentFlag(0), 0);
+ // TODO: Use correct NumIterations
+ Value *NumIterations = Builder.getInt64(0);
+ // TODO: Use correct DynCGGroupMem
+ Value *DynCGGroupMem = Builder.getInt32(0);
+
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(
+ NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
+
+ // Assume no error was returned because TaskBodyCB and
+ // EmitTargetCallFallbackCB don't produce any.
+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+ // The presence of certain clauses on the target directive require the
+ // explicit generation of the target task.
+ if (RequiresOuterTargetTask)
+ return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+ Dependencies, HasNoWait);
+
+ return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+ EmitTargetCallFallbackCB, KArgs,
+ DeviceID, RTLoc, AllocaIP);
+ }());
- OpenMPIRBuilder::TargetDataInfo Info(
- /*RequiresDevicePointerInfo=*/false,
- /*SeparateBeginEndCalls=*/true);
+ Builder.restoreIP(AfterIP);
+ return Error::success();
+ };
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
- OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
- RTArgs, MapInfo,
- /*IsNonContiguous=*/true,
- /*ForEndCall=*/false);
+ // If we don't have an ID for the target region, it means an offload entry
+ // wasn't created. In this case we just run the host fallback directly.
+ if (!OutlinedFnID) {
+ cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+ return;
+ }
- SmallVector<Value *, 3> NumTeamsC;
- SmallVector<Value *, 3> NumThreadsC;
- for (auto V : NumTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : NumThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ // If there's no IF clause, only generate the kernel launch code path.
+ if (!IfCond) {
+ cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+ return;
+ }
- unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
- uint32_t SrcLocStrSize;
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
- Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
- llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
- // TODO: Use correct DynCGGroupMem
- Value *DynCGGroupMem = Builder.getInt32(0);
-
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
-
- // Assume no error was returned because TaskBodyCB and
- // EmitTargetCallFallbackCB don't produce any.
- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
- // The presence of certain clauses on the target directive require the
- // explicit generation of the target task.
- if (RequiresOuterTargetTask)
- return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
- Dependencies, HasNoWait);
-
- return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
- EmitTargetCallFallbackCB, KArgs,
- DeviceID, RTLoc, AllocaIP);
- }());
-
- Builder.restoreIP(AfterIP);
+ cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+ EmitTargetCallElse, AllocaIP));
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
- SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
+ SmallVectorImpl<Value *> &Args, Value *IfCond,
+ GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
@@ -7455,7 +7478,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+ NumThreads, Args, IfCond, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index cdca725b147436..94dce5243d7004 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6232,7 +6232,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, -1, 0, Inputs,
+ Builder.saveIP(), EntryInfo, /*NumTeams=*/-1,
+ /*NumThreads=*/0, Inputs, /*IfCond=*/nullptr,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
OMPBuilder.finalize();
@@ -6343,8 +6344,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+ EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+ CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
@@ -6500,8 +6501,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
ASSERT_EXPECTED_INIT(
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+ EntryInfo, /*NumTeams=*/-1, /*NumThreads=*/0,
+ CapturedArgs, /*IfCond=*/nullptr, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a364098e0bd8a6..0c637bd32ab3f3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -285,7 +285,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
- checkIf(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
// Privatization clauses are supported, except on some situations, so we
@@ -4112,11 +4111,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::Value *ifCond = nullptr;
+ if (Value targetIfCond = targetOp.getIfExpr())
+ ifCond = moduleTranslation.lookupValue(targetIfCond);
+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
- defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
- argAccessorCB, dds, targetOp.getNowait());
+ defaultValTeams, defaultValThreads, kernelInput, ifCond, genMapInfoCB,
+ bodyCB, argAccessorCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @target_if_variable(%x : i1) {
+ omp.target if(%x) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_variable(
+ // CHECK-SAME: i1 %[[IF_COND:.*]])
+ // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+ // CHECK: [[THEN_LABEL]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ // CHECK: [[OFFLOAD_CONT_LABEL]]:
+ // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+ // CHECK: [[ELSE_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+ // CHECK-NEXT: br label %[[END_LABEL]]
+
+ llvm.func @target_if_true() {
+ %0 = llvm.mlir.constant(true) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_true()
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ llvm.func @target_if_false() {
+ %0 = llvm.mlir.constant(false) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_false()
+ // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 83a0990d631620..4e0925c833c3b7 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -266,17 +266,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
// -----
-llvm.func @target_if(%x : i1) {
- // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.target}}
- omp.target if(%x) {
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
More information about the llvm-branch-commits
mailing list