[Mlir-commits] [mlir] df067f1 - [mlir][IR][NFC] Move `walk` definitions to header file
Matthias Springer
llvmlistbot at llvm.org
Mon Mar 6 00:21:52 PST 2023
Author: Matthias Springer
Date: 2023-03-06T09:21:32+01:00
New Revision: df067f13de569979b0d8ad8e9fc91ca06630e58f
URL: https://github.com/llvm/llvm-project/commit/df067f13de569979b0d8ad8e9fc91ca06630e58f
DIFF: https://github.com/llvm/llvm-project/commit/df067f13de569979b0d8ad8e9fc91ca06630e58f.diff
LOG: [mlir][IR][NFC] Move `walk` definitions to header file
This allows users to provide custom `Iterator` templates. A new iterator will be added in a subsequent change.
Also rename `makeRange` to `makeIterable` and add a test case for the reverse iterator.
Differential Revision: https://reviews.llvm.org/D144887
Added:
Modified:
mlir/include/mlir/IR/Visitors.h
mlir/lib/IR/Visitors.cpp
mlir/test/IR/visitors.mlir
mlir/test/lib/IR/TestVisitors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 3fc105510d06d..871c19ff5c9b0 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -64,8 +64,12 @@ enum class WalkOrder { PreOrder, PostOrder };
/// This iterator enumerates the elements in "forward" order.
struct ForwardIterator {
- template <typename RangeT>
- static constexpr RangeT &makeRange(RangeT &range) {
+ /// Make operations iterable: return the list of regions.
+ static MutableArrayRef<Region> makeIterable(Operation &range);
+
+ /// Regions and block are already iterable.
+ template <typename T>
+ static constexpr T &makeIterable(T &range) {
return range;
}
};
@@ -74,9 +78,10 @@ struct ForwardIterator {
/// llvm::reverse.
struct ReverseIterator {
template <typename RangeT>
- static constexpr auto makeRange(RangeT &&range) {
+ static constexpr auto makeIterable(RangeT &&range) {
// llvm::reverse uses RangeT::rbegin and RangeT::rend.
- return llvm::reverse(std::forward<RangeT>(range));
+ return llvm::reverse(
+ ForwardIterator::makeIterable(std::forward<RangeT>(range)));
}
};
@@ -141,12 +146,58 @@ using first_argument = decltype(first_argument_type(std::declval<T>()));
/// pre-order erasure.
template <typename Iterator>
void walk(Operation *op, function_ref<void(Region *)> callback,
- WalkOrder order);
+ WalkOrder order) {
+ // We don't use early increment for regions because they can't be erased from
+ // a callback.
+ for (auto ®ion : Iterator::makeIterable(*op)) {
+ if (order == WalkOrder::PreOrder)
+ callback(®ion);
+ for (auto &block : Iterator::makeIterable(region)) {
+ for (auto &nestedOp : Iterator::makeIterable(block))
+ walk<Iterator>(&nestedOp, callback, order);
+ }
+ if (order == WalkOrder::PostOrder)
+ callback(®ion);
+ }
+}
+
template <typename Iterator>
-void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
+void walk(Operation *op, function_ref<void(Block *)> callback,
+ WalkOrder order) {
+ for (auto ®ion : Iterator::makeIterable(*op)) {
+ // Early increment here in the case where the block is erased.
+ for (auto &block :
+ llvm::make_early_inc_range(Iterator::makeIterable(region))) {
+ if (order == WalkOrder::PreOrder)
+ callback(&block);
+ for (auto &nestedOp : Iterator::makeIterable(block))
+ walk<Iterator>(&nestedOp, callback, order);
+ if (order == WalkOrder::PostOrder)
+ callback(&block);
+ }
+ }
+}
+
template <typename Iterator>
void walk(Operation *op, function_ref<void(Operation *)> callback,
- WalkOrder order);
+ WalkOrder order) {
+ if (order == WalkOrder::PreOrder)
+ callback(op);
+
+ // TODO: This walk should be iterative over the operations.
+ for (auto ®ion : Iterator::makeIterable(*op)) {
+ for (auto &block : Iterator::makeIterable(region)) {
+ // Early increment here in the case where the operation is erased.
+ for (auto &nestedOp :
+ llvm::make_early_inc_range(Iterator::makeIterable(block)))
+ walk<Iterator>(&nestedOp, callback, order);
+ }
+ }
+
+ if (order == WalkOrder::PostOrder)
+ callback(op);
+}
+
/// Walk all of the regions, blocks, or operations nested under (and including)
/// the given operation. The order in which regions, blocks and operations at
/// the same nesting level are visited (e.g., lexicographical or reverse
@@ -159,13 +210,88 @@ void walk(Operation *op, function_ref<void(Operation *)> callback,
/// * the walk is in pre-order and the walk is skipped after the erasure.
template <typename Iterator>
WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
- WalkOrder order);
+ WalkOrder order) {
+ // We don't use early increment for regions because they can't be erased from
+ // a callback.
+ for (auto ®ion : Iterator::makeIterable(*op)) {
+ if (order == WalkOrder::PreOrder) {
+ WalkResult result = callback(®ion);
+ if (result.wasSkipped())
+ continue;
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ }
+ for (auto &block : Iterator::makeIterable(region)) {
+ for (auto &nestedOp : Iterator::makeIterable(block))
+ if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
+ return WalkResult::interrupt();
+ }
+ if (order == WalkOrder::PostOrder) {
+ if (callback(®ion).wasInterrupted())
+ return WalkResult::interrupt();
+ // We don't check if this region was skipped because its walk already
+ // finished and the walk will continue with the next region.
+ }
+ }
+ return WalkResult::advance();
+}
+
template <typename Iterator>
WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
- WalkOrder order);
+ WalkOrder order) {
+ for (auto ®ion : Iterator::makeIterable(*op)) {
+ // Early increment here in the case where the block is erased.
+ for (auto &block :
+ llvm::make_early_inc_range(Iterator::makeIterable(region))) {
+ if (order == WalkOrder::PreOrder) {
+ WalkResult result = callback(&block);
+ if (result.wasSkipped())
+ continue;
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ }
+ for (auto &nestedOp : Iterator::makeIterable(block))
+ if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
+ return WalkResult::interrupt();
+ if (order == WalkOrder::PostOrder) {
+ if (callback(&block).wasInterrupted())
+ return WalkResult::interrupt();
+ // We don't check if this block was skipped because its walk already
+ // finished and the walk will continue with the next block.
+ }
+ }
+ }
+ return WalkResult::advance();
+}
+
template <typename Iterator>
WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
- WalkOrder order);
+ WalkOrder order) {
+ if (order == WalkOrder::PreOrder) {
+ WalkResult result = callback(op);
+ // If skipped, caller will continue the walk on the next operation.
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ }
+
+ // TODO: This walk should be iterative over the operations.
+ for (auto ®ion : Iterator::makeIterable(*op)) {
+ for (auto &block : Iterator::makeIterable(region)) {
+ // Early increment here in the case where the operation is erased.
+ for (auto &nestedOp :
+ llvm::make_early_inc_range(Iterator::makeIterable(block))) {
+ if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
+ return WalkResult::interrupt();
+ }
+ }
+ }
+
+ if (order == WalkOrder::PostOrder)
+ return callback(op);
+ return WalkResult::advance();
+}
// Below are a set of functions to walk nested operations. Users should favor
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index a54eca053dcc2..73235f0aa5244 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -14,92 +14,9 @@ using namespace mlir;
WalkStage::WalkStage(Operation *op)
: numRegions(op->getNumRegions()), nextRegion(0) {}
-/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. The order in which regions, blocks and operations at the
-/// same nesting level are visited (e.g., lexicographical or reverse
-/// lexicographical order) is determined by 'Iterator'. The walk order for
-/// enclosing regions, blocks and operations with respect to their nested ones
-/// is specified by 'order'. These methods are invoked for void-returning
-/// callbacks. A callback on a block or operation is allowed to erase that block
-/// or operation only if the walk is in post-order. See non-void method for
-/// pre-order erasure.
-template <typename Iterator>
-void detail::walk(Operation *op, function_ref<void(Region *)> callback,
- WalkOrder order) {
- // We don't use early increment for regions because they can't be erased from
- // a callback.
- MutableArrayRef<Region> regions = op->getRegions();
- for (auto ®ion : Iterator::makeRange(regions)) {
- if (order == WalkOrder::PreOrder)
- callback(®ion);
- for (auto &block : Iterator::makeRange(region)) {
- for (auto &nestedOp : Iterator::makeRange(block))
- walk<Iterator>(&nestedOp, callback, order);
- }
- if (order == WalkOrder::PostOrder)
- callback(®ion);
- }
-}
-// Explicit template instantiations for all supported iterators.
-template void detail::walk<ForwardIterator>(Operation *,
- function_ref<void(Region *)>,
- WalkOrder);
-template void detail::walk<ReverseIterator>(Operation *,
- function_ref<void(Region *)>,
- WalkOrder);
-
-template <typename Iterator>
-void detail::walk(Operation *op, function_ref<void(Block *)> callback,
- WalkOrder order) {
- MutableArrayRef<Region> regions = op->getRegions();
- for (auto ®ion : Iterator::makeRange(regions)) {
- // Early increment here in the case where the block is erased.
- for (auto &block :
- llvm::make_early_inc_range(Iterator::makeRange(region))) {
- if (order == WalkOrder::PreOrder)
- callback(&block);
- for (auto &nestedOp : Iterator::makeRange(block))
- walk<Iterator>(&nestedOp, callback, order);
- if (order == WalkOrder::PostOrder)
- callback(&block);
- }
- }
-}
-// Explicit template instantiations for all supported iterators.
-template void detail::walk<ForwardIterator>(Operation *,
- function_ref<void(Block *)>,
- WalkOrder);
-template void detail::walk<ReverseIterator>(Operation *,
- function_ref<void(Block *)>,
- WalkOrder);
-
-template <typename Iterator>
-void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
- WalkOrder order) {
- if (order == WalkOrder::PreOrder)
- callback(op);
-
- // TODO: This walk should be iterative over the operations.
- MutableArrayRef<Region> regions = op->getRegions();
- for (auto ®ion : Iterator::makeRange(regions)) {
- for (auto &block : Iterator::makeRange(region)) {
- // Early increment here in the case where the operation is erased.
- for (auto &nestedOp :
- llvm::make_early_inc_range(Iterator::makeRange(block)))
- walk<Iterator>(&nestedOp, callback, order);
- }
- }
-
- if (order == WalkOrder::PostOrder)
- callback(op);
+MutableArrayRef<Region> ForwardIterator::makeIterable(Operation &range) {
+ return range.getRegions();
}
-// Explicit template instantiations for all supported iterators.
-template void detail::walk<ForwardIterator>(Operation *,
- function_ref<void(Operation *)>,
- WalkOrder);
-template void detail::walk<ReverseIterator>(Operation *,
- function_ref<void(Operation *)>,
- WalkOrder);
void detail::walk(Operation *op,
function_ref<void(Operation *, const WalkStage &)> callback) {
@@ -120,128 +37,6 @@ void detail::walk(Operation *op,
callback(op, stage);
}
-/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. These functions walk operations until an interrupt result
-/// is returned by the callback. Walks on regions, blocks and operations may
-/// also be skipped if the callback returns a skip result. Regions, blocks and
-/// operations at the same nesting level are visited in lexicographical order.
-/// The walk order for enclosing regions, blocks and operations with respect to
-/// their nested ones is specified by 'order'. A callback on a block or
-/// operation is allowed to erase that block or operation if either:
-/// * the walk is in post-order, or
-/// * the walk is in pre-order and the walk is skipped after the erasure.
-template <typename Iterator>
-WalkResult detail::walk(Operation *op,
- function_ref<WalkResult(Region *)> callback,
- WalkOrder order) {
- // We don't use early increment for regions because they can't be erased from
- // a callback.
- MutableArrayRef<Region> regions = op->getRegions();
- for (auto ®ion : Iterator::makeRange(regions)) {
- if (order == WalkOrder::PreOrder) {
- WalkResult result = callback(®ion);
- if (result.wasSkipped())
- continue;
- if (result.wasInterrupted())
- return WalkResult::interrupt();
- }
- for (auto &block : Iterator::makeRange(region)) {
- for (auto &nestedOp : Iterator::makeRange(block))
- if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
- return WalkResult::interrupt();
- }
- if (order == WalkOrder::PostOrder) {
- if (callback(®ion).wasInterrupted())
- return WalkResult::interrupt();
- // We don't check if this region was skipped because its walk already
- // finished and the walk will continue with the next region.
- }
- }
- return WalkResult::advance();
-}
-// Explicit template instantiations for all supported iterators.
-template WalkResult
-detail::walk<ForwardIterator>(Operation *, function_ref<WalkResult(Region *)>,
- WalkOrder);
-template WalkResult
-detail::walk<ReverseIterator>(Operation *, function_ref<WalkResult(Region *)>,
- WalkOrder);
-
-template <typename Iterator>
-WalkResult detail::walk(Operation *op,
- function_ref<WalkResult(Block *)> callback,
- WalkOrder order) {
- MutableArrayRef<Region> regions = op->getRegions();
- for (auto ®ion : Iterator::makeRange(regions)) {
- // Early increment here in the case where the block is erased.
- for (auto &block :
- llvm::make_early_inc_range(Iterator::makeRange(region))) {
- if (order == WalkOrder::PreOrder) {
- WalkResult result = callback(&block);
- if (result.wasSkipped())
- continue;
- if (result.wasInterrupted())
- return WalkResult::interrupt();
- }
- for (auto &nestedOp : Iterator::makeRange(block))
- if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
- return WalkResult::interrupt();
- if (order == WalkOrder::PostOrder) {
- if (callback(&block).wasInterrupted())
- return WalkResult::interrupt();
- // We don't check if this block was skipped because its walk already
- // finished and the walk will continue with the next block.
- }
- }
- }
- return WalkResult::advance();
-}
-// Explicit template instantiations for all supported iterators.
-template WalkResult
-detail::walk<ForwardIterator>(Operation *, function_ref<WalkResult(Block *)>,
- WalkOrder);
-template WalkResult
-detail::walk<ReverseIterator>(Operation *, function_ref<WalkResult(Block *)>,
- WalkOrder);
-
-template <typename Iterator>
-WalkResult detail::walk(Operation *op,
- function_ref<WalkResult(Operation *)> callback,
- WalkOrder order) {
- if (order == WalkOrder::PreOrder) {
- WalkResult result = callback(op);
- // If skipped, caller will continue the walk on the next operation.
- if (result.wasSkipped())
- return WalkResult::advance();
- if (result.wasInterrupted())
- return WalkResult::interrupt();
- }
-
- // TODO: This walk should be iterative over the operations.
- MutableArrayRef<Region> regions = op->getRegions();
- for (auto ®ion : Iterator::makeRange(regions)) {
- for (auto &block : Iterator::makeRange(region)) {
- // Early increment here in the case where the operation is erased.
- for (auto &nestedOp :
- llvm::make_early_inc_range(Iterator::makeRange(block))) {
- if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
- return WalkResult::interrupt();
- }
- }
- }
-
- if (order == WalkOrder::PostOrder)
- return callback(op);
- return WalkResult::advance();
-}
-// Explicit template instantiations for all supported iterators.
-template WalkResult
-detail::walk<ForwardIterator>(Operation *,
- function_ref<WalkResult(Operation *)>, WalkOrder);
-template WalkResult
-detail::walk<ReverseIterator>(Operation *,
- function_ref<WalkResult(Operation *)>, WalkOrder);
-
WalkResult detail::walk(
Operation *op,
function_ref<WalkResult(Operation *, const WalkStage &)> callback) {
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 3a78a0c384045..0c9fabeee5377 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -71,6 +71,23 @@ func.func @structured_cfg() {
// CHECK: Visiting region 0 from operation 'func.func'
// CHECK: Visiting region 0 from operation 'builtin.module'
+// CHECK-LABEL: Op reverse post-order visits
+// CHECK: Visiting op 'func.return'
+// CHECK: Visiting op 'scf.yield'
+// CHECK: Visiting op 'use3'
+// CHECK: Visiting op 'scf.yield'
+// CHECK: Visiting op 'use2'
+// CHECK: Visiting op 'scf.yield'
+// CHECK: Visiting op 'use1'
+// CHECK: Visiting op 'scf.if'
+// CHECK: Visiting op 'use0'
+// CHECK: Visiting op 'scf.for'
+// CHECK: Visiting op 'arith.constant'
+// CHECK: Visiting op 'arith.constant'
+// CHECK: Visiting op 'arith.constant'
+// CHECK: Visiting op 'func.func'
+// CHECK: Visiting op 'builtin.module'
+
// CHECK-LABEL: Op pre-order erasures
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'func.return'
@@ -172,6 +189,29 @@ func.func @unstructured_cfg() {
// CHECK: Visiting region 0 from operation 'func.func'
// CHECK: Visiting region 0 from operation 'builtin.module'
+// CHECK-LABEL: Op reverse post-order visits
+// CHECK: Visiting op 'func.return'
+// CHECK: Visiting op 'op2'
+// CHECK: Visiting op 'cf.br'
+// CHECK: Visiting op 'op1'
+// CHECK: Visiting op 'cf.br'
+// CHECK: Visiting op 'op0'
+// CHECK: Visiting op 'regionOp0'
+// CHECK: Visiting op 'func.func'
+// CHECK: Visiting op 'builtin.module'
+
+// CHECK-LABEL: Block reverse post-order visits
+// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'func.func'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'builtin.module'
+
+// CHECK-LABEL: Region reverse post-order visits
+// CHECK: Visiting region 0 from operation 'regionOp0'
+// CHECK: Visiting region 0 from operation 'func.func'
+// CHECK: Visiting region 0 from operation 'builtin.module'
+
// CHECK-LABEL: Op pre-order erasures (skip)
// CHECK: Erasing op 'regionOp0'
// CHECK: Erasing op 'func.return'
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index f63576e548c96..3211d10bab9fa 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -64,6 +64,16 @@ static void testPureCallbacks(Operation *op) {
llvm::outs() << "Region post-order visits"
<< "\n";
op->walk<WalkOrder::PostOrder>(regionPure);
+
+ llvm::outs() << "Op reverse post-order visits"
+ << "\n";
+ op->walk<WalkOrder::PostOrder, ReverseIterator>(opPure);
+ llvm::outs() << "Block reverse post-order visits"
+ << "\n";
+ op->walk<WalkOrder::PostOrder, ReverseIterator>(blockPure);
+ llvm::outs() << "Region reverse post-order visits"
+ << "\n";
+ op->walk<WalkOrder::PostOrder, ReverseIterator>(regionPure);
}
/// Tests erasure callbacks that skip the walk.
More information about the Mlir-commits
mailing list