[Mlir-commits] [mlir] 71a8624 - [mlir] Extend Operation visitor with pre-order traversal
Diego Caballero
llvmlistbot at llvm.org
Fri Mar 5 14:06:12 PST 2021
Author: Diego Caballero
Date: 2021-03-06T00:02:20+02:00
New Revision: 71a86245ca620d26ed63a86dda3b65a533f5df6b
URL: https://github.com/llvm/llvm-project/commit/71a86245ca620d26ed63a86dda3b65a533f5df6b
DIFF: https://github.com/llvm/llvm-project/commit/71a86245ca620d26ed63a86dda3b65a533f5df6b.diff
LOG: [mlir] Extend Operation visitor with pre-order traversal
This patch extends the Region, Block and Operation visitors to also support pre-order walks.
We introduce a new template argument that dictates the walk order (only pre-order and
post-order are supported for now). The default order for Regions, Blocks and Operations is
post-order. Mixed orders (e.g., Region/Block pre-order + Operation post-order) could easily
be implemented, as shown in NumberOfExecutions.cpp.
Reviewed By: rriddle, frgossen, bondhugula
Differential Revision: https://reviews.llvm.org/D97217
Added:
Modified:
mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/Analysis/Liveness.cpp
mlir/lib/Analysis/NumberOfExecutions.cpp
mlir/lib/IR/Visitors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index a04063814423..9f265b3b56f5 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -249,34 +249,40 @@ class Block : public IRObjectWithUseList<BlockOperand>,
// Operation Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this block in postorder, calling the callback for
- /// each operation.
+ /// Walk the operations in this block, calling the callback for each
+ /// operation. The walk order for regions, blocks and operations is specified
+ /// by 'Order' (post-order by default).
/// See Operation::walk for more details.
- template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+ typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
- return walk(begin(), end(), std::forward<FnT>(callback));
+ return walk<Order>(begin(), end(), std::forward<FnT>(callback));
}
- /// Walk the operations in the specified [begin, end) range of this block in
- /// postorder, calling the callback for each operation. This method is invoked
- /// for void return callbacks.
+ /// Walk the operations in the specified [begin, end) range of this block,
+ /// calling the callback for each operation. The walk order for regions,
+ /// blocks and operations is specified by 'Order' (post-order by default).
+ /// This method is invoked for void return callbacks.
/// See Operation::walk for more details.
- template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+ typename RetT = detail::walkResultType<FnT>>
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- detail::walk(&op, callback);
+ detail::walk<Order>(&op, callback);
}
- /// Walk the operations in the specified [begin, end) range of this block in
- /// postorder, calling the callback for each operation. This method is invoked
- /// for interruptible callbacks.
+ /// Walk the operations in the specified [begin, end) range of this block,
+ /// calling the callback for each operation. The walk order for regions,
+ /// blocks and operations is specified by 'Order' (post-order by default).
+ /// This method is invoked for interruptible callbacks.
/// See Operation::walk for more details.
- template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+ typename RetT = detail::walkResultType<FnT>>
typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- if (detail::walk(&op, callback).wasInterrupted())
+ if (detail::walk<Order>(&op, callback).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
}
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c65e653f2a50..4c9399f2a6fc 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -165,12 +165,14 @@ class OpState {
/// handlers that may be listening.
InFlightDiagnostic emitRemark(const Twine &message = {});
- /// Walk the operation in postorder, calling the callback for each nested
- /// operation(including this one).
+ /// Walk the operation by calling the callback for each nested
+ /// operation(including this one). The walk order for regions, blocks and
+ /// operations is specified by 'Order' (post-order by default).
/// See Operation::walk for more details.
- template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+ typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
- return state->walk(std::forward<FnT>(callback));
+ return state->walk<Order>(std::forward<FnT>(callback));
}
// These are default implementations of customization hooks.
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 027c878929f4..679071457e52 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -484,9 +484,10 @@ class alignas(8) Operation final
// Operation Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operation in postorder, calling the callback for each nested
- /// operation(including this one). The callback method can take any of the
- /// following forms:
+ /// Walk the operation by calling the callback for each nested operation
+ /// (including this one). The walk order for regions, blocks and operations is
+ /// specified by 'Order' (post-order by default). The callback method can take
+ /// any of the following forms:
/// void(Operation*) : Walk all operations opaquely.
/// * op->walk([](Operation *nestedOp) { ...});
/// void(OpT) : Walk all operations of the given derived type.
@@ -499,9 +500,10 @@ class alignas(8) Operation final
/// return WalkResult::interrupt();
/// return WalkResult::advance();
/// });
- template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+ typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
- return detail::walk(this, std::forward<FnT>(callback));
+ return detail::walk<Order>(this, std::forward<FnT>(callback));
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 9b6f54fec936..349be0f5aedc 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -243,23 +243,29 @@ class Region {
//===--------------------------------------------------------------------===//
/// Walk the operations in this region in postorder, calling the callback for
- /// each operation. This method is invoked for void-returning callbacks.
+ /// each operation. The walk order for regions, blocks and operations is
+ /// specified by 'Order' (post-order by default). This method is invoked for
+ /// void-returning callbacks.
/// See Operation::walk for more details.
- template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+ typename RetT = detail::walkResultType<FnT>>
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
walk(FnT &&callback) {
for (auto &block : *this)
- block.walk(callback);
+ block.walk<Order>(callback);
}
/// Walk the operations in this region in postorder, calling the callback for
- /// each operation. This method is invoked for interruptible callbacks.
+ /// each operation. The walk order for regions, blocks and operations is
+ /// specified by 'Order' (post-order by default). This method is invoked for
+ /// interruptible callbacks.
/// See Operation::walk for more details.
- template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+ typename RetT = detail::walkResultType<FnT>>
typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
walk(FnT &&callback) {
for (auto &block : *this)
- if (block.walk(callback).wasInterrupted())
+ if (block.walk<Order>(callback).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
}
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index abb7f0f50e15..b4571e19f8fd 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -49,6 +49,9 @@ class WalkResult {
bool wasInterrupted() const { return result == Interrupt; }
};
+/// Traversal order for region, block and operation walk utilities.
+enum class WalkOrder { PreOrder, PostOrder };
+
namespace detail {
/// Helper templates to deduce the first argument of a callback parameter.
template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
@@ -64,17 +67,21 @@ template <typename T>
using first_argument = decltype(first_argument_type(std::declval<T>()));
/// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation.
-void walk(Operation *op, function_ref<void(Region *)> callback);
-void walk(Operation *op, function_ref<void(Block *)> callback);
-void walk(Operation *op, function_ref<void(Operation *)> callback);
-
+/// the given operation. The walk order is specified by 'order'.
+void walk(Operation *op, function_ref<void(Region *)> callback,
+ WalkOrder order);
+void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
+void walk(Operation *op, function_ref<void(Operation *)> callback,
+ WalkOrder order);
/// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. These functions walk until an interrupt result is
-/// returned by the callback.
-WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
-WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
-WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
+/// the given operation. The walk order is specified by 'order'. These functions
+/// walk until an interrupt result is returned by the callback.
+WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
+ WalkOrder order);
+WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
+ WalkOrder order);
+WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
+ WalkOrder order);
// Below are a set of functions to walk nested operations. Users should favor
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
@@ -82,7 +89,8 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
// upon the type of the callback function.
/// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. This method is selected for callbacks that operate on
+/// the given operation. The walk order is specified by 'Order' (post-order
+/// by default). This method is selected for callbacks that operate on
/// Region*, Block*, and Operation*.
///
/// Example:
@@ -90,22 +98,25 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
/// op->walk([](Block *b) { ... });
/// op->walk([](Operation *op) { ... });
template <
- typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+ WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
+ typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<
llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, RetT>::type
walk(Operation *op, FuncTy &&callback) {
- return walk(op, function_ref<RetT(ArgT)>(callback));
+ return detail::walk(op, function_ref<RetT(ArgT)>(callback), Order);
}
/// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. This method is selected for void returning callbacks that
-/// operate on a specific derived operation type.
+/// given operation. The walk order for regions, blocks and operations is
+/// specified by 'Order' (post-order by default). This method is selected for
+/// void returning callbacks that operate on a specific derived operation type.
///
/// Example:
/// op->walk([](ReturnOp op) { ... });
template <
- typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+ WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
+ typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<
!llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
@@ -116,12 +127,14 @@ walk(Operation *op, FuncTy &&callback) {
if (auto derivedOp = dyn_cast<ArgT>(op))
callback(derivedOp);
};
- return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
+ return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
}
/// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. This method is selected for WalkReturn returning
-/// interruptible callbacks that operate on a specific derived operation type.
+/// given operation. The walk order for regions, blocks and operations is
+/// specified by 'Order' (post-order by default). This method is selected for
+/// WalkReturn returning interruptible callbacks that operate on a specific
+/// derived operation type.
///
/// Example:
/// op->walk([](ReturnOp op) {
@@ -130,7 +143,8 @@ walk(Operation *op, FuncTy &&callback) {
/// return WalkResult::advance();
/// });
template <
- typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+ WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
+ typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
typename std::enable_if<
!llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
@@ -142,7 +156,7 @@ walk(Operation *op, FuncTy &&callback) {
return callback(derivedOp);
return WalkResult::advance();
};
- return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
+ return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
}
/// Utility to provide the return type of a templated walk method.
diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp
index 4dae386e94b2..c9558a7d502c 100644
--- a/mlir/lib/Analysis/Liveness.cpp
+++ b/mlir/lib/Analysis/Liveness.cpp
@@ -130,7 +130,7 @@ static void buildBlockMapping(Operation *operation,
DenseMap<Block *, BlockInfoBuilder> &builders) {
llvm::SetVector<Block *> toProcess;
- operation->walk([&](Block *block) {
+ operation->walk<WalkOrder::PreOrder>([&](Block *block) {
BlockInfoBuilder &builder =
builders.try_emplace(block, block).first->second;
@@ -270,7 +270,7 @@ void Liveness::print(raw_ostream &os) const {
DenseMap<Block *, size_t> blockIds;
DenseMap<Operation *, size_t> operationIds;
DenseMap<Value, size_t> valueIds;
- operation->walk([&](Block *block) {
+ operation->walk<WalkOrder::PreOrder>([&](Block *block) {
blockIds.insert({block, blockIds.size()});
for (BlockArgument argument : block->getArguments())
valueIds.insert({argument, valueIds.size()});
@@ -304,7 +304,7 @@ void Liveness::print(raw_ostream &os) const {
};
// Dump information about in and out values.
- operation->walk([&](Block *block) {
+ operation->walk<WalkOrder::PreOrder>([&](Block *block) {
os << "// - Block: " << blockIds[block] << "\n";
const auto *liveness = getLiveness(block);
os << "// --- LiveIn: ";
diff --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp
index 425936b5eaaf..26f31c0913d1 100644
--- a/mlir/lib/Analysis/NumberOfExecutions.cpp
+++ b/mlir/lib/Analysis/NumberOfExecutions.cpp
@@ -115,7 +115,7 @@ static void computeRegionBlockNumberOfExecutions(
/// Creates a new NumberOfExecutions analysis that computes how many times a
/// block within a region is executed for all associated regions.
NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) {
- operation->walk([&](Region *region) {
+ operation->walk<WalkOrder::PreOrder>([&](Region *region) {
computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution);
});
}
@@ -191,7 +191,7 @@ void NumberOfExecutions::printBlockExecutions(
raw_ostream &os, Region *perEntryOfThisRegion) const {
unsigned blockId = 0;
- operation->walk([&](Block *block) {
+ operation->walk<WalkOrder::PreOrder>([&](Block *block) {
llvm::errs() << "Block: " << blockId++ << "\n";
llvm::errs() << "Number of executions: ";
if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion))
@@ -203,7 +203,7 @@ void NumberOfExecutions::printBlockExecutions(
void NumberOfExecutions::printOperationExecutions(
raw_ostream &os, Region *perEntryOfThisRegion) const {
- operation->walk([&](Block *block) {
+ operation->walk<WalkOrder::PreOrder>([&](Block *block) {
block->walk([&](Operation *operation) {
// Skip the operation that was used to build the analysis.
if (operation == this->operation)
diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index d03bdb508d37..be995a2a4fb2 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -12,79 +12,112 @@
using namespace mlir;
/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation.
-void detail::walk(Operation *op, function_ref<void(Region *)> callback) {
+/// given operation. The walk order is specified by 'Order'.
+
+void detail::walk(Operation *op, function_ref<void(Region *)> callback,
+ WalkOrder order) {
for (auto ®ion : op->getRegions()) {
- callback(®ion);
+ if (order == WalkOrder::PreOrder)
+ callback(®ion);
for (auto &block : region) {
for (auto &nestedOp : block)
- walk(&nestedOp, callback);
+ walk(&nestedOp, callback, order);
}
+ if (order == WalkOrder::PostOrder)
+ callback(®ion);
}
}
-void detail::walk(Operation *op, function_ref<void(Block *)> callback) {
+void detail::walk(Operation *op, function_ref<void(Block *)> callback,
+ WalkOrder order) {
for (auto ®ion : op->getRegions()) {
for (auto &block : region) {
- callback(&block);
+ if (order == WalkOrder::PreOrder)
+ callback(&block);
for (auto &nestedOp : block)
- walk(&nestedOp, callback);
+ walk(&nestedOp, callback, order);
+ if (order == WalkOrder::PostOrder)
+ callback(&block);
}
}
}
-void detail::walk(Operation *op, function_ref<void(Operation *op)> callback) {
+void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
+ WalkOrder order) {
+ if (order == WalkOrder::PreOrder)
+ callback(op);
+
// TODO: This walk should be iterative over the operations.
for (auto ®ion : op->getRegions()) {
for (auto &block : region) {
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block))
- walk(&nestedOp, callback);
+ walk(&nestedOp, callback, order);
}
}
- callback(op);
+
+ if (order == WalkOrder::PostOrder)
+ callback(op);
}
/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. These functions walk operations until an interrupt result
-/// is returned by the callback.
+/// given operation. The walk order is specified by 'order'. These functions
+/// walk operations until an interrupt result is returned by the callback.
WalkResult detail::walk(Operation *op,
- function_ref<WalkResult(Region *op)> callback) {
+ function_ref<WalkResult(Region *)> callback,
+ WalkOrder order) {
for (auto ®ion : op->getRegions()) {
- if (callback(®ion).wasInterrupted())
- return WalkResult::interrupt();
+ if (order == WalkOrder::PreOrder)
+ if (callback(®ion).wasInterrupted())
+ return WalkResult::interrupt();
for (auto &block : region) {
for (auto &nestedOp : block)
- walk(&nestedOp, callback);
+ walk(&nestedOp, callback, order);
}
+ if (order == WalkOrder::PostOrder)
+ if (callback(®ion).wasInterrupted())
+ return WalkResult::interrupt();
}
return WalkResult::advance();
}
WalkResult detail::walk(Operation *op,
- function_ref<WalkResult(Block *op)> callback) {
+ function_ref<WalkResult(Block *)> callback,
+ WalkOrder order) {
for (auto ®ion : op->getRegions()) {
for (auto &block : region) {
- if (callback(&block).wasInterrupted())
- return WalkResult::interrupt();
+ if (order == WalkOrder::PreOrder)
+ if (callback(&block).wasInterrupted())
+ return WalkResult::interrupt();
for (auto &nestedOp : block)
- walk(&nestedOp, callback);
+ walk(&nestedOp, callback, order);
+ if (order == WalkOrder::PostOrder)
+ if (callback(&block).wasInterrupted())
+ return WalkResult::interrupt();
}
}
return WalkResult::advance();
}
WalkResult detail::walk(Operation *op,
- function_ref<WalkResult(Operation *op)> callback) {
+ function_ref<WalkResult(Operation *)> callback,
+ WalkOrder order) {
+ if (order == WalkOrder::PreOrder)
+ if (callback(op).wasInterrupted())
+ return WalkResult::interrupt();
+
// TODO: This walk should be iterative over the operations.
for (auto ®ion : op->getRegions()) {
for (auto &block : region) {
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block)) {
- if (walk(&nestedOp, callback).wasInterrupted())
+ if (walk(&nestedOp, callback, order).wasInterrupted())
return WalkResult::interrupt();
}
}
}
- return callback(op);
+
+ if (order == WalkOrder::PostOrder)
+ return callback(op);
+ return WalkResult::advance();
}
More information about the Mlir-commits
mailing list