[Mlir-commits] [mlir] [mlir][IR] Change block/region walkers to enumerate `this` block/region (PR #75020)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Dec 10 19:06:47 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region.
Example:
```
// Current behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 is NOT enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ });
region1->walk([](Region *region2) { /* region1 is NOT enumerated });
// New behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 IS enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ });
region1->walk([](Region *region2) { /* region1 IS enumerated });
```
Depends on #<!-- -->75016. Only review the top commit.
---
Full diff: https://github.com/llvm/llvm-project/pull/75020.diff
5 Files Affected:
- (modified) mlir/include/mlir/IR/Block.h (+67-44)
- (modified) mlir/include/mlir/IR/Region.h (+47-35)
- (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir (+16-1)
- (modified) mlir/test/IR/visitors.mlir (+21-1)
- (modified) mlir/test/lib/IR/TestVisitors.cpp (+55)
``````````diff
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 3d00c405ead374..e58b87774b8658 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -260,68 +260,91 @@ class Block : public IRObjectWithUseList<BlockOperand>,
SuccessorRange getSuccessors() { return SuccessorRange(this); }
//===--------------------------------------------------------------------===//
- // Operation Walkers
+ // Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this block. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
+ /// Walk all nested operations, blocks (including this block) or regions,
+ /// depending on the type of callback.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). A callback on a block or operation is allowed to
- /// erase that block or operation if either:
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
+ typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
- return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
- }
-
- /// Walk the operations in the specified [begin, end) range of this block. The
- /// callback method is called for each nested region, block or operation,
- /// depending on the callback provided. The order in which regions, blocks and
- /// operations at the same nesting level are visited (e.g., lexicographical or
- /// reverse lexicographical order) is determined by 'Iterator'. The walk order
- /// for enclosing regions, blocks and operations with respect to their nested
- /// ones is specified by 'Order' (post-order by default). This method is
- /// invoked for void-returning callbacks. A callback on a block or operation
- /// is allowed to erase that block or operation only if the walk is in
- /// post-order. See non-void method for pre-order erasure.
- /// See Operation::walk for more details.
- template <WalkOrder Order = WalkOrder::PostOrder,
- typename Iterator = ForwardIterator, typename FnT,
- typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, void>::value, RetT>
- walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
- for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- detail::walk<Order, Iterator>(&op, callback);
+ if constexpr (std::is_same<ArgT, Block *>::value &&
+ Order == WalkOrder::PreOrder) {
+ // Pre-order walk on blocks: invoke the callback on this block.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ RetT result = callback(this);
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ callback(this);
+ }
+ }
+
+ // Walk nested operations, blocks or regions.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
+ .wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
+ }
+
+ if constexpr (std::is_same<ArgT, Block *>::value &&
+ Order == WalkOrder::PostOrder) {
+ // Post-order walk on blocks: invoke the callback on this block.
+ return callback(this);
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
- /// Walk the operations in the specified [begin, end) range of this block. The
- /// callback method is called for each nested region, block or operation,
- /// depending on the callback provided. The order in which regions, blocks and
- /// operations at the same nesting level are visited (e.g., lexicographical or
- /// reverse lexicographical order) is determined by 'Iterator'. The walk order
- /// for enclosing regions, blocks and operations with respect to their nested
- /// ones is specified by 'Order' (post-order by default). This method is
- /// invoked for skippable or interruptible callbacks. A callback on a block or
- /// operation is allowed to erase that block or operation if either:
+ /// Walk all nested operations, blocks (excluding this block) or regions,
+ /// depending on the type of callback, in the specified [begin, end) range of
+ /// this block.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
+ /// level are visited (e.g., lexicographical or reverse lexicographical order)
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
/// * the walk is in post-order, or
/// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
- walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
- for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
- return WalkResult::interrupt();
- return WalkResult::advance();
+ RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+ for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ detail::walk<Order, Iterator>(&op, callback);
+ }
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 4f4812dda79b89..b626350d1b657d 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -260,48 +260,60 @@ class Region {
void dropAllReferences();
//===--------------------------------------------------------------------===//
- // Operation Walkers
+ // Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this region. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
- /// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). This method is invoked for void-returning
- /// callbacks. A callback on a block or operation is allowed to erase that
- /// block or operation only if the walk is in post-order. See non-void method
- /// for pre-order erasure. See Operation::walk for more details.
- template <WalkOrder Order = WalkOrder::PostOrder,
- typename Iterator = ForwardIterator, typename FnT,
- typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
- for (auto &block : *this)
- block.walk<Order, Iterator>(callback);
- }
-
- /// Walk the operations in this region. The callback method is called for each
- /// nested region, block or operation, depending on the callback provided.
- /// The order in which regions, blocks and operations at the same nesting
+ /// Walk all nested operations, blocks or regions (including this region),
+ /// depending on the type of callback.
+ ///
+ /// The order in which operations, blocks or regions at the same nesting
/// level are visited (e.g., lexicographical or reverse lexicographical order)
- /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
- /// and operations with respect to their nested ones is specified by 'Order'
- /// (post-order by default). This method is invoked for skippable or
- /// interruptible callbacks. A callback on a block or operation is allowed to
- /// erase that block or operation if either:
- /// * the walk is in post-order,
- /// * or the walk is in pre-order and the walk is skipped after the erasure.
+ /// is determined by `Iterator`. The walk order for enclosing operations,
+ /// blocks or regions with respect to their nested ones is specified by
+ /// `Order` (post-order by default).
+ ///
+ /// A callback on a operation or block is allowed to erase that operation or
+ /// block if either:
+ /// * the walk is in post-order, or
+ /// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder,
typename Iterator = ForwardIterator, typename FnT,
+ typename ArgT = detail::first_argument<FnT>,
typename RetT = detail::walkResultType<FnT>>
- std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
- walk(FnT &&callback) {
- for (auto &block : *this)
- if (block.walk<Order, Iterator>(callback).wasInterrupted())
- return WalkResult::interrupt();
- return WalkResult::advance();
+ RetT walk(FnT &&callback) {
+ if constexpr (std::is_same<ArgT, Region *>::value &&
+ Order == WalkOrder::PreOrder) {
+ // Pre-order walk on regions: invoke the callback on this region.
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ RetT result = callback(this);
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ callback(this);
+ }
+ }
+
+ // Walk nested operations, blocks or regions.
+ for (auto &block : *this) {
+ if constexpr (std::is_same<RetT, WalkResult>::value) {
+ if (block.walk<Order, Iterator>(callback).wasInterrupted())
+ return WalkResult::interrupt();
+ } else {
+ block.walk<Order, Iterator>(callback);
+ }
+ }
+
+ if constexpr (std::is_same<ArgT, Region *>::value &&
+ Order == WalkOrder::PostOrder) {
+ // Post-order walk on regions: invoke the callback on this block.
+ return callback(this);
+ }
+ if constexpr (std::is_same<RetT, WalkResult>::value)
+ return WalkResult::advance();
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index ad7c4c783e907f..1a8a930bc9002b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -531,8 +531,8 @@ func.func @noRegionBranchOpInterface() {
// This is not allowed in buffer deallocation.
func.func @noRegionBranchOpInterface() {
- // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
%0 = "test.bar"() ({
+ // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
%1 = "test.bar"() ({
%2 = "test.get_memref"() : () -> memref<2xi32>
"test.yield"(%2) : (memref<2xi32>) -> ()
@@ -544,6 +544,21 @@ func.func @noRegionBranchOpInterface() {
// -----
+// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
+// This is not allowed in buffer deallocation.
+
+func.func @noRegionBranchOpInterface() {
+ // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
+ %0 = "test.bar"() ({
+ %2 = "test.get_memref"() : () -> memref<2xi32>
+ %3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
+ "test.yield"(%3) : (i32) -> ()
+ }) : () -> (i32)
+ "test.terminator"() : () -> ()
+}
+
+// -----
+
func.func @while_two_arg(%arg0: index) {
%a = memref.alloc(%arg0) : memref<?xf32>
scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 2d83d6922e0cd0..ec7712a45d3882 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -17,7 +17,7 @@ func.func @structured_cfg() {
"use2"(%i) : (index) -> ()
}
"use3"(%i) : (index) -> ()
- }
+ } {walk_blocks, walk_regions}
return
}
@@ -88,6 +88,26 @@ func.func @structured_cfg() {
// CHECK: Visiting op 'func.func'
// CHECK: Visiting op 'builtin.module'
+// CHECK-LABEL: Invoke block pre-order visits on blocks
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke block post-order visits on blocks
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Invoke region pre-order visits on region
+// CHECK: Visiting region 0 from operation 'scf.for'
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke region post-order visits on region
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+// CHECK: Visiting region 0 from operation 'scf.for'
+
// CHECK-LABEL: Op pre-order erasures
// CHECK: Erasing op 'scf.for'
// CHECK: Erasing op 'func.return'
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index a3ef3f35159534..f4cff39cf2e523 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
cloned->erase();
}
+/// Invoke region/block walks on regions/blocks.
+static void testBlockAndRegionWalkers(Operation *op) {
+ auto blockPure = [](Block *block) {
+ llvm::outs() << "Visiting ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ };
+ auto regionPure = [](Region *region) {
+ llvm::outs() << "Visiting ";
+ printRegion(region);
+ llvm::outs() << "\n";
+ };
+
+ llvm::outs() << "Invoke block pre-order visits on blocks\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_blocks"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ block.walk<WalkOrder::PreOrder>(blockPure);
+ }
+ }
+ });
+
+ llvm::outs() << "Invoke block post-order visits on blocks\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_blocks"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ for (Block &block : region.getBlocks()) {
+ block.walk<WalkOrder::PostOrder>(blockPure);
+ }
+ }
+ });
+
+ llvm::outs() << "Invoke region pre-order visits on region\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_regions"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ region.walk<WalkOrder::PreOrder>(regionPure);
+ }
+ });
+
+ llvm::outs() << "Invoke region post-order visits on region\n";
+ op->walk([&](Operation *op) {
+ if (!op->hasAttr("walk_regions"))
+ return;
+ for (Region ®ion : op->getRegions()) {
+ region.walk<WalkOrder::PostOrder>(regionPure);
+ }
+ });
+}
+
namespace {
/// This pass exercises the different configurations of the IR visitors.
struct TestIRVisitorsPass
@@ -215,6 +269,7 @@ struct TestIRVisitorsPass
void runOnOperation() override {
Operation *op = getOperation();
testPureCallbacks(op);
+ testBlockAndRegionWalkers(op);
testSkipErasureCallbacks(op);
testNoSkipErasureCallbacks(op);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/75020
More information about the Mlir-commits
mailing list