[Mlir-commits] [llvm] [mlir] [MLIR][OpenMP] Remove Generic-SPMD early detection (PR #150922)
Sergio Afonso
llvmlistbot at llvm.org
Fri Oct 3 07:50:06 PDT 2025
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/150922
>From ad2a460ea1a8a2deaec2062cba9b1bf8d6977b97 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 21 May 2025 16:29:23 +0100
Subject: [PATCH 1/2] [MLIR][OpenMP] Remove Generic-SPMD early detection
This patch removes logic from MLIR to attempt identifying Generic kernels that
could be executed in SPMD mode.
This optimization is done by the OpenMPOpt pass for Clang and is only required
here to circumvent missing support for the new DeviceRTL APIs used in MLIR to
LLVM IR translation that Clang doesn't currently use (e.g.
`kmpc_distribute_static_loop`). Removing checks in MLIR avoids duplicating the
logic that should be centralized in the OpenMPOpt pass.
Additionally, offloading kernels currently compiled through the OpenMP dialect
fail to run parallel regions properly when in Generic mode. By disabling early
detection, this issue becomes apparent for a range of kernels where this was
masked by having them run in SPMD mode.
---
.../mlir/Dialect/OpenMP/OpenMPEnums.td | 13 ++---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 48 +++++--------------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 23 +++------
.../LLVMIR/openmp-target-generic-spmd.mlir | 2 +-
4 files changed, 23 insertions(+), 63 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index f693a0737e0fc..9666ffd92fc52 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -227,24 +227,21 @@ def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
//===----------------------------------------------------------------------===//
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
-def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
-def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
-def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
-def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 3>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
+def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 2>;
def TargetRegionFlags : OpenMP_BitEnumAttr<
"TargetRegionFlags",
"These flags describe properties of the target kernel. "
- "TargetRegionFlagsGeneric - denotes generic kernel. "
"TargetRegionFlagsSpmd - denotes SPMD kernel. "
"TargetRegionFlagsNoLoop - denotes kernel where "
"num_teams * num_threads >= loop_trip_count. It allows the conversion "
"of loops into sequential code by ensuring that each team/thread "
"executes at most one iteration. "
- "TargetRegionFlagsTripCount - checks if the loop trip count should be "
- "calculated.", [
+ "TargetRegionFlagsTripCount - checks if a singular loop trip count should "
+ "be calculated for the target region.", [
TargetRegionFlagsNone,
- TargetRegionFlagsGeneric,
TargetRegionFlagsSpmd,
TargetRegionFlagsTripCount,
TargetRegionFlagsNoLoop
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 32ebe06e240db..add17cac519da 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2373,7 +2373,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2386,25 +2386,25 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
if (!wsloopOp)
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
if (!teamsOp)
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
if (teamsOp->getParentOp() == targetOp.getOperation()) {
TargetRegionFlags result =
@@ -2418,53 +2418,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
if (teamsOp->getParentOp() != targetOp.getOperation())
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
if (isa<LoopOp>(innermostWrapper))
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
- // Find single immediately nested captured omp.parallel and add spmd flag
- // (generic-spmd case).
- //
- // TODO: This shouldn't have to be done here, as it is too easy to break.
- // The openmp-opt pass should be updated to be able to promote kernels like
- // this from "Generic" to "Generic-SPMD". However, the use of the
- // `kmpc_distribute_static_loop` family of functions produced by the
- // OMPIRBuilder for these kernels prevents that from working.
- Dialect *ompDialect = targetOp->getDialect();
- Operation *nestedCapture = findCapturedOmpOp(
- capturedOp, /*checkSingleMandatoryExec=*/false,
- [&](Operation *sibling) {
- return sibling && (ompDialect != sibling->getDialect() ||
- sibling->hasTrait<OpTrait::IsTerminator>());
- });
-
- TargetRegionFlags result =
- TargetRegionFlags::generic | TargetRegionFlags::trip_count;
-
- if (!nestedCapture)
- return result;
-
- while (nestedCapture->getParentOp() != capturedOp)
- nestedCapture = nestedCapture->getParentOp();
-
- return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
- : result;
+ return TargetRegionFlags::trip_count;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
if (parallelOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd;
}
- return TargetRegionFlags::generic;
+ return TargetRegionFlags::none;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 53209a40665ae..b71df68d90199 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2605,9 +2605,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
targetOp.getKernelExecFlags(targetCapturedOp);
if (omp::bitEnumContainsAll(kernelFlags,
omp::TargetRegionFlags::spmd |
- omp::TargetRegionFlags::no_loop) &&
- !omp::bitEnumContainsAny(kernelFlags,
- omp::TargetRegionFlags::generic))
+ omp::TargetRegionFlags::no_loop))
noLoopMode = true;
}
}
@@ -5438,21 +5436,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
- assert(
- omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
- omp::TargetRegionFlags::spmd) &&
- "invalid kernel flags");
attrs.ExecFlags =
- omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
- ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
- ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
- : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
- : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
- if (omp::bitEnumContainsAll(kernelFlags,
- omp::TargetRegionFlags::spmd |
- omp::TargetRegionFlags::no_loop) &&
- !omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic))
- attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
+ omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+ ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::no_loop)
+ ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP
+ : llvm::omp::OMP_TGT_EXEC_MODE_SPMD
+ : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
index 504d91b1f6198..6084a33fac8aa 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
}
}
-// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]]
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
>From e3a67f629ee53dcb0ab5dbbe60dc480474c05670 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 13 Aug 2025 12:53:34 +0100
Subject: [PATCH 2/2] Update TargetRegionFlags to mirror OMPTgtExecModeFlags
---
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 3 +-
.../mlir/Dialect/OpenMP/OpenMPEnums.td | 33 +++++-------
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 12 +++--
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 51 +++++++++++--------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 35 +++++++------
5 files changed, 73 insertions(+), 61 deletions(-)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 5980ee35a5cd2..fb501983ca938 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6775,7 +6775,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
- Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
+ Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD &&
+ Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP);
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 9666ffd92fc52..95f4e8b0d5673 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -223,28 +223,21 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
//===----------------------------------------------------------------------===//
-// target_region_flags enum.
+// target_exec_mode enum.
//===----------------------------------------------------------------------===//
-def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
-def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
-def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
-def TargetRegionFlagsNoLoop : I32BitEnumAttrCaseBit<"no_loop", 2>;
-
-def TargetRegionFlags : OpenMP_BitEnumAttr<
- "TargetRegionFlags",
- "These flags describe properties of the target kernel. "
- "TargetRegionFlagsSpmd - denotes SPMD kernel. "
- "TargetRegionFlagsNoLoop - denotes kernel where "
- "num_teams * num_threads >= loop_trip_count. It allows the conversion "
- "of loops into sequential code by ensuring that each team/thread "
- "executes at most one iteration. "
- "TargetRegionFlagsTripCount - checks if a singular loop trip count should "
- "be calculated for the target region.", [
- TargetRegionFlagsNone,
- TargetRegionFlagsSpmd,
- TargetRegionFlagsTripCount,
- TargetRegionFlagsNoLoop
+def TargetExecModeBare : I32EnumAttrCase<"bare", 0>;
+def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>;
+def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>;
+def TargetExecModeSpmdNoLoop : I32EnumAttrCase<"no_loop", 3>;
+
+def TargetExecMode : OpenMP_I32EnumAttr<
+ "TargetExecMode",
+ "target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [
+ TargetExecModeBare,
+ TargetExecModeGeneric,
+ TargetExecModeSpmd,
+ TargetExecModeSpmdNoLoop,
]>;
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5c77e215467e4..9003fb2ef7959 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1522,13 +1522,17 @@ def TargetOp : OpenMP_Op<"target", traits = [
/// operations, the top level one will be the one captured.
Operation *getInnermostCapturedOmpOp();
- /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
- /// contents of the target region.
+ /// Infers the kernel type (Bare, Generic or SPMD) based on the contents of
+ /// the target region.
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
- static ::mlir::omp::TargetRegionFlags
- getKernelExecFlags(Operation *capturedOp);
+ /// \param hostEvalTripCount output argument to store whether this kernel
+ /// wraps a loop whose bounds must be evaluated on the host prior to
+ /// launching it.
+ static ::mlir::omp::TargetExecMode
+ getKernelExecFlags(Operation *capturedOp,
+ bool *hostEvalTripCount = nullptr);
}] # clausesExtraClassDeclaration;
let assemblyFormat = clausesAssemblyFormat # [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index add17cac519da..8640c4ba0b757 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2205,8 +2205,9 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");
// Check that host_eval values are only used in legal ways.
+ bool hostEvalTripCount;
Operation *capturedOp = getInnermostCapturedOmpOp();
- TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
+ TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
@@ -2221,7 +2222,7 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
- if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+ if (execMode == TargetExecMode::spmd &&
parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;
@@ -2231,8 +2232,7 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
- if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
- loopNestOp.getOperation() == capturedOp &&
+ if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
@@ -2362,7 +2362,9 @@ static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
ompFlags.getAssumeThreadsOversubscription();
}
-TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
+TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
+ bool *hostEvalTripCount) {
+ // TODO: Support detection of bare kernel mode.
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
@@ -2371,9 +2373,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");
+ if (hostEvalTripCount)
+ *hostEvalTripCount = false;
+
// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
@@ -2386,31 +2391,32 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
if (!wsloopOp)
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
if (!teamsOp)
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
if (teamsOp->getParentOp() == targetOp.getOperation()) {
- TargetRegionFlags result =
- TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+ TargetExecMode result = TargetExecMode::spmd;
if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
- result = result | TargetRegionFlags::no_loop;
+ result = TargetExecMode::no_loop;
+ if (hostEvalTripCount)
+ *hostEvalTripCount = true;
return result;
}
}
@@ -2418,27 +2424,30 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
if (teamsOp->getParentOp() != targetOp.getOperation())
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
+
+ if (hostEvalTripCount)
+ *hostEvalTripCount = true;
if (isa<LoopOp>(innermostWrapper))
- return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+ return TargetExecMode::spmd;
- return TargetRegionFlags::trip_count;
+ return TargetExecMode::generic;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
if (parallelOp->getParentOp() == targetOp.getOperation())
- return TargetRegionFlags::spmd;
+ return TargetExecMode::spmd;
}
- return TargetRegionFlags::none;
+ return TargetExecMode::generic;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b71df68d90199..172029196905d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2601,11 +2601,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
// for every omp.wsloop nested inside a no-loop SPMD target region, even if
// that loop is not the top-level SPMD one.
if (loopOp == targetCapturedOp) {
- omp::TargetRegionFlags kernelFlags =
- targetOp.getKernelExecFlags(targetCapturedOp);
- if (omp::bitEnumContainsAll(kernelFlags,
- omp::TargetRegionFlags::spmd |
- omp::TargetRegionFlags::no_loop))
+ if (targetOp.getKernelExecFlags(targetCapturedOp) ==
+ omp::TargetExecMode::no_loop)
noLoopMode = true;
}
}
@@ -5435,14 +5432,21 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
- omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
- attrs.ExecFlags =
- omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
- ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::no_loop)
- ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP
- : llvm::omp::OMP_TGT_EXEC_MODE_SPMD
- : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
-
+ omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
+ switch (execMode) {
+ case omp::TargetExecMode::bare:
+ attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
+ break;
+ case omp::TargetExecMode::generic:
+ attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
+ break;
+ case omp::TargetExecMode::spmd:
+ attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
+ break;
+ case omp::TargetExecMode::no_loop:
+ attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
+ break;
+ }
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@@ -5492,8 +5496,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
- if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
- omp::TargetRegionFlags::trip_count)) {
+ bool hostEvalTripCount;
+ targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
+ if (hostEvalTripCount) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;
More information about the Mlir-commits
mailing list