[Mlir-commits] [mlir] f59b5b8 - [MLIR][OpenMP] Fix standalone distribute on the device (#133094)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 3 07:41:03 PDT 2025
Author: Sergio Afonso
Date: 2025-04-03T15:41:00+01:00
New Revision: f59b5b8d597d52336a59d2c0555212242e29a45b
URL: https://github.com/llvm/llvm-project/commit/f59b5b8d597d52336a59d2c0555212242e29a45b
DIFF: https://github.com/llvm/llvm-project/commit/f59b5b8d597d52336a59d2c0555212242e29a45b.diff
LOG: [MLIR][OpenMP] Fix standalone distribute on the device (#133094)
This patch updates the handling of target regions to set trip counts and
kernel execution modes properly, based on clang's behavior. This fixes a
race condition on `target teams distribute` constructs with no `parallel
do` loop inside.
This is how kernels are classified, after changes introduced in this
patch:
```f90
! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do
! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
!$omp parallel do private(idx, y)
do j=...
end do
end do
! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
!$omp parallel
...
!$omp end parallel
end do
! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do
! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do
! Exec mode: Generic.
! Trip count: Not set.
!$omp target
...
!$omp end target
```
For the split `target teams distribute + parallel do` case, clang
produces a Generic kernel which gets promoted to Generic-SPMD by the
openmp-opt pass. We can't currently replicate that behavior in flang
because our codegen for these constructs results in the introduction of
calls to the `kmpc_distribute_static_loop` family of functions, instead
of `kmpc_distribute_static_init`, which currently prevent promotion of
the kernel to Generic-SPMD.
For the time being, instead of relying on the openmp-opt pass, we look
at the MLIR representation to find the Generic-SPMD pattern and directly
tag the kernel as such during codegen. This is what we were already
doing, but incorrectly matching other kinds of kernels as such in the
process.
Added:
mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
Modified:
mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
+//===----------------------------------------------------------------------===//
+// target_region_flags enum.
+//===----------------------------------------------------------------------===//
+
+def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
+def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+
+def TargetRegionFlags : OpenMP_BitEnumAttr<
+ "TargetRegionFlags",
+ "target region property flags", [
+ TargetRegionFlagsNone,
+ TargetRegionFlagsGeneric,
+ TargetRegionFlagsSpmd,
+ TargetRegionFlagsTripCount
+ ]>;
+
//===----------------------------------------------------------------------===//
// variable_capture_kind enum.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 65095932be627..11530c0fa3620 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
- static llvm::omp::OMPTgtExecModeFlags
+ static ::mlir::omp::TargetRegionFlags
getKernelExecFlags(Operation *capturedOp);
}] # clausesExtraClassDeclaration;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 4ac9f49f12161..ecadf16e1e9f6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");
// Check that host_eval values are only used in legal ways.
- llvm::omp::OMPTgtExecModeFlags execFlags =
- getKernelExecFlags(getInnermostCapturedOmpOp());
+ Operation *capturedOp = getInnermostCapturedOmpOp();
+ TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
- if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
+ if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+ parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
- if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
+ if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
+ loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
continue;
return emitOpError() << "host_eval argument only legal as loop bounds "
- "and steps in 'omp.loop_nest' when "
- "representing target SPMD or Generic-SPMD";
+ "and steps in 'omp.loop_nest' when trip count "
+ "must be evaluated in the host";
}
return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,33 +1953,12 @@ LogicalResult TargetOp::verifyRegions() {
return success();
}
-/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
-/// effects, but don't include a memory write effect.
-static bool siblingAllowedInCapture(Operation *op) {
- if (!op)
- return false;
+static Operation *
+findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
+ llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
+ assert(rootOp && "expected valid operation");
- bool isOmpDialect =
- op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
- op->getDialect();
-
- if (isOmpDialect)
- return op->hasTrait<OpTrait::IsTerminator>();
-
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
- memOp.getEffects(effects);
- return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
- return isa<MemoryEffects::Write>(effect.getEffect()) &&
- isa<SideEffects::AutomaticAllocationScopeResource>(
- effect.getResource());
- });
- }
- return true;
-}
-
-Operation *TargetOp::getInnermostCapturedOmpOp() {
- Dialect *ompDialect = (*this)->getDialect();
+ Dialect *ompDialect = rootOp->getDialect();
Operation *capturedOp = nullptr;
DominanceInfo domInfo;
@@ -1985,8 +1966,8 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region holding no operations to be captured.
- walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (op == *this)
+ rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (op == rootOp)
return WalkResult::advance();
// Ignore operations of other dialects or omp operations with no regions,
@@ -2001,22 +1982,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
// (i.e. its block's successors can reach it) or if it's not guaranteed to
// be executed before all exits of the region (i.e. it doesn't dominate all
// blocks with no successors reachable from the entry block).
- Region *parentRegion = op->getParentRegion();
- Block *parentBlock = op->getBlock();
-
- for (Block *successor : parentBlock->getSuccessors())
- if (successor->isReachable(parentBlock))
- return WalkResult::interrupt();
-
- for (Block &block : *parentRegion)
- if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
- !domInfo.dominates(parentBlock, &block))
- return WalkResult::interrupt();
+ if (checkSingleMandatoryExec) {
+ Region *parentRegion = op->getParentRegion();
+ Block *parentBlock = op->getBlock();
+
+ for (Block *successor : parentBlock->getSuccessors())
+ if (successor->isReachable(parentBlock))
+ return WalkResult::interrupt();
+
+ for (Block &block : *parentRegion)
+ if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
+ !domInfo.dominates(parentBlock, &block))
+ return WalkResult::interrupt();
+ }
// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
- if (&sibling != op && !siblingAllowedInCapture(&sibling))
+ if (&sibling != op && !siblingAllowedFn(&sibling))
return WalkResult::interrupt();
// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2012,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
return capturedOp;
}
-llvm::omp::OMPTgtExecModeFlags
-TargetOp::getKernelExecFlags(Operation *capturedOp) {
- using namespace llvm::omp;
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+ auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
+
+ // Only allow OpenMP terminators and non-OpenMP ops that have known memory
+ // effects, but don't include a memory write effect.
+ return findCapturedOmpOp(
+ *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
+ if (!sibling)
+ return false;
+
+ if (ompDialect == sibling->getDialect())
+ return sibling->hasTrait<OpTrait::IsTerminator>();
+
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+ effects;
+ memOp.getEffects(effects);
+ return !llvm::any_of(
+ effects, [&](MemoryEffects::EffectInstance &effect) {
+ return isa<MemoryEffects::Write>(effect.getEffect()) &&
+ isa<SideEffects::AutomaticAllocationScopeResource>(
+ effect.getResource());
+ });
+ }
+ return true;
+ });
+}
+TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
@@ -2041,60 +2049,94 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");
- // Make sure this region is capturing a loop. Otherwise, it's a generic
- // kernel.
+ // If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
- SmallVector<LoopWrapperInterface> wrappers;
- cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
- assert(!wrappers.empty());
+ // Get the innermost non-simd loop wrapper.
+ SmallVector<LoopWrapperInterface> loopWrappers;
+ cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
+ assert(!loopWrappers.empty());
- // Ignore optional SIMD leaf construct.
- auto *innermostWrapper = wrappers.begin();
+ LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
if (isa<SimdOp>(innermostWrapper))
innermostWrapper = std::next(innermostWrapper);
- long numWrappers = std::distance(innermostWrapper, wrappers.end());
-
- // Detect Generic-SPMD: target-teams-distribute[-simd].
- // Detect SPMD: target-teams-loop.
- if (numWrappers == 1) {
- if (!isa<DistributeOp, LoopOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
-
- Operation *teamsOp = (*innermostWrapper)->getParentOp();
- if (!isa_and_present<TeamsOp>(teamsOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
+ if (numWrappers != 1 && numWrappers != 2)
+ return TargetRegionFlags::generic;
- if (teamsOp->getParentOp() == targetOp.getOperation())
- return isa<DistributeOp>(innermostWrapper)
- ? OMP_TGT_EXEC_MODE_GENERIC_SPMD
- : OMP_TGT_EXEC_MODE_SPMD;
- }
-
- // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
+ // Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
if (teamsOp->getParentOp() == targetOp.getOperation())
- return OMP_TGT_EXEC_MODE_SPMD;
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+ }
+ // Detect target-teams-distribute[-simd] and target-teams-loop.
+ else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
+ Operation *teamsOp = (*innermostWrapper)->getParentOp();
+ if (!isa_and_present<TeamsOp>(teamsOp))
+ return TargetRegionFlags::generic;
+
+ if (teamsOp->getParentOp() != targetOp.getOperation())
+ return TargetRegionFlags::generic;
+
+ if (isa<LoopOp>(innermostWrapper))
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+
+ // Find single immediately nested captured omp.parallel and add spmd flag
+ // (generic-spmd case).
+ //
+ // TODO: This shouldn't have to be done here, as it is too easy to break.
+ // The openmp-opt pass should be updated to be able to promote kernels like
+ // this from "Generic" to "Generic-SPMD". However, the use of the
+ // `kmpc_distribute_static_loop` family of functions produced by the
+ // OMPIRBuilder for these kernels prevents that from working.
+ Dialect *ompDialect = targetOp->getDialect();
+ Operation *nestedCapture = findCapturedOmpOp(
+ capturedOp, /*checkSingleMandatoryExec=*/false,
+ [&](Operation *sibling) {
+ return sibling && (ompDialect != sibling->getDialect() ||
+ sibling->hasTrait<OpTrait::IsTerminator>());
+ });
+
+ TargetRegionFlags result =
+ TargetRegionFlags::generic | TargetRegionFlags::trip_count;
+
+ if (!nestedCapture)
+ return result;
+
+ while (nestedCapture->getParentOp() != capturedOp)
+ nestedCapture = nestedCapture->getParentOp();
+
+ return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
+ : result;
+ }
+ // Detect target-parallel-wsloop[-simd].
+ else if (isa<WsloopOp>(innermostWrapper)) {
+ Operation *parallelOp = (*innermostWrapper)->getParentOp();
+ if (!isa_and_present<ParallelOp>(parallelOp))
+ return TargetRegionFlags::generic;
+
+ if (parallelOp->getParentOp() == targetOp.getOperation())
+ return TargetRegionFlags::spmd;
}
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..4d610d6e2656d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
combinedMaxThreadsVal = maxThreadsVal;
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
- attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
+ omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
+ assert(
+ omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
+ omp::TargetRegionFlags::spmd) &&
+ "invalid kernel flags");
+ attrs.ExecFlags =
+ omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
+ ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+ ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
+ : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
+ : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
- if (targetOp.getKernelExecFlags(capturedOp) !=
- llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+ if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
+ omp::TargetRegionFlags::trip_count)) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 403128bb2300e..bd0541987339a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
// -----
func.func @omp_target_host_eval_loop1(%x : i32) {
- // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+ // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
// -----
func.func @omp_target_host_eval_loop2(%x : i32) {
- // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+ // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.teams {
^bb0:
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 0a10626cd4877..6bc2500471997 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
omp.terminator
}
+ // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+ // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest
+ omp.target host_eval(%x -> %arg0 : i32) {
+ %y = arith.constant 2 : i32
+ omp.parallel num_threads(%arg0 : i32) {
+ omp.wsloop {
+ omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
// CHECK: omp.teams {
// CHECK: omp.distribute {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
new file mode 100644
index 0000000000000..8101660e571e4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -0,0 +1,111 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @main(%arg0 : !llvm.ptr) {
+ %x = llvm.load %arg0 : !llvm.ptr -> i32
+ %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+ omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
+ %x.map = llvm.load %ptr : !llvm.ptr -> i32
+ omp.teams {
+ omp.distribute {
+ omp.loop_nest (%iv1) : i32 = (%lb) to (%ub) step (%step) {
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%iv2) : i32 = (%x.map) to (%x.map) step (%x.map) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// HOST-LABEL: define void @main
+// HOST: %omp_loop.tripcount = {{.*}}
+// HOST-NEXT: br label %[[ENTRY:.*]]
+// HOST: [[ENTRY]]:
+// HOST: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST: [[OFFLOAD_FAILED]]:
+// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[TARGET_OUTLINE]]
+// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST: define internal void @[[TEAMS_OUTLINE]]
+// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// HOST: call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+
+// HOST: define internal void @[[PARALLEL_OUTLINE]]
+// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+
+//--- device.mlir
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+ llvm.func @main(%arg0 : !llvm.ptr) {
+ %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+ omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
+ %x = llvm.load %ptr : !llvm.ptr -> i32
+ omp.teams {
+ omp.distribute {
+ omp.loop_nest (%iv1) : i32 = (%x) to (%x) step (%x) {
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%iv2) : i32 = (%x) to (%x) step (%x) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// DEVICE: call void @__kmpc_target_deinit()
+
+// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}})
+// DEVICE: call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+
+// DEVICE: define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// DEVICE: call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+
+// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// DEVICE: call void @__kmpc_for_static_loop{{.*}}({{.*}})
More information about the Mlir-commits
mailing list