[llvm-branch-commits] [mlir] [MLIR][OpenMP] Explicit tagging of combined constructs (PR #198782)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed May 20 06:36:53 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

Combined OpenMP constructs, such as `parallel do`, which represent nests of constructs where each one contains a single other construct without any other directives or statements in between, are currently not marked in any way in the MLIR representation.

This works because they don't usually require any specific handling other than what would be done for the included operations. However, the handling of `target` regions needs to know whether it was part of a combined construct in order to properly optimize for the SPMD case and detect when certain clauses must be inconditionally evaluated in the host.

So far, this has been achieved by having some MLIR pattern-matching logic to infer whether a nest of operations could have potentially been produced for a combined construct. This approach is error prone, computationally expensive and it can't really work in the general case. On the other hand, a compiler frontend can easily tell the difference and tag MLIR operations accordingly.

This patch extends the `ComposableOpInterface` of the OpenMP dialect to handle a new `omp.combined` attribute that must be set for all leafs (except for the innermost one) on a combined construct. Verification logic is added for this interface, which is added to all operations that can be used as part of a combined construct, and the previous `target`-related pattern-matching logic is removed.

This patch has to be followed up with Flang lowering changes to pass all unit tests.

---

Patch is 72.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/198782.diff


41 Files Affected:

- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+29-33) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+43-2) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+137-134) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+4-2) 
- (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+2-2) 
- (added) mlir/test/Dialect/OpenMP/invalid-interface.mlir (+106) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+113-66) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+33-33) 
- (modified) mlir/test/Dialect/OpenMP/stack-to-shared.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/allocatable_gpu_reduction.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/allocatable_gpu_reduction_teams.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/omptarget-debug-loop-loc.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/omptarget-memcpy-align-metadata.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/omptarget-multi-block-reduction.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/omptarget-multi-reduction.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/omptarget-private-llvm.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/omptarget-teams-distribute-reduction-array-descriptor.mlir (+4-4) 
- (modified) mlir/test/Target/LLVMIR/openmp-data-target-device.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-nested-task-target-parallel.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+4-4) 
- (modified) mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-target-spmd.mlir (+4-4) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-bounds-cast.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-cancel.mlir (+3-3) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-cancellation-point.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-collapse.mlir (+6-6) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-context-alloca.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-final.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-grainsize.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-if.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-local-bounds.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-mergeable.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-no-context-struct.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-nogroup.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-num_tasks.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-outer-bounds.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-priority.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop-untied.mlir (+2-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-taskloop.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/openmp-teams-clauses-trunc-ext.mlir (+24-24) 
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 21b1df3a7e608..3d64bfb21602c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -238,7 +238,8 @@ def TerminatorOp : OpenMP_Op<"terminator", [Terminator, Pure]> {
 // 2.7 teams Construct
 //===----------------------------------------------------------------------===//
 def TeamsOp : OpenMP_Op<"teams", traits = [
-    AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
+    AttrSizedOperandSegments, DeclareOpInterfaceMethods<ComposableOpInterface>,
+    RecursiveMemoryEffects, OutlineableOpenMPOpInterface
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
     OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
@@ -293,7 +294,7 @@ def SectionOp : OpenMP_Op<"section", traits = [
 }
 
 def SectionsOp : OpenMP_Op<"sections", traits = [
-    AttrSizedOperandSegments
+    AttrSizedOperandSegments, DeclareOpInterfaceMethods<ComposableOpInterface>
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_NowaitClause, OpenMP_PrivateClause,
     OpenMP_ReductionClause
@@ -331,7 +332,7 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 //===----------------------------------------------------------------------===//
 
 def SingleOp : OpenMP_Op<"single", traits = [
-    AttrSizedOperandSegments
+    AttrSizedOperandSegments, DeclareOpInterfaceMethods<ComposableOpInterface>
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_CopyprivateClause, OpenMP_NowaitClause,
     OpenMP_PrivateClause
@@ -612,7 +613,7 @@ def FuseOp
 //===----------------------------------------------------------------------===//
 
 def WorkshareOp : OpenMP_Op<"workshare", traits = [
-    RecursiveMemoryEffects,
+    DeclareOpInterfaceMethods<ComposableOpInterface>, RecursiveMemoryEffects
   ], clauses = [
     OpenMP_NowaitClause,
   ], singleRegion = true> {
@@ -631,6 +632,7 @@ def WorkshareOp : OpenMP_Op<"workshare", traits = [
   let builders = [
     OpBuilder<(ins CArg<"const WorkshareOperands &">:$clauses)>
   ];
+  let hasVerifier = 1;
 }
 
 def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
@@ -957,18 +959,16 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [
 // 2.10.1 task Construct
 //===----------------------------------------------------------------------===//
 
-def TaskOp
-    : OpenMP_Op<"task",
-                traits = [AttrSizedOperandSegments, AutomaticAllocationScope,
-                          OutlineableOpenMPOpInterface],
-                clauses = [
-                    // TODO: Complete clause list (detach).
-                    OpenMP_AffinityClause, OpenMP_AllocateClause,
-                    OpenMP_DependClause, OpenMP_FinalClause, OpenMP_IfClause,
-                    OpenMP_InReductionClause, OpenMP_MergeableClause,
-                    OpenMP_PriorityClause, OpenMP_PrivateClause,
-                    OpenMP_UntiedClause, OpenMP_DetachClause],
-                singleRegion = true> {
+def TaskOp : OpenMP_Op<"task", traits = [
+    AttrSizedOperandSegments, AutomaticAllocationScope,
+    DeclareOpInterfaceMethods<ComposableOpInterface>,
+    OutlineableOpenMPOpInterface
+  ], clauses = [
+    OpenMP_AffinityClause, OpenMP_AllocateClause, OpenMP_DependClause,
+    OpenMP_FinalClause, OpenMP_IfClause, OpenMP_InReductionClause,
+    OpenMP_MergeableClause, OpenMP_PriorityClause, OpenMP_PrivateClause,
+    OpenMP_UntiedClause, OpenMP_DetachClause
+  ], singleRegion = true> {
   let summary = "task construct";
   let description = [{
     The task construct defines an explicit task.
@@ -1007,6 +1007,7 @@ def TaskOp
 def TaskloopContextOp : OpenMP_Op<"taskloop.context", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
     RecursiveMemoryEffects, SingleBlock,
+    DeclareOpInterfaceMethods<ComposableOpInterface>,
     DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_FinalClause, OpenMP_GrainsizeClause,
@@ -1412,7 +1413,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
 //===---------------------------------------------------------------------===//
 
 def TargetDataOp: OpenMP_Op<"target_data", traits = [
-    AttrSizedOperandSegments
+    AttrSizedOperandSegments, DeclareOpInterfaceMethods<ComposableOpInterface>
   ], clauses = [
     OpenMP_DeviceClause, OpenMP_IfClause, OpenMP_MapClause,
     OpenMP_UseDeviceAddrClause, OpenMP_UseDevicePtrClause
@@ -1574,7 +1575,8 @@ def TargetUpdateOp: OpenMP_Op<"target_update", traits = [
 //===----------------------------------------------------------------------===//
 
 def TargetOp : OpenMP_Op<"target", traits = [
-    AttrSizedOperandSegments, BlockArgOpenMPOpInterface, IsolatedFromAbove,
+    AttrSizedOperandSegments, BlockArgOpenMPOpInterface,
+    DeclareOpInterfaceMethods<ComposableOpInterface>, IsolatedFromAbove,
     OutlineableOpenMPOpInterface
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
@@ -1632,18 +1634,6 @@ def TargetOp : OpenMP_Op<"target", traits = [
       return getMapVars()[mapInfoOpIdx];
     }
 
-    /// Returns the innermost OpenMP dialect operation captured by this target
-    /// construct. For an operation to be detected as captured, it must be
-    /// inside a (possibly multi-level) nest of OpenMP dialect operation's
-    /// regions where none of these levels contain other operations considered
-    /// not-allowed for these purposes (i.e. only terminator operations are
-    /// allowed from the OpenMP dialect, and other dialect's operations are
-    /// allowed as long as they don't have a memory write effect).
-    ///
-    /// If there are omp.loop_nest operations in the sequence of nested
-    /// operations, the top level one will be the one captured.
-    Operation *getInnermostCapturedOmpOp();
-
     /// Returns whether this kernel requires host evaluation of loop trip count.
     bool hasHostEvalTripCount();
   }] # clausesExtraClassDeclaration;
@@ -1667,7 +1657,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
 //===----------------------------------------------------------------------===//
 // 2.16 master Construct
 //===----------------------------------------------------------------------===//
-def MasterOp : OpenMP_Op<"master", singleRegion = true> {
+def MasterOp : OpenMP_Op<"master", traits = [
+    DeclareOpInterfaceMethods<ComposableOpInterface>
+  ], singleRegion = true> {
   let summary = "master construct";
   let description = [{
     The master construct specifies a structured block that is executed by
@@ -2218,7 +2210,9 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
 //===----------------------------------------------------------------------===//
 // [Spec 5.2] 10.5 masked Construct
 //===----------------------------------------------------------------------===//
-def MaskedOp : OpenMP_Op<"masked", clauses = [
+def MaskedOp : OpenMP_Op<"masked", traits = [
+    DeclareOpInterfaceMethods<ComposableOpInterface>
+  ], clauses = [
     OpenMP_FilterClause
   ], singleRegion = 1> {
   let summary = "masked construct";
@@ -2429,7 +2423,9 @@ def FreeSharedMemOp : OpenMP_Op<"free_shared_mem", traits = [
 // workdistribute Construct
 //===----------------------------------------------------------------------===//
 
-def WorkdistributeOp : OpenMP_Op<"workdistribute"> {
+def WorkdistributeOp : OpenMP_Op<"workdistribute", traits = [
+    DeclareOpInterfaceMethods<ComposableOpInterface>
+  ]> {
   let summary = "workdistribute directive";
   let description = [{
     workdistribute divides execution of the enclosed structured block into
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index fd500134e10f9..51f925b17f47e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -281,8 +281,9 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
 
 def ComposableOpInterface : OpInterface<"ComposableOpInterface"> {
   let description = [{
-    OpenMP operations that can represent a single leaf of a composite OpenMP
-    construct.
+    OpenMP operations that can represent a single leaf of a compound OpenMP
+    construct, or one leaf of a combined construct or equivalent immediate
+    nesting of constructs.
   }];
 
   let cppNamespace = "::mlir::omp";
@@ -311,8 +312,48 @@ def ComposableOpInterface : OpInterface<"ComposableOpInterface"> {
         else
           $_op->removeDiscardableAttr("omp.composite");
       }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        Check whether the operation is representing a non-innermost child leaf
+        of a combined OpenMP construct or an equivalent immediate nesting of
+        constructs.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isCombined",
+      (ins ), [{}], [{
+        return $_op->hasAttr("omp.combined");
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        Mark the operation as a non-innermost child leaf of an OpenMP combined
+        construct or an equivalent immediate nesting of constructs.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"setCombined",
+      (ins "bool":$val), [{}], [{
+        if (val)
+          $_op->setDiscardableAttr("omp.combined", mlir::UnitAttr::get($_op->getContext()));
+        else
+          $_op->removeDiscardableAttr("omp.combined");
+      }]
     >
   ];
+
+  let extraClassDeclaration = [{
+    /// Follows the combined-composite chain of nested operations and returns
+    /// the deepest-nested one.
+    Operation *findCapturedOp();
+
+    /// Interface verifier implementation.
+    llvm::LogicalResult verifyImpl();
+  }];
+
+  let verify = [{
+    return ::llvm::cast<::mlir::omp::ComposableOpInterface>($_op).verifyImpl();
+  }];
+  let verifyWithRegions = 1;
 }
 
 def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6a3a4bcf6a274..cb1efd49666c6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2446,7 +2446,8 @@ bool TargetOp::hasHostEvalTripCount() {
 
   // If it represents a `target teams distribute` construct, also evaluate the
   // `distribute` trip count on the host.
-  Operation *capturedOp = getInnermostCapturedOmpOp();
+  Operation *capturedOp =
+      cast<ComposableOpInterface>(getOperation()).findCapturedOp();
   if (auto loopNestOp = dyn_cast_if_present<LoopNestOp>(capturedOp)) {
     SmallVector<LoopWrapperInterface> loopWrappers;
     loopNestOp.gatherWrappers(loopWrappers);
@@ -2472,6 +2473,9 @@ bool TargetOp::hasHostEvalTripCount() {
 }
 
 LogicalResult TargetOp::verify() {
+  if (getKernelType() == TargetExecMode::bare && !isCombined())
+    return emitOpError() << "bare kernel requires 'omp.combined'";
+
   if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
                                  getDependIteratedKinds(),
                                  getDependIterated())))
@@ -2493,21 +2497,17 @@ LogicalResult TargetOp::verifyRegions() {
   if (numNestedTeams > 1)
     return emitError("target containing multiple 'omp.teams' nested ops");
 
-  if (numNestedTeams == 0 && getKernelType() == TargetExecMode::bare)
+  if (getKernelType() == TargetExecMode::bare && numNestedTeams == 0)
     return emitOpError()
            << "bare kernel must contain a nested 'omp.teams' operation";
 
-  if (getKernelType() == TargetExecMode::spmd ||
-      getKernelType() == TargetExecMode::spmd_no_loop) {
-    bool containsLoop = getRegion()
-                            .walk<WalkOrder::PreOrder>([](LoopNestOp loopOp) {
-                              return WalkResult::interrupt();
-                            })
-                            .wasInterrupted();
-    if (!containsLoop)
-      return emitOpError()
-             << "SPMD kernel must contain a nested 'omp.loop_nest' operation";
-  }
+  Operation *capturedOp =
+      cast<ComposableOpInterface>(getOperation()).findCapturedOp();
+  if ((getKernelType() == TargetExecMode::spmd ||
+       getKernelType() == TargetExecMode::spmd_no_loop) &&
+      !isa_and_present<LoopNestOp>(capturedOp))
+    return emitOpError()
+           << "SPMD kernel must capture an 'omp.loop_nest' operation";
 
   bool isTargetDevice = false;
   if (auto offloadMod = (*this)->getParentOfType<OffloadModuleInterface>())
@@ -2518,9 +2518,6 @@ LogicalResult TargetOp::verifyRegions() {
   llvm::ArrayRef<BlockArgument> hostEvalBlockArgs =
       cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs();
 
-  if (!hostEvalBlockArgs.empty() && isTargetDevice)
-    emitOpError() << "'host_eval' is only supported during host compilation";
-
   bool hostEvalTripCount = hasHostEvalTripCount();
   for (Value hostEvalArg : hostEvalBlockArgs) {
     for (Operation *user : hostEvalArg.getUsers()) {
@@ -2560,121 +2557,19 @@ LogicalResult TargetOp::verifyRegions() {
   }
 
   if (hostEvalTripCount && !isTargetDevice) {
-    if (auto loopOp = dyn_cast<LoopNestOp>(getInnermostCapturedOmpOp())) {
-      for (auto arg : llvm::concat<Value>(loopOp.getLoopLowerBounds(),
-                                          loopOp.getLoopUpperBounds(),
-                                          loopOp.getLoopSteps())) {
-        if (!llvm::is_contained(hostEvalBlockArgs, arg))
-          return emitOpError() << "nested 'omp.loop_nest' bounds expected to "
-                                  "be host-evaluated";
-      }
+    auto loopOp = cast<LoopNestOp>(capturedOp);
+    for (auto arg : llvm::concat<Value>(loopOp.getLoopLowerBounds(),
+                                        loopOp.getLoopUpperBounds(),
+                                        loopOp.getLoopSteps())) {
+      if (!llvm::is_contained(hostEvalBlockArgs, arg))
+        return emitOpError() << "nested 'omp.loop_nest' bounds expected to "
+                                "be host-evaluated";
     }
   }
 
   return success();
 }
 
-static Operation *
-findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
-                  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
-  assert(rootOp && "expected valid operation");
-
-  Dialect *ompDialect = rootOp->getDialect();
-  Operation *capturedOp = nullptr;
-  DominanceInfo domInfo;
-
-  // Process in pre-order to check operations from outermost to innermost,
-  // ensuring we only enter the region of an operation if it meets the criteria
-  // for being captured. We stop the exploration of nested operations as soon as
-  // we process a region holding no operations to be captured.
-  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
-    if (op == rootOp)
-      return WalkResult::advance();
-
-    // Ignore operations of other dialects or omp operations with no regions,
-    // because these will only be checked if they are siblings of an omp
-    // operation that can potentially be captured.
-    bool isOmpDialect = op->getDialect() == ompDialect;
-    bool hasRegions = op->getNumRegions() > 0;
-    if (!isOmpDialect || !hasRegions)
-      return WalkResult::skip();
-
-    // This operation cannot be captured if it can be executed more than once
-    // (i.e. its block's successors can reach it) or if it's not guaranteed to
-    // be executed before all exits of the region (i.e. it doesn't dominate all
-    // blocks with no successors reachable from the entry block).
-    if (checkSingleMandatoryExec) {
-      Region *parentRegion = op->getParentRegion();
-      Block *parentBlock = op->getBlock();
-
-      for (Block *successor : parentBlock->getSuccessors())
-        if (successor->isReachable(parentBlock))
-          return WalkResult::interrupt();
-
-      for (Block &block : *parentRegion)
-        if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
-            !domInfo.dominates(parentBlock, &block))
-          return WalkResult::interrupt();
-    }
-
-    // Don't capture this op if it has a not-allowed sibling, and stop recursing
-    // into nested operations.
-    for (Operation &sibling : op->getParentRegion()->getOps())
-      if (&sibling != op && !siblingAllowedFn(&sibling))
-        return WalkResult::interrupt();
-
-    // Don't continue capturing nested operations if we reach an omp.loop_nest.
-    // Otherwise, process the contents of this operation.
-    capturedOp = op;
-    return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
-                                     : WalkResult::advance();
-  });
-
-  return capturedOp;
-}
-
-Operation *TargetOp::getInnermostCapturedOmpOp() {
-  // If this is an SPMD kernel, then just attempt to find the first available
-  // omp.loop_nest. If the kernel type has been properly set, that must be the
-  // captured loop.
-  if (getKernelType() == TargetExecMode::spmd ||
-      getKernelType() == TargetExecMode::spmd_no_loop) {
-    Operation *spmdLoop = nullptr;
-    getRegion().walk<WalkOrder::PreOrder>([&spmdLoop](LoopNestOp loopOp) {
-      spmdLoop = loopOp.getOperation();
-      return WalkResult::interrupt();
-    });
-    assert(spmdLoop && "SPMD target regions must contain a loop");
-    return spmdLoop;
-  }
-
-  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
-
-  // Only allow OpenMP terminators and non-OpenMP ops that either have known
-  // memory effects excluding memory write effects, or are pure.
-  return findCapturedOmpOp(
-      *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
-        if (!sibling)
-          return false;
-
-        if (ompDialect == sibling->getDialect())
-          return sibling->hasTrait<OpTrait::IsTerminator>();
-
-        if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
-          SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
-              effects;
-          memOp.getEffects(effects);
-          return !llvm::any_of(
-              effects, [&](MemoryEffects::EffectInstance &effect) {
-                return isa<MemoryEffects::Write>(effect.getEffect()) &&
-                       isa<SideEffects::AutomaticAllocationScopeResource>(
-                           effect.getResource());
-              });
-        }
-        return isPure(sibling);
-      });
-}
-
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
@@ -2892,6 +2787,9 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult SectionsOp::verify() {
+  if (isCombined())
+    return emitOpError() << "cannot be a non-innermost combined construct leaf";
+
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
@@ -2973,6 +2871,13 @@ void WorkshareOp::build(OpBuilder &builder, OperationState &state,
   WorkshareOp::build(builder, state, clauses.nowait);
 }
 
+LogicalResult WorkshareOp::verify() {
+  if (isCombined())
+    return emitOpError() << "cannot be a non-innermost combined construct leaf";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WorkshareLoopWrapperOp
 //===----------------------------------------------------------------------===//
@@ -3018,6 +2923,98 @@ LogicalResult LoopWrapperInterface::verifyImpl() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ComposableOpInterface
+//===----------------------------------------------------------------------===//
+
+Operation *ComposableOpInterface::findCapturedOp() {
+  Operation *op = this->getOperation();
+
+  // Handle the composite case by returning the wrapped omp.loop_nest.
+  if (auto wrapperOp = dyn_cast<LoopWrapperInterface>(op))
+    return wrapperOp.getWrappedLoop();
+
+  // Do not look further if this op is not combined with any of its children.
+  // Need to check for composite for the omp.parallel case, which is not a loop
+  // wrapper itself.
+  if (!isCombined() && !isComposite())
+    return op;
+
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/198782


More information about the llvm-branch-commits mailing list