[Mlir-commits] [mlir] [MLIR][OpenMP] Make omp.distribute into a loop wrapper (PR #87239)

Sergio Afonso llvmlistbot at llvm.org
Tue Apr 16 02:15:01 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/87239

>From 281121e682cdf5df7914f8b8b0a3b77c773d51cb Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 29 Mar 2024 16:07:03 +0000
Subject: [PATCH 1/5] [MLIR][OpenMP] Add omp.loop_nest operation

This patch introduces an operation intended to hold loop information associated
to the `omp.distribute`, `omp.simdloop`, `omp.taskloop` and `omp.wsloop`
operations. This is a stopgap solution to unblock work on transitioning these
operations to becoming wrappers, as discussed in
[this RFC](https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986).

Long-term, this operation will likely be replaced by `omp.canonical_loop`,
which is being designed to address missing support for loop transformations,
etc.
---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 65 ++++++++++++++-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 71 +++++++++++++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 37 +++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 79 +++++++++++++++++++
 4 files changed, 251 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f33942b3c7c02d..ffd00948915153 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -511,6 +511,69 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// Loop Nest
+//===----------------------------------------------------------------------===//
+
+def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
+                        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                        ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp",
+                                     "WsloopOp"]>,
+                        RecursiveMemoryEffects]> {
+  let summary = "rectangular loop nest";
+  let description = [{
+    This operation represents a collapsed rectangular loop nest. For each
+    rectangular loop of the nest represented by an instance of this operation,
+    lower and upper bounds, as well as a step variable, must be defined.
+
+    The lower and upper bounds specify a half-open range: the range includes the
+    lower bound but does not include the upper bound. If the `inclusive`
+    attribute is specified then the upper bound is also included.
+
+    The body region can contain any number of blocks. The region is terminated
+    by an `omp.yield` instruction without operands. The induction variables,
+    represented as entry block arguments to the loop nest operation's single
+    region, match the types of the `lowerBound`, `upperBound` and `step`
+    arguments.
+
+    ```mlir
+    omp.loop_nest (%i1, %i2) : i32 = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+      %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+      %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+      %sum = arith.addf %a, %b : f32
+      store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+      omp.yield
+    }
+    ```
+
+    This is a temporary simplified definition of a loop based on existing OpenMP
+    loop operations intended to serve as a stopgap solution until the long-term
+    representation of canonical loops is defined. Specifically, this operation
+    is intended to serve as a unique source for loop information during the
+    transition to making `omp.distribute`, `omp.simdloop`, `omp.taskloop` and
+    `omp.wsloop` wrapper operations. It is not intended to help with the
+    addition of support for loop transformations.
+  }];
+
+  let arguments = (ins Variadic<IntLikeType>:$lowerBound,
+                       Variadic<IntLikeType>:$upperBound,
+                       Variadic<IntLikeType>:$step,
+                       UnitAttr:$inclusive);
+
+  let regions = (region AnyRegion:$region);
+
+  let extraClassDeclaration = [{
+    /// Returns the number of loops in the loop nest.
+    unsigned getNumLoops() { return getLowerBound().size(); }
+
+    /// Returns the induction variables of the loop nest.
+    ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // 2.9.2 Workshare Loop Construct
 //===----------------------------------------------------------------------===//
@@ -724,7 +787,7 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
 
 def YieldOp : OpenMP_Op<"yield",
     [Pure, ReturnLike, Terminator,
-     ParentOneOf<["WsloopOp", "DeclareReductionOp",
+     ParentOneOf<["LoopNestOp", "WsloopOp", "DeclareReductionOp",
      "AtomicUpdateOp", "SimdLoopOp", "PrivateClauseOp"]>]> {
   let summary = "loop yield and termination operation";
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bf5875071e0dc4..796df1d13e6564 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1662,6 +1662,77 @@ LogicalResult TaskloopOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopNestOp
+//===----------------------------------------------------------------------===//
+
+ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
+  // Parse an opening `(` followed by induction variables followed by `)`
+  SmallVector<OpAsmParser::Argument> ivs;
+  SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs;
+  Type loopVarType;
+  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
+      parser.parseColonType(loopVarType) ||
+      // Parse loop bounds.
+      parser.parseEqual() ||
+      parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
+      parser.parseKeyword("to") ||
+      parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  for (auto &iv : ivs)
+    iv.type = loopVarType;
+
+  // Parse "inclusive" flag.
+  if (succeeded(parser.parseOptionalKeyword("inclusive")))
+    result.addAttribute("inclusive",
+                        UnitAttr::get(parser.getBuilder().getContext()));
+
+  // Parse step values.
+  SmallVector<OpAsmParser::UnresolvedOperand> steps;
+  if (parser.parseKeyword("step") ||
+      parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
+    return failure();
+
+  // Parse the body.
+  Region *region = result.addRegion();
+  if (parser.parseRegion(*region, ivs))
+    return failure();
+
+  // Resolve operands.
+  if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
+      parser.resolveOperands(ubs, loopVarType, result.operands) ||
+      parser.resolveOperands(steps, loopVarType, result.operands))
+    return failure();
+
+  // Parse the optional attribute list.
+  return parser.parseOptionalAttrDict(result.attributes);
+}
+
+void LoopNestOp::print(OpAsmPrinter &p) {
+  Region &region = getRegion();
+  auto args = region.getArguments();
+  p << " (" << args << ") : " << args[0].getType() << " = (" << getLowerBound()
+    << ") to (" << getUpperBound() << ") ";
+  if (getInclusive())
+    p << "inclusive ";
+  p << "step (" << getStep() << ") ";
+  p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
+LogicalResult LoopNestOp::verify() {
+  if (getLowerBound().size() != getIVs().size())
+    return emitOpError() << "number of range arguments and IVs do not match";
+
+  for (auto [lb, iv] : llvm::zip_equal(getLowerBound(), getIVs())) {
+    if (lb.getType() != iv.getType())
+      return emitOpError()
+             << "range argument type does not match corresponding IV type";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index a00383cf44057c..760ebb14d94121 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -87,6 +87,43 @@ func.func @proc_bind_once() {
 
 // -----
 
+func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
+  // expected-error at +1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}}
+  omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+    omp.yield
+  }
+}
+
+// -----
+
+func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // expected-error at +1 {{range argument type does not match corresponding IV type}}
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv2: i32):
+      omp.yield
+    }) : (index, index, index) -> ()
+    omp.yield
+  }
+}
+
+// -----
+
+func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // expected-error at +1 {{number of range arguments and IVs do not match}}
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv1 : index, %iv2 : index):
+      omp.yield
+    }) : (index, index, index) -> ()
+    omp.yield
+  }
+}
+
+// -----
+
 func.func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {
   // expected-error @below {{expected 'for'}}
   omp.wsloop nowait inclusive
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 30ce77423005ac..8d9acab67e0358 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -133,6 +133,85 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
   return
 }
 
+// CHECK-LABEL: omp_loop_nest
+func.func @omp_loop_nest(%lb : index, %ub : index, %step : index) -> () {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv2: index):
+      omp.yield
+    }) : (index, index, index) -> ()
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}})
+    "omp.loop_nest" (%lb, %ub, %step) ({
+    ^bb0(%iv2: index):
+      omp.yield
+    }) {inclusive} : (index, index, index) -> ()
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}, %{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
+    "omp.loop_nest" (%lb, %lb, %ub, %ub, %step, %step) ({
+    ^bb0(%iv2: index, %iv3: index):
+      omp.yield
+    }) : (index, index, index, index, index, index) -> ()
+    omp.yield
+  }
+
+  return
+}
+
+// CHECK-LABEL: omp_loop_nest_pretty
+func.func @omp_loop_nest_pretty(%lb : index, %ub : index, %step : index) -> () {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+    omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}})
+    omp.loop_nest (%iv2) : index = (%lb) to (%ub) inclusive step (%step) {
+      omp.yield
+    }
+    omp.yield
+  }
+
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    // CHECK: omp.loop_nest
+    // CHECK-SAME: (%{{.*}}) : index =
+    // CHECK-SAME: (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
+    omp.loop_nest (%iv2, %iv3) : index = (%lb, %lb) to (%ub, %ub) step (%step, %step) {
+      omp.yield
+    }
+    omp.yield
+  }
+
+  return
+}
+
 // CHECK-LABEL: omp_wsloop
 func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memref<i32>, %linear_var : i32, %chunk_var : i32) -> () {
 

>From 2452bc75a7f2efb67a0522bbe8b0e7ba5bc3365b Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Mon, 1 Apr 2024 13:04:14 +0100
Subject: [PATCH 2/5] [MLIR][OpenMP] Introduce the LoopWrapperInterface

This patch defines a common interface to be shared by all OpenMP loop wrapper
operations. The main restrictions these operations must meet in order to be
considered a wrapper are:

- They contain a single region.
- Their region contains a single block.
- Their block only contains another loop wrapper or `omp.loop_nest` and a
terminator.

The new interface is attached to the `omp.parallel`, `omp.wsloop`,
`omp.simdloop`, `omp.distribute` and `omp.taskloop` operations. It is not
currently enforced that these operations meet the wrapper restrictions, which
would break existing OpenMP loop-generating code. Rather, this will be
introduced progressively in subsequent patches.
---
 .../mlir/Dialect/OpenMP/OpenMPInterfaces.h    |  3 +
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 16 +++--
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 68 +++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 19 ++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 16 ++++-
 5 files changed, 117 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
index b3184db8852161..787c48b05c5c5c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
@@ -21,6 +21,9 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#define GET_OP_FWD_DEFINES
+#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
+
 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.h.inc"
 
 namespace mlir::omp {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ffd00948915153..a7bf93deae2fb3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -236,6 +236,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove]> {
 
 def ParallelOp : OpenMP_Op<"parallel", [
                  AutomaticAllocationScope, AttrSizedOperandSegments,
+                 DeclareOpInterfaceMethods<LoopWrapperInterface>,
                  DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
                  RecursiveMemoryEffects, ReductionClauseInterface]> {
   let summary = "parallel construct";
@@ -517,8 +518,6 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
 
 def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
                         AllTypesMatch<["lowerBound", "upperBound", "step"]>,
-                        ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp",
-                                     "WsloopOp"]>,
                         RecursiveMemoryEffects]> {
   let summary = "rectangular loop nest";
   let description = [{
@@ -568,6 +567,10 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
 
     /// Returns the induction variables of the loop nest.
     ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
+
+    /// Returns the list of wrapper operations around this loop nest. Wrappers
+    /// in the resulting vector will be sorted from innermost to outermost.
+    SmallVector<LoopWrapperInterface> getWrappers();
   }];
 
   let hasCustomAssemblyFormat = 1;
@@ -580,6 +583,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
 
 def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
                          AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                         DeclareOpInterfaceMethods<LoopWrapperInterface>,
                          RecursiveMemoryEffects, ReductionClauseInterface]> {
   let summary = "worksharing-loop construct";
   let description = [{
@@ -700,7 +704,9 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
 //===----------------------------------------------------------------------===//
 
 def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
-                         AllTypesMatch<["lowerBound", "upperBound", "step"]>]> {
+                         AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                         DeclareOpInterfaceMethods<LoopWrapperInterface>,
+                         RecursiveMemoryEffects]> {
  let summary = "simd loop construct";
   let description = [{
     The simd construct can be applied to a loop to indicate that the loop can be
@@ -809,7 +815,8 @@ def YieldOp : OpenMP_Op<"yield",
 // Distribute construct [2.9.4.1]
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
-                              MemoryEffects<[MemWrite]>]> {
+                             DeclareOpInterfaceMethods<LoopWrapperInterface>,
+                             RecursiveMemoryEffects]> {
   let summary = "distribute construct";
   let description = [{
     The distribute construct specifies that the iterations of one or more loops
@@ -980,6 +987,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
 def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
                            AutomaticAllocationScope, RecursiveMemoryEffects,
                            AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                           DeclareOpInterfaceMethods<LoopWrapperInterface>,
                            ReductionClauseInterface]> {
   let summary = "taskloop construct";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 2e37384ce3eb71..b6a3560b7da56a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -69,6 +69,74 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
   ];
 }
 
+def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
+  let description = [{
+    OpenMP operations that can wrap a single loop nest. When taking a wrapper
+    role, these operations must only contain a single region with a single block
+    in which there's a single operation and a terminator. That nested operation
+    must be another loop wrapper or an `omp.loop_nest`.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/[{
+        Tell whether the operation could be taking the role of a loop wrapper.
+        That is, it has a single region with a single block in which there are
+        two operations: another wrapper or `omp.loop_nest` operation and a
+        terminator.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isWrapper",
+      (ins ), [{}], [{
+        if ($_op->getNumRegions() != 1)
+          return false;
+
+        ::mlir::Region &r = $_op->getRegion(0);
+        if (!r.hasOneBlock())
+          return false;
+
+        if (std::distance(r.op_begin(), r.op_end()) != 2)
+          return false;
+
+        ::mlir::Operation &firstOp = *r.op_begin();
+        ::mlir::Operation &secondOp = *(++r.op_begin());
+        return ::llvm::isa<::mlir::omp::LoopNestOp,
+                           ::mlir::omp::LoopWrapperInterface>(firstOp) &&
+               secondOp.hasTrait<::mlir::OpTrait::IsTerminator>();
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        If there is another loop wrapper immediately nested inside, return that
+        operation. Assumes this operation is taking a loop wrapper role.
+      }],
+      /*retTy=*/"::mlir::omp::LoopWrapperInterface",
+      /*methodName=*/"getNestedWrapper",
+      (ins), [{}], [{
+        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin();
+        return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested);
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        Return the loop nest nested directly or indirectly inside of this loop
+        wrapper. Assumes this operation is taking a loop wrapper role.
+      }],
+      /*retTy=*/"::mlir::Operation *",
+      /*methodName=*/"getWrappedLoop",
+      (ins), [{}], [{
+        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper())
+          return nested.getWrappedLoop();
+        return &*$_op->getRegion(0).op_begin();
+      }]
+    >
+  ];
+}
+
 def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
   let description = [{
     OpenMP operations that support declare target have this interface.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 796df1d13e6564..564c23201db4fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1730,9 +1730,28 @@ LogicalResult LoopNestOp::verify() {
              << "range argument type does not match corresponding IV type";
   }
 
+  auto wrapper =
+      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
+
+  if (!wrapper || !wrapper.isWrapper())
+    return emitOpError() << "expects parent op to be a valid loop wrapper";
+
   return success();
 }
 
+SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
+  SmallVector<LoopWrapperInterface> wrappers;
+  Operation *parent = (*this)->getParentOp();
+  while (auto wrapper =
+             llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
+    if (!wrapper.isWrapper())
+      break;
+    wrappers.push_back(wrapper);
+    parent = parent->getParentOp();
+  }
+  return wrappers;
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 760ebb14d94121..8f4103dabee5df 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -88,7 +88,7 @@ func.func @proc_bind_once() {
 // -----
 
 func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
-  // expected-error at +1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}}
+  // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
   omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
     omp.yield
   }
@@ -96,6 +96,20 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
 
 // -----
 
+func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    %0 = arith.constant 0 : i32
+    // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
+    omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+    omp.yield
+  }
+}
+
+// -----
+
 func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
   // TODO Remove induction variables from omp.wsloop.
   omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {

>From 4537071171506b17de3727800e3754e412c9a967 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Mon, 1 Apr 2024 14:08:33 +0100
Subject: [PATCH 3/5] [MLIR][OpenMP] Make omp.distribute into a loop wrapper

This patch updates the definition of `omp.distribute` to enforce the
restrictions of a wrapper operation.
---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 22 +++++++++--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 11 ++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 34 ++++++++++++++++-
 mlir/test/Dialect/OpenMP/ops.mlir             | 38 +++++++++++++++----
 4 files changed, 93 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a7bf93deae2fb3..8dbfe447616e11 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -816,7 +816,8 @@ def YieldOp : OpenMP_Op<"yield",
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
                              DeclareOpInterfaceMethods<LoopWrapperInterface>,
-                             RecursiveMemoryEffects]> {
+                             RecursiveMemoryEffects,
+                             SingleBlockImplicitTerminator<"TerminatorOp">]> {
   let summary = "distribute construct";
   let description = [{
     The distribute construct specifies that the iterations of one or more loops
@@ -831,15 +832,28 @@ def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
     The distribute loop construct specifies that the iterations of the loop(s)
     will be executed in parallel by threads in the current context. These
     iterations are spread across threads that already exist in the enclosing
-    region. The lower and upper bounds specify a half-open range: the
-    range includes the lower bound but does not include the upper bound. If the
-    `inclusive` attribute is specified then the upper bound is also included.
+    region.
+    
+    The body region can contain a single block which must contain a single
+    operation and a terminator. The operation must be another compatible loop
+    wrapper or an `omp.loop_nest`.
 
     The `dist_schedule_static` attribute specifies the  schedule for this
     loop, determining how the loop is distributed across the parallel threads.
     The optional `schedule_chunk` associated with this determines further
     controls this distribution.
 
+    ```mlir
+    omp.distribute <clauses> {
+      omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
+        %a = load %arrA[%i1, %i2] : memref<?x?xf32>
+        %b = load %arrB[%i1, %i2] : memref<?x?xf32>
+        %sum = arith.addf %a, %b : f32
+        store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
+        omp.yield
+      }
+    }
+    ```
     // TODO: private_var, firstprivate_var, lastprivate_var, collapse
   }];
   let arguments = (ins
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 564c23201db4fd..b407d27ef53e39 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1516,6 +1516,17 @@ LogicalResult DistributeOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  if (!isWrapper())
+    return emitOpError() << "must be a loop wrapper";
+
+  if (LoopWrapperInterface nested = getNestedWrapper()) {
+    // Check for the allowed leaf constructs that may appear in a composite
+    // construct directly after DISTRIBUTE.
+    if (!isa<ParallelOp, SimdLoopOp>(nested))
+      return emitError() << "only supported nested wrappers are 'omp.parallel' "
+                            "and 'omp.simdloop'";
+  }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8f4103dabee5df..35f5d24deb5d17 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1847,7 +1847,16 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
 
 // -----
 
-func.func @omp_distribute(%data_var : memref<i32>) -> () {
+func.func @omp_distribute_schedule(%chunk_size : i32) -> () {
+  // expected-error @below {{op chunk size set without dist_schedule_static being present}}
+  "omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 1, 0, 0>}> ({
+      "omp.terminator"() : () -> ()
+    }) : (i32) -> ()
+}
+
+// -----
+
+func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
   "omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({
       "omp.terminator"() : () -> ()
@@ -1856,6 +1865,29 @@ func.func @omp_distribute(%data_var : memref<i32>) -> () {
 
 // -----
 
+func.func @omp_distribute_wrapper() -> () {
+  // expected-error @below {{op must be a loop wrapper}}
+  "omp.distribute"() ({
+      %0 = arith.constant 0 : i32
+      "omp.terminator"() : () -> ()
+    }) : () -> ()
+}
+
+// -----
+
+func.func @omp_distribute_nested_wrapper(%data_var : memref<i32>) -> () {
+  // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simdloop'}}
+  "omp.distribute"() ({
+      "omp.wsloop"() ({
+        %0 = arith.constant 0 : i32
+        "omp.terminator"() : () -> ()
+      }) : () -> ()
+      "omp.terminator"() : () -> ()
+    }) : () -> ()
+}
+
+// -----
+
 omp.private {type = private} @x.privatizer : i32 alloc {
 ^bb0(%arg0: i32):
   %0 = arith.constant 0.0 : f32
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 8d9acab67e0358..a7b0832eff21f3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -559,30 +559,54 @@ func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : ind
 }
 
 // CHECK-LABEL: omp_distribute
-func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>) -> () {
+func.func @omp_distribute(%chunk_size : i32, %data_var : memref<i32>, %arg0 : i32) -> () {
   // CHECK: omp.distribute
   "omp.distribute" () ({
-    omp.terminator
+    "omp.loop_nest" (%arg0, %arg0, %arg0) ({
+    ^bb0(%iv: i32):
+      "omp.yield"() : () -> ()
+    }) : (i32, i32, i32) -> ()
+    "omp.terminator"() : () -> ()
   }) {} : () -> ()
   // CHECK: omp.distribute
   omp.distribute {
-    omp.terminator
+    omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+      omp.yield
+    }
   }
   // CHECK: omp.distribute dist_schedule_static
   omp.distribute dist_schedule_static {
-    omp.terminator
+    omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+      omp.yield
+    }
   }
   // CHECK: omp.distribute dist_schedule_static chunk_size(%{{.+}} : i32)
   omp.distribute dist_schedule_static chunk_size(%chunk_size : i32) {
-    omp.terminator
+    omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+      omp.yield
+    }
   }
   // CHECK: omp.distribute order(concurrent)
   omp.distribute order(concurrent) {
-    omp.terminator
+    omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+      omp.yield
+    }
   }
   // CHECK: omp.distribute allocate(%{{.+}} : memref<i32> -> %{{.+}} : memref<i32>)
   omp.distribute allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
-    omp.terminator
+    omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+      omp.yield
+    }
+  }
+  // CHECK: omp.distribute
+  omp.distribute {
+    // TODO Remove induction variables from omp.simdloop.
+    omp.simdloop for (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+      omp.loop_nest (%iv2) : i32 = (%arg0) to (%arg0) step (%arg0) {
+        omp.yield
+      }
+      omp.yield
+    }
   }
 return
 }

>From e3c440c9df93eedad1e3a96d58a759c869574c20 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 2 Apr 2024 14:15:25 +0100
Subject: [PATCH 4/5] Update op description according to review comments

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ffd00948915153..3d87585d52847c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -552,7 +552,8 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
     is intended to serve as a unique source for loop information during the
     transition to making `omp.distribute`, `omp.simdloop`, `omp.taskloop` and
     `omp.wsloop` wrapper operations. It is not intended to help with the
-    addition of support for loop transformations.
+    addition of support for loop transformations, non-rectangular loops and
+    non-perfectly nested loops.
   }];
 
   let arguments = (ins Variadic<IntLikeType>:$lowerBound,

>From 904f27489b0d3c27e773f085b18a9b85cb548f44 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 2 Apr 2024 15:36:41 +0100
Subject: [PATCH 5/5] Address review comments

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  4 ++--
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 19 +++++++++----------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  5 ++---
 3 files changed, 13 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a7bf93deae2fb3..50627712ea3109 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -568,9 +568,9 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
     /// Returns the induction variables of the loop nest.
     ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
 
-    /// Returns the list of wrapper operations around this loop nest. Wrappers
+    /// Fills a list of wrapper operations around this loop nest. Wrappers
     /// in the resulting vector will be sorted from innermost to outermost.
-    SmallVector<LoopWrapperInterface> getWrappers();
+    void gatherWrappers(SmallVectorImpl<LoopWrapperInterface> &wrappers);
   }];
 
   let hasCustomAssemblyFormat = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index b6a3560b7da56a..ab9b78e755d9d5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -93,18 +93,17 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
         if ($_op->getNumRegions() != 1)
           return false;
 
-        ::mlir::Region &r = $_op->getRegion(0);
+        Region &r = $_op->getRegion(0);
         if (!r.hasOneBlock())
           return false;
 
-        if (std::distance(r.op_begin(), r.op_end()) != 2)
+        if (::llvm::range_size(r.getOps()) != 2)
           return false;
 
-        ::mlir::Operation &firstOp = *r.op_begin();
-        ::mlir::Operation &secondOp = *(++r.op_begin());
-        return ::llvm::isa<::mlir::omp::LoopNestOp,
-                           ::mlir::omp::LoopWrapperInterface>(firstOp) &&
-               secondOp.hasTrait<::mlir::OpTrait::IsTerminator>();
+        Operation &firstOp = *r.op_begin();
+        Operation &secondOp = *(std::next(r.op_begin()));
+        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
+               secondOp.hasTrait<OpTrait::IsTerminator>();
       }]
     >,
     InterfaceMethod<
@@ -116,8 +115,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*methodName=*/"getNestedWrapper",
       (ins), [{}], [{
         assert($_op.isWrapper() && "Unexpected non-wrapper op");
-        ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin();
-        return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested);
+        Operation *nested = &*$_op->getRegion(0).op_begin();
+        return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
       }]
     >,
     InterfaceMethod<
@@ -129,7 +128,7 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*methodName=*/"getWrappedLoop",
       (ins), [{}], [{
         assert($_op.isWrapper() && "Unexpected non-wrapper op");
-        if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper())
+        if (LoopWrapperInterface nested = $_op.getNestedWrapper())
           return nested.getWrappedLoop();
         return &*$_op->getRegion(0).op_begin();
       }]
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 564c23201db4fd..a7d265328df6ef 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1739,8 +1739,8 @@ LogicalResult LoopNestOp::verify() {
   return success();
 }
 
-SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
-  SmallVector<LoopWrapperInterface> wrappers;
+void LoopNestOp::gatherWrappers(
+    SmallVectorImpl<LoopWrapperInterface> &wrappers) {
   Operation *parent = (*this)->getParentOp();
   while (auto wrapper =
              llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
@@ -1749,7 +1749,6 @@ SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
     wrappers.push_back(wrapper);
     parent = parent->getParentOp();
   }
-  return wrappers;
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list