[Mlir-commits] [flang] [mlir] [MLIR][Flang][OpenMP] Remove omp.parallel from loop wrapper ops (PR #105833)
Sergio Afonso
llvmlistbot at llvm.org
Thu Aug 29 02:43:15 PDT 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/105833
>From 37792b4ddec46beb33bdc400433782b42f310217 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 23 Aug 2024 14:39:44 +0100
Subject: [PATCH 1/2] [MLIR][Flang][OpenMP] Remove omp.parallel from loop
wrapper ops
This patch updates the `omp.parallel` operation according to the results of
the discussion in [this RFC](https://discourse.llvm.org/t/rfc-disambiguation-between-loop-and-block-associated-omp-parallelop/79972).
It is removed from the set of loop wrapper operations, changing the expected
MLIR representation for composite `distribute parallel do/for` into the
following:
```mlir
omp.parallel {
...
omp.distribute {
omp.wsloop {
omp.loop_nest ... { ... }
omp.terminator
}
omp.terminator
}
...
omp.terminator
}
```
MLIR verifiers for operations impacted by this representation change are
updated, as well as related tests. The `LoopWrapperInterface` is also updated,
since it's no longer representing an optional "role" of an operation but a
mandatory set of restrictions instead.
---
.../lib/Lower/OpenMP/DataSharingProcessor.cpp | 4 +-
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 1 -
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 30 ++--
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 75 ++++-----
mlir/test/Dialect/OpenMP/invalid.mlir | 152 ++++++++----------
mlir/test/Dialect/OpenMP/ops.mlir | 20 +--
6 files changed, 126 insertions(+), 156 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 1b2f926e21bed8..a2003473a0fd80 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -231,9 +231,7 @@ void DataSharingProcessor::insertBarrier() {
void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
mlir::omp::LoopNestOp loopOp;
if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
- loopOp = wrapper.isWrapper()
- ? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
- : nullptr;
+ loopOp = mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop());
bool cmpCreated = false;
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5a7dae0b5f3074..1aa4e771cd4dea 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -129,7 +129,6 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>
def ParallelOp : OpenMP_Op<"parallel", traits = [
AttrSizedOperandSegments, AutomaticAllocationScope,
DeclareOpInterfaceMethods<ComposableOpInterface>,
- DeclareOpInterfaceMethods<LoopWrapperInterface>,
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
RecursiveMemoryEffects
], clauses = [
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 78637f0ab8c2da..1dd787ebe7de4f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -71,10 +71,10 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
let description = [{
- OpenMP operations that can wrap a single loop nest. When taking a wrapper
- role, these operations must only contain a single region with a single block
- in which there's a single operation and a terminator. That nested operation
- must be another loop wrapper or an `omp.loop_nest`.
+ OpenMP operations that wrap a single loop nest. They must only contain a
+ single region with a single block in which there's a single operation and a
+ terminator. That nested operation must be another loop wrapper or an
+ `omp.loop_nest`.
}];
let cppNamespace = "::mlir::omp";
@@ -82,13 +82,12 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
let methods = [
InterfaceMethod<
/*description=*/[{
- Tell whether the operation could be taking the role of a loop wrapper.
- That is, it has a single region with a single block in which there are
- two operations: another wrapper (also taking a loop wrapper role) or
- `omp.loop_nest` operation and a terminator.
+ Check whether the operation is a valid loop wrapper. That is, it has a
+ single region with a single block in which there are two operations:
+ another loop wrapper or `omp.loop_nest` operation and a terminator.
}],
/*retTy=*/"bool",
- /*methodName=*/"isWrapper",
+ /*methodName=*/"isValidWrapper",
(ins ), [{}], [{
if ($_op->getNumRegions() != 1)
return false;
@@ -106,21 +105,18 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
return false;
- if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
- return wrapper.isWrapper();
-
- return ::llvm::isa<LoopNestOp>(firstOp);
+ return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp);
}]
>,
InterfaceMethod<
/*description=*/[{
If there is another loop wrapper immediately nested inside, return that
- operation. Assumes this operation is taking a loop wrapper role.
+ operation. Assumes this operation is a valid loop wrapper.
}],
/*retTy=*/"::mlir::omp::LoopWrapperInterface",
/*methodName=*/"getNestedWrapper",
(ins), [{}], [{
- assert($_op.isWrapper() && "Unexpected non-wrapper op");
+ assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
Operation *nested = &*$_op->getRegion(0).op_begin();
return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
}]
@@ -128,12 +124,12 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
InterfaceMethod<
/*description=*/[{
Return the loop nest nested directly or indirectly inside of this loop
- wrapper. Assumes this operation is taking a loop wrapper role.
+ wrapper. Assumes this operation is a valid loop wrapper.
}],
/*retTy=*/"::mlir::Operation *",
/*methodName=*/"getWrappedLoop",
(ins), [{}], [{
- assert($_op.isWrapper() && "Unexpected non-wrapper op");
+ assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
if (LoopWrapperInterface nested = $_op.getNestedWrapper())
return nested.getWrappedLoop();
return &*$_op->getRegion(0).op_begin();
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index eb4f9cb041841b..6db4796eb37e0d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1541,26 +1541,25 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}
LogicalResult ParallelOp::verify() {
- // Check that it is a valid loop wrapper if it's taking that role.
- if (isa<DistributeOp>((*this)->getParentOp())) {
- if (!isWrapper())
- return emitOpError() << "must take a loop wrapper role if nested inside "
- "of 'omp.distribute'";
+ auto distributeChildOps = getOps<DistributeOp>();
+ if (!distributeChildOps.empty()) {
if (!isComposite())
return emitError()
- << "'omp.composite' attribute missing from composite wrapper";
+ << "'omp.composite' attribute missing from composite operation";
- if (LoopWrapperInterface nested = getNestedWrapper()) {
- // Check for the allowed leaf constructs that may appear in a composite
- // construct directly after PARALLEL.
- if (!isa<WsloopOp>(nested))
- return emitError() << "only supported nested wrapper is 'omp.wsloop'";
- } else {
- return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
+ auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
+ Operation &distributeOp = **distributeChildOps.begin();
+ for (Operation &childOp : getOps()) {
+ if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
+ continue;
+
+ if (!childOp.hasTrait<OpTrait::IsTerminator>())
+ return emitError() << "unexpected OpenMP operation inside of composite "
+ "'omp.parallel'";
}
} else if (isComposite()) {
return emitError()
- << "'omp.composite' attribute present in non-composite wrapper";
+ << "'omp.composite' attribute present in non-composite operation";
}
if (getAllocateVars().size() != getAllocatorVars().size())
@@ -1751,15 +1750,12 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult WsloopOp::verify() {
- if (!isWrapper())
- return emitOpError() << "must be a loop wrapper";
+ if (!isValidWrapper())
+ return emitOpError() << "must be a valid loop wrapper";
- auto wrapper =
- llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
bool isCompositeChildLeaf =
- wrapper && wrapper.isWrapper() &&
- (!llvm::isa<ParallelOp>(wrapper) ||
- llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));
+ llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
+
if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
return emitError()
@@ -1813,18 +1809,14 @@ LogicalResult SimdOp::verify() {
if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
return failure();
- if (!isWrapper())
- return emitOpError() << "must be a loop wrapper";
+ if (!isValidWrapper())
+ return emitOpError() << "must be a valid loop wrapper";
if (getNestedWrapper())
return emitOpError() << "must wrap an 'omp.loop_nest' directly";
- auto wrapper =
- llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
bool isCompositeChildLeaf =
- wrapper && wrapper.isWrapper() &&
- (!llvm::isa<ParallelOp>(wrapper) ||
- llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));
+ llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
if (!isComposite() && isCompositeChildLeaf)
return emitError()
@@ -1859,8 +1851,8 @@ LogicalResult DistributeOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- if (!isWrapper())
- return emitOpError() << "must be a loop wrapper";
+ if (!isValidWrapper())
+ return emitOpError() << "must be a valid loop wrapper";
if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
@@ -1868,9 +1860,13 @@ LogicalResult DistributeOp::verify() {
<< "'omp.composite' attribute missing from composite wrapper";
// Check for the allowed leaf constructs that may appear in a composite
// construct directly after DISTRIBUTE.
- if (!isa<ParallelOp, SimdOp>(nested))
- return emitError() << "only supported nested wrappers are 'omp.parallel' "
- "and 'omp.simd'";
+ if (isa<WsloopOp>(nested)) {
+ if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
+ return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
+ "when 'omp.parallel' is the direct parent";
+ } else if (!isa<SimdOp>(nested))
+ return emitError() << "only supported nested wrappers are 'omp.simd' and "
+ "'omp.wsloop'";
} else if (isComposite()) {
return emitError()
<< "'omp.composite' attribute present in non-composite wrapper";
@@ -2063,8 +2059,8 @@ LogicalResult TaskloopOp::verify() {
"may not appear on the same taskloop directive");
}
- if (!isWrapper())
- return emitOpError() << "must be a loop wrapper";
+ if (!isValidWrapper())
+ return emitOpError() << "must be a valid loop wrapper";
if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
@@ -2161,11 +2157,8 @@ LogicalResult LoopNestOp::verify() {
<< "range argument type does not match corresponding IV type";
}
- auto wrapper =
- llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
-
- if (!wrapper || !wrapper.isWrapper())
- return emitOpError() << "expects parent op to be a valid loop wrapper";
+ if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
+ return emitOpError() << "expects parent op to be a loop wrapper";
return success();
}
@@ -2175,8 +2168,6 @@ void LoopNestOp::gatherWrappers(
Operation *parent = (*this)->getParentOp();
while (auto wrapper =
llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
- if (!wrapper.isWrapper())
- break;
wrappers.push_back(wrapper);
parent = parent->getParentOp();
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 332d22fc2c6425..8d29d1622e3b62 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -10,58 +10,6 @@ func.func @unknown_clause() {
// -----
-func.func @not_wrapper() {
- // expected-error at +1 {{op must be a loop wrapper}}
- omp.distribute {
- omp.parallel {
- %0 = arith.constant 0 : i32
- omp.terminator
- }
- omp.terminator
- }
-
- return
-}
-
-// -----
-
-func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) {
- omp.distribute {
- // expected-error at +1 {{only supported nested wrapper is 'omp.wsloop'}}
- omp.parallel {
- omp.simd {
- omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- omp.terminator
- } {omp.composite}
- omp.terminator
- } {omp.composite}
- omp.terminator
- } {omp.composite}
-
- return
-}
-
-// -----
-
-func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) {
- omp.distribute {
- // expected-error at +1 {{op must not wrap an 'omp.loop_nest' directly}}
- omp.parallel {
- omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- omp.terminator
- } {omp.composite}
- omp.terminator
- } {omp.composite}
-
- return
-}
-
-// -----
-
func.func @if_once(%n : i1) {
// expected-error at +1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
omp.parallel if(%n) if(%n) {
@@ -140,7 +88,7 @@ func.func @proc_bind_once() {
// -----
func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
- // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
+ // expected-error at +1 {{op expects parent op to be a loop wrapper}}
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
@@ -148,19 +96,6 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
// -----
-func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) {
- omp.parallel {
- %0 = arith.constant 0 : i32
- // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
- omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- omp.terminator
- }
-}
-
-// -----
-
func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
omp.wsloop {
// expected-error at +1 {{range argument type does not match corresponding IV type}}
@@ -188,7 +123,7 @@ func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
// -----
func.func @no_wrapper(%lb : index, %ub : index, %step : index) {
- // expected-error @below {{op must be a loop wrapper}}
+ // expected-error @below {{op must be a valid loop wrapper}}
omp.wsloop {
%0 = arith.constant 0 : i32
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -374,7 +309,7 @@ llvm.func @test_omp_wsloop_dynamic_wrong_modifier3(%lb : i64, %ub : i64, %step :
// -----
func.func @omp_simd() -> () {
- // expected-error @below {{op must be a loop wrapper}}
+ // expected-error @below {{op must be a valid loop wrapper}}
omp.simd {
omp.terminator
}
@@ -2028,7 +1963,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// -----
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
- // expected-error @below {{op must be a loop wrapper}}
+ // expected-error @below {{op must be a valid loop wrapper}}
omp.taskloop {
%0 = arith.constant 0 : i32
omp.terminator
@@ -2237,7 +2172,7 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
// -----
func.func @omp_distribute_wrapper() -> () {
- // expected-error @below {{op must be a loop wrapper}}
+ // expected-error @below {{op must be a valid loop wrapper}}
omp.distribute {
%0 = arith.constant 0 : i32
"omp.terminator"() : () -> ()
@@ -2247,7 +2182,7 @@ func.func @omp_distribute_wrapper() -> () {
// -----
func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
- // expected-error @below {{only supported nested wrappers are 'omp.parallel' and 'omp.simd'}}
+ // expected-error @below {{an 'omp.wsloop' nested wrapper is only allowed when 'omp.parallel' is the direct parent}}
omp.distribute {
"omp.wsloop"() ({
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -2261,6 +2196,36 @@ func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -
// -----
+func.func @omp_distribute_nested_wrapper2(%lb: index, %ub: index, %step: index) -> () {
+ // expected-error @below {{only supported nested wrappers are 'omp.simd' and 'omp.wsloop'}}
+ omp.distribute {
+ "omp.taskloop"() ({
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ "omp.yield"() : () -> ()
+ }
+ "omp.terminator"() : () -> ()
+ }) {omp.composite} : () -> ()
+ "omp.terminator"() : () -> ()
+ } {omp.composite}
+}
+
+// -----
+
+func.func @omp_distribute_nested_wrapper3(%lb: index, %ub: index, %step: index) -> () {
+ // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
+ omp.distribute {
+ "omp.simd"() ({
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ "omp.yield"() : () -> ()
+ }
+ "omp.terminator"() : () -> ()
+ }) {omp.composite} : () -> ()
+ "omp.terminator"() : () -> ()
+ }
+}
+
+// -----
+
func.func @omp_distribute_order() -> () {
// expected-error @below {{invalid clause value: 'default'}}
omp.distribute order(default) {
@@ -2469,9 +2434,9 @@ func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
// -----
func.func @omp_parallel_missing_composite(%lb: index, %ub: index, %step: index) -> () {
- omp.distribute {
- // expected-error at +1 {{'omp.composite' attribute missing from composite wrapper}}
- omp.parallel {
+ // expected-error at +1 {{'omp.composite' attribute missing from composite operation}}
+ omp.parallel {
+ omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
@@ -2479,15 +2444,15 @@ func.func @omp_parallel_missing_composite(%lb: index, %ub: index, %step: index)
omp.terminator
} {omp.composite}
omp.terminator
- }
+ } {omp.composite}
omp.terminator
- } {omp.composite}
+ }
return
}
// -----
func.func @omp_parallel_invalid_composite(%lb: index, %ub: index, %step: index) -> () {
- // expected-error @below {{'omp.composite' attribute present in non-composite wrapper}}
+ // expected-error @below {{'omp.composite' attribute present in non-composite operation}}
omp.parallel {
omp.wsloop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -2500,6 +2465,25 @@ func.func @omp_parallel_invalid_composite(%lb: index, %ub: index, %step: index)
return
}
+// -----
+func.func @omp_parallel_invalid_composite2(%lb: index, %ub: index, %step: index) -> () {
+ // expected-error @below {{unexpected OpenMP operation inside of composite 'omp.parallel'}}
+ omp.parallel {
+ omp.barrier
+ omp.distribute {
+ omp.wsloop {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ return
+}
+
// -----
func.func @omp_wsloop_missing_composite(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
@@ -2529,8 +2513,8 @@ func.func @omp_wsloop_invalid_composite(%lb: index, %ub: index, %step: index) ->
// -----
func.func @omp_wsloop_missing_composite_2(%lb: index, %ub: index, %step: index) -> () {
- omp.distribute {
- omp.parallel {
+ omp.parallel {
+ omp.distribute {
// expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
omp.wsloop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -2574,9 +2558,9 @@ func.func @omp_simd_invalid_composite(%lb: index, %ub: index, %step: index) -> (
// -----
func.func @omp_distribute_missing_composite(%lb: index, %ub: index, %step: index) -> () {
- // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
- omp.distribute {
- omp.parallel {
+ omp.parallel {
+ // expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
+ omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
@@ -2584,9 +2568,9 @@ func.func @omp_distribute_missing_composite(%lb: index, %ub: index, %step: index
omp.terminator
} {omp.composite}
omp.terminator
- } {omp.composite}
+ }
omp.terminator
- }
+ } {omp.composite}
return
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 7ba55fc957a473..dce5b3950def49 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -99,10 +99,11 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
omp.terminator
}) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
- // CHECK: omp.distribute
- omp.distribute {
- // CHECK-NEXT: omp.parallel
- omp.parallel {
+ // CHECK: omp.parallel
+ omp.parallel {
+ // CHECK-NOT: omp.terminator
+ // CHECK: omp.distribute
+ omp.distribute {
// CHECK-NEXT: omp.wsloop
omp.wsloop {
// CHECK-NEXT: omp.loop_nest
@@ -116,14 +117,15 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
omp.terminator
} {omp.composite}
- // CHECK: omp.distribute
- omp.distribute {
- // CHECK-NEXT: omp.parallel
- omp.parallel {
+ // CHECK: omp.parallel
+ omp.parallel {
+ // CHECK-NOT: omp.terminator
+ // CHECK: omp.distribute
+ omp.distribute {
// CHECK-NEXT: omp.wsloop
omp.wsloop {
// CHECK-NEXT: omp.simd
- omp.simd{
+ omp.simd {
// CHECK-NEXT: omp.loop_nest
omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
omp.yield
>From cbfb06925563eea2a61966327e6cb2bd68ea8f64 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 27 Aug 2024 13:32:49 +0100
Subject: [PATCH 2/2] Address review comments
---
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 30 -------------
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 42 +++++++++++++++----
mlir/test/Dialect/OpenMP/invalid.mlir | 14 ++++---
3 files changed, 42 insertions(+), 44 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 1dd787ebe7de4f..0078e22b1c89a6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -80,34 +80,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
let cppNamespace = "::mlir::omp";
let methods = [
- InterfaceMethod<
- /*description=*/[{
- Check whether the operation is a valid loop wrapper. That is, it has a
- single region with a single block in which there are two operations:
- another loop wrapper or `omp.loop_nest` operation and a terminator.
- }],
- /*retTy=*/"bool",
- /*methodName=*/"isValidWrapper",
- (ins ), [{}], [{
- if ($_op->getNumRegions() != 1)
- return false;
-
- Region &r = $_op->getRegion(0);
- if (!r.hasOneBlock())
- return false;
-
- if (::llvm::range_size(r.getOps()) != 2)
- return false;
-
- Operation &firstOp = *r.op_begin();
- Operation &secondOp = *(std::next(r.op_begin()));
-
- if (!secondOp.hasTrait<OpTrait::IsTerminator>())
- return false;
-
- return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp);
- }]
- >,
InterfaceMethod<
/*description=*/[{
If there is another loop wrapper immediately nested inside, return that
@@ -116,7 +88,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
/*retTy=*/"::mlir::omp::LoopWrapperInterface",
/*methodName=*/"getNestedWrapper",
(ins), [{}], [{
- assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
Operation *nested = &*$_op->getRegion(0).op_begin();
return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
}]
@@ -129,7 +100,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
/*retTy=*/"::mlir::Operation *",
/*methodName=*/"getWrappedLoop",
(ins), [{}], [{
- assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
if (LoopWrapperInterface nested = $_op.getNestedWrapper())
return nested.getWrappedLoop();
return &*$_op->getRegion(0).op_begin();
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6db4796eb37e0d..1a9b87f0d68c9d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1720,6 +1720,32 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region ®ion,
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
+static LogicalResult verifyLoopWrapperInterface(Operation *op) {
+ if (op->getNumRegions() != 1)
+ return op->emitOpError() << "loop wrapper contains multiple regions";
+
+ Region ®ion = op->getRegion(0);
+ if (!region.hasOneBlock())
+ return op->emitOpError() << "loop wrapper contains multiple blocks";
+
+ if (::llvm::range_size(region.getOps()) != 2)
+ return op->emitOpError()
+ << "loop wrapper does not contain exactly two nested ops";
+
+ Operation &firstOp = *region.op_begin();
+ Operation &secondOp = *(std::next(region.op_begin()));
+
+ if (!secondOp.hasTrait<OpTrait::IsTerminator>())
+ return op->emitOpError()
+ << "second nested op in loop wrapper is not a terminator";
+
+ if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
+ return op->emitOpError() << "first nested op in loop wrapper is not "
+ "another loop wrapper or `omp.loop_nest`";
+
+ return success();
+}
+
void WsloopOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
@@ -1750,8 +1776,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult WsloopOp::verify() {
- if (!isValidWrapper())
- return emitOpError() << "must be a valid loop wrapper";
+ if (verifyLoopWrapperInterface(*this).failed())
+ return failure();
bool isCompositeChildLeaf =
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
@@ -1809,8 +1835,8 @@ LogicalResult SimdOp::verify() {
if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
return failure();
- if (!isValidWrapper())
- return emitOpError() << "must be a valid loop wrapper";
+ if (verifyLoopWrapperInterface(*this).failed())
+ return failure();
if (getNestedWrapper())
return emitOpError() << "must wrap an 'omp.loop_nest' directly";
@@ -1851,8 +1877,8 @@ LogicalResult DistributeOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- if (!isValidWrapper())
- return emitOpError() << "must be a valid loop wrapper";
+ if (verifyLoopWrapperInterface(*this).failed())
+ return failure();
if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
@@ -2059,8 +2085,8 @@ LogicalResult TaskloopOp::verify() {
"may not appear on the same taskloop directive");
}
- if (!isValidWrapper())
- return emitOpError() << "must be a valid loop wrapper";
+ if (verifyLoopWrapperInterface(*this).failed())
+ return failure();
if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8d29d1622e3b62..d8745f1015af83 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -123,7 +123,7 @@ func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
// -----
func.func @no_wrapper(%lb : index, %ub : index, %step : index) {
- // expected-error @below {{op must be a valid loop wrapper}}
+ // expected-error @below {{op loop wrapper does not contain exactly two nested ops}}
omp.wsloop {
%0 = arith.constant 0 : i32
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -309,7 +309,7 @@ llvm.func @test_omp_wsloop_dynamic_wrong_modifier3(%lb : i64, %ub : i64, %step :
// -----
func.func @omp_simd() -> () {
- // expected-error @below {{op must be a valid loop wrapper}}
+ // expected-error @below {{op loop wrapper does not contain exactly two nested ops}}
omp.simd {
omp.terminator
}
@@ -1963,7 +1963,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
// -----
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
- // expected-error @below {{op must be a valid loop wrapper}}
+ // expected-error @below {{op first nested op in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
omp.taskloop {
%0 = arith.constant 0 : i32
omp.terminator
@@ -2171,11 +2171,13 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
// -----
-func.func @omp_distribute_wrapper() -> () {
- // expected-error @below {{op must be a valid loop wrapper}}
+func.func @omp_distribute_wrapper(%lb: index, %ub: index, %step: index) -> () {
+ // expected-error @below {{op second nested op in loop wrapper is not a terminator}}
omp.distribute {
+ omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+ "omp.yield"() : () -> ()
+ }
%0 = arith.constant 0 : i32
- "omp.terminator"() : () -> ()
}
}
More information about the Mlir-commits
mailing list