[llvm] 9bc8828 - [OMPIRBuilder][MLIR] Add support for target 'if' clause (#122478)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 15 02:16:22 PST 2025
Author: Sergio Afonso
Date: 2025-01-15T10:16:19Z
New Revision: 9bc88280931e3b08adfab6951047191dfe12392b
URL: https://github.com/llvm/llvm-project/commit/9bc88280931e3b08adfab6951047191dfe12392b
DIFF: https://github.com/llvm/llvm-project/commit/9bc88280931e3b08adfab6951047191dfe12392b.diff
LOG: [OMPIRBuilder][MLIR] Add support for target 'if' clause (#122478)
This patch implements support for handling the 'if' clause of OpenMP
'target' constructs in the OMPIRBuilder and updates MLIR to LLVM IR
translation of the `omp.target` MLIR operation to make use of this new
feature.
Added:
mlir/test/Target/LLVMIR/omptarget-if.mlir
Modified:
llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/openmp-todo.mlir
Removed:
################################################################################
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 7eceec3d8cf8f5..6b6e5bc19d95a4 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2994,27 +2994,29 @@ class OpenMPIRBuilder {
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
/// \param CodeGenIP The insertion point where the call to the outlined
- /// function should be emitted.
+ /// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default attributes, including
/// numbers of threads and teams to launch the kernel with.
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
/// and teams to launch the kernel with.
+ /// \param IfCond value of the `if` clause.
/// \param Inputs The input values to the region that will be passed.
- /// as arguments to the outlined function.
+ /// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
/// \param ArgAccessorFuncCB Callback that will generate accessors
- /// instructions for passed in target arguments where neccessary
+ /// instructions for passed in target arguments where neccessary
/// \param Dependencies A vector of DependData objects that carry
- // dependency information as passed in the depend clause
- // \param HasNowait Whether the target construct has a `nowait` clause or not.
+ /// dependency information as passed in the depend clause
+ /// \param HasNowait Whether the target construct has a `nowait` clause or
+ /// not.
InsertPointOrErrorTy createTarget(
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
- const TargetKernelRuntimeAttrs &RuntimeAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 3d461f0ad4228c..c6603635d5e281 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5308,8 +5308,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
Value *Alignment = AlignedItem.second;
Instruction *loadInst = dyn_cast<Instruction>(AlignedPtr);
Builder.SetInsertPoint(loadInst->getNextNode());
- Builder.CreateAlignmentAssumption(F->getDataLayout(),
- AlignedPtr, Alignment);
+ Builder.CreateAlignmentAssumption(F->getDataLayout(), AlignedPtr,
+ Alignment);
}
Builder.restoreIP(IP);
}
@@ -5457,16 +5457,16 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
Loop *L = LI.getLoopFor(CLI->getHeader());
assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
- TargetTransformInfo::UnrollingPreferences UP =
- gatherUnrollingPreferences(L, SE, TTI,
- /*BlockFrequencyInfo=*/nullptr,
- /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
- /*UserThreshold=*/std::nullopt,
- /*UserCount=*/std::nullopt,
- /*UserAllowPartial=*/true,
- /*UserAllowRuntime=*/true,
- /*UserUpperBound=*/std::nullopt,
- /*UserFullUnrollMaxCount=*/std::nullopt);
+ TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
+ L, SE, TTI,
+ /*BlockFrequencyInfo=*/nullptr,
+ /*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
+ /*UserThreshold=*/std::nullopt,
+ /*UserCount=*/std::nullopt,
+ /*UserAllowPartial=*/true,
+ /*UserAllowRuntime=*/true,
+ /*UserUpperBound=*/std::nullopt,
+ /*UserFullUnrollMaxCount=*/std::nullopt);
UP.Force = true;
@@ -7340,7 +7340,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
- Function *OutlinedFn, Constant *OutlinedFnID,
+ Value *IfCond, Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
@@ -7386,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
return Error::success();
};
- // If we don't have an ID for the target region, it means an offload entry
- // wasn't created. In this case we just run the host fallback directly.
- if (!OutlinedFnID) {
+ auto &&EmitTargetCallElse =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
// Assume no error was returned because EmitTargetCallFallbackCB doesn't
// produce any.
OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
@@ -7404,102 +7404,126 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}());
Builder.restoreIP(AfterIP);
- return;
- }
-
- OpenMPIRBuilder::TargetDataInfo Info(
- /*RequiresDevicePointerInfo=*/false,
- /*SeparateBeginEndCalls=*/true);
-
- OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
- OpenMPIRBuilder::TargetDataRTArgs RTArgs;
- OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
- RTArgs, MapInfo,
- /*IsNonContiguous=*/true,
- /*ForEndCall=*/false);
-
- SmallVector<Value *, 3> NumTeamsC;
- for (auto [DefaultVal, RuntimeVal] :
- zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
- NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
-
- // Calculate number of threads: 0 if no clauses specified, otherwise it is the
- // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
- auto InitMaxThreadsClause = [&Builder](Value *Clause) {
- if (Clause)
- Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
- /*isSigned=*/false);
- return Clause;
+ return Error::success();
};
- auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
- if (Clause)
- Result = Result
- ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+
+ auto &&EmitTargetCallThen =
+ [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+ OpenMPIRBuilder::InsertPointTy CodeGenIP) -> Error {
+ OpenMPIRBuilder::TargetDataInfo Info(
+ /*RequiresDevicePointerInfo=*/false,
+ /*SeparateBeginEndCalls=*/true);
+
+ OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
+ OpenMPIRBuilder::TargetDataRTArgs RTArgs;
+ OMPBuilder.emitOffloadingArraysAndArgs(AllocaIP, Builder.saveIP(), Info,
+ RTArgs, MapInfo,
+ /*IsNonContiguous=*/true,
+ /*ForEndCall=*/false);
+
+ SmallVector<Value *, 3> NumTeamsC;
+ for (auto [DefaultVal, RuntimeVal] :
+ zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+ NumTeamsC.push_back(RuntimeVal ? RuntimeVal
+ : Builder.getInt32(DefaultVal));
+
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is
+ // the minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+ if (Clause)
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+ /*isSigned=*/false);
+ return Clause;
+ };
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+ if (Clause)
+ Result =
+ Result ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
Result, Clause)
: Clause;
- };
+ };
- // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
- // the NUM_THREADS clause is overriden by THREAD_LIMIT.
- SmallVector<Value *, 3> NumThreadsC;
- Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
- ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
- : nullptr;
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
+ SmallVector<Value *, 3> NumThreadsC;
+ Value *MaxThreadsClause =
+ RuntimeAttrs.TeamsThreadLimit.size() == 1
+ ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+ : nullptr;
- for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
- RuntimeAttrs.TargetThreadLimit)) {
- Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
- Value *NumThreads = InitMaxThreadsClause(TargetVal);
+ for (auto [TeamsVal, TargetVal] : zip_equal(
+ RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+ Value *NumThreads = InitMaxThreadsClause(TargetVal);
- CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
- CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
- NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
- }
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+ }
- unsigned NumTargetItems = Info.NumberOfPtrs;
- // TODO: Use correct device ID
- Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
- uint32_t SrcLocStrSize;
- Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
- Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
- llvm::omp::IdentFlag(0), 0);
+ unsigned NumTargetItems = Info.NumberOfPtrs;
+ // TODO: Use correct device ID
+ Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
+ Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
+ llvm::omp::IdentFlag(0), 0);
- Value *TripCount = RuntimeAttrs.LoopTripCount
- ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
- Builder.getInt64Ty(),
- /*isSigned=*/false)
- : Builder.getInt64(0);
+ Value *TripCount = RuntimeAttrs.LoopTripCount
+ ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+ Builder.getInt64Ty(),
+ /*isSigned=*/false)
+ : Builder.getInt64(0);
- // TODO: Use correct DynCGGroupMem
- Value *DynCGGroupMem = Builder.getInt32(0);
+ // TODO: Use correct DynCGGroupMem
+ Value *DynCGGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
- NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+ NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
- // Assume no error was returned because TaskBodyCB and
- // EmitTargetCallFallbackCB don't produce any.
- OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
- // The presence of certain clauses on the target directive require the
- // explicit generation of the target task.
- if (RequiresOuterTargetTask)
- return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
- Dependencies, HasNoWait);
+ // Assume no error was returned because TaskBodyCB and
+ // EmitTargetCallFallbackCB don't produce any.
+ OpenMPIRBuilder::InsertPointTy AfterIP = cantFail([&]() {
+ // The presence of certain clauses on the target directive require the
+ // explicit generation of the target task.
+ if (RequiresOuterTargetTask)
+ return OMPBuilder.emitTargetTask(TaskBodyCB, DeviceID, RTLoc, AllocaIP,
+ Dependencies, HasNoWait);
+
+ return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
+ EmitTargetCallFallbackCB, KArgs,
+ DeviceID, RTLoc, AllocaIP);
+ }());
+
+ Builder.restoreIP(AfterIP);
+ return Error::success();
+ };
- return OMPBuilder.emitKernelLaunch(Builder, OutlinedFnID,
- EmitTargetCallFallbackCB, KArgs,
- DeviceID, RTLoc, AllocaIP);
- }());
+ // If we don't have an ID for the target region, it means an offload entry
+ // wasn't created. In this case we just run the host fallback directly and
+ // ignore any potential 'if' clauses.
+ if (!OutlinedFnID) {
+ cantFail(EmitTargetCallElse(AllocaIP, Builder.saveIP()));
+ return;
+ }
+
+ // If there's no 'if' clause, only generate the kernel launch code path.
+ if (!IfCond) {
+ cantFail(EmitTargetCallThen(AllocaIP, Builder.saveIP()));
+ return;
+ }
- Builder.restoreIP(AfterIP);
+ cantFail(OMPBuilder.emitIfClause(IfCond, EmitTargetCallThen,
+ EmitTargetCallElse, AllocaIP));
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
- const TargetKernelRuntimeAttrs &RuntimeAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs, Value *IfCond,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7524,7 +7548,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs, IfCond,
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
HasNowait);
return Builder.saveIP();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 3b571cce09a4f8..a7b513bdfdc667 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6243,8 +6243,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultAttrs,
- RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
- SimpleArgAccessorCB));
+ RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
OMPBuilder.finalize();
@@ -6402,11 +6402,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
- OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+ ASSERT_EXPECTED_INIT(
+ OpenMPIRBuilder::InsertPointTy, AfterIP,
+ OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs,
+ /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
Builder.CreateRetVoid();
@@ -6561,8 +6562,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
OpenMPIRBuilder::InsertPointTy, AfterIP,
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
Builder.saveIP(), EntryInfo, DefaultAttrs,
- RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
- SimpleArgAccessorCB));
+ RuntimeAttrs, /*IfCond=*/nullptr, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
OMPBuilder.finalize();
@@ -6660,11 +6661,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
- OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+ ASSERT_EXPECTED_INIT(
+ OpenMPIRBuilder::InsertPointTy, AfterIP,
+ OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs,
+ /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
Builder.CreateRetVoid();
@@ -6774,11 +6776,12 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
/*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
- OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+ ASSERT_EXPECTED_INIT(
+ OpenMPIRBuilder::InsertPointTy, AfterIP,
+ OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs,
+ /*IfCond=*/nullptr, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB));
Builder.restoreIP(AfterIP);
Builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0be515e63b470c..abef2cb7411aaf 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -183,10 +183,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = op.emitError("not yet implemented: host evaluation of loop "
"bounds in omp.target operation");
};
- auto checkIf = [&todo](auto op, LogicalResult &result) {
- if (op.getIfExpr())
- result = todo("if");
- };
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
op.getInReductionSyms())
@@ -306,7 +302,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkDevice(op, result);
checkHasDeviceAddr(op, result);
checkHostEval(op, result);
- checkIf(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
checkPrivate(op, result);
@@ -4378,10 +4373,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::Value *ifCond = nullptr;
+ if (Value targetIfCond = targetOp.getIfExpr())
+ ifCond = moduleTranslation.lookupValue(targetIfCond);
+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
- defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB,
+ defaultAttrs, runtimeAttrs, ifCond, kernelInput, genMapInfoCB, bodyCB,
argAccessorCB, dds, targetOp.getNowait());
if (failed(handleError(afterIP, opInst)))
diff --git a/mlir/test/Target/LLVMIR/omptarget-if.mlir b/mlir/test/Target/LLVMIR/omptarget-if.mlir
new file mode 100644
index 00000000000000..706ad4411438ba
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-if.mlir
@@ -0,0 +1,68 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @target_if_variable(%x : i1) {
+ omp.target if(%x) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_variable(
+ // CHECK-SAME: i1 %[[IF_COND:.*]])
+ // CHECK: br i1 %[[IF_COND]], label %[[THEN_LABEL:.*]], label %[[ELSE_LABEL:.*]]
+
+ // CHECK: [[THEN_LABEL]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:__omp_offloading_.*_.*_target_if_variable_l.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ // CHECK: [[OFFLOAD_CONT_LABEL]]:
+ // CHECK-NEXT: br label %[[END_LABEL:.*]]
+
+ // CHECK: [[ELSE_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN]]()
+ // CHECK-NEXT: br label %[[END_LABEL]]
+
+ llvm.func @target_if_true() {
+ %0 = llvm.mlir.constant(true) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_true()
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NOT: {{^.*}}:
+ // CHECK: %[[RC:.*]] = call i32 @__tgt_target_kernel
+ // CHECK-NEXT: %[[OFFLOAD_SUCCESS:.*]] = icmp ne i32 %[[RC]], 0
+ // CHECK-NEXT: br i1 %[[OFFLOAD_SUCCESS]], label %[[OFFLOAD_FAIL_LABEL:.*]], label %[[OFFLOAD_CONT_LABEL:.*]]
+
+ // CHECK: [[OFFLOAD_FAIL_LABEL]]:
+ // CHECK-NEXT: call void @[[FALLBACK_FN:.*]]()
+ // CHECK-NEXT: br label %[[OFFLOAD_CONT_LABEL]]
+
+ llvm.func @target_if_false() {
+ %0 = llvm.mlir.constant(false) : i1
+ omp.target if(%0) {
+ omp.terminator
+ }
+ llvm.return
+ }
+
+ // CHECK-LABEL: define void @target_if_false()
+ // CHECK-NEXT: br label %[[ENTRY:.*]]
+
+ // CHECK: [[ENTRY]]:
+ // CHECK-NEXT: call void @__omp_offloading_{{.*}}_{{.*}}_target_if_false_l{{.*}}()
+}
+
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 392a6558dcfa69..c1e30964b25078 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -271,17 +271,6 @@ llvm.func @target_host_eval(%x : i32) {
// -----
-llvm.func @target_if(%x : i1) {
- // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.target}}
- omp.target if(%x) {
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
More information about the llvm-commits
mailing list