[Mlir-commits] [flang] [mlir] [MLIR][Flang][OpenMP] Remove omp.parallel from loop wrapper ops (PR #105833)

Sergio Afonso llvmlistbot at llvm.org
Tue Aug 27 05:33:03 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/105833

>From c2d8afbe73d658cac006b19f502c23caceccf751 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 23 Aug 2024 14:39:44 +0100
Subject: [PATCH 1/2] [MLIR][Flang][OpenMP] Remove omp.parallel from loop
 wrapper ops

This patch updates the `omp.parallel` operation according to the results of
the discussion in [this RFC](https://discourse.llvm.org/t/rfc-disambiguation-between-loop-and-block-associated-omp-parallelop/79972).
It is removed from the set of loop wrapper operations, changing the expected
MLIR representation for composite `distribute parallel do/for` into the
following:

```mlir
omp.parallel {
  ...
  omp.distribute {
    omp.wsloop {
      omp.loop_nest ... { ... }
      omp.terminator
    }
    omp.terminator
  }
  ...
  omp.terminator
}
```

MLIR verifiers for operations impacted by this representation change are
updated, as well as related tests. The `LoopWrapperInterface` is also updated,
since it's no longer representing an optional "role" of an operation but a
mandatory set of restrictions instead.
---
 .../lib/Lower/OpenMP/DataSharingProcessor.cpp |   4 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   1 -
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     |  30 ++--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  75 ++++-----
 mlir/test/Dialect/OpenMP/invalid.mlir         | 152 ++++++++----------
 mlir/test/Dialect/OpenMP/ops.mlir             |  20 +--
 6 files changed, 126 insertions(+), 156 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e1a193edc004a7..f3ed6eb9a08370 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -231,9 +231,7 @@ void DataSharingProcessor::insertBarrier() {
 void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
   mlir::omp::LoopNestOp loopOp;
   if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
-    loopOp = wrapper.isWrapper()
-                 ? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
-                 : nullptr;
+    loopOp = mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop());
 
   bool cmpCreated = false;
   mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a7dae0b5f3074..1aa4e771cd4dea 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -129,7 +129,6 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>
 def ParallelOp : OpenMP_Op<"parallel", traits = [
     AttrSizedOperandSegments, AutomaticAllocationScope,
     DeclareOpInterfaceMethods<ComposableOpInterface>,
-    DeclareOpInterfaceMethods<LoopWrapperInterface>,
     DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
     RecursiveMemoryEffects
   ], clauses = [
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 78637f0ab8c2da..1dd787ebe7de4f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -71,10 +71,10 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
 
 def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
   let description = [{
-    OpenMP operations that can wrap a single loop nest. When taking a wrapper
-    role, these operations must only contain a single region with a single block
-    in which there's a single operation and a terminator. That nested operation
-    must be another loop wrapper or an `omp.loop_nest`.
+    OpenMP operations that wrap a single loop nest. They must only contain a
+    single region with a single block in which there's a single operation and a
+    terminator. That nested operation must be another loop wrapper or an
+    `omp.loop_nest`.
   }];
 
   let cppNamespace = "::mlir::omp";
@@ -82,13 +82,12 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
   let methods = [
     InterfaceMethod<
       /*description=*/[{
-        Tell whether the operation could be taking the role of a loop wrapper.
-        That is, it has a single region with a single block in which there are
-        two operations: another wrapper (also taking a loop wrapper role) or
-        `omp.loop_nest` operation and a terminator.
+        Check whether the operation is a valid loop wrapper. That is, it has a
+        single region with a single block in which there are two operations:
+        another loop wrapper or `omp.loop_nest` operation and a terminator.
       }],
       /*retTy=*/"bool",
-      /*methodName=*/"isWrapper",
+      /*methodName=*/"isValidWrapper",
       (ins ), [{}], [{
         if ($_op->getNumRegions() != 1)
           return false;
@@ -106,21 +105,18 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
         if (!secondOp.hasTrait<OpTrait::IsTerminator>())
           return false;
 
-        if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
-          return wrapper.isWrapper();
-
-        return ::llvm::isa<LoopNestOp>(firstOp);
+        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp);
       }]
     >,
     InterfaceMethod<
       /*description=*/[{
         If there is another loop wrapper immediately nested inside, return that
-        operation. Assumes this operation is taking a loop wrapper role.
+        operation. Assumes this operation is a valid loop wrapper.
       }],
       /*retTy=*/"::mlir::omp::LoopWrapperInterface",
       /*methodName=*/"getNestedWrapper",
       (ins), [{}], [{
-        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
         Operation *nested = &*$_op->getRegion(0).op_begin();
         return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
       }]
@@ -128,12 +124,12 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
     InterfaceMethod<
       /*description=*/[{
         Return the loop nest nested directly or indirectly inside of this loop
-        wrapper. Assumes this operation is taking a loop wrapper role.
+        wrapper. Assumes this operation is a valid loop wrapper.
       }],
       /*retTy=*/"::mlir::Operation *",
       /*methodName=*/"getWrappedLoop",
       (ins), [{}], [{
-        assert($_op.isWrapper() && "Unexpected non-wrapper op");
+        assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
         if (LoopWrapperInterface nested = $_op.getNestedWrapper())
           return nested.getWrappedLoop();
         return &*$_op->getRegion(0).op_begin();
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index eb4f9cb041841b..6db4796eb37e0d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1541,26 +1541,25 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
 }
 
 LogicalResult ParallelOp::verify() {
-  // Check that it is a valid loop wrapper if it's taking that role.
-  if (isa<DistributeOp>((*this)->getParentOp())) {
-    if (!isWrapper())
-      return emitOpError() << "must take a loop wrapper role if nested inside "
-                              "of 'omp.distribute'";
+  auto distributeChildOps = getOps<DistributeOp>();
+  if (!distributeChildOps.empty()) {
     if (!isComposite())
       return emitError()
-             << "'omp.composite' attribute missing from composite wrapper";
+             << "'omp.composite' attribute missing from composite operation";
 
-    if (LoopWrapperInterface nested = getNestedWrapper()) {
-      // Check for the allowed leaf constructs that may appear in a composite
-      // construct directly after PARALLEL.
-      if (!isa<WsloopOp>(nested))
-        return emitError() << "only supported nested wrapper is 'omp.wsloop'";
-    } else {
-      return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
+    auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
+    Operation &distributeOp = **distributeChildOps.begin();
+    for (Operation &childOp : getOps()) {
+      if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
+        continue;
+
+      if (!childOp.hasTrait<OpTrait::IsTerminator>())
+        return emitError() << "unexpected OpenMP operation inside of composite "
+                              "'omp.parallel'";
     }
   } else if (isComposite()) {
     return emitError()
-           << "'omp.composite' attribute present in non-composite wrapper";
+           << "'omp.composite' attribute present in non-composite operation";
   }
 
   if (getAllocateVars().size() != getAllocatorVars().size())
@@ -1751,15 +1750,12 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult WsloopOp::verify() {
-  if (!isWrapper())
-    return emitOpError() << "must be a loop wrapper";
+  if (!isValidWrapper())
+    return emitOpError() << "must be a valid loop wrapper";
 
-  auto wrapper =
-      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
   bool isCompositeChildLeaf =
-      wrapper && wrapper.isWrapper() &&
-      (!llvm::isa<ParallelOp>(wrapper) ||
-       llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));
+      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
+
   if (LoopWrapperInterface nested = getNestedWrapper()) {
     if (!isComposite())
       return emitError()
@@ -1813,18 +1809,14 @@ LogicalResult SimdOp::verify() {
   if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
     return failure();
 
-  if (!isWrapper())
-    return emitOpError() << "must be a loop wrapper";
+  if (!isValidWrapper())
+    return emitOpError() << "must be a valid loop wrapper";
 
   if (getNestedWrapper())
     return emitOpError() << "must wrap an 'omp.loop_nest' directly";
 
-  auto wrapper =
-      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
   bool isCompositeChildLeaf =
-      wrapper && wrapper.isWrapper() &&
-      (!llvm::isa<ParallelOp>(wrapper) ||
-       llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));
+      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
 
   if (!isComposite() && isCompositeChildLeaf)
     return emitError()
@@ -1859,8 +1851,8 @@ LogicalResult DistributeOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
-  if (!isWrapper())
-    return emitOpError() << "must be a loop wrapper";
+  if (!isValidWrapper())
+    return emitOpError() << "must be a valid loop wrapper";
 
   if (LoopWrapperInterface nested = getNestedWrapper()) {
     if (!isComposite())
@@ -1868,9 +1860,13 @@ LogicalResult DistributeOp::verify() {
              << "'omp.composite' attribute missing from composite wrapper";
     // Check for the allowed leaf constructs that may appear in a composite
     // construct directly after DISTRIBUTE.
-    if (!isa<ParallelOp, SimdOp>(nested))
-      return emitError() << "only supported nested wrappers are 'omp.parallel' "
-                            "and 'omp.simd'";
+    if (isa<WsloopOp>(nested)) {
+      if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
+        return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
+                              "when 'omp.parallel' is the direct parent";
+    } else if (!isa<SimdOp>(nested))
+      return emitError() << "only supported nested wrappers are 'omp.simd' and "
+                            "'omp.wsloop'";
   } else if (isComposite()) {
     return emitError()
            << "'omp.composite' attribute present in non-composite wrapper";
@@ -2063,8 +2059,8 @@ LogicalResult TaskloopOp::verify() {
         "may not appear on the same taskloop directive");
   }
 
-  if (!isWrapper())
-    return emitOpError() << "must be a loop wrapper";
+  if (!isValidWrapper())
+    return emitOpError() << "must be a valid loop wrapper";
 
   if (LoopWrapperInterface nested = getNestedWrapper()) {
     if (!isComposite())
@@ -2161,11 +2157,8 @@ LogicalResult LoopNestOp::verify() {
              << "range argument type does not match corresponding IV type";
   }
 
-  auto wrapper =
-      llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
-
-  if (!wrapper || !wrapper.isWrapper())
-    return emitOpError() << "expects parent op to be a valid loop wrapper";
+  if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
+    return emitOpError() << "expects parent op to be a loop wrapper";
 
   return success();
 }
@@ -2175,8 +2168,6 @@ void LoopNestOp::gatherWrappers(
   Operation *parent = (*this)->getParentOp();
   while (auto wrapper =
              llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
-    if (!wrapper.isWrapper())
-      break;
     wrappers.push_back(wrapper);
     parent = parent->getParentOp();
   }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 332d22fc2c6425..8d29d1622e3b62 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -10,58 +10,6 @@ func.func @unknown_clause() {
 
 // -----
 
-func.func @not_wrapper() {
-  // expected-error at +1 {{op must be a loop wrapper}}
-  omp.distribute {
-    omp.parallel {
-      %0 = arith.constant 0 : i32
-      omp.terminator
-    }
-    omp.terminator
-  }
-
-  return
-}
-
-// -----
-
-func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) {
-  omp.distribute {
-    // expected-error at +1 {{only supported nested wrapper is 'omp.wsloop'}}
-    omp.parallel {
-      omp.simd {
-        omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
-          omp.yield
-        }
-        omp.terminator
-      } {omp.composite}
-      omp.terminator
-    } {omp.composite}
-    omp.terminator
-  } {omp.composite}
-
-  return
-}
-
-// -----
-
-func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) {
-  omp.distribute {
-    // expected-error at +1 {{op must not wrap an 'omp.loop_nest' directly}}
-    omp.parallel {
-      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
-        omp.yield
-      }
-      omp.terminator
-    } {omp.composite}
-    omp.terminator
-  } {omp.composite}
-
-  return
-}
-
-// -----
-
 func.func @if_once(%n : i1) {
   // expected-error at +1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
   omp.parallel if(%n) if(%n) {
@@ -140,7 +88,7 @@ func.func @proc_bind_once() {
 // -----
 
 func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
-  // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
+  // expected-error at +1 {{op expects parent op to be a loop wrapper}}
   omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
     omp.yield
   }
@@ -148,19 +96,6 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
 
 // -----
 
-func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) {
-  omp.parallel {
-    %0 = arith.constant 0 : i32
-    // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
-    omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
-      omp.yield
-    }
-    omp.terminator
-  }
-}
-
-// -----
-
 func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
   omp.wsloop {
     // expected-error at +1 {{range argument type does not match corresponding IV type}}
@@ -188,7 +123,7 @@ func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
 // -----
 
 func.func @no_wrapper(%lb : index, %ub : index, %step : index) {
-  // expected-error @below {{op must be a loop wrapper}}
+  // expected-error @below {{op must be a valid loop wrapper}}
   omp.wsloop {
     %0 = arith.constant 0 : i32
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -374,7 +309,7 @@ llvm.func @test_omp_wsloop_dynamic_wrong_modifier3(%lb : i64, %ub : i64, %step :
 // -----
 
 func.func @omp_simd() -> () {
-  // expected-error @below {{op must be a loop wrapper}}
+  // expected-error @below {{op must be a valid loop wrapper}}
   omp.simd {
     omp.terminator
   }
@@ -2028,7 +1963,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
 // -----
 
 func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
-  // expected-error @below {{op must be a loop wrapper}}
+  // expected-error @below {{op must be a valid loop wrapper}}
   omp.taskloop {
     %0 = arith.constant 0 : i32
     omp.terminator
@@ -2237,7 +2172,7 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
 // -----
 
 func.func @omp_distribute_wrapper() -> () {
-  // expected-error @below {{op must be a loop wrapper}}
+  // expected-error @below {{op must be a valid loop wrapper}}
   omp.distribute {
       %0 = arith.constant 0 : i32
       "omp.terminator"() : () -> ()
@@ -2247,7 +2182,7 @@ func.func @omp_distribute_wrapper() -> () {
 // -----
 
 func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
-  // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
+  // expected-error @below {{an 'omp.wsloop' nested wrapper is only allowed when 'omp.parallel' is the direct parent}}
   omp.distribute {
     "omp.wsloop"() ({
       omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -2261,6 +2196,36 @@ func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -
 
 // -----
 
+func.func @omp_distribute_nested_wrapper2(%lb: index, %ub: index, %step: index) -> () {
+  // expected-error @below {{only supported nested wrappers are 'omp.simd' and 'omp.wsloop'}}
+  omp.distribute {
+    "omp.taskloop"() ({
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        "omp.yield"() : () -> ()
+      }
+      "omp.terminator"() : () -> ()
+    }) {omp.composite} : () -> ()
+    "omp.terminator"() : () -> ()
+  } {omp.composite}
+}
+
+// -----
+
+func.func @omp_distribute_nested_wrapper3(%lb: index, %ub: index, %step: index) -> () {
+  // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
+  omp.distribute {
+    "omp.simd"() ({
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        "omp.yield"() : () -> ()
+      }
+      "omp.terminator"() : () -> ()
+    }) {omp.composite} : () -> ()
+    "omp.terminator"() : () -> ()
+  }
+}
+
+// -----
+
 func.func @omp_distribute_order() -> () {
 // expected-error @below {{invalid clause value: 'default'}}
   omp.distribute order(default) {
@@ -2469,9 +2434,9 @@ func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
 
 // -----
 func.func @omp_parallel_missing_composite(%lb: index, %ub: index, %step: index) -> () {
-  omp.distribute {
-    // expected-error at +1 {{'omp.composite' attribute missing from composite wrapper}}
-    omp.parallel {
+  // expected-error at +1 {{'omp.composite' attribute missing from composite operation}}
+  omp.parallel {
+    omp.distribute {
       omp.wsloop {
         omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
           omp.yield
@@ -2479,15 +2444,15 @@ func.func @omp_parallel_missing_composite(%lb: index, %ub: index, %step: index)
         omp.terminator
       } {omp.composite}
       omp.terminator
-    }
+    } {omp.composite}
     omp.terminator
-  } {omp.composite}
+  }
   return
 }
 
 // -----
 func.func @omp_parallel_invalid_composite(%lb: index, %ub: index, %step: index) -> () {
-  // expected-error @below {{'omp.composite' attribute present in non-composite wrapper}}
+  // expected-error @below {{'omp.composite' attribute present in non-composite operation}}
   omp.parallel {
     omp.wsloop {
       omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -2500,6 +2465,25 @@ func.func @omp_parallel_invalid_composite(%lb: index, %ub: index, %step: index)
   return
 }
 
+// -----
+func.func @omp_parallel_invalid_composite2(%lb: index, %ub: index, %step: index) -> () {
+  // expected-error @below {{unexpected OpenMP operation inside of composite 'omp.parallel'}}
+  omp.parallel {
+    omp.barrier
+    omp.distribute {
+      omp.wsloop {
+        omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+          omp.yield
+        }
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  } {omp.composite}
+  return
+}
+
 // -----
 func.func @omp_wsloop_missing_composite(%lb: index, %ub: index, %step: index) -> () {
   // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
@@ -2529,8 +2513,8 @@ func.func @omp_wsloop_invalid_composite(%lb: index, %ub: index, %step: index) ->
 
 // -----
 func.func @omp_wsloop_missing_composite_2(%lb: index, %ub: index, %step: index) -> () {
-  omp.distribute {
-    omp.parallel {
+  omp.parallel {
+    omp.distribute {
       // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
       omp.wsloop {
         omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -2574,9 +2558,9 @@ func.func @omp_simd_invalid_composite(%lb: index, %ub: index, %step: index) -> (
 
 // -----
 func.func @omp_distribute_missing_composite(%lb: index, %ub: index, %step: index) -> () {
-  // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
-  omp.distribute {
-    omp.parallel {
+  omp.parallel {
+    // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
+    omp.distribute {
       omp.wsloop {
         omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
           omp.yield
@@ -2584,9 +2568,9 @@ func.func @omp_distribute_missing_composite(%lb: index, %ub: index, %step: index
         omp.terminator
       } {omp.composite}
       omp.terminator
-    } {omp.composite}
+    }
     omp.terminator
-  }
+  } {omp.composite}
   return
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 9c308cc0108493..cf79f2207129e2 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -99,10 +99,11 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
     omp.terminator
   }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
 
-  // CHECK: omp.distribute
-  omp.distribute {
-    // CHECK-NEXT: omp.parallel
-    omp.parallel {
+  // CHECK: omp.parallel
+  omp.parallel {
+    // CHECK-NOT: omp.terminator
+    // CHECK: omp.distribute
+    omp.distribute {
       // CHECK-NEXT: omp.wsloop
       omp.wsloop {
         // CHECK-NEXT: omp.loop_nest
@@ -116,14 +117,15 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
     omp.terminator
   } {omp.composite}
 
-  // CHECK: omp.distribute
-  omp.distribute {
-    // CHECK-NEXT: omp.parallel
-    omp.parallel {
+  // CHECK: omp.parallel
+  omp.parallel {
+    // CHECK-NOT: omp.terminator
+    // CHECK: omp.distribute
+    omp.distribute {
       // CHECK-NEXT: omp.wsloop
       omp.wsloop {
         // CHECK-NEXT: omp.simd
-        omp.simd{
+        omp.simd {
           // CHECK-NEXT: omp.loop_nest
           omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
             omp.yield

>From 678a1a387721d1ad23a9383eaea7a6e43c59d917 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 27 Aug 2024 13:32:49 +0100
Subject: [PATCH 2/2] Address review comments

---
 .../Dialect/OpenMP/OpenMPOpsInterfaces.td     | 30 -------------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 42 +++++++++++++++----
 mlir/test/Dialect/OpenMP/invalid.mlir         | 14 ++++---
 3 files changed, 42 insertions(+), 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 1dd787ebe7de4f..0078e22b1c89a6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -80,34 +80,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
   let cppNamespace = "::mlir::omp";
 
   let methods = [
-    InterfaceMethod<
-      /*description=*/[{
-        Check whether the operation is a valid loop wrapper. That is, it has a
-        single region with a single block in which there are two operations:
-        another loop wrapper or `omp.loop_nest` operation and a terminator.
-      }],
-      /*retTy=*/"bool",
-      /*methodName=*/"isValidWrapper",
-      (ins ), [{}], [{
-        if ($_op->getNumRegions() != 1)
-          return false;
-
-        Region &r = $_op->getRegion(0);
-        if (!r.hasOneBlock())
-          return false;
-
-        if (::llvm::range_size(r.getOps()) != 2)
-          return false;
-
-        Operation &firstOp = *r.op_begin();
-        Operation &secondOp = *(std::next(r.op_begin()));
-
-        if (!secondOp.hasTrait<OpTrait::IsTerminator>())
-          return false;
-
-        return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp);
-      }]
-    >,
     InterfaceMethod<
       /*description=*/[{
         If there is another loop wrapper immediately nested inside, return that
@@ -116,7 +88,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*retTy=*/"::mlir::omp::LoopWrapperInterface",
       /*methodName=*/"getNestedWrapper",
       (ins), [{}], [{
-        assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
         Operation *nested = &*$_op->getRegion(0).op_begin();
         return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
       }]
@@ -129,7 +100,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
       /*retTy=*/"::mlir::Operation *",
       /*methodName=*/"getWrappedLoop",
       (ins), [{}], [{
-        assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
         if (LoopWrapperInterface nested = $_op.getNestedWrapper())
           return nested.getWrappedLoop();
         return &*$_op->getRegion(0).op_begin();
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6db4796eb37e0d..1a9b87f0d68c9d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1720,6 +1720,32 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
+static LogicalResult verifyLoopWrapperInterface(Operation *op) {
+  if (op->getNumRegions() != 1)
+    return op->emitOpError() << "loop wrapper contains multiple regions";
+
+  Region &region = op->getRegion(0);
+  if (!region.hasOneBlock())
+    return op->emitOpError() << "loop wrapper contains multiple blocks";
+
+  if (::llvm::range_size(region.getOps()) != 2)
+    return op->emitOpError()
+           << "loop wrapper does not contain exactly two nested ops";
+
+  Operation &firstOp = *region.op_begin();
+  Operation &secondOp = *(std::next(region.op_begin()));
+
+  if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+    return op->emitOpError()
+           << "second nested op in loop wrapper is not a terminator";
+
+  if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
+    return op->emitOpError() << "first nested op in loop wrapper is not "
+                                "another loop wrapper or `omp.loop_nest`";
+
+  return success();
+}
+
 void WsloopOp::build(OpBuilder &builder, OperationState &state,
                      ArrayRef<NamedAttribute> attributes) {
   build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
@@ -1750,8 +1776,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult WsloopOp::verify() {
-  if (!isValidWrapper())
-    return emitOpError() << "must be a valid loop wrapper";
+  if (verifyLoopWrapperInterface(*this).failed())
+    return failure();
 
   bool isCompositeChildLeaf =
       llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
@@ -1809,8 +1835,8 @@ LogicalResult SimdOp::verify() {
   if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
     return failure();
 
-  if (!isValidWrapper())
-    return emitOpError() << "must be a valid loop wrapper";
+  if (verifyLoopWrapperInterface(*this).failed())
+    return failure();
 
   if (getNestedWrapper())
     return emitOpError() << "must wrap an 'omp.loop_nest' directly";
@@ -1851,8 +1877,8 @@ LogicalResult DistributeOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
-  if (!isValidWrapper())
-    return emitOpError() << "must be a valid loop wrapper";
+  if (verifyLoopWrapperInterface(*this).failed())
+    return failure();
 
   if (LoopWrapperInterface nested = getNestedWrapper()) {
     if (!isComposite())
@@ -2059,8 +2085,8 @@ LogicalResult TaskloopOp::verify() {
         "may not appear on the same taskloop directive");
   }
 
-  if (!isValidWrapper())
-    return emitOpError() << "must be a valid loop wrapper";
+  if (verifyLoopWrapperInterface(*this).failed())
+    return failure();
 
   if (LoopWrapperInterface nested = getNestedWrapper()) {
     if (!isComposite())
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8d29d1622e3b62..d8745f1015af83 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -123,7 +123,7 @@ func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
 // -----
 
 func.func @no_wrapper(%lb : index, %ub : index, %step : index) {
-  // expected-error @below {{op must be a valid loop wrapper}}
+  // expected-error @below {{op loop wrapper does not contain exactly two nested ops}}
   omp.wsloop {
     %0 = arith.constant 0 : i32
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -309,7 +309,7 @@ llvm.func @test_omp_wsloop_dynamic_wrong_modifier3(%lb : i64, %ub : i64, %step :
 // -----
 
 func.func @omp_simd() -> () {
-  // expected-error @below {{op must be a valid loop wrapper}}
+  // expected-error @below {{op loop wrapper does not contain exactly two nested ops}}
   omp.simd {
     omp.terminator
   }
@@ -1963,7 +1963,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
 // -----
 
 func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
-  // expected-error @below {{op must be a valid loop wrapper}}
+  // expected-error @below {{op first nested op in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
   omp.taskloop {
     %0 = arith.constant 0 : i32
     omp.terminator
@@ -2171,11 +2171,13 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
 
 // -----
 
-func.func @omp_distribute_wrapper() -> () {
-  // expected-error @below {{op must be a valid loop wrapper}}
+func.func @omp_distribute_wrapper(%lb: index, %ub: index, %step: index) -> () {
+  // expected-error @below {{op second nested op in loop wrapper is not a terminator}}
   omp.distribute {
+      omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+        "omp.yield"() : () -> ()
+      }
       %0 = arith.constant 0 : i32
-      "omp.terminator"() : () -> ()
   }
 }
 



More information about the Mlir-commits mailing list