[Mlir-commits] [mlir] df067f1 - [mlir][IR][NFC] Move `walk` definitions to header file

Matthias Springer llvmlistbot at llvm.org
Mon Mar 6 00:21:52 PST 2023


Author: Matthias Springer
Date: 2023-03-06T09:21:32+01:00
New Revision: df067f13de569979b0d8ad8e9fc91ca06630e58f

URL: https://github.com/llvm/llvm-project/commit/df067f13de569979b0d8ad8e9fc91ca06630e58f
DIFF: https://github.com/llvm/llvm-project/commit/df067f13de569979b0d8ad8e9fc91ca06630e58f.diff

LOG: [mlir][IR][NFC] Move `walk` definitions to header file

This allows users to provide custom `Iterator` templates. A new iterator will be added in a subsequent change.

Also rename `makeRange` to `makeIterable` and add a test case for the reverse iterator.

Differential Revision: https://reviews.llvm.org/D144887

Added: 
    

Modified: 
    mlir/include/mlir/IR/Visitors.h
    mlir/lib/IR/Visitors.cpp
    mlir/test/IR/visitors.mlir
    mlir/test/lib/IR/TestVisitors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 3fc105510d06d..871c19ff5c9b0 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -64,8 +64,12 @@ enum class WalkOrder { PreOrder, PostOrder };
 
 /// This iterator enumerates the elements in "forward" order.
 struct ForwardIterator {
-  template <typename RangeT>
-  static constexpr RangeT &makeRange(RangeT &range) {
+  /// Make operations iterable: return the list of regions.
+  static MutableArrayRef<Region> makeIterable(Operation &range);
+
+  /// Regions and block are already iterable.
+  template <typename T>
+  static constexpr T &makeIterable(T &range) {
     return range;
   }
 };
@@ -74,9 +78,10 @@ struct ForwardIterator {
 /// llvm::reverse.
 struct ReverseIterator {
   template <typename RangeT>
-  static constexpr auto makeRange(RangeT &&range) {
+  static constexpr auto makeIterable(RangeT &&range) {
     // llvm::reverse uses RangeT::rbegin and RangeT::rend.
-    return llvm::reverse(std::forward<RangeT>(range));
+    return llvm::reverse(
+        ForwardIterator::makeIterable(std::forward<RangeT>(range)));
   }
 };
 
@@ -141,12 +146,58 @@ using first_argument = decltype(first_argument_type(std::declval<T>()));
 /// pre-order erasure.
 template <typename Iterator>
 void walk(Operation *op, function_ref<void(Region *)> callback,
-          WalkOrder order);
+          WalkOrder order) {
+  // We don't use early increment for regions because they can't be erased from
+  // a callback.
+  for (auto &region : Iterator::makeIterable(*op)) {
+    if (order == WalkOrder::PreOrder)
+      callback(&region);
+    for (auto &block : Iterator::makeIterable(region)) {
+      for (auto &nestedOp : Iterator::makeIterable(block))
+        walk<Iterator>(&nestedOp, callback, order);
+    }
+    if (order == WalkOrder::PostOrder)
+      callback(&region);
+  }
+}
+
 template <typename Iterator>
-void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
+void walk(Operation *op, function_ref<void(Block *)> callback,
+          WalkOrder order) {
+  for (auto &region : Iterator::makeIterable(*op)) {
+    // Early increment here in the case where the block is erased.
+    for (auto &block :
+         llvm::make_early_inc_range(Iterator::makeIterable(region))) {
+      if (order == WalkOrder::PreOrder)
+        callback(&block);
+      for (auto &nestedOp : Iterator::makeIterable(block))
+        walk<Iterator>(&nestedOp, callback, order);
+      if (order == WalkOrder::PostOrder)
+        callback(&block);
+    }
+  }
+}
+
 template <typename Iterator>
 void walk(Operation *op, function_ref<void(Operation *)> callback,
-          WalkOrder order);
+          WalkOrder order) {
+  if (order == WalkOrder::PreOrder)
+    callback(op);
+
+  // TODO: This walk should be iterative over the operations.
+  for (auto &region : Iterator::makeIterable(*op)) {
+    for (auto &block : Iterator::makeIterable(region)) {
+      // Early increment here in the case where the operation is erased.
+      for (auto &nestedOp :
+           llvm::make_early_inc_range(Iterator::makeIterable(block)))
+        walk<Iterator>(&nestedOp, callback, order);
+    }
+  }
+
+  if (order == WalkOrder::PostOrder)
+    callback(op);
+}
+
 /// Walk all of the regions, blocks, or operations nested under (and including)
 /// the given operation. The order in which regions, blocks and operations at
 /// the same nesting level are visited (e.g., lexicographical or reverse
@@ -159,13 +210,88 @@ void walk(Operation *op, function_ref<void(Operation *)> callback,
 ///   * the walk is in pre-order and the walk is skipped after the erasure.
 template <typename Iterator>
 WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
-                WalkOrder order);
+                WalkOrder order) {
+  // We don't use early increment for regions because they can't be erased from
+  // a callback.
+  for (auto &region : Iterator::makeIterable(*op)) {
+    if (order == WalkOrder::PreOrder) {
+      WalkResult result = callback(&region);
+      if (result.wasSkipped())
+        continue;
+      if (result.wasInterrupted())
+        return WalkResult::interrupt();
+    }
+    for (auto &block : Iterator::makeIterable(region)) {
+      for (auto &nestedOp : Iterator::makeIterable(block))
+        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
+          return WalkResult::interrupt();
+    }
+    if (order == WalkOrder::PostOrder) {
+      if (callback(&region).wasInterrupted())
+        return WalkResult::interrupt();
+      // We don't check if this region was skipped because its walk already
+      // finished and the walk will continue with the next region.
+    }
+  }
+  return WalkResult::advance();
+}
+
 template <typename Iterator>
 WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
-                WalkOrder order);
+                WalkOrder order) {
+  for (auto &region : Iterator::makeIterable(*op)) {
+    // Early increment here in the case where the block is erased.
+    for (auto &block :
+         llvm::make_early_inc_range(Iterator::makeIterable(region))) {
+      if (order == WalkOrder::PreOrder) {
+        WalkResult result = callback(&block);
+        if (result.wasSkipped())
+          continue;
+        if (result.wasInterrupted())
+          return WalkResult::interrupt();
+      }
+      for (auto &nestedOp : Iterator::makeIterable(block))
+        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
+          return WalkResult::interrupt();
+      if (order == WalkOrder::PostOrder) {
+        if (callback(&block).wasInterrupted())
+          return WalkResult::interrupt();
+        // We don't check if this block was skipped because its walk already
+        // finished and the walk will continue with the next block.
+      }
+    }
+  }
+  return WalkResult::advance();
+}
+
 template <typename Iterator>
 WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
-                WalkOrder order);
+                WalkOrder order) {
+  if (order == WalkOrder::PreOrder) {
+    WalkResult result = callback(op);
+    // If skipped, caller will continue the walk on the next operation.
+    if (result.wasSkipped())
+      return WalkResult::advance();
+    if (result.wasInterrupted())
+      return WalkResult::interrupt();
+  }
+
+  // TODO: This walk should be iterative over the operations.
+  for (auto &region : Iterator::makeIterable(*op)) {
+    for (auto &block : Iterator::makeIterable(region)) {
+      // Early increment here in the case where the operation is erased.
+      for (auto &nestedOp :
+           llvm::make_early_inc_range(Iterator::makeIterable(block))) {
+        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
+          return WalkResult::interrupt();
+      }
+    }
+  }
+
+  if (order == WalkOrder::PostOrder)
+    return callback(op);
+  return WalkResult::advance();
+}
 
 // Below are a set of functions to walk nested operations. Users should favor
 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these

diff  --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index a54eca053dcc2..73235f0aa5244 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -14,92 +14,9 @@ using namespace mlir;
 WalkStage::WalkStage(Operation *op)
     : numRegions(op->getNumRegions()), nextRegion(0) {}
 
-/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. The order in which regions, blocks and operations at the
-/// same nesting level are visited (e.g., lexicographical or reverse
-/// lexicographical order) is determined by 'Iterator'. The walk order for
-/// enclosing regions,  blocks and operations with respect to their nested ones
-/// is specified by 'order'. These methods are invoked for void-returning
-/// callbacks. A callback on a block or operation is allowed to erase that block
-/// or operation only if the walk is in post-order. See non-void method for
-/// pre-order erasure.
-template <typename Iterator>
-void detail::walk(Operation *op, function_ref<void(Region *)> callback,
-                  WalkOrder order) {
-  // We don't use early increment for regions because they can't be erased from
-  // a callback.
-  MutableArrayRef<Region> regions = op->getRegions();
-  for (auto &region : Iterator::makeRange(regions)) {
-    if (order == WalkOrder::PreOrder)
-      callback(&region);
-    for (auto &block : Iterator::makeRange(region)) {
-      for (auto &nestedOp : Iterator::makeRange(block))
-        walk<Iterator>(&nestedOp, callback, order);
-    }
-    if (order == WalkOrder::PostOrder)
-      callback(&region);
-  }
-}
-// Explicit template instantiations for all supported iterators.
-template void detail::walk<ForwardIterator>(Operation *,
-                                            function_ref<void(Region *)>,
-                                            WalkOrder);
-template void detail::walk<ReverseIterator>(Operation *,
-                                            function_ref<void(Region *)>,
-                                            WalkOrder);
-
-template <typename Iterator>
-void detail::walk(Operation *op, function_ref<void(Block *)> callback,
-                  WalkOrder order) {
-  MutableArrayRef<Region> regions = op->getRegions();
-  for (auto &region : Iterator::makeRange(regions)) {
-    // Early increment here in the case where the block is erased.
-    for (auto &block :
-         llvm::make_early_inc_range(Iterator::makeRange(region))) {
-      if (order == WalkOrder::PreOrder)
-        callback(&block);
-      for (auto &nestedOp : Iterator::makeRange(block))
-        walk<Iterator>(&nestedOp, callback, order);
-      if (order == WalkOrder::PostOrder)
-        callback(&block);
-    }
-  }
-}
-// Explicit template instantiations for all supported iterators.
-template void detail::walk<ForwardIterator>(Operation *,
-                                            function_ref<void(Block *)>,
-                                            WalkOrder);
-template void detail::walk<ReverseIterator>(Operation *,
-                                            function_ref<void(Block *)>,
-                                            WalkOrder);
-
-template <typename Iterator>
-void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
-                  WalkOrder order) {
-  if (order == WalkOrder::PreOrder)
-    callback(op);
-
-  // TODO: This walk should be iterative over the operations.
-  MutableArrayRef<Region> regions = op->getRegions();
-  for (auto &region : Iterator::makeRange(regions)) {
-    for (auto &block : Iterator::makeRange(region)) {
-      // Early increment here in the case where the operation is erased.
-      for (auto &nestedOp :
-           llvm::make_early_inc_range(Iterator::makeRange(block)))
-        walk<Iterator>(&nestedOp, callback, order);
-    }
-  }
-
-  if (order == WalkOrder::PostOrder)
-    callback(op);
+MutableArrayRef<Region> ForwardIterator::makeIterable(Operation &range) {
+  return range.getRegions();
 }
-// Explicit template instantiations for all supported iterators.
-template void detail::walk<ForwardIterator>(Operation *,
-                                            function_ref<void(Operation *)>,
-                                            WalkOrder);
-template void detail::walk<ReverseIterator>(Operation *,
-                                            function_ref<void(Operation *)>,
-                                            WalkOrder);
 
 void detail::walk(Operation *op,
                   function_ref<void(Operation *, const WalkStage &)> callback) {
@@ -120,128 +37,6 @@ void detail::walk(Operation *op,
   callback(op, stage);
 }
 
-/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. These functions walk operations until an interrupt result
-/// is returned by the callback. Walks on regions, blocks and operations may
-/// also be skipped if the callback returns a skip result. Regions, blocks and
-/// operations at the same nesting level are visited in lexicographical order.
-/// The walk order for enclosing regions, blocks and operations with respect to
-/// their nested ones is specified by 'order'. A callback on a block or
-/// operation is allowed to erase that block or operation if either:
-///   * the walk is in post-order, or
-///   * the walk is in pre-order and the walk is skipped after the erasure.
-template <typename Iterator>
-WalkResult detail::walk(Operation *op,
-                        function_ref<WalkResult(Region *)> callback,
-                        WalkOrder order) {
-  // We don't use early increment for regions because they can't be erased from
-  // a callback.
-  MutableArrayRef<Region> regions = op->getRegions();
-  for (auto &region : Iterator::makeRange(regions)) {
-    if (order == WalkOrder::PreOrder) {
-      WalkResult result = callback(&region);
-      if (result.wasSkipped())
-        continue;
-      if (result.wasInterrupted())
-        return WalkResult::interrupt();
-    }
-    for (auto &block : Iterator::makeRange(region)) {
-      for (auto &nestedOp : Iterator::makeRange(block))
-        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
-          return WalkResult::interrupt();
-    }
-    if (order == WalkOrder::PostOrder) {
-      if (callback(&region).wasInterrupted())
-        return WalkResult::interrupt();
-      // We don't check if this region was skipped because its walk already
-      // finished and the walk will continue with the next region.
-    }
-  }
-  return WalkResult::advance();
-}
-// Explicit template instantiations for all supported iterators.
-template WalkResult
-detail::walk<ForwardIterator>(Operation *, function_ref<WalkResult(Region *)>,
-                              WalkOrder);
-template WalkResult
-detail::walk<ReverseIterator>(Operation *, function_ref<WalkResult(Region *)>,
-                              WalkOrder);
-
-template <typename Iterator>
-WalkResult detail::walk(Operation *op,
-                        function_ref<WalkResult(Block *)> callback,
-                        WalkOrder order) {
-  MutableArrayRef<Region> regions = op->getRegions();
-  for (auto &region : Iterator::makeRange(regions)) {
-    // Early increment here in the case where the block is erased.
-    for (auto &block :
-         llvm::make_early_inc_range(Iterator::makeRange(region))) {
-      if (order == WalkOrder::PreOrder) {
-        WalkResult result = callback(&block);
-        if (result.wasSkipped())
-          continue;
-        if (result.wasInterrupted())
-          return WalkResult::interrupt();
-      }
-      for (auto &nestedOp : Iterator::makeRange(block))
-        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
-          return WalkResult::interrupt();
-      if (order == WalkOrder::PostOrder) {
-        if (callback(&block).wasInterrupted())
-          return WalkResult::interrupt();
-        // We don't check if this block was skipped because its walk already
-        // finished and the walk will continue with the next block.
-      }
-    }
-  }
-  return WalkResult::advance();
-}
-// Explicit template instantiations for all supported iterators.
-template WalkResult
-detail::walk<ForwardIterator>(Operation *, function_ref<WalkResult(Block *)>,
-                              WalkOrder);
-template WalkResult
-detail::walk<ReverseIterator>(Operation *, function_ref<WalkResult(Block *)>,
-                              WalkOrder);
-
-template <typename Iterator>
-WalkResult detail::walk(Operation *op,
-                        function_ref<WalkResult(Operation *)> callback,
-                        WalkOrder order) {
-  if (order == WalkOrder::PreOrder) {
-    WalkResult result = callback(op);
-    // If skipped, caller will continue the walk on the next operation.
-    if (result.wasSkipped())
-      return WalkResult::advance();
-    if (result.wasInterrupted())
-      return WalkResult::interrupt();
-  }
-
-  // TODO: This walk should be iterative over the operations.
-  MutableArrayRef<Region> regions = op->getRegions();
-  for (auto &region : Iterator::makeRange(regions)) {
-    for (auto &block : Iterator::makeRange(region)) {
-      // Early increment here in the case where the operation is erased.
-      for (auto &nestedOp :
-           llvm::make_early_inc_range(Iterator::makeRange(block))) {
-        if (walk<Iterator>(&nestedOp, callback, order).wasInterrupted())
-          return WalkResult::interrupt();
-      }
-    }
-  }
-
-  if (order == WalkOrder::PostOrder)
-    return callback(op);
-  return WalkResult::advance();
-}
-// Explicit template instantiations for all supported iterators.
-template WalkResult
-detail::walk<ForwardIterator>(Operation *,
-                              function_ref<WalkResult(Operation *)>, WalkOrder);
-template WalkResult
-detail::walk<ReverseIterator>(Operation *,
-                              function_ref<WalkResult(Operation *)>, WalkOrder);
-
 WalkResult detail::walk(
     Operation *op,
     function_ref<WalkResult(Operation *, const WalkStage &)> callback) {

diff  --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
index 3a78a0c384045..0c9fabeee5377 100644
--- a/mlir/test/IR/visitors.mlir
+++ b/mlir/test/IR/visitors.mlir
@@ -71,6 +71,23 @@ func.func @structured_cfg() {
 // CHECK:       Visiting region 0 from operation 'func.func'
 // CHECK:       Visiting region 0 from operation 'builtin.module'
 
+// CHECK-LABEL: Op reverse post-order visits
+// CHECK:       Visiting op 'func.return'
+// CHECK:       Visiting op 'scf.yield'
+// CHECK:       Visiting op 'use3'
+// CHECK:       Visiting op 'scf.yield'
+// CHECK:       Visiting op 'use2'
+// CHECK:       Visiting op 'scf.yield'
+// CHECK:       Visiting op 'use1'
+// CHECK:       Visiting op 'scf.if'
+// CHECK:       Visiting op 'use0'
+// CHECK:       Visiting op 'scf.for'
+// CHECK:       Visiting op 'arith.constant'
+// CHECK:       Visiting op 'arith.constant'
+// CHECK:       Visiting op 'arith.constant'
+// CHECK:       Visiting op 'func.func'
+// CHECK:       Visiting op 'builtin.module'
+
 // CHECK-LABEL: Op pre-order erasures
 // CHECK:       Erasing op 'scf.for'
 // CHECK:       Erasing op 'func.return'
@@ -172,6 +189,29 @@ func.func @unstructured_cfg() {
 // CHECK:       Visiting region 0 from operation 'func.func'
 // CHECK:       Visiting region 0 from operation 'builtin.module'
 
+// CHECK-LABEL: Op reverse post-order visits
+// CHECK:       Visiting op 'func.return'
+// CHECK:       Visiting op 'op2'
+// CHECK:       Visiting op 'cf.br'
+// CHECK:       Visiting op 'op1'
+// CHECK:       Visiting op 'cf.br'
+// CHECK:       Visiting op 'op0'
+// CHECK:       Visiting op 'regionOp0'
+// CHECK:       Visiting op 'func.func'
+// CHECK:       Visiting op 'builtin.module'
+
+// CHECK-LABEL: Block reverse post-order visits
+// CHECK:       Visiting block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'func.func'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'builtin.module'
+
+// CHECK-LABEL: Region reverse post-order visits
+// CHECK:       Visiting region 0 from operation 'regionOp0'
+// CHECK:       Visiting region 0 from operation 'func.func'
+// CHECK:       Visiting region 0 from operation 'builtin.module'
+
 // CHECK-LABEL: Op pre-order erasures (skip)
 // CHECK:       Erasing op 'regionOp0'
 // CHECK:       Erasing op 'func.return'

diff  --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
index f63576e548c96..3211d10bab9fa 100644
--- a/mlir/test/lib/IR/TestVisitors.cpp
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -64,6 +64,16 @@ static void testPureCallbacks(Operation *op) {
   llvm::outs() << "Region post-order visits"
                << "\n";
   op->walk<WalkOrder::PostOrder>(regionPure);
+
+  llvm::outs() << "Op reverse post-order visits"
+               << "\n";
+  op->walk<WalkOrder::PostOrder, ReverseIterator>(opPure);
+  llvm::outs() << "Block reverse post-order visits"
+               << "\n";
+  op->walk<WalkOrder::PostOrder, ReverseIterator>(blockPure);
+  llvm::outs() << "Region reverse post-order visits"
+               << "\n";
+  op->walk<WalkOrder::PostOrder, ReverseIterator>(regionPure);
 }
 
 /// Tests erasure callbacks that skip the walk.


        


More information about the Mlir-commits mailing list