[Mlir-commits] [mlir] [mlir][IR] Change block/region walkers to enumerate `this` block/region (PR #75020)
Matthias Springer
llvmlistbot at llvm.org
Mon Dec 11 16:53:42 PST 2023
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/75020
>From fc8cbc5248711e371f055de6c7f62dab31b5d5f8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 11 Dec 2023 12:02:24 +0900
Subject: [PATCH] [mlir][IR] Change block/region walkers to enumerate `this`
block/region
This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region.
Example:
```
// Current behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 is NOT enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ });
region1->walk([](Region *region2) { /* region1 is NOT enumerated });
// New behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 IS enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ });
region1->walk([](Region *region2) { /* region1 IS enumerated });
```
---
mlir/include/mlir/IR/Block.h | 111 +++++++++++-------
mlir/include/mlir/IR/Region.h | 82 +++++++------
.../OwnershipBasedBufferDeallocation.cpp | 16 +--
mlir/test/IR/visitors.mlir | 22 +++-
mlir/test/lib/IR/TestVisitors.cpp | 55 +++++++++
5 files changed, 193 insertions(+), 93 deletions(-)
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 3d00c405ead37..e58b87774b865 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -260,68 +260,91 @@ class Block : public IRObjectWithUseList<BlockOperand>,
SuccessorRange getSuccessors() { return SuccessorRange(this); }
//===--------------------------------------------------------------------===//
- // Operation Walkers
+ // Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this block. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
+ /// Walk all nested operations, blocks (including this block) or regions,
+ /// depending on the type of callback.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). A callback on a block or operation is allowed to
- /// erase that block or operation if either:
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
+ typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
- return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
- }
-
- /// Walk the operations in the specified [begin, end) range of this block. The
- /// callback method is called for each nested region, block or operation,
- /// depending on the callback provided. The order in which regions, blocks and
- /// operations at the same nesting level are visited (e.g., lexicographical or
- /// reverse lexicographical order) is determined by 'Iterator'. The walk order
- /// for enclosing regions, blocks and operations with respect to their nested
- /// ones is specified by 'Order' (post-order by default). This method is
- /// invoked for void-returning callbacks. A callback on a block or operation
- /// is allowed to erase that block or operation only if the walk is in
- /// post-order. See non-void method for pre-order erasure.
- /// See Operation::walk for more details.
- template <WalkOrder Order = WalkOrder::PostOrder,
- typename Iterator = ForwardIterator, typename FnT,
- typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, void>::value, RetT>
- walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
- for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- detail::walk<Order, Iterator>(&op, callback);
+ if constexpr (std::is_same<ArgT, Block *>::value &&
+ Order == WalkOrder::PreOrder) {
+ // Pre-order walk on blocks: invoke the callback on this block.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ RetT result = callback(this);
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ callback(this);
+ }
+ }
+
+ // Walk nested operations, blocks or regions.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
+ .wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
+ }
+
+ if constexpr (std::is_same<ArgT, Block *>::value &&
+ Order == WalkOrder::PostOrder) {
+ // Post-order walk on blocks: invoke the callback on this block.
+ return callback(this);
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
- /// Walk the operations in the specified [begin, end) range of this block. The
- /// callback method is called for each nested region, block or operation,
- /// depending on the callback provided. The order in which regions, blocks and
- /// operations at the same nesting level are visited (e.g., lexicographical or
- /// reverse lexicographical order) is determined by 'Iterator'. The walk order
- /// for enclosing regions, blocks and operations with respect to their nested
- /// ones is specified by 'Order' (post-order by default). This method is
- /// invoked for skippable or interruptible callbacks. A callback on a block or
- /// operation is allowed to erase that block or operation if either:
+ /// Walk all nested operations, blocks (excluding this block) or regions,
+ /// depending on the type of callback, in the specified [begin, end) range of
+ /// this block.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
+ /// level are visited (e.g., lexicographical or reverse lexicographical order)
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
- walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
- for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
- return WalkResult::interrupt();
- return WalkResult::advance();
+ RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+ for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ detail::walk<Order, Iterator>(&op, callback);
+ }
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 4f4812dda79b8..b626350d1b657 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -260,48 +260,60 @@ class Region {
void dropAllReferences();
//===--------------------------------------------------------------------===//
- // Operation Walkers
+ // Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this region. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
- /// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). This method is invoked for void-returning
- /// callbacks. A callback on a block or operation is allowed to erase that
- /// block or operation only if the walk is in post-order. See non-void method
- /// for pre-order erasure. See Operation::walk for more details.
- template <WalkOrder Order = WalkOrder::PostOrder,
- typename Iterator = ForwardIterator, typename FnT,
- typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
- for (auto &block : *this)
- block.walk<Order, Iterator>(callback);
- }
-
- /// Walk the operations in this region. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
+ /// Walk all nested operations, blocks or regions (including this region),
+ /// depending on the type of callback.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). This method is invoked for skippable or
- /// interruptible callbacks. A callback on a block or operation is allowed to
- /// erase that block or operation if either:
- /// * the walk is in post-order,
- /// * or the walk is in pre-order and the walk is skipped after the erasure.
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
+ /// * the walk is in post-order, or
+ /// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
+ typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
- walk(FnT &&callback) {
- for (auto &block : *this)
- if (block.walk<Order, Iterator>(callback).wasInterrupted())
- return WalkResult::interrupt();
- return WalkResult::advance();
+ RetT walk(FnT &&callback) {
+ if constexpr (std::is_same<ArgT, Region *>::value &&
+ Order == WalkOrder::PreOrder) {
+ // Pre-order walk on regions: invoke the callback on this region.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ RetT result = callback(this);
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ callback(this);
+ }
+ }
+
+ // Walk nested operations, blocks or regions.
+ for (auto &block : *this) {
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (block.walk<Order, Iterator>(callback).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ block.walk<Order, Iterator>(callback);
+ }
+ }
+
+ if constexpr (std::is_same<ArgT, Region *>::value &&
+ Order == WalkOrder::PostOrder) {
+ // Post-order walk on regions: invoke the callback on this block.
+ return callback(this);
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 38ffae68a43de..9459cc43547fa 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -463,7 +463,7 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
}
static bool regionOperatesOnMemrefValues(Region ®ion) {
- auto checkBlock = [](Block *block) {
+ WalkResult result = region.walk([](Block *block) {
if (llvm::any_of(block->getArguments(), isMemref))
return WalkResult::interrupt();
for (Operation &op : *block) {
@@ -473,18 +473,8 @@ static bool regionOperatesOnMemrefValues(Region ®ion) {
return WalkResult::interrupt();
}
return WalkResult::advance();
- };
- WalkResult result = region.walk(checkBlock);
- if (result.wasInterrupted())
- return true;
-
- // Note: Block::walk/Region::walk visits only blocks that are nested under
- // nested operations, but not direct children.
- for (Block &block : region)
- if (checkBlock(&block).wasInterrupted())
- return true;
-
- return false;
+ });
+ return result.wasInterrupted();
}
LogicalResult
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 2d83d6922e0cd..ec7712a45d388 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -17,7 +17,7 @@ func.func @structured_cfg() {
"use2"(%i) : (index) -> ()
}
"use3"(%i) : (index) -> ()
- }
+ } {walk_blocks, walk_regions}
return
}
@@ -88,6 +88,26 @@ func.func @structured_cfg() {
// CHECK: Visiting op 'func.func'
// CHECK: Visiting op 'builtin.module'
+// CHECK-LABEL: Invoke block pre-order visits on blocks
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke block post-order visits on blocks
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Invoke region pre-order visits on region
+// CHECK: Visiting region 0 from operation 'scf.for'
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke region post-order visits on region
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+// CHECK: Visiting region 0 from operation 'scf.for'
+
// CHECK-LABEL: Op pre-order erasures
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'func.return'
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index a3ef3f3515953..f4cff39cf2e52 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
cloned->erase();
}
+/// Invoke region/block walks on regions/blocks.
+static void testBlockAndRegionWalkers(Operation *op) {
+ auto blockPure = [](Block *block) {
+ llvm::outs() << "Visiting ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ };
+ auto regionPure = [](Region *region) {
+ llvm::outs() << "Visiting ";
+ printRegion(region);
+ llvm::outs() << "\n";
+ };
+
+ llvm::outs() << "Invoke block pre-order visits on blocks\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_blocks"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ block.walk<WalkOrder::PreOrder>(blockPure);
+ }
+ }
+ });
+
+ llvm::outs() << "Invoke block post-order visits on blocks\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_blocks"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ block.walk<WalkOrder::PostOrder>(blockPure);
+ }
+ }
+ });
+
+ llvm::outs() << "Invoke region pre-order visits on region\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_regions"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ region.walk<WalkOrder::PreOrder>(regionPure);
+ }
+ });
+
+ llvm::outs() << "Invoke region post-order visits on region\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_regions"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ region.walk<WalkOrder::PostOrder>(regionPure);
+ }
+ });
+}
+
namespace {
/// This pass exercises the different configurations of the IR visitors.
struct TestIRVisitorsPass
@@ -215,6 +269,7 @@ struct TestIRVisitorsPass
void runOnOperation() override {
Operation *op = getOperation();
testPureCallbacks(op);
+ testBlockAndRegionWalkers(op);
testSkipErasureCallbacks(op);
testNoSkipErasureCallbacks(op);
}
More information about the Mlir-commits
mailing list