[Mlir-commits] [mlir] [MLIR][OpenMP] Introduce the LoopWrapperInterface (PR #87232)

Sergio Afonso llvmlistbot at llvm.org
Mon Apr 15 02:24:55 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/87232

>From 281121e682cdf5df7914f8b8b0a3b77c773d51cb Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 29 Mar 2024 16:07:03 +0000
Subject: [PATCH 1/4] [MLIR][OpenMP] Add omp.loop_nest operation

This patch introduces an operation intended to hold loop information associated
to the `omp.distribute`, `omp.simdloop`, `omp.taskloop` and `omp.wsloop`
operations. This is a stopgap solution to unblock work on transitioning these
operations to becoming wrappers, as discussed in
[this RFC](https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986).

Long-term, this operation will likely be replaced by `omp.canonical_loop`,
which is being designed to address missing support for loop transformations,
etc.
---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 65 ++++++++++++++-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 71 +++++++++++++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 37 +++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 79 +++++++++++++++++++
 4 files changed, 251 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f33942b3c7c02d..ffd00948915153 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -511,6 +511,69 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Loop Nest
+//===----------------------------------------------------------------------===//
+
+def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
+                        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                        ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp",
+                                     "WsloopOp"]>,
+                        RecursiveMemoryEffects]> {
+  let summary = "rectangular loop nest";
+  let description = [{
+    This operation represents a collapsed rectangular loop nest. For each
+    rectangular loop of the nest represented by an instance of this operation,
+    lower and upper bounds, as well as a step variable, must be defined.
+
+    The lower and upper bounds specify a half-open range: the range includes the
+    lower bound but does not include the upper bound. If the `inclusive`
+    attribute is specified then the upper bound is also included.
+
+    The body region can contain any number of blocks. The region is terminated
+    by an `omp.yield` instruction without operands. The induction variables,
+    represented as entry block arguments to the loop nest operation's single
+    region, match the types of the `lowerBound`, `upperBound` and `step`
+    arguments.
+
+    ```mlir
+    omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+      %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+      %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+      %sum = arith.addf %a, %b : f32
+      store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+      omp.yield
+    }
+    ```
+
+    This is a temporary simplified definition of a loop based on existing OpenMP
+    loop operations intended to serve as a stopgap solution until the long-term
+    representation of canonical loops is defined. Specifically, this operation
+    is intended to serve as a unique source for loop information during the
+    transition to making `omp.distribute`, `omp.simdloop`, `omp.taskloop` and
+    `omp.wsloop` wrapper operations. It is not intended to help with the
+    addition of support for loop transformations.
+  }];
+
+  let arguments = (ins Variadic<IntLikeType>:$lowerBound,
+                       Variadic<IntLikeType>:$upperBound,
+                       Variadic<IntLikeType>:$step,
+                       UnitAttr:$inclusive);
+
+  let regions = (region AnyRegion:$region);
+
+  let extraClassDeclaration = [{
+    /// Returns the number of loops in the loop nest.
+    unsigned getNumLoops() { return getLowerBound().size(); }
+
+    /// Returns the induction variables of the loop nest.
+    ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // 2.9.2 Workshare Loop Construct
 //===----------------------------------------------------------------------===//
@@ -724,7 +787,7 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
 
 def YieldOp : OpenMP_Op<"yield",
     [Pure, ReturnLike, Terminator,
-     ParentOneOf<["WsloopOp", "DeclareReductionOp",
+     ParentOneOf<["LoopNestOp", "WsloopOp", "DeclareReductionOp",
      "AtomicUpdateOp", "SimdLoopOp", "PrivateClauseOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bf5875071e0dc4..796df1d13e6564 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1662,6 +1662,77 @@ LogicalResult TaskloopOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopNestOp
+//===----------------------------------------------------------------------===//
+
+ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Parse an opening `(` followed by induction variables followed by `)`
+  SmallVector<OpAsmParser::Argument> ivs;
+  SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs;
+  Type loopVarType;
+  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
+      parser.parseColonType(loopVarType) ||
+      // Parse loop bounds.
+      parser.parseEqual() ||
+      parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
+      parser.parseKeyword("to") ||
+      parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  for (auto &iv : ivs)
+    iv.type = loopVarType;
+
+  // Parse "inclusive" flag.
+  if (succeeded(parser.parseOptionalKeyword("inclusive")))
+    result.addAttribute("inclusive",
+                        UnitAttr::get(parser.getBuilder().getContext()));
+
+  // Parse step values.
+  SmallVector<OpAsmParser::UnresolvedOperand> steps;
+  if (parser.parseKeyword("step") ||
+      parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  // Parse the body.
+  Region *region = result.addRegion();
+  if (parser.parseRegion(*region, ivs))
+    return failure();
+
+  // Resolve operands.
+  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
+      parser.resolveOperands(ubs, loopVarType, result.operands) ||
+      parser.resolveOperands(steps, loopVarType, result.operands))
+    return failure();
+
+  // Parse the optional attribute list.
+  return parser.parseOptionalAttrDict(result.attributes);
+}
+
+void LoopNestOp::print(OpAsmPrinter &p) {
+  Region &region = getRegion();
+  auto args = region.getArguments();
+  p << " (" << args << ") : " << args[0].getType() << " = (" << getLowerBound()
+    << ") to (" << getUpperBound() << ") ";
+  if (getInclusive())
+    p << "inclusive ";
+  p << "step (" << getStep() << ") ";
+  p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
+LogicalResult LoopNestOp::verify() {
+  if (getLowerBound().size() != getIVs().size())
+    return emitOpError() << "number of range arguments and IVs do not match";
+
+  for (auto [lb, iv] : llvm::zip_equal(getLowerBound(), getIVs())) {
+    if (lb.getType() != iv.getType())
+      return emitOpError()
+             << "range argument type does not match corresponding IV type";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index a00383cf44057c..760ebb14d94121 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -87,6 +87,43 @@ func.func @proc_bind_once() {
 
 // -----
 
+func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
+  // expected-error at +1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}}
+  omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+    omp.yield
+  }
+}
+
+// -----
+
+func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // expected-error at +1 {{range argument type does not match corresponding IV type}}
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv2: i32):
+      omp.yield
+    }) : (index, index, index) -> ()
+    omp.yield
+  }
+}
+
+// -----
+
+func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // expected-error at +1 {{number of range arguments and IVs do not match}}
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv1 : index, %iv2 : index):
+      omp.yield
+    }) : (index, index, index) -> ()
+    omp.yield
+  }
+}
+
+// -----
+
 func.func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {
   // expected-error @below {{expected 'for'}}
   omp.wsloop nowait inclusive
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 30ce77423005ac..8d9acab67e0358 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -133,6 +133,85 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
   return
 }
 
+// CHECK-LABEL: omp_loop_nest
+func.func @omp_loop_nest(%lb : index, %ub : index, %step : index) -> () {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv2: index):
+      omp.yield
+    }) : (index, index, index) -> ()
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}})
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv2: index):
+      omp.yield
+    }) {inclusive} : (index, index, index) -> ()
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}, %{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
+    "omp.loop_nest" (%lb, %lb, %ub, %ub, %step, %step) ({
+    ^bb0(%iv2: index, %iv3: index):
+      omp.yield
+    }) : (index, index, index, index, index, index) -> ()
+    omp.yield
+  }
+
+  return
+}
+
+// CHECK-LABEL: omp_loop_nest_pretty
+func.func @omp_loop_nest_pretty(%lb : index, %ub : index, %step : index) -> () {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+    omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}})
+    omp.loop_nest (%iv2) : index = (%lb) to (%ub) inclusive step (%step) {
+      omp.yield
+    }
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
+    omp.loop_nest (%iv2, %iv3) : index = (%lb, %lb) to (%ub, %ub) step (%step, %step) {
+      omp.yield
+    }
+    omp.yield
+  }
+
+  return
+}
+
 // CHECK-LABEL: omp_wsloop
 func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref<i32>, %linear_var : i32, %chunk_var : i32) -> () {
 

>From 2452bc75a7f2efb67a0522bbe8b0e7ba5bc3365b Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Mon, 1 Apr 2024 13:04:14 +0100
Subject: [PATCH 2/4] [MLIR][OpenMP] Introduce the LoopWrapperInterface

This patch defines a common interface to be shared by all OpenMP loop wrapper
operations. The main restrictions these operations must meet in order to be
considered a wrapper are:

- They contain a single region.
- Their region contains a single block.
- Their block only contains another loop wrapper or `omp.loop_nest` and a
terminator.

The new interface is attached to the `omp.parallel`, `omp.wsloop`,
`omp.simdloop`, `omp.distribute` and `omp.taskloop` operations. It is not
currently enforced that these operations meet the wrapper restrictions, which
would break existing OpenMP loop-generating code. Rather, this will be
introduced progressively in subsequent patches.
---
 .../mlir/Dialect/OpenMP/OpenMPInterfaces.h    |  3 +
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 16 +++--
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 68 +++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 19 ++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 16 ++++-
 5 files changed, 117 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
index b3184db8852161..787c48b05c5c5c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
@@ -21,6 +21,9 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#define GET_OP_FWD_DEFINES
+#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
+
 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.h.inc"
 
 namespace mlir::omp {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ffd00948915153..a7bf93deae2fb3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -236,6 +236,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove]> {
 
 def ParallelOp : OpenMP_Op<"parallel", [
                  AutomaticAllocationScope, AttrSizedOperandSegments,
+                 DeclareOpInterfaceMethods<LoopWrapperInterface>,
                  DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
                  RecursiveMemoryEffects, ReductionClauseInterface]> {
   let summary = "parallel construct";
@@ -517,8 +518,6 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
 
 def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
                         AllTypesMatch<["lowerBound", "upperBound", "step"]>,
-                        ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp",
-                                     "WsloopOp"]>,
                         RecursiveMemoryEffects]> {
   let summary = "rectangular loop nest";
   let description = [{
@@ -568,6 +567,10 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
 
     /// Returns the induction variables of the loop nest.
     ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
+
+    /// Returns the list of wrapper operations around this loop nest. Wrappers
+    /// in the resulting vector will be sorted from innermost to outermost.
+    SmallVector<LoopWrapperInterface> getWrappers();
   }];
 
   let hasCustomAssemblyFormat = 1;
@@ -580,6 +583,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
 
 def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
                          AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                         DeclareOpInterfaceMethods<LoopWrapperInterface>,
                          RecursiveMemoryEffects, ReductionClauseInterface]> {
   let summary = "worksharing-loop construct";
   let description = [{
@@ -700,7 +704,9 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
 //===----------------------------------------------------------------------===//
 
 def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
-                         AllTypesMatch<["lowerBound", "upperBound", "step"]>]> {
+                         AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                         DeclareOpInterfaceMethods<LoopWrapperInterface>,
+                         RecursiveMemoryEffects]> {
  let summary = "simd loop construct";
   let description = [{
     The simd construct can be applied to a loop to indicate that the loop can be
@@ -809,7 +815,8 @@ def YieldOp : OpenMP_Op<"yield",
 // Distribute construct [2.9.4.1]
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
-                              MemoryEffects<[MemWrite]>]> {
+                             DeclareOpInterfaceMethods<LoopWrapperInterface>,
+                             RecursiveMemoryEffects]> {
   let summary = "distribute construct";
   let description = [{
     The distribute construct specifies that the iterations of one or more loops
@@ -980,6 +987,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
 def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
                            AutomaticAllocationScope, RecursiveMemoryEffects,
                            AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                           DeclareOpInterfaceMethods<LoopWrapperInterface>,
                            ReductionClauseInterface]> {
   let summary = "taskloop construct";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 2e37384ce3eb71..b6a3560b7da56a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -69,6 +69,74 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
   ];
 }
 
+def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
+  let description = [{
+    OpenMP operations that can wrap a single loop nest. When taking a wrapper
+    role, these operations must only contain a single region with a single block
+    in which there's a single operation and a terminator. That nested operation
+    must be another loop wrapper or an `omp.loop_nest`.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/[{
+        Tell whether the operation could be taking the role of a loop wrapper.
+        That is, it has a single region with a single block in which there are
+        two operations: another wrapper or `omp.loop_nest` operation and a
+        terminator.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isWrapper",
+      (ins ), [{}], [{
+        if ($_op->getNumRegions() != 1)
+          return false;
+
+        ::mlir::Region &r = $_op->getRegion(0);
+        if (!r.hasOneBlock())
+          return false;
+
+        if (std::distance(r.op_begin(), r.op_end()) != 2)
+          return false;
+
+        ::mlir::Operation &firstOp = *r.op_begin();
+        ::mlir::Operation &secondOp = *(++r.op_begin());
+        return ::llvm::isa<::mlir::omp::LoopNestOp,
+                           ::mlir::omp::LoopWrapperInterface>(firstOp) &&
+               secondOp.hasTrait<::mlir::OpTrait::IsTerminator>();
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        If there is another loop wrapper immediately nested inside, return that
+        operation. Assumes this operation is taking a loop wrapper role.
+      }],
+      /*retTy=*/"::mlir::omp::LoopWrapperInterface",
+      /*methodName=*/"getNestedWrapper",
+      (ins), [{}], [{
+        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin();
+        return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested);
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        Return the loop nest nested directly or indirectly inside of this loop
+        wrapper. Assumes this operation is taking a loop wrapper role.
+      }],
+      /*retTy=*/"::mlir::Operation *",
+      /*methodName=*/"getWrappedLoop",
+      (ins), [{}], [{
+        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper())
+          return nested.getWrappedLoop();
+        return &*$_op->getRegion(0).op_begin();
+      }]
+    >
+  ];
+}
+
 def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
   let description = [{
     OpenMP operations that support declare target have this interface.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 796df1d13e6564..564c23201db4fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1730,9 +1730,28 @@ LogicalResult LoopNestOp::verify() {
              << "range argument type does not match corresponding IV type";
   }
 
+  auto wrapper =
+      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
+
+  if (!wrapper || !wrapper.isWrapper())
+    return emitOpError() << "expects parent op to be a valid loop wrapper";
+
   return success();
 }
 
+SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
+  SmallVector<LoopWrapperInterface> wrappers;
+  Operation *parent = (*this)->getParentOp();
+  while (auto wrapper =
+             llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
+    if (!wrapper.isWrapper())
+      break;
+    wrappers.push_back(wrapper);
+    parent = parent->getParentOp();
+  }
+  return wrappers;
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 760ebb14d94121..8f4103dabee5df 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -88,7 +88,7 @@ func.func @proc_bind_once() {
 // -----
 
 func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
-  // expected-error at +1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}}
+  // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
   omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
     omp.yield
   }
@@ -96,6 +96,20 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
 
 // -----
 
+func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    %0 = arith.constant 0 : i32
+    // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
+    omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+    omp.yield
+  }
+}
+
+// -----
+
 func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
   // TODO Remove induction variables from omp.wsloop.
   omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {

>From e3c440c9df93eedad1e3a96d58a759c869574c20 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 2 Apr 2024 14:15:25 +0100
Subject: [PATCH 3/4] Update op description according to review comments

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ffd00948915153..3d87585d52847c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -552,7 +552,8 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
     is intended to serve as a unique source for loop information during the
     transition to making `omp.distribute`, `omp.simdloop`, `omp.taskloop` and
     `omp.wsloop` wrapper operations. It is not intended to help with the
-    addition of support for loop transformations.
+    addition of support for loop transformations, non-rectangular loops and
+    non-perfectly nested loops.
   }];
 
   let arguments = (ins Variadic<IntLikeType>:$lowerBound,

>From 904f27489b0d3c27e773f085b18a9b85cb548f44 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 2 Apr 2024 15:36:41 +0100
Subject: [PATCH 4/4] Address review comments

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  4 ++--
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 19 +++++++++----------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  5 ++---
 3 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a7bf93deae2fb3..50627712ea3109 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -568,9 +568,9 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
     /// Returns the induction variables of the loop nest.
     ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
 
-    /// Returns the list of wrapper operations around this loop nest. Wrappers
+    /// Fills a list of wrapper operations around this loop nest. Wrappers
     /// in the resulting vector will be sorted from innermost to outermost.
-    SmallVector<LoopWrapperInterface> getWrappers();
+    void gatherWrappers(SmallVectorImpl<LoopWrapperInterface> &wrappers);
   }];
 
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index b6a3560b7da56a..ab9b78e755d9d5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -93,18 +93,17 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
         if ($_op->getNumRegions() != 1)
           return false;
 
-        ::mlir::Region &r = $_op->getRegion(0);
+        Region &r = $_op->getRegion(0);
         if (!r.hasOneBlock())
           return false;
 
-        if (std::distance(r.op_begin(), r.op_end()) != 2)
+        if (::llvm::range_size(r.getOps()) != 2)
           return false;
 
-        ::mlir::Operation &firstOp = *r.op_begin();
-        ::mlir::Operation &secondOp = *(++r.op_begin());
-        return ::llvm::isa<::mlir::omp::LoopNestOp,
-                           ::mlir::omp::LoopWrapperInterface>(firstOp) &&
-               secondOp.hasTrait<::mlir::OpTrait::IsTerminator>();
+        Operation &firstOp = *r.op_begin();
+        Operation &secondOp = *(std::next(r.op_begin()));
+        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
+               secondOp.hasTrait<OpTrait::IsTerminator>();
       }]
     >,
     InterfaceMethod<
@@ -116,8 +115,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*methodName=*/"getNestedWrapper",
       (ins), [{}], [{
         assert($_op.isWrapper() && "Unexpected non-wrapper op");
-        ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin();
-        return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested);
+        Operation *nested = &*$_op->getRegion(0).op_begin();
+        return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
       }]
     >,
     InterfaceMethod<
@@ -129,7 +128,7 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*methodName=*/"getWrappedLoop",
       (ins), [{}], [{
         assert($_op.isWrapper() && "Unexpected non-wrapper op");
-        if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper())
+        if (LoopWrapperInterface nested = $_op.getNestedWrapper())
           return nested.getWrappedLoop();
         return &*$_op->getRegion(0).op_begin();
       }]
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 564c23201db4fd..a7d265328df6ef 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1739,8 +1739,8 @@ LogicalResult LoopNestOp::verify() {
   return success();
 }
 
-SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
-  SmallVector<LoopWrapperInterface> wrappers;
+void LoopNestOp::gatherWrappers(
+    SmallVectorImpl<LoopWrapperInterface> &wrappers) {
   Operation *parent = (*this)->getParentOp();
   while (auto wrapper =
              llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
@@ -1749,7 +1749,6 @@ SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
     wrappers.push_back(wrapper);
     parent = parent->getParentOp();
   }
-  return wrappers;
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list