[Mlir-commits] [mlir] [mlir][IR] Change block/region walkers to enumerate `this` block/region (PR #75020)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Dec 10 19:06:47 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This change makes block/region walkers consistent with operation walkers. An operation walk enumerates the current operation. Similarly, block/region walks should enumerate the current block/region.

Example:
```
// Current behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 is NOT enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 are NOT enumerated */ });
region1->walk([](Region *region2) { /* region1 is NOT enumerated });

// New behavior:
op1->walk([](Operation *op2) { /* op1 is enumerated */ });
block1->walk([](Block *block2) { /* block1 IS enumerated */ });
region1->walk([](Block *block) { /* blocks of region1 ARE enumerated */ });
region1->walk([](Region *region2) { /* region1 IS enumerated });
```

Depends on #<!-- -->75016. Only review the top commit.


---
Full diff: https://github.com/llvm/llvm-project/pull/75020.diff


5 Files Affected:

- (modified) mlir/include/mlir/IR/Block.h (+67-44) 
- (modified) mlir/include/mlir/IR/Region.h (+47-35) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir (+16-1) 
- (modified) mlir/test/IR/visitors.mlir (+21-1) 
- (modified) mlir/test/lib/IR/TestVisitors.cpp (+55) 


``````````diff
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 3d00c405ead374..e58b87774b8658 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -260,68 +260,91 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   SuccessorRange getSuccessors() { return SuccessorRange(this); }
 
   //===--------------------------------------------------------------------===//
-  // Operation Walkers
+  // Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operations in this block. The callback method is called for each
-  /// nested region, block or operation, depending on the callback provided.
-  /// The order in which regions, blocks and operations at the same nesting
+  /// Walk all nested operations, blocks (including this block) or regions,
+  /// depending on the type of callback.
+  ///
+  /// The order in which operations, blocks or regions at the same nesting
   /// level are visited (e.g., lexicographical or reverse lexicographical order)
-  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
-  /// and operations with respect to their nested ones is specified by 'Order'
-  /// (post-order by default). A callback on a block or operation is allowed to
-  /// erase that block or operation if either:
+  /// is determined by `Iterator`. The walk order for enclosing operations,
+  /// blocks or regions with respect to their nested ones is specified by
+  /// `Order` (post-order by default).
+  ///
+  /// A callback on a operation or block is allowed to erase that operation or
+  /// block if either:
   ///   * the walk is in post-order, or
   ///   * the walk is in pre-order and the walk is skipped after the erasure.
+  ///
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder,
             typename Iterator = ForwardIterator, typename FnT,
+            typename ArgT = detail::first_argument<FnT>,
             typename RetT = detail::walkResultType<FnT>>
   RetT walk(FnT &&callback) {
-    return walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
-  }
-
-  /// Walk the operations in the specified [begin, end) range of this block. The
-  /// callback method is called for each nested region, block or operation,
-  /// depending on the callback provided. The order in which regions, blocks and
-  /// operations at the same nesting level are visited (e.g., lexicographical or
-  /// reverse lexicographical order) is determined by 'Iterator'. The walk order
-  /// for enclosing regions, blocks and operations with respect to their nested
-  /// ones is specified by 'Order' (post-order by default). This method is
-  /// invoked for void-returning callbacks. A callback on a block or operation
-  /// is allowed to erase that block or operation only if the walk is in
-  /// post-order. See non-void method for pre-order erasure.
-  /// See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder,
-            typename Iterator = ForwardIterator, typename FnT,
-            typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, void>::value, RetT>
-  walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
-    for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      detail::walk<Order, Iterator>(&op, callback);
+    if constexpr (std::is_same<ArgT, Block *>::value &&
+                  Order == WalkOrder::PreOrder) {
+      // Pre-order walk on blocks: invoke the callback on this block.
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        RetT result = callback(this);
+        if (result.wasSkipped())
+          return WalkResult::advance();
+        if (result.wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        callback(this);
+      }
+    }
+
+    // Walk nested operations, blocks or regions.
+    if constexpr (std::is_same<RetT, WalkResult>::value) {
+      if (walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback))
+              .wasInterrupted())
+        return WalkResult::interrupt();
+    } else {
+      walk<Order, Iterator>(begin(), end(), std::forward<FnT>(callback));
+    }
+
+    if constexpr (std::is_same<ArgT, Block *>::value &&
+                  Order == WalkOrder::PostOrder) {
+      // Post-order walk on blocks: invoke the callback on this block.
+      return callback(this);
+    }
+    if constexpr (std::is_same<RetT, WalkResult>::value)
+      return WalkResult::advance();
   }
 
-  /// Walk the operations in the specified [begin, end) range of this block. The
-  /// callback method is called for each nested region, block or operation,
-  /// depending on the callback provided. The order in which regions, blocks and
-  /// operations at the same nesting level are visited (e.g., lexicographical or
-  /// reverse lexicographical order) is determined by 'Iterator'. The walk order
-  /// for enclosing regions, blocks and operations with respect to their nested
-  /// ones is specified by 'Order' (post-order by default). This method is
-  /// invoked for skippable or interruptible callbacks. A callback on a block or
-  /// operation is allowed to erase that block or operation if either:
+  /// Walk all nested operations, blocks (excluding this block) or regions,
+  /// depending on the type of callback, in the specified [begin, end) range of
+  /// this block.
+  ///
+  /// The order in which operations, blocks or regions at the same nesting
+  /// level are visited (e.g., lexicographical or reverse lexicographical order)
+  /// is determined by `Iterator`. The walk order for enclosing operations,
+  /// blocks or regions with respect to their nested ones is specified by
+  /// `Order` (post-order by default).
+  ///
+  /// A callback on a operation or block is allowed to erase that operation or
+  /// block if either:
   ///   * the walk is in post-order, or
   ///   * the walk is in pre-order and the walk is skipped after the erasure.
+  ///
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder,
             typename Iterator = ForwardIterator, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
-  walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
-    for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
-        return WalkResult::interrupt();
-    return WalkResult::advance();
+  RetT walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
+    for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end))) {
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        if (detail::walk<Order, Iterator>(&op, callback).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        detail::walk<Order, Iterator>(&op, callback);
+      }
+    }
+    if constexpr (std::is_same<RetT, WalkResult>::value)
+      return WalkResult::advance();
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 4f4812dda79b89..b626350d1b657d 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -260,48 +260,60 @@ class Region {
   void dropAllReferences();
 
   //===--------------------------------------------------------------------===//
-  // Operation Walkers
+  // Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operations in this region. The callback method is called for each
-  /// nested region, block or operation, depending on the callback provided.
-  /// The order in which regions, blocks and operations at the same nesting
-  /// level are visited (e.g., lexicographical or reverse lexicographical order)
-  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
-  /// and operations with respect to their nested ones is specified by 'Order'
-  /// (post-order by default). This method is invoked for void-returning
-  /// callbacks. A callback on a block or operation is allowed to erase that
-  /// block or operation only if the walk is in post-order. See non-void method
-  /// for pre-order erasure. See Operation::walk for more details.
-  template <WalkOrder Order = WalkOrder::PostOrder,
-            typename Iterator = ForwardIterator, typename FnT,
-            typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, void>::value, RetT> walk(FnT &&callback) {
-    for (auto &block : *this)
-      block.walk<Order, Iterator>(callback);
-  }
-
-  /// Walk the operations in this region. The callback method is called for each
-  /// nested region, block or operation, depending on the callback provided.
-  /// The order in which regions, blocks and operations at the same nesting
+  /// Walk all nested operations, blocks or regions (including this region),
+  /// depending on the type of callback.
+  ///
+  /// The order in which operations, blocks or regions at the same nesting
   /// level are visited (e.g., lexicographical or reverse lexicographical order)
-  /// is determined by 'Iterator'. The walk order for enclosing regions, blocks
-  /// and operations with respect to their nested ones is specified by 'Order'
-  /// (post-order by default). This method is invoked for skippable or
-  /// interruptible callbacks. A callback on a block or operation is allowed to
-  /// erase that block or operation if either:
-  ///   * the walk is in post-order,
-  ///   * or the walk is in pre-order and the walk is skipped after the erasure.
+  /// is determined by `Iterator`. The walk order for enclosing operations,
+  /// blocks or regions with respect to their nested ones is specified by
+  /// `Order` (post-order by default).
+  ///
+  /// A callback on a operation or block is allowed to erase that operation or
+  /// block if either:
+  ///   * the walk is in post-order, or
+  ///   * the walk is in pre-order and the walk is skipped after the erasure.
+  ///
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder,
             typename Iterator = ForwardIterator, typename FnT,
+            typename ArgT = detail::first_argument<FnT>,
             typename RetT = detail::walkResultType<FnT>>
-  std::enable_if_t<std::is_same<RetT, WalkResult>::value, RetT>
-  walk(FnT &&callback) {
-    for (auto &block : *this)
-      if (block.walk<Order, Iterator>(callback).wasInterrupted())
-        return WalkResult::interrupt();
-    return WalkResult::advance();
+  RetT walk(FnT &&callback) {
+    if constexpr (std::is_same<ArgT, Region *>::value &&
+                  Order == WalkOrder::PreOrder) {
+      // Pre-order walk on regions: invoke the callback on this region.
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        RetT result = callback(this);
+        if (result.wasSkipped())
+          return WalkResult::advance();
+        if (result.wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        callback(this);
+      }
+    }
+
+    // Walk nested operations, blocks or regions.
+    for (auto &block : *this) {
+      if constexpr (std::is_same<RetT, WalkResult>::value) {
+        if (block.walk<Order, Iterator>(callback).wasInterrupted())
+          return WalkResult::interrupt();
+      } else {
+        block.walk<Order, Iterator>(callback);
+      }
+    }
+
+    if constexpr (std::is_same<ArgT, Region *>::value &&
+                  Order == WalkOrder::PostOrder) {
+      // Post-order walk on regions: invoke the callback on this block.
+      return callback(this);
+    }
+    if constexpr (std::is_same<RetT, WalkResult>::value)
+      return WalkResult::advance();
   }
 
   //===--------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index ad7c4c783e907f..1a8a930bc9002b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -531,8 +531,8 @@ func.func @noRegionBranchOpInterface() {
 // This is not allowed in buffer deallocation.
 
 func.func @noRegionBranchOpInterface() {
-  // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
   %0 = "test.bar"() ({
+    // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
     %1 = "test.bar"() ({
       %2 = "test.get_memref"() : () -> memref<2xi32>
       "test.yield"(%2) : (memref<2xi32>) -> ()
@@ -544,6 +544,21 @@ func.func @noRegionBranchOpInterface() {
 
 // -----
 
+// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
+// This is not allowed in buffer deallocation.
+
+func.func @noRegionBranchOpInterface() {
+  // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
+  %0 = "test.bar"() ({
+    %2 = "test.get_memref"() : () -> memref<2xi32>
+    %3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
+    "test.yield"(%3) : (i32) -> ()
+  }) : () -> (i32)
+  "test.terminator"() : () -> ()
+}
+
+// -----
+
 func.func @while_two_arg(%arg0: index) {
   %a = memref.alloc(%arg0) : memref<?xf32>
   scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 2d83d6922e0cd0..ec7712a45d3882 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -17,7 +17,7 @@ func.func @structured_cfg() {
       "use2"(%i) : (index) -> ()
     }
     "use3"(%i) : (index) -> ()
-  }
+  } {walk_blocks, walk_regions}
   return
 }
 
@@ -88,6 +88,26 @@ func.func @structured_cfg() {
 // CHECK:       Visiting op 'func.func'
 // CHECK:       Visiting op 'builtin.module'
 
+// CHECK-LABEL: Invoke block pre-order visits on blocks
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke block post-order visits on blocks
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 1 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Invoke region pre-order visits on region
+// CHECK:       Visiting region 0 from operation 'scf.for'
+// CHECK:       Visiting region 0 from operation 'scf.if'
+// CHECK:       Visiting region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Invoke region post-order visits on region
+// CHECK:       Visiting region 0 from operation 'scf.if'
+// CHECK:       Visiting region 1 from operation 'scf.if'
+// CHECK:       Visiting region 0 from operation 'scf.for'
+
 // CHECK-LABEL: Op pre-order erasures
 // CHECK:       Erasing op 'scf.for'
 // CHECK:       Erasing op 'func.return'
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index a3ef3f35159534..f4cff39cf2e523 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -204,6 +204,60 @@ static void testNoSkipErasureCallbacks(Operation *op) {
   cloned->erase();
 }
 
+/// Invoke region/block walks on regions/blocks.
+static void testBlockAndRegionWalkers(Operation *op) {
+  auto blockPure = [](Block *block) {
+    llvm::outs() << "Visiting ";
+    printBlock(block);
+    llvm::outs() << "\n";
+  };
+  auto regionPure = [](Region *region) {
+    llvm::outs() << "Visiting ";
+    printRegion(region);
+    llvm::outs() << "\n";
+  };
+
+  llvm::outs() << "Invoke block pre-order visits on blocks\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_blocks"))
+      return;
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        block.walk<WalkOrder::PreOrder>(blockPure);
+      }
+    }
+  });
+
+  llvm::outs() << "Invoke block post-order visits on blocks\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_blocks"))
+      return;
+    for (Region &region : op->getRegions()) {
+      for (Block &block : region.getBlocks()) {
+        block.walk<WalkOrder::PostOrder>(blockPure);
+      }
+    }
+  });
+
+  llvm::outs() << "Invoke region pre-order visits on region\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_regions"))
+      return;
+    for (Region &region : op->getRegions()) {
+      region.walk<WalkOrder::PreOrder>(regionPure);
+    }
+  });
+
+  llvm::outs() << "Invoke region post-order visits on region\n";
+  op->walk([&](Operation *op) {
+    if (!op->hasAttr("walk_regions"))
+      return;
+    for (Region &region : op->getRegions()) {
+      region.walk<WalkOrder::PostOrder>(regionPure);
+    }
+  });
+}
+
 namespace {
 /// This pass exercises the different configurations of the IR visitors.
 struct TestIRVisitorsPass
@@ -215,6 +269,7 @@ struct TestIRVisitorsPass
   void runOnOperation() override {
     Operation *op = getOperation();
     testPureCallbacks(op);
+    testBlockAndRegionWalkers(op);
     testSkipErasureCallbacks(op);
     testNoSkipErasureCallbacks(op);
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/75020


More information about the Mlir-commits mailing list