[llvm-branch-commits] [mlir] [MLIR][OpenMP] Introduce the LoopWrapperInterface (PR #87232)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Apr 1 05:18:03 PDT 2024


https://github.com/skatrak created https://github.com/llvm/llvm-project/pull/87232

This patch defines a common interface to be shared by all OpenMP loop wrapper operations. The main restrictions these operations must meet in order to be considered a wrapper are:

- They contain a single region.
- Their region contains a single block.
- Their block only contains another loop wrapper or `omp.loop_nest` and a terminator.

The new interface is attached to the `omp.parallel`, `omp.wsloop`, `omp.simdloop`, `omp.distribute` and `omp.taskloop` operations. It is not currently enforced that these operations meet the wrapper restrictions, which would break existing OpenMP loop-generating code. Rather, this will be introduced progressively in subsequent patches.

>From 2452bc75a7f2efb67a0522bbe8b0e7ba5bc3365b Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Mon, 1 Apr 2024 13:04:14 +0100
Subject: [PATCH] [MLIR][OpenMP] Introduce the LoopWrapperInterface

This patch defines a common interface to be shared by all OpenMP loop wrapper
operations. The main restrictions these operations must meet in order to be
considered a wrapper are:

- They contain a single region.
- Their region contains a single block.
- Their block only contains another loop wrapper or `omp.loop_nest` and a
terminator.

The new interface is attached to the `omp.parallel`, `omp.wsloop`,
`omp.simdloop`, `omp.distribute` and `omp.taskloop` operations. It is not
currently enforced that these operations meet the wrapper restrictions, which
would break existing OpenMP loop-generating code. Rather, this will be
introduced progressively in subsequent patches.
---
 .../mlir/Dialect/OpenMP/OpenMPInterfaces.h    |  3 +
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 16 +++--
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 68 +++++++++++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 19 ++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 16 ++++-
 5 files changed, 117 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
index b3184db8852161..787c48b05c5c5c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
@@ -21,6 +21,9 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#define GET_OP_FWD_DEFINES
+#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
+
 #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.h.inc"
 
 namespace mlir::omp {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ffd00948915153..a7bf93deae2fb3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -236,6 +236,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove]> {
 
 def ParallelOp : OpenMP_Op<"parallel", [
                  AutomaticAllocationScope, AttrSizedOperandSegments,
+                 DeclareOpInterfaceMethods<LoopWrapperInterface>,
                  DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
                  RecursiveMemoryEffects, ReductionClauseInterface]> {
   let summary = "parallel construct";
@@ -517,8 +518,6 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
 
 def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
                         AllTypesMatch<["lowerBound", "upperBound", "step"]>,
-                        ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp",
-                                     "WsloopOp"]>,
                         RecursiveMemoryEffects]> {
   let summary = "rectangular loop nest";
   let description = [{
@@ -568,6 +567,10 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
 
     /// Returns the induction variables of the loop nest.
     ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
+
+    /// Returns the list of wrapper operations around this loop nest. Wrappers
+    /// in the resulting vector will be sorted from innermost to outermost.
+    SmallVector<LoopWrapperInterface> getWrappers();
   }];
 
   let hasCustomAssemblyFormat = 1;
@@ -580,6 +583,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
 
 def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
                          AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                         DeclareOpInterfaceMethods<LoopWrapperInterface>,
                          RecursiveMemoryEffects, ReductionClauseInterface]> {
   let summary = "worksharing-loop construct";
   let description = [{
@@ -700,7 +704,9 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
 //===----------------------------------------------------------------------===//
 
 def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
-                         AllTypesMatch<["lowerBound", "upperBound", "step"]>]> {
+                         AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                         DeclareOpInterfaceMethods<LoopWrapperInterface>,
+                         RecursiveMemoryEffects]> {
  let summary = "simd loop construct";
   let description = [{
     The simd construct can be applied to a loop to indicate that the loop can be
@@ -809,7 +815,8 @@ def YieldOp : OpenMP_Op<"yield",
 // Distribute construct [2.9.4.1]
 //===----------------------------------------------------------------------===//
 def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
-                              MemoryEffects<[MemWrite]>]> {
+                             DeclareOpInterfaceMethods<LoopWrapperInterface>,
+                             RecursiveMemoryEffects]> {
   let summary = "distribute construct";
   let description = [{
     The distribute construct specifies that the iterations of one or more loops
@@ -980,6 +987,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
 def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
                            AutomaticAllocationScope, RecursiveMemoryEffects,
                            AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+                           DeclareOpInterfaceMethods<LoopWrapperInterface>,
                            ReductionClauseInterface]> {
   let summary = "taskloop construct";
   let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 2e37384ce3eb71..b6a3560b7da56a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -69,6 +69,74 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
   ];
 }
 
+def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
+  let description = [{
+    OpenMP operations that can wrap a single loop nest. When taking a wrapper
+    role, these operations must only contain a single region with a single block
+    in which there's a single operation and a terminator. That nested operation
+    must be another loop wrapper or an `omp.loop_nest`.
+  }];
+
+  let cppNamespace = "::mlir::omp";
+
+  let methods = [
+    InterfaceMethod<
+      /*description=*/[{
+        Tell whether the operation could be taking the role of a loop wrapper.
+        That is, it has a single region with a single block in which there are
+        two operations: another wrapper or `omp.loop_nest` operation and a
+        terminator.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"isWrapper",
+      (ins ), [{}], [{
+        if ($_op->getNumRegions() != 1)
+          return false;
+
+        ::mlir::Region &r = $_op->getRegion(0);
+        if (!r.hasOneBlock())
+          return false;
+
+        if (std::distance(r.op_begin(), r.op_end()) != 2)
+          return false;
+
+        ::mlir::Operation &firstOp = *r.op_begin();
+        ::mlir::Operation &secondOp = *(++r.op_begin());
+        return ::llvm::isa<::mlir::omp::LoopNestOp,
+                           ::mlir::omp::LoopWrapperInterface>(firstOp) &&
+               secondOp.hasTrait<::mlir::OpTrait::IsTerminator>();
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        If there is another loop wrapper immediately nested inside, return that
+        operation. Assumes this operation is taking a loop wrapper role.
+      }],
+      /*retTy=*/"::mlir::omp::LoopWrapperInterface",
+      /*methodName=*/"getNestedWrapper",
+      (ins), [{}], [{
+        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin();
+        return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested);
+      }]
+    >,
+    InterfaceMethod<
+      /*description=*/[{
+        Return the loop nest nested directly or indirectly inside of this loop
+        wrapper. Assumes this operation is taking a loop wrapper role.
+      }],
+      /*retTy=*/"::mlir::Operation *",
+      /*methodName=*/"getWrappedLoop",
+      (ins), [{}], [{
+        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper())
+          return nested.getWrappedLoop();
+        return &*$_op->getRegion(0).op_begin();
+      }]
+    >
+  ];
+}
+
 def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
   let description = [{
     OpenMP operations that support declare target have this interface.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 796df1d13e6564..564c23201db4fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1730,9 +1730,28 @@ LogicalResult LoopNestOp::verify() {
              << "range argument type does not match corresponding IV type";
   }
 
+  auto wrapper =
+      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
+
+  if (!wrapper || !wrapper.isWrapper())
+    return emitOpError() << "expects parent op to be a valid loop wrapper";
+
   return success();
 }
 
+SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
+  SmallVector<LoopWrapperInterface> wrappers;
+  Operation *parent = (*this)->getParentOp();
+  while (auto wrapper =
+             llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
+    if (!wrapper.isWrapper())
+      break;
+    wrappers.push_back(wrapper);
+    parent = parent->getParentOp();
+  }
+  return wrappers;
+}
+
 //===----------------------------------------------------------------------===//
 // WsloopOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 760ebb14d94121..8f4103dabee5df 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -88,7 +88,7 @@ func.func @proc_bind_once() {
 // -----
 
 func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
-  // expected-error at +1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}}
+  // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
   omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
     omp.yield
   }
@@ -96,6 +96,20 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
 
 // -----
 
+func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) {
+  // TODO Remove induction variables from omp.wsloop.
+  omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    %0 = arith.constant 0 : i32
+    // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
+    omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+    omp.yield
+  }
+}
+
+// -----
+
 func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
   // TODO Remove induction variables from omp.wsloop.
   omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {



More information about the llvm-branch-commits mailing list