[llvm-branch-commits] [mlir] [MLIR][OpenMP] Introduce the LoopWrapperInterface (PR #87232)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Apr 2 07:37:14 PDT 2024
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/87232
>From 2452bc75a7f2efb67a0522bbe8b0e7ba5bc3365b Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Mon, 1 Apr 2024 13:04:14 +0100
Subject: [PATCH 1/2] [MLIR][OpenMP] Introduce the LoopWrapperInterface
This patch defines a common interface to be shared by all OpenMP loop wrapper
operations. The main restrictions these operations must meet in order to be
considered a wrapper are:
- They contain a single region.
- Their region contains a single block.
- Their block only contains another loop wrapper or `omp.loop_nest` and a
terminator.
The new interface is attached to the `omp.parallel`, `omp.wsloop`,
`omp.simdloop`, `omp.distribute` and `omp.taskloop` operations. It is not
currently enforced that these operations meet the wrapper restrictions, which
would break existing OpenMP loop-generating code. Rather, this will be
introduced progressively in subsequent patches.
---
.../mlir/Dialect/OpenMP/OpenMPInterfaces.h | 3 +
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 16 +++--
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 68 +++++++++++++++++++
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 19 ++++++
mlir/test/Dialect/OpenMP/invalid.mlir | 16 ++++-
5 files changed, 117 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
index b3184db8852161..787c48b05c5c5c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPInterfaces.h
@@ -21,6 +21,9 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#define GET_OP_FWD_DEFINES
+#include "mlir/Dialect/OpenMP/OpenMPOps.h.inc"
+
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.h.inc"
namespace mlir::omp {
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ffd00948915153..a7bf93deae2fb3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -236,6 +236,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove]> {
def ParallelOp : OpenMP_Op<"parallel", [
AutomaticAllocationScope, AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopWrapperInterface>,
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
RecursiveMemoryEffects, ReductionClauseInterface]> {
let summary = "parallel construct";
@@ -517,8 +518,6 @@ def SingleOp : OpenMP_Op<"single", [AttrSizedOperandSegments]> {
def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
- ParentOneOf<["DistributeOp", "SimdLoopOp", "TaskloopOp",
- "WsloopOp"]>,
RecursiveMemoryEffects]> {
let summary = "rectangular loop nest";
let description = [{
@@ -568,6 +567,10 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
/// Returns the induction variables of the loop nest.
ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
+
+ /// Returns the list of wrapper operations around this loop nest. Wrappers
+ /// in the resulting vector will be sorted from innermost to outermost.
+ SmallVector<LoopWrapperInterface> getWrappers();
}];
let hasCustomAssemblyFormat = 1;
@@ -580,6 +583,7 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+ DeclareOpInterfaceMethods<LoopWrapperInterface>,
RecursiveMemoryEffects, ReductionClauseInterface]> {
let summary = "worksharing-loop construct";
let description = [{
@@ -700,7 +704,9 @@ def WsloopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
//===----------------------------------------------------------------------===//
def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
- AllTypesMatch<["lowerBound", "upperBound", "step"]>]> {
+ AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+ DeclareOpInterfaceMethods<LoopWrapperInterface>,
+ RecursiveMemoryEffects]> {
let summary = "simd loop construct";
let description = [{
The simd construct can be applied to a loop to indicate that the loop can be
@@ -809,7 +815,8 @@ def YieldOp : OpenMP_Op<"yield",
// Distribute construct [2.9.4.1]
//===----------------------------------------------------------------------===//
def DistributeOp : OpenMP_Op<"distribute", [AttrSizedOperandSegments,
- MemoryEffects<[MemWrite]>]> {
+ DeclareOpInterfaceMethods<LoopWrapperInterface>,
+ RecursiveMemoryEffects]> {
let summary = "distribute construct";
let description = [{
The distribute construct specifies that the iterations of one or more loops
@@ -980,6 +987,7 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
AutomaticAllocationScope, RecursiveMemoryEffects,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
+ DeclareOpInterfaceMethods<LoopWrapperInterface>,
ReductionClauseInterface]> {
let summary = "taskloop construct";
let description = [{
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 2e37384ce3eb71..b6a3560b7da56a 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -69,6 +69,74 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
];
}
+def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
+ let description = [{
+ OpenMP operations that can wrap a single loop nest. When taking a wrapper
+ role, these operations must only contain a single region with a single block
+ in which there's a single operation and a terminator. That nested operation
+ must be another loop wrapper or an `omp.loop_nest`.
+ }];
+
+ let cppNamespace = "::mlir::omp";
+
+ let methods = [
+ InterfaceMethod<
+ /*description=*/[{
+ Tell whether the operation could be taking the role of a loop wrapper.
+ That is, it has a single region with a single block in which there are
+ two operations: another wrapper or `omp.loop_nest` operation and a
+ terminator.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"isWrapper",
+ (ins ), [{}], [{
+ if ($_op->getNumRegions() != 1)
+ return false;
+
+ ::mlir::Region &r = $_op->getRegion(0);
+ if (!r.hasOneBlock())
+ return false;
+
+ if (std::distance(r.op_begin(), r.op_end()) != 2)
+ return false;
+
+ ::mlir::Operation &firstOp = *r.op_begin();
+ ::mlir::Operation &secondOp = *(++r.op_begin());
+ return ::llvm::isa<::mlir::omp::LoopNestOp,
+ ::mlir::omp::LoopWrapperInterface>(firstOp) &&
+ secondOp.hasTrait<::mlir::OpTrait::IsTerminator>();
+ }]
+ >,
+ InterfaceMethod<
+ /*description=*/[{
+ If there is another loop wrapper immediately nested inside, return that
+ operation. Assumes this operation is taking a loop wrapper role.
+ }],
+ /*retTy=*/"::mlir::omp::LoopWrapperInterface",
+ /*methodName=*/"getNestedWrapper",
+ (ins), [{}], [{
+ assert($_op.isWrapper() && "Unexpected non-wrapper op");
+ ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin();
+ return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested);
+ }]
+ >,
+ InterfaceMethod<
+ /*description=*/[{
+ Return the loop nest nested directly or indirectly inside of this loop
+ wrapper. Assumes this operation is taking a loop wrapper role.
+ }],
+ /*retTy=*/"::mlir::Operation *",
+ /*methodName=*/"getWrappedLoop",
+ (ins), [{}], [{
+ assert($_op.isWrapper() && "Unexpected non-wrapper op");
+ if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper())
+ return nested.getWrappedLoop();
+ return &*$_op->getRegion(0).op_begin();
+ }]
+ >
+ ];
+}
+
def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
let description = [{
OpenMP operations that support declare target have this interface.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 796df1d13e6564..564c23201db4fd 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1730,9 +1730,28 @@ LogicalResult LoopNestOp::verify() {
<< "range argument type does not match corresponding IV type";
}
+ auto wrapper =
+ llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
+
+ if (!wrapper || !wrapper.isWrapper())
+ return emitOpError() << "expects parent op to be a valid loop wrapper";
+
return success();
}
+SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
+ SmallVector<LoopWrapperInterface> wrappers;
+ Operation *parent = (*this)->getParentOp();
+ while (auto wrapper =
+ llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
+ if (!wrapper.isWrapper())
+ break;
+ wrappers.push_back(wrapper);
+ parent = parent->getParentOp();
+ }
+ return wrappers;
+}
+
//===----------------------------------------------------------------------===//
// WsloopOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 760ebb14d94121..8f4103dabee5df 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -88,7 +88,7 @@ func.func @proc_bind_once() {
// -----
func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
- // expected-error at +1 {{op expects parent op to be one of 'omp.distribute, omp.simdloop, omp.taskloop, omp.wsloop'}}
+ // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
@@ -96,6 +96,20 @@ func.func @invalid_parent(%lb : index, %ub : index, %step : index) {
// -----
+func.func @invalid_wrapper(%lb : index, %ub : index, %step : index) {
+ // TODO Remove induction variables from omp.wsloop.
+ omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+ %0 = arith.constant 0 : i32
+ // expected-error at +1 {{op expects parent op to be a valid loop wrapper}}
+ omp.loop_nest (%iv2) : index = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ omp.yield
+ }
+}
+
+// -----
+
func.func @type_mismatch(%lb : index, %ub : index, %step : index) {
// TODO Remove induction variables from omp.wsloop.
omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) {
>From 904f27489b0d3c27e773f085b18a9b85cb548f44 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Tue, 2 Apr 2024 15:36:41 +0100
Subject: [PATCH 2/2] Address review comments
---
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 4 ++--
.../Dialect/OpenMP/OpenMPOpsInterfaces.td | 19 +++++++++----------
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 ++---
3 files changed, 13 insertions(+), 15 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a7bf93deae2fb3..50627712ea3109 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -568,9 +568,9 @@ def LoopNestOp : OpenMP_Op<"loop_nest", [SameVariadicOperandSize,
/// Returns the induction variables of the loop nest.
ArrayRef<BlockArgument> getIVs() { return getRegion().getArguments(); }
- /// Returns the list of wrapper operations around this loop nest. Wrappers
+ /// Fills a list of wrapper operations around this loop nest. Wrappers
/// in the resulting vector will be sorted from innermost to outermost.
- SmallVector<LoopWrapperInterface> getWrappers();
+ void gatherWrappers(SmallVectorImpl<LoopWrapperInterface> &wrappers);
}];
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index b6a3560b7da56a..ab9b78e755d9d5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -93,18 +93,17 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
if ($_op->getNumRegions() != 1)
return false;
- ::mlir::Region &r = $_op->getRegion(0);
+ Region &r = $_op->getRegion(0);
if (!r.hasOneBlock())
return false;
- if (std::distance(r.op_begin(), r.op_end()) != 2)
+ if (::llvm::range_size(r.getOps()) != 2)
return false;
- ::mlir::Operation &firstOp = *r.op_begin();
- ::mlir::Operation &secondOp = *(++r.op_begin());
- return ::llvm::isa<::mlir::omp::LoopNestOp,
- ::mlir::omp::LoopWrapperInterface>(firstOp) &&
- secondOp.hasTrait<::mlir::OpTrait::IsTerminator>();
+ Operation &firstOp = *r.op_begin();
+ Operation &secondOp = *(std::next(r.op_begin()));
+ return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp) &&
+ secondOp.hasTrait<OpTrait::IsTerminator>();
}]
>,
InterfaceMethod<
@@ -116,8 +115,8 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
/*methodName=*/"getNestedWrapper",
(ins), [{}], [{
assert($_op.isWrapper() && "Unexpected non-wrapper op");
- ::mlir::Operation *nested = &*$_op->getRegion(0).op_begin();
- return ::llvm::dyn_cast<::mlir::omp::LoopWrapperInterface>(nested);
+ Operation *nested = &*$_op->getRegion(0).op_begin();
+ return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
}]
>,
InterfaceMethod<
@@ -129,7 +128,7 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
/*methodName=*/"getWrappedLoop",
(ins), [{}], [{
assert($_op.isWrapper() && "Unexpected non-wrapper op");
- if (::mlir::omp::LoopWrapperInterface nested = $_op.getNestedWrapper())
+ if (LoopWrapperInterface nested = $_op.getNestedWrapper())
return nested.getWrappedLoop();
return &*$_op->getRegion(0).op_begin();
}]
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 564c23201db4fd..a7d265328df6ef 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1739,8 +1739,8 @@ LogicalResult LoopNestOp::verify() {
return success();
}
-SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
- SmallVector<LoopWrapperInterface> wrappers;
+void LoopNestOp::gatherWrappers(
+ SmallVectorImpl<LoopWrapperInterface> &wrappers) {
Operation *parent = (*this)->getParentOp();
while (auto wrapper =
llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
@@ -1749,7 +1749,6 @@ SmallVector<LoopWrapperInterface> LoopNestOp::getWrappers() {
wrappers.push_back(wrapper);
parent = parent->getParentOp();
}
- return wrappers;
}
//===----------------------------------------------------------------------===//
More information about the llvm-branch-commits
mailing list