[Mlir-commits] [mlir] [MLIR][OpenMP] Fix standalone distribute on the device (PR #133094)

Sergio Afonso llvmlistbot at llvm.org
Wed Mar 26 07:33:29 PDT 2025


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/133094

This patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on `target teams distribute` constructs with no `parallel do` loop inside.

This is how kernels are classified, after changes introduced in this patch:

```f90
! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel do private(idx, y)
  do j=...
  end do
end do

! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do

! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do

! Exec mode: Generic.
! Trip count: Not set.
!$omp target
  ...
!$omp end target
```

For the split `target teams distribute + parallel do` case, clang produces a Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We can't currently replicate that behavior in flang because our codegen for these constructs results in the introduction of calls to the `kmpc_distribute_static_loop` family of functions, instead of `kmpc_distribute_static_init`, which currently prevent promotion of the kernel to Generic-SPMD.

For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process.

>From 48cf32cdc4c1c4e89f93ca6407cf21f810e18bc9 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 25 Mar 2025 10:40:47 +0000
Subject: [PATCH] [MLIR][OpenMP] Fix standalone distribute on the device

This patch updates the handling of target regions to set trip counts and kernel
execution modes properly, based on clang's behavior. This fixes a race
condition on `target teams distribute` constructs with no `parallel do` loop
inside.

This is how kernels are classified, after changes introduced in this patch:

```f90
! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel do private(idx, y)
  do j=...
  end do
end do

! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do

! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do

! Exec mode: Generic.
! Trip count: Not set.
!$omp target
  ...
!$omp end target
```

For the split `target teams distribute + parallel do` case, clang produces a
Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We
can't currently replicate that behavior in flang because our codegen for these
constructs results in the introduction of calls to the
`kmpc_distribute_static_loop` family of functions, instead of
`kmpc_distribute_static_init`, which currently prevent promotion of the kernel
to Generic-SPMD.

For the time being, instead of relying on the openmp-opt pass, we look at the
MLIR representation to find the Generic-SPMD pattern and directly tag the
kernel as such during codegen. This is what we were already doing, but
incorrectly matching other kinds of kernels as such in the process.
---
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        |  18 ++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 187 +++++++++++-------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  16 +-
 mlir/test/Dialect/OpenMP/invalid.mlir         |   4 +-
 mlir/test/Dialect/OpenMP/ops.mlir             |  17 ++
 .../LLVMIR/openmp-target-generic-spmd.mlir    | 111 +++++++++++
 7 files changed, 282 insertions(+), 73 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
 
 def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
 
+//===----------------------------------------------------------------------===//
+// target_region_flags enum.
+//===----------------------------------------------------------------------===//
+
+def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
+def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+
+def TargetRegionFlags : OpenMP_BitEnumAttr<
+    "TargetRegionFlags",
+    "target region property flags", [
+      TargetRegionFlagsNone,
+      TargetRegionFlagsGeneric,
+      TargetRegionFlagsSpmd,
+      TargetRegionFlagsTripCount
+    ]>;
+
 //===----------------------------------------------------------------------===//
 // variable_capture_kind enum.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 65095932be627..11530c0fa3620 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
     ///
     /// \param capturedOp result of a still valid (no modifications made to any
     /// nested operations) previous call to `getInnermostCapturedOmpOp()`.
-    static llvm::omp::OMPTgtExecModeFlags
+    static ::mlir::omp::TargetRegionFlags
     getKernelExecFlags(Operation *capturedOp);
   }] # clausesExtraClassDeclaration;
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..5b46cab96dd88 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
     return emitError("target containing multiple 'omp.teams' nested ops");
 
   // Check that host_eval values are only used in legal ways.
-  llvm::omp::OMPTgtExecModeFlags execFlags =
-      getKernelExecFlags(getInnermostCapturedOmpOp());
+  Operation *capturedOp = getInnermostCapturedOmpOp();
+  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
   for (Value hostEvalArg :
        cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
     for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
                                 "and 'thread_limit' in 'omp.teams'";
       }
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
-        if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
+        if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+            parallelOp->isAncestor(capturedOp) &&
             hostEvalArg == parallelOp.getNumThreads())
           continue;
 
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
                   "'omp.parallel' when representing target SPMD";
       }
       if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
-        if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
+        if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
+            loopNestOp.getOperation() == capturedOp &&
             (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
           continue;
 
         return emitOpError() << "host_eval argument only legal as loop bounds "
-                                "and steps in 'omp.loop_nest' when "
-                                "representing target SPMD or Generic-SPMD";
+                                "and steps in 'omp.loop_nest' when trip count "
+                                "must be evaluated in the host";
       }
 
       return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,33 +1953,12 @@ LogicalResult TargetOp::verifyRegions() {
   return success();
 }
 
-/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
-/// effects, but don't include a memory write effect.
-static bool siblingAllowedInCapture(Operation *op) {
-  if (!op)
-    return false;
+static Operation *
+findCapturedOmpOp(Operation *rootOp,
+                  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
+  assert(rootOp && "expected valid operation");
 
-  bool isOmpDialect =
-      op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
-      op->getDialect();
-
-  if (isOmpDialect)
-    return op->hasTrait<OpTrait::IsTerminator>();
-
-  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
-    SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
-    memOp.getEffects(effects);
-    return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
-      return isa<MemoryEffects::Write>(effect.getEffect()) &&
-             isa<SideEffects::AutomaticAllocationScopeResource>(
-                 effect.getResource());
-    });
-  }
-  return true;
-}
-
-Operation *TargetOp::getInnermostCapturedOmpOp() {
-  Dialect *ompDialect = (*this)->getDialect();
+  Dialect *ompDialect = rootOp->getDialect();
   Operation *capturedOp = nullptr;
   DominanceInfo domInfo;
 
@@ -1985,8 +1966,8 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   // ensuring we only enter the region of an operation if it meets the criteria
   // for being captured. We stop the exploration of nested operations as soon as
   // we process a region holding no operations to be captured.
-  walk<WalkOrder::PreOrder>([&](Operation *op) {
-    if (op == *this)
+  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op == rootOp)
       return WalkResult::advance();
 
     // Ignore operations of other dialects or omp operations with no regions,
@@ -2016,7 +1997,7 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
     // Don't capture this op if it has a not-allowed sibling, and stop recursing
     // into nested operations.
     for (Operation &sibling : op->getParentRegion()->getOps())
-      if (&sibling != op && !siblingAllowedInCapture(&sibling))
+      if (&sibling != op && !siblingAllowedFn(&sibling))
         return WalkResult::interrupt();
 
     // Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2010,33 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   return capturedOp;
 }
 
-llvm::omp::OMPTgtExecModeFlags
-TargetOp::getKernelExecFlags(Operation *capturedOp) {
-  using namespace llvm::omp;
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
 
+  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
+  // effects, but don't include a memory write effect.
+  return findCapturedOmpOp(*this, [&](Operation *sibling) {
+    if (!sibling)
+      return false;
+
+    if (ompDialect == sibling->getDialect())
+      return sibling->hasTrait<OpTrait::IsTerminator>();
+
+    if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+      SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+          effects;
+      memOp.getEffects(effects);
+      return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+        return isa<MemoryEffects::Write>(effect.getEffect()) &&
+               isa<SideEffects::AutomaticAllocationScopeResource>(
+                   effect.getResource());
+      });
+    }
+    return true;
+  });
+}
+
+TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
   // A non-null captured op is only valid if it resides inside of a TargetOp
   // and is the result of calling getInnermostCapturedOmpOp() on it.
   TargetOp targetOp =
@@ -2041,57 +2045,106 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
           (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
          "unexpected captured op");
 
-  // Make sure this region is capturing a loop. Otherwise, it's a generic
-  // kernel.
+  // If it's not capturing a loop, it's a default target region.
   if (!isa_and_present<LoopNestOp>(capturedOp))
-    return OMP_TGT_EXEC_MODE_GENERIC;
-
-  SmallVector<LoopWrapperInterface> wrappers;
-  cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
-  assert(!wrappers.empty());
+    return TargetRegionFlags::generic;
 
-  // Ignore optional SIMD leaf construct.
-  auto *innermostWrapper = wrappers.begin();
-  if (isa<SimdOp>(innermostWrapper))
-    innermostWrapper = std::next(innermostWrapper);
+  auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
+    SmallVector<LoopWrapperInterface> wrappers;
+    loopOp.gatherWrappers(wrappers);
+    assert(!wrappers.empty());
 
-  long numWrappers = std::distance(innermostWrapper, wrappers.end());
+    // Ignore optional SIMD leaf construct.
+    auto *wrapper = wrappers.begin();
+    if (isa<SimdOp>(wrapper))
+      wrapper = std::next(wrapper);
 
-  // Detect Generic-SPMD: target-teams-distribute[-simd].
-  if (numWrappers == 1) {
-    if (!isa<DistributeOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+    numWrappers = static_cast<int>(std::distance(wrapper, wrappers.end()));
+    return wrapper;
+  };
 
-    Operation *teamsOp = (*innermostWrapper)->getParentOp();
-    if (!isa_and_present<TeamsOp>(teamsOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+  int numWrappers;
+  LoopWrapperInterface *innermostWrapper =
+      getInnermostWrapper(cast<LoopNestOp>(capturedOp), numWrappers);
 
-    if (teamsOp->getParentOp() == targetOp.getOperation())
-      return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
-  }
+  if (numWrappers != 1 && numWrappers != 2)
+    return TargetRegionFlags::generic;
 
-  // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
+  // Detect target-teams-distribute-parallel-wsloop[-simd].
   if (numWrappers == 2) {
     if (!isa<WsloopOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     innermostWrapper = std::next(innermostWrapper);
     if (!isa<DistributeOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     Operation *parallelOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<ParallelOp>(parallelOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     Operation *teamsOp = parallelOp->getParentOp();
     if (!isa_and_present<TeamsOp>(teamsOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     if (teamsOp->getParentOp() == targetOp.getOperation())
-      return OMP_TGT_EXEC_MODE_SPMD;
+      return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+  }
+  // Detect target-teams-distribute[-simd].
+  else if (isa<DistributeOp>(innermostWrapper)) {
+    Operation *teamsOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<TeamsOp>(teamsOp))
+      return TargetRegionFlags::generic;
+
+    if (teamsOp->getParentOp() != targetOp.getOperation())
+      return TargetRegionFlags::generic;
+
+    TargetRegionFlags result =
+        TargetRegionFlags::generic | TargetRegionFlags::trip_count;
+
+    // Find single nested parallel-do and add spmd flag (generic-spmd case).
+    // TODO: This shouldn't have to be done here, as it is too easy to break.
+    // The openmp-opt pass should be updated to be able to promote kernels like
+    // this from "Generic" to "Generic-SPMD". However, the use of the
+    // `kmpc_distribute_static_loop` family of functions produced by the
+    // OMPIRBuilder for these kernels prevents that from working.
+    Dialect *ompDialect = targetOp->getDialect();
+    Operation *nestedCapture =
+        findCapturedOmpOp(capturedOp, [&](Operation *sibling) {
+          return sibling && (ompDialect != sibling->getDialect() ||
+                             sibling->hasTrait<OpTrait::IsTerminator>());
+        });
+
+    if (!isa_and_present<LoopNestOp>(nestedCapture))
+      return result;
+
+    int numNestedWrappers;
+    LoopWrapperInterface *nestedWrapper =
+        getInnermostWrapper(cast<LoopNestOp>(nestedCapture), numNestedWrappers);
+
+    if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
+      return result;
+
+    Operation *parallelOp = (*nestedWrapper)->getParentOp();
+    if (!isa_and_present<ParallelOp>(parallelOp))
+      return result;
+
+    if (parallelOp->getParentOp() != capturedOp)
+      return result;
+
+    return result | TargetRegionFlags::spmd;
+  }
+  // Detect target-parallel-wsloop[-simd].
+  else if (isa<WsloopOp>(innermostWrapper)) {
+    Operation *parallelOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<ParallelOp>(parallelOp))
+      return TargetRegionFlags::generic;
+
+    if (parallelOp->getParentOp() == targetOp.getOperation())
+      return TargetRegionFlags::spmd;
   }
 
-  return OMP_TGT_EXEC_MODE_GENERIC;
+  return TargetRegionFlags::generic;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..4d610d6e2656d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     combinedMaxThreadsVal = maxThreadsVal;
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
-  attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
+  omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
+  assert(
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
+                                               omp::TargetRegionFlags::spmd) &&
+      "invalid kernel flags");
+  attrs.ExecFlags =
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
+          ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+                ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
+                : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
+          : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
 
-  if (targetOp.getKernelExecFlags(capturedOp) !=
-      llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+  if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
+                              omp::TargetRegionFlags::trip_count)) {
     llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
     attrs.LoopTripCount = nullptr;
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 403128bb2300e..bd0541987339a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
 // -----
 
 func.func @omp_target_host_eval_loop1(%x : i32) {
-  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.wsloop {
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
 // -----
 
 func.func @omp_target_host_eval_loop2(%x : i32) {
-  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.teams {
     ^bb0:
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..e3d2f8bd01018 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
     omp.terminator
   }
 
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest
+  omp.target host_eval(%x -> %arg0 : i32) {
+    %y = arith.constant 2 : i32
+    omp.parallel num_threads(%arg0 : i32) {
+      omp.wsloop {
+        omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+
   // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
   // CHECK: omp.teams {
   // CHECK: omp.distribute {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
new file mode 100644
index 0000000000000..8101660e571e4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -0,0 +1,111 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @main(%arg0 : !llvm.ptr) {
+    %x = llvm.load %arg0 : !llvm.ptr -> i32
+    %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
+      %x.map = llvm.load %ptr : !llvm.ptr -> i32
+      omp.teams {
+        omp.distribute {
+          omp.loop_nest (%iv1) : i32 = (%lb) to (%ub) step (%step) {
+            omp.parallel {
+              omp.wsloop {
+                omp.loop_nest (%iv2) : i32 = (%x.map) to (%x.map) step (%x.map) {
+                  omp.yield
+                }
+              }
+              omp.terminator
+            }
+            omp.yield
+          }
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// HOST-LABEL: define void @main
+// HOST:         %omp_loop.tripcount = {{.*}}
+// HOST-NEXT:    br label %[[ENTRY:.*]]
+// HOST:       [[ENTRY]]:
+// HOST:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST:       [[OFFLOAD_FAILED]]:
+// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[TARGET_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[TEAMS_OUTLINE]]
+// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// HOST:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[PARALLEL_OUTLINE]]
+// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+
+//--- device.mlir
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+  llvm.func @main(%arg0 : !llvm.ptr) {
+    %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+    omp.target map_entries(%0 -> %ptr : !llvm.ptr) {
+      %x = llvm.load %ptr : !llvm.ptr -> i32
+      omp.teams {
+        omp.distribute {
+          omp.loop_nest (%iv1) : i32 = (%x) to (%x) step (%x) {
+            omp.parallel {
+              omp.wsloop {
+                omp.loop_nest (%iv2) : i32 = (%x) to (%x) step (%x) {
+                  omp.yield
+                }
+              }
+              omp.terminator
+            }
+            omp.yield
+          }
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
+// DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},
+// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// DEVICE:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// DEVICE:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// DEVICE:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// DEVICE:        call void @__kmpc_target_deinit()
+
+// DEVICE:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// DEVICE:        call void @[[TEAMS_OUTLINE:.*]]({{.*}})
+
+// DEVICE:      define internal void @[[TEAMS_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_distribute_static_loop{{.*}}({{.*}}, ptr @[[DISTRIBUTE_OUTLINE:[^,]*]], {{.*}})
+
+// DEVICE:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// DEVICE:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_for_static_loop{{.*}}({{.*}})



More information about the Mlir-commits mailing list