[llvm-branch-commits] [mlir] [MLIR][OpenMP] Update op verifiers dependent on omp.wsloop (2/5) (PR #89211)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Apr 24 04:14:21 PDT 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/89211

>From f9b14e37a6f437768c405291c064f541f0655b1c Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 17 Apr 2024 16:40:03 +0100
Subject: [PATCH 1/2] [MLIR][OpenMP] Update op verifiers dependent on
 omp.wsloop (2/5)

This patch updates verifiers for `omp.ordered.region`, `omp.cancel` and
`omp.cancellation_point`, which check for a parent `omp.wsloop`.

After transitioning to a loop wrapper-based approach, the expected direct
parent will become `omp.loop_nest` instead, so verifiers need to take this into
account.

This PR on its own will not pass premerge tests. All patches in the stack are
needed before it can be compiled and passes tests.
---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 23 ++++++++++++--------
 1 file changed, 14 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d66186effc31d6..d014c27e1ee157 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1977,9 +1977,10 @@ LogicalResult OrderedRegionOp::verify() {
   if (getSimd())
     return failure();
 
-  if (auto container = (*this)->getParentOfType<WsloopOp>()) {
-    if (!container.getOrderedValAttr() ||
-        container.getOrderedValAttr().getInt() != 0)
+  if (auto loopOp = dyn_cast<LoopNestOp>((*this)->getParentOp())) {
+    auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(loopOp->getParentOp());
+    if (!wsloopOp || !wsloopOp.getOrderedValAttr() ||
+        wsloopOp.getOrderedValAttr().getInt() != 0)
       return emitOpError() << "ordered region must be closely nested inside "
                            << "a worksharing-loop region with an ordered "
                            << "clause without parameter present";
@@ -2130,15 +2131,19 @@ LogicalResult CancelOp::verify() {
                          << "inside a parallel region";
   }
   if (cct == ClauseCancellationConstructType::Loop) {
-    if (!isa<WsloopOp>(parentOp)) {
-      return emitOpError() << "cancel loop must appear "
-                           << "inside a worksharing-loop region";
+    auto loopOp = dyn_cast<LoopNestOp>(parentOp);
+    auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
+        loopOp ? loopOp->getParentOp() : nullptr);
+
+    if (!wsloopOp) {
+      return emitOpError()
+             << "cancel loop must appear inside a worksharing-loop region";
     }
-    if (cast<WsloopOp>(parentOp).getNowaitAttr()) {
+    if (wsloopOp.getNowaitAttr()) {
       return emitError() << "A worksharing construct that is canceled "
                          << "must not have a nowait clause";
     }
-    if (cast<WsloopOp>(parentOp).getOrderedValAttr()) {
+    if (wsloopOp.getOrderedValAttr()) {
       return emitError() << "A worksharing construct that is canceled "
                          << "must not have an ordered clause";
     }
@@ -2176,7 +2181,7 @@ LogicalResult CancellationPointOp::verify() {
                          << "inside a parallel region";
   }
   if ((cct == ClauseCancellationConstructType::Loop) &&
-      !isa<WsloopOp>(parentOp)) {
+      (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
     return emitOpError() << "cancellation point loop must appear "
                          << "inside a worksharing-loop region";
   }

>From 18c8bda112d60824c74904d8c27a16f7f016c020 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 19 Apr 2024 11:37:21 +0100
Subject: [PATCH 2/2] Address review comment, improve tests

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 57 +++++++++++------
 mlir/test/Dialect/OpenMP/invalid.mlir        | 67 ++++++++++++++++----
 2 files changed, 94 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d014c27e1ee157..c546d3a9044de1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1945,6 +1945,39 @@ LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 // Ordered construct
 //===----------------------------------------------------------------------===//
 
+static LogicalResult verifyOrderedParent(Operation &op) {
+  bool hasRegion = op.getNumRegions() > 0;
+  auto loopOp = op.getParentOfType<LoopNestOp>();
+  if (!loopOp) {
+    if (hasRegion)
+      return success();
+
+    // TODO: Consider if this needs to be the case only for the standalone
+    // variant of the ordered construct.
+    return op.emitOpError() << "must be nested inside of a loop";
+  }
+
+  Operation *wrapper = loopOp->getParentOp();
+  if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
+    IntegerAttr orderedAttr = wsloopOp.getOrderedValAttr();
+    if (!orderedAttr)
+      return op.emitOpError() << "the enclosing worksharing-loop region must "
+                                 "have an ordered clause";
+
+    if (hasRegion && orderedAttr.getInt() != 0)
+      return op.emitOpError() << "the enclosing loop's ordered clause must not "
+                                 "have a parameter present";
+
+    if (!hasRegion && orderedAttr.getInt() == 0)
+      return op.emitOpError() << "the enclosing loop's ordered clause must "
+                                 "have a parameter present";
+  } else if (!isa<SimdOp>(wrapper)) {
+    return op.emitOpError() << "must be nested inside of a worksharing, simd "
+                               "or worksharing simd loop";
+  }
+  return success();
+}
+
 void OrderedOp::build(OpBuilder &builder, OperationState &state,
                       const OrderedOpClauseOps &clauses) {
   OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr,
@@ -1952,14 +1985,11 @@ void OrderedOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult OrderedOp::verify() {
-  auto container = (*this)->getParentOfType<WsloopOp>();
-  if (!container || !container.getOrderedValAttr() ||
-      container.getOrderedValAttr().getInt() == 0)
-    return emitOpError() << "ordered depend directive must be closely "
-                         << "nested inside a worksharing-loop with ordered "
-                         << "clause with parameter present";
-
-  if (container.getOrderedValAttr().getInt() != (int64_t)*getNumLoopsVal())
+  if (failed(verifyOrderedParent(**this)))
+    return failure();
+
+  auto wrapper = (*this)->getParentOfType<WsloopOp>();
+  if (!wrapper || *wrapper.getOrderedVal() != *getNumLoopsVal())
     return emitOpError() << "number of variables in depend clause does not "
                          << "match number of iteration variables in the "
                          << "doacross loop";
@@ -1977,16 +2007,7 @@ LogicalResult OrderedRegionOp::verify() {
   if (getSimd())
     return failure();
 
-  if (auto loopOp = dyn_cast<LoopNestOp>((*this)->getParentOp())) {
-    auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(loopOp->getParentOp());
-    if (!wsloopOp || !wsloopOp.getOrderedValAttr() ||
-        wsloopOp.getOrderedValAttr().getInt() != 0)
-      return emitOpError() << "ordered region must be closely nested inside "
-                           << "a worksharing-loop region with an ordered "
-                           << "clause without parameter present";
-  }
-
-  return success();
+  return verifyOrderedParent(**this);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 3bd4838c4e0f42..3361092425e73d 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -683,10 +683,10 @@ omp.critical.declare @mutex hint(invalid_hint)
 
 // -----
 
-func.func @omp_ordered1(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () {
-  omp.wsloop ordered(1) {
-    omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
-      // expected-error @below {{ordered region must be closely nested inside a worksharing-loop region with an ordered clause without parameter present}}
+func.func @omp_ordered_region1(%x : i32) -> () {
+  omp.distribute {
+    omp.loop_nest (%i) : i32 = (%x) to (%x) step (%x) {
+      // expected-error @below {{op must be nested inside of a worksharing, simd or worksharing simd loop}}
       omp.ordered.region {
         omp.terminator
       }
@@ -699,10 +699,10 @@ func.func @omp_ordered1(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () {
 
 // -----
 
-func.func @omp_ordered2(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () {
+func.func @omp_ordered_region2(%x : i32) -> () {
   omp.wsloop {
-    omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
-      // expected-error @below {{ordered region must be closely nested inside a worksharing-loop region with an ordered clause without parameter present}}
+    omp.loop_nest (%i) : i32 = (%x) to (%x) step (%x) {
+      // expected-error @below {{the enclosing worksharing-loop region must have an ordered clause}}
       omp.ordered.region {
         omp.terminator
       }
@@ -715,26 +715,70 @@ func.func @omp_ordered2(%arg1 : i32, %arg2 : i32, %arg3 : i32) -> () {
 
 // -----
 
-func.func @omp_ordered3(%vec0 : i64) -> () {
-  // expected-error @below {{ordered depend directive must be closely nested inside a worksharing-loop with ordered clause with parameter present}}
+func.func @omp_ordered_region3(%x : i32) -> () {
+  omp.wsloop ordered(1) {
+    omp.loop_nest (%i) : i32 = (%x) to (%x) step (%x) {
+      // expected-error @below {{the enclosing loop's ordered clause must not have a parameter present}}
+      omp.ordered.region {
+        omp.terminator
+      }
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_ordered1(%vec0 : i64) -> () {
+  // expected-error @below {{op must be nested inside of a loop}}
   omp.ordered depend_type(dependsink) depend_vec(%vec0 : i64) {num_loops_val = 1 : i64}
   return
 }
 
 // -----
 
+func.func @omp_ordered2(%arg1 : i32, %arg2 : i32, %arg3 : i32, %vec0 : i64) -> () {
+  omp.distribute {
+    omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
+      // expected-error @below {{op must be nested inside of a worksharing, simd or worksharing simd loop}}
+      omp.ordered depend_type(dependsink) depend_vec(%vec0 : i64) {num_loops_val = 1 : i64}
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_ordered3(%arg1 : i32, %arg2 : i32, %arg3 : i32, %vec0 : i64) -> () {
+  omp.wsloop {
+    omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
+      // expected-error @below {{the enclosing worksharing-loop region must have an ordered clause}}
+      omp.ordered depend_type(dependsink) depend_vec(%vec0 : i64) {num_loops_val = 1 : i64}
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_ordered4(%arg1 : i32, %arg2 : i32, %arg3 : i32, %vec0 : i64) -> () {
   omp.wsloop ordered(0) {
     omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
-      // expected-error @below {{ordered depend directive must be closely nested inside a worksharing-loop with ordered clause with parameter present}}
+      // expected-error @below {{the enclosing loop's ordered clause must have a parameter present}}
       omp.ordered depend_type(dependsink) depend_vec(%vec0 : i64) {num_loops_val = 1 : i64}
-
       omp.yield
     }
     omp.terminator
   }
   return
 }
+
 // -----
 
 func.func @omp_ordered5(%arg1 : i32, %arg2 : i32, %arg3 : i32, %vec0 : i64, %vec1 : i64) -> () {
@@ -742,7 +786,6 @@ func.func @omp_ordered5(%arg1 : i32, %arg2 : i32, %arg3 : i32, %vec0 : i64, %vec
     omp.loop_nest (%0) : i32 = (%arg1) to (%arg2) step (%arg3) {
       // expected-error @below {{number of variables in depend clause does not match number of iteration variables in the doacross loop}}
       omp.ordered depend_type(dependsource) depend_vec(%vec0, %vec1 : i64, i64) {num_loops_val = 2 : i64}
-
       omp.yield
     }
     omp.terminator



More information about the llvm-branch-commits mailing list