[Mlir-commits] [mlir] 2de6dbd - [mlir] Add 'Skip' result to Operation visitor

Diego Caballero llvmlistbot at llvm.org
Fri Mar 5 14:06:14 PST 2021


Author: Diego Caballero
Date: 2021-03-06T00:02:20+02:00
New Revision: 2de6dbda66b3ff23f1e0cb52862d90224852ae59

URL: https://github.com/llvm/llvm-project/commit/2de6dbda66b3ff23f1e0cb52862d90224852ae59
DIFF: https://github.com/llvm/llvm-project/commit/2de6dbda66b3ff23f1e0cb52862d90224852ae59.diff

LOG: [mlir] Add 'Skip' result to Operation visitor

This patch is a follow-up on D97217. It adds a new 'Skip' result to the Operation visitor
so that a callback can stop the ongoing visit of an operation/block/region and
continue visiting the next one without fully interrupting the walk. Skipping is
needed to be able to erase an operation/block in pre-order and do not continue
visiting the internals of that operation/block.

Related to the skipping mechanism, the patch also introduces the following changes:
 * Added new TestIRVisitors pass with basic testing for the IR visitors.
 * Fixed missing early increment ranges in visitor implementation.
 * Updated documentation of walk methods to include erasure information and walk
   order information.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D97820

Added: 
    mlir/test/IR/visitors.mlir
    mlir/test/lib/IR/TestVisitors.cpp

Modified: 
    mlir/include/mlir/IR/Block.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/Region.h
    mlir/include/mlir/IR/Visitors.h
    mlir/lib/IR/Visitors.cpp
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 9f265b3b56f5..9f26155b5265 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -249,9 +249,15 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   // Operation Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operations in this block, calling the callback for each
-  /// operation. The walk order for regions, blocks and operations is specified
-  /// by 'Order' (post-order by default).
+  /// Walk the operations in this block. The callback method is called for each
+  /// nested region, block or operation, depending on the callback provided.
+  /// Regions, blocks and operations at the same nesting level are visited in
+  /// lexicographical order. The walk order for enclosing regions, blocks and
+  /// operations with respect to their nested ones is specified by 'Order'
+  /// (post-order by default). A callback on a block or operation is allowed to
+  /// erase that block or operation if either:
+  ///   * the walk is in post-order, or
+  ///   * the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
@@ -259,10 +265,15 @@ class Block : public IRObjectWithUseList<BlockOperand>,
     return walk<Order>(begin(), end(), std::forward<FnT>(callback));
   }
 
-  /// Walk the operations in the specified [begin, end) range of this block,
-  /// calling the callback for each operation. The walk order for regions,
-  /// blocks and operations is specified by 'Order' (post-order by default).
-  /// This method is invoked for void return callbacks.
+  /// Walk the operations in the specified [begin, end) range of this block. The
+  /// callback method is called for each nested region, block or operation,
+  /// depending on the callback provided. Regions, blocks and operations at the
+  /// same nesting level are visited in lexicographical order. The walk order
+  /// for enclosing regions, blocks and operations with respect to their nested
+  /// ones is specified by 'Order' (post-order by default). This method is
+  /// invoked for void-returning callbacks. A callback on a block or operation
+  /// is allowed to erase that block or operation only if the walk is in
+  /// post-order. See non-void method for pre-order erasure.
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
@@ -272,10 +283,16 @@ class Block : public IRObjectWithUseList<BlockOperand>,
       detail::walk<Order>(&op, callback);
   }
 
-  /// Walk the operations in the specified [begin, end) range of this block,
-  /// calling the callback for each operation. The walk order for regions,
-  /// blocks and operations is specified by 'Order' (post-order by default).
-  /// This method is invoked for interruptible callbacks.
+  /// Walk the operations in the specified [begin, end) range of this block. The
+  /// callback method is called for each nested region, block or operation,
+  /// depending on the callback provided. Regions, blocks and operations at the
+  /// same nesting level are visited in lexicographical order. The walk order
+  /// for enclosing regions, blocks and operations with respect to their nested
+  /// ones is specified by 'Order' (post-order by default). This method is
+  /// invoked for skippable or interruptible callbacks. A callback on a block or
+  /// operation is allowed to erase that block or operation if either:
+  ///   * the walk is in post-order, or
+  ///   * the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
             typename RetT = detail::walkResultType<FnT>>

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 4c9399f2a6fc..44bd7fb3283b 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -165,9 +165,15 @@ class OpState {
   /// handlers that may be listening.
   InFlightDiagnostic emitRemark(const Twine &message = {});
 
-  /// Walk the operation by calling the callback for each nested
-  /// operation(including this one). The walk order for regions, blocks and
-  /// operations is specified by 'Order' (post-order by default).
+  /// Walk the operation by calling the callback for each nested operation
+  /// (including this one), block or region, depending on the callback provided.
+  /// Regions, blocks and operations at the same nesting level are visited in
+  /// lexicographical order. The walk order for enclosing regions, blocks and
+  /// operations with respect to their nested ones is specified by 'Order'
+  /// (post-order by default). A callback on a block or operation is allowed to
+  /// erase that block or operation if either:
+  ///   * the walk is in post-order, or
+  ///   * the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
             typename RetT = detail::walkResultType<FnT>>

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 679071457e52..f76ea2657811 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -485,18 +485,28 @@ class alignas(8) Operation final
   //===--------------------------------------------------------------------===//
 
   /// Walk the operation by calling the callback for each nested operation
-  /// (including this one). The walk order for regions, blocks and operations is
-  /// specified by 'Order' (post-order by default). The callback method can take
-  /// any of the following forms:
+  /// (including this one), block or region, depending on the callback provided.
+  /// Regions, blocks and operations at the same nesting level are visited in
+  /// lexicographical order. The walk order for enclosing regions, blocks and
+  /// operations with respect to their nested ones is specified by 'Order'
+  /// (post-order by default). A callback on a block or operation is allowed to
+  /// erase that block or operation if either:
+  ///   * the walk is in post-order, or
+  ///   * the walk is in pre-order and the walk is skipped after the erasure.
+  ///
+  /// The callback method can take any of the following forms:
   ///   void(Operation*) : Walk all operations opaquely.
   ///     * op->walk([](Operation *nestedOp) { ...});
   ///   void(OpT) : Walk all operations of the given derived type.
   ///     * op->walk([](ReturnOp returnOp) { ...});
   ///   WalkResult(Operation*|OpT) : Walk operations, but allow for
-  ///                                interruption/cancellation.
+  ///                                interruption/skipping.
   ///     * op->walk([](... op) {
-  ///         // Interrupt, i.e cancel, the walk based on some invariant.
+  ///         // Skip the walk of this op based on some invariant.
   ///         if (some_invariant)
+  ///           return WalkResult::skip();
+  ///         // Interrupt, i.e cancel, the walk based on some invariant.
+  ///         if (another_invariant)
   ///           return WalkResult::interrupt();
   ///         return WalkResult::advance();
   ///       });

diff  --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 349be0f5aedc..8888862dd10c 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -242,11 +242,15 @@ class Region {
   // Operation Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operations in this region in postorder, calling the callback for
-  /// each operation. The walk order for regions, blocks and operations is
-  /// specified by 'Order' (post-order by default). This method is invoked for
-  /// void-returning callbacks.
-  /// See Operation::walk for more details.
+  /// Walk the operations in this region. The callback method is called for each
+  /// nested region, block or operation, depending on the callback provided.
+  /// Regions, blocks and operations at the same nesting level are visited in
+  /// lexicographical order. The walk order for enclosing regions, blocks and
+  /// operations with respect to their nested ones is specified by 'Order'
+  /// (post-order by default). This method is invoked for void-returning
+  /// callbacks. A callback on a block or operation is allowed to erase that
+  /// block or operation only if the walk is in post-order. See non-void method
+  /// for pre-order erasure. See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
   typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
@@ -255,10 +259,16 @@ class Region {
       block.walk<Order>(callback);
   }
 
-  /// Walk the operations in this region in postorder, calling the callback for
-  /// each operation. The walk order for regions, blocks and operations is
-  /// specified by 'Order' (post-order by default). This method is invoked for
-  /// interruptible callbacks.
+  /// Walk the operations in this region. The callback method is called for each
+  /// nested region, block or operation, depending on the callback provided.
+  /// Regions, blocks and operations at the same nesting level are visited in
+  /// lexicographical order. The walk order for enclosing regions, blocks and
+  /// operations with respect to their nested ones is specified by 'Order'
+  /// (post-order by default). This method is invoked for skippable or
+  /// interruptible callbacks. A callback on a block or operation is allowed to
+  /// erase that block or operation if either:
+  ///   * the walk is in post-order,
+  ///   * or the walk is in pre-order and the walk is skipped after the erasure.
   /// See Operation::walk for more details.
   template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
             typename RetT = detail::walkResultType<FnT>>

diff  --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index b4571e19f8fd..af2d05379662 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -24,10 +24,15 @@ class Operation;
 class Block;
 class Region;
 
-/// A utility result that is used to signal if a walk method should be
-/// interrupted or advance.
+/// A utility result that is used to signal how to proceed with an ongoing walk:
+///   * Interrupt: the walk will be interrupted and no more operations, regions
+///   or blocks will be visited.
+///   * Advance: the walk will continue.
+///   * Skip: the walk of the current operation, region or block and their
+///   nested elements that haven't been visited already will be skipped and will
+///   continue with the next operation, region or block.
 class WalkResult {
-  enum ResultEnum { Interrupt, Advance } result;
+  enum ResultEnum { Interrupt, Advance, Skip } result;
 
 public:
   WalkResult(ResultEnum result) : result(result) {}
@@ -44,9 +49,13 @@ class WalkResult {
 
   static WalkResult interrupt() { return {Interrupt}; }
   static WalkResult advance() { return {Advance}; }
+  static WalkResult skip() { return {Skip}; }
 
   /// Returns true if the walk was interrupted.
   bool wasInterrupted() const { return result == Interrupt; }
+
+  /// Returns true if the walk was skipped.
+  bool wasSkipped() const { return result == Skip; }
 };
 
 /// Traversal order for region, block and operation walk utilities.
@@ -67,15 +76,27 @@ template <typename T>
 using first_argument = decltype(first_argument_type(std::declval<T>()));
 
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. The walk order is specified by 'order'.
+/// the given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'order'. These methods are invoked for void-returning
+/// callbacks. A callback on a block or operation is allowed to erase that block
+/// or operation only if the walk is in post-order. See non-void method for
+/// pre-order erasure.
 void walk(Operation *op, function_ref<void(Region *)> callback,
           WalkOrder order);
 void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
 void walk(Operation *op, function_ref<void(Operation *)> callback,
           WalkOrder order);
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. The walk order is specified by 'order'. These functions
-/// walk until an interrupt result is returned by the callback.
+/// the given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'order'. This method is invoked for skippable or interruptible
+/// callbacks. A callback on a block or operation is allowed to erase that block
+/// or operation if either:
+///   * the walk is in post-order, or
+///   * the walk is in pre-order and the walk is skipped after the erasure.
 WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
                 WalkOrder order);
 WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
@@ -89,9 +110,15 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
 // upon the type of the callback function.
 
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. The walk order is specified by 'Order' (post-order
-/// by default). This method is selected for callbacks that operate on
-/// Region*, Block*, and Operation*.
+/// the given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'Order' (post-order by default). A callback on a block or
+/// operation is allowed to erase that block or operation if either:
+///   * the walk is in post-order, or
+///   * the walk is in pre-order and the walk is skipped after the erasure.
+/// This method is selected for callbacks that operate on Region*, Block*, and
+/// Operation*.
 ///
 /// Example:
 ///   op->walk([](Region *r) { ... });
@@ -108,9 +135,13 @@ walk(Operation *op, FuncTy &&callback) {
 }
 
 /// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. The walk order for regions, blocks and operations is
-/// specified by 'Order' (post-order by default). This method is selected for
-/// void returning callbacks that operate on a specific derived operation type.
+/// given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'order' (post-order by default). This method is selected for
+/// void-returning callbacks that operate on a specific derived operation type.
+/// A callback on an operation is allowed to erase that operation only if the
+/// walk is in post-order. See non-void method for pre-order erasure.
 ///
 /// Example:
 ///   op->walk([](ReturnOp op) { ... });
@@ -131,14 +162,21 @@ walk(Operation *op, FuncTy &&callback) {
 }
 
 /// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. The walk order for regions, blocks and operations is
-/// specified by 'Order' (post-order by default). This method is selected for
-/// WalkReturn returning interruptible callbacks that operate on a specific
-/// derived operation type.
+/// given operation. Regions, blocks and operations at the same nesting level
+/// are visited in lexicographical order. The walk order for enclosing regions,
+/// blocks and operations with respect to their nested ones is specified by
+/// 'Order' (post-order by default). This method is selected for WalkReturn
+/// returning skippable or interruptible callbacks that operate on a specific
+/// derived operation type. A callback on an operation is allowed to erase that
+/// operation if either:
+///   * the walk is in post-order, or
+///   * the walk is in pre-order and the walk is skipped after the erasure.
 ///
 /// Example:
 ///   op->walk([](ReturnOp op) {
 ///     if (some_invariant)
+///       return WalkResult::skip();
+///     if (another_invariant)
 ///       return WalkResult::interrupt();
 ///     return WalkResult::advance();
 ///   });

diff  --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index be995a2a4fb2..efe7da403291 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -12,10 +12,16 @@
 using namespace mlir;
 
 /// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. The walk order is specified by 'Order'.
-
+/// given operation. Regions, blocks and operations at the same nesting level
+/// are visited in lexicographical order. The walk order for enclosing regions,
+/// blocks and operations with respect to their nested ones is specified by
+/// 'order'. These methods are invoked for void-returning callbacks. A callback
+/// on a block or operation is allowed to erase that block or operation only if
+/// the walk is in post-order. See non-void method for pre-order erasure.
 void detail::walk(Operation *op, function_ref<void(Region *)> callback,
                   WalkOrder order) {
+  // We don't use early increment for regions because they can't be erased from
+  // a callback.
   for (auto &region : op->getRegions()) {
     if (order == WalkOrder::PreOrder)
       callback(&region);
@@ -31,7 +37,8 @@ void detail::walk(Operation *op, function_ref<void(Region *)> callback,
 void detail::walk(Operation *op, function_ref<void(Block *)> callback,
                   WalkOrder order) {
   for (auto &region : op->getRegions()) {
-    for (auto &block : region) {
+    // Early increment here in the case where the block is erased.
+    for (auto &block : llvm::make_early_inc_range(region)) {
       if (order == WalkOrder::PreOrder)
         callback(&block);
       for (auto &nestedOp : block)
@@ -61,22 +68,38 @@ void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
 }
 
 /// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. The walk order is specified by 'order'. These functions
-/// walk operations until an interrupt result is returned by the callback.
+/// given operation. These functions walk operations until an interrupt result
+/// is returned by the callback. Walks on regions, blocks and operations may
+/// also be skipped if the callback returns a skip result. Regions, blocks and
+/// operations at the same nesting level are visited in lexicographical order.
+/// The walk order for enclosing regions, blocks and operations with respect to
+/// their nested ones is specified by 'order'. A callback on a block or
+/// operation is allowed to erase that block or operation if either:
+///   * the walk is in post-order, or
+///   * the walk is in pre-order and the walk is skipped after the erasure.
 WalkResult detail::walk(Operation *op,
                         function_ref<WalkResult(Region *)> callback,
                         WalkOrder order) {
+  // We don't use early increment for regions because they can't be erased from
+  // a callback.
   for (auto &region : op->getRegions()) {
-    if (order == WalkOrder::PreOrder)
-      if (callback(&region).wasInterrupted())
+    if (order == WalkOrder::PreOrder) {
+      WalkResult result = callback(&region);
+      if (result.wasSkipped())
+        continue;
+      if (result.wasInterrupted())
         return WalkResult::interrupt();
+    }
     for (auto &block : region) {
       for (auto &nestedOp : block)
         walk(&nestedOp, callback, order);
     }
-    if (order == WalkOrder::PostOrder)
+    if (order == WalkOrder::PostOrder) {
       if (callback(&region).wasInterrupted())
         return WalkResult::interrupt();
+      // We don't check if this region was skipped because its walk already
+      // finished and the walk will continue with the next region.
+    }
   }
   return WalkResult::advance();
 }
@@ -85,15 +108,23 @@ WalkResult detail::walk(Operation *op,
                         function_ref<WalkResult(Block *)> callback,
                         WalkOrder order) {
   for (auto &region : op->getRegions()) {
-    for (auto &block : region) {
-      if (order == WalkOrder::PreOrder)
-        if (callback(&block).wasInterrupted())
+    // Early increment here in the case where the block is erased.
+    for (auto &block : llvm::make_early_inc_range(region)) {
+      if (order == WalkOrder::PreOrder) {
+        WalkResult result = callback(&block);
+        if (result.wasSkipped())
+          continue;
+        if (result.wasInterrupted())
           return WalkResult::interrupt();
+      }
       for (auto &nestedOp : block)
         walk(&nestedOp, callback, order);
-      if (order == WalkOrder::PostOrder)
+      if (order == WalkOrder::PostOrder) {
         if (callback(&block).wasInterrupted())
           return WalkResult::interrupt();
+        // We don't check if this block was skipped because its walk already
+        // finished and the walk will continue with the next block.
+      }
     }
   }
   return WalkResult::advance();
@@ -102,9 +133,14 @@ WalkResult detail::walk(Operation *op,
 WalkResult detail::walk(Operation *op,
                         function_ref<WalkResult(Operation *)> callback,
                         WalkOrder order) {
-  if (order == WalkOrder::PreOrder)
-    if (callback(op).wasInterrupted())
+  if (order == WalkOrder::PreOrder) {
+    WalkResult result = callback(op);
+    // If skipped, caller will continue the walk on the next operation.
+    if (result.wasSkipped())
+      return WalkResult::advance();
+    if (result.wasInterrupted())
       return WalkResult::interrupt();
+  }
 
   // TODO: This walk should be iterative over the operations.
   for (auto &region : op->getRegions()) {

diff  --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
new file mode 100644
index 000000000000..789ae8f07570
--- /dev/null
+++ b/mlir/test/IR/visitors.mlir
@@ -0,0 +1,212 @@
+// RUN: mlir-opt -test-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// Verify the 
diff erent configurations of IR visitors.
+// Constant, yield and other terminator ops are not matched for simplicity.
+// Module and function op and their immediately nested blocks are not erased in
+// callbacks with return so that the output includes more cases in pre-order.
+
+func @structured_cfg() {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c10 = constant 10 : index
+  scf.for %i = %c1 to %c10 step %c1 {
+    %cond = "use0"(%i) : (index) -> (i1)
+    scf.if %cond {
+      "use1"(%i) : (index) -> ()
+    } else {
+      "use2"(%i) : (index) -> ()
+    }
+    "use3"(%i) : (index) -> ()
+  }
+  return
+}
+
+// CHECK-LABEL: Op pre-order visit
+// CHECK:       Visiting op 'module'
+// CHECK:       Visiting op 'func'
+// CHECK:       Visiting op 'scf.for'
+// CHECK:       Visiting op 'use0'
+// CHECK:       Visiting op 'scf.if'
+// CHECK:       Visiting op 'use1'
+// CHECK:       Visiting op 'use2'
+// CHECK:       Visiting op 'use3'
+// CHECK:       Visiting op 'std.return'
+
+// CHECK-LABEL: Block pre-order visits
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'module'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Region pre-order visits
+// CHECK:       Visiting region 0 from operation 'module'
+// CHECK:       Visiting region 0 from operation 'func'
+// CHECK:       Visiting region 0 from operation 'scf.for'
+// CHECK:       Visiting region 0 from operation 'scf.if'
+// CHECK:       Visiting region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Op post-order visits
+// CHECK:       Visiting op 'use0'
+// CHECK:       Visiting op 'use1'
+// CHECK:       Visiting op 'use2'
+// CHECK:       Visiting op 'scf.if'
+// CHECK:       Visiting op 'use3'
+// CHECK:       Visiting op 'scf.for'
+// CHECK:       Visiting op 'std.return'
+// CHECK:       Visiting op 'func'
+// CHECK:       Visiting op 'module'
+
+// CHECK-LABEL: Block post-order visits
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 1 from operation 'scf.if'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'module'
+
+// CHECK-LABEL: Region post-order visits
+// CHECK:       Visiting region 0 from operation 'scf.if'
+// CHECK:       Visiting region 1 from operation 'scf.if'
+// CHECK:       Visiting region 0 from operation 'scf.for'
+// CHECK:       Visiting region 0 from operation 'func'
+// CHECK:       Visiting region 0 from operation 'module'
+
+// CHECK-LABEL: Op pre-order erasures
+// CHECK:       Erasing op 'scf.for'
+// CHECK:       Erasing op 'std.return'
+
+// CHECK-LABEL: Block pre-order erasures
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Op post-order erasures (skip)
+// CHECK:       Erasing op 'use0'
+// CHECK:       Erasing op 'use1'
+// CHECK:       Erasing op 'use2'
+// CHECK:       Erasing op 'scf.if'
+// CHECK:       Erasing op 'use3'
+// CHECK:       Erasing op 'scf.for'
+// CHECK:       Erasing op 'std.return'
+
+// CHECK-LABEL: Block post-order erasures (skip)
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Erasing block ^bb0 from region 1 from operation 'scf.if'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Op post-order erasures (no skip)
+// CHECK:       Erasing op 'use0'
+// CHECK:       Erasing op 'use1'
+// CHECK:       Erasing op 'use2'
+// CHECK:       Erasing op 'scf.if'
+// CHECK:       Erasing op 'use3'
+// CHECK:       Erasing op 'scf.for'
+// CHECK:       Erasing op 'std.return'
+// CHECK:       Erasing op 'func'
+// CHECK:       Erasing op 'module'
+
+// CHECK-LABEL: Block post-order erasures (no skip)
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'scf.if'
+// CHECK:       Erasing block ^bb0 from region 1 from operation 'scf.if'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'scf.for'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'func'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'module'
+
+// -----
+
+func @unstructured_cfg() {
+  "regionOp0"() ({
+    ^bb0:
+      "op0"() : () -> ()
+      br ^bb2
+    ^bb1:
+      "op1"() : () -> ()
+      br ^bb2
+    ^bb2:
+      "op2"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// CHECK-LABEL: Op pre-order visits
+// CHECK:       Visiting op 'module'
+// CHECK:       Visiting op 'func'
+// CHECK:       Visiting op 'regionOp0'
+// CHECK:       Visiting op 'op0'
+// CHECK:       Visiting op 'std.br'
+// CHECK:       Visiting op 'op1'
+// CHECK:       Visiting op 'std.br'
+// CHECK:       Visiting op 'op2'
+// CHECK:       Visiting op 'std.return'
+
+// CHECK-LABEL: Block pre-order visits
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'module'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb2 from region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Region pre-order visits
+// CHECK:       Visiting region 0 from operation 'module'
+// CHECK:       Visiting region 0 from operation 'func'
+// CHECK:       Visiting region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Op post-order visits
+// CHECK:       Visiting op 'op0'
+// CHECK:       Visiting op 'std.br'
+// CHECK:       Visiting op 'op1'
+// CHECK:       Visiting op 'std.br'
+// CHECK:       Visiting op 'op2'
+// CHECK:       Visiting op 'regionOp0'
+// CHECK:       Visiting op 'std.return'
+// CHECK:       Visiting op 'func'
+// CHECK:       Visiting op 'module'
+
+// CHECK-LABEL: Block post-order visits
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK:       Visiting block ^bb0 from region 0 from operation 'module'
+
+// CHECK-LABEL: Region post-order visits
+// CHECK:       Visiting region 0 from operation 'regionOp0'
+// CHECK:       Visiting region 0 from operation 'func'
+// CHECK:       Visiting region 0 from operation 'module'
+
+// CHECK-LABEL: Op pre-order erasures (skip)
+// CHECK:       Erasing op 'regionOp0'
+// CHECK:       Erasing op 'std.return'
+
+// CHECK-LABEL: Block pre-order erasures (skip)
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Op post-order erasures (skip)
+// CHECK:       Erasing op 'op0'
+// CHECK:       Erasing op 'std.br'
+// CHECK:       Erasing op 'op1'
+// CHECK:       Erasing op 'std.br'
+// CHECK:       Erasing op 'op2'
+// CHECK:       Erasing op 'regionOp0'
+// CHECK:       Erasing op 'std.return'
+
+// CHECK-LABEL: Block post-order erasures (skip)
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Op post-order erasures (no skip)
+// CHECK:       Erasing op 'op0'
+// CHECK:       Erasing op 'std.br'
+// CHECK:       Erasing op 'op1'
+// CHECK:       Erasing op 'std.br'
+// CHECK:       Erasing op 'op2'
+// CHECK:       Erasing op 'regionOp0'
+// CHECK:       Erasing op 'std.return'
+
+// CHECK-LABEL: Block post-order erasures (no skip)
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'func'
+// CHECK:       Erasing block ^bb0 from region 0 from operation 'module'

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index a42f90bb9268..337029e0dac8 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRTestIR
   TestSlicing.cpp
   TestSymbolUses.cpp
   TestTypes.cpp
+  TestVisitors.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
new file mode 100644
index 000000000000..7ce3422904cd
--- /dev/null
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -0,0 +1,171 @@
+//===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static void printRegion(Region *region) {
+  llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
+               << region->getParentOp()->getName() << "'";
+}
+
+static void printBlock(Block *block) {
+  llvm::outs() << "block ";
+  block->printAsOperand(llvm::outs(), /*printType=*/false);
+  llvm::outs() << " from ";
+  printRegion(block->getParent());
+}
+
+static void printOperation(Operation *op) {
+  llvm::outs() << "op '" << op->getName() << "'";
+}
+
+/// Tests pure callbacks.
+static void testPureCallbacks(Operation *op) {
+  auto opPure = [](Operation *op) {
+    llvm::outs() << "Visiting ";
+    printOperation(op);
+    llvm::outs() << "\n";
+  };
+  auto blockPure = [](Block *block) {
+    llvm::outs() << "Visiting ";
+    printBlock(block);
+    llvm::outs() << "\n";
+  };
+  auto regionPure = [](Region *region) {
+    llvm::outs() << "Visiting ";
+    printRegion(region);
+    llvm::outs() << "\n";
+  };
+
+  llvm::outs() << "Op pre-order visits"
+               << "\n";
+  op->walk<WalkOrder::PreOrder>(opPure);
+  llvm::outs() << "Block pre-order visits"
+               << "\n";
+  op->walk<WalkOrder::PreOrder>(blockPure);
+  llvm::outs() << "Region pre-order visits"
+               << "\n";
+  op->walk<WalkOrder::PreOrder>(regionPure);
+
+  llvm::outs() << "Op post-order visits"
+               << "\n";
+  op->walk<WalkOrder::PostOrder>(opPure);
+  llvm::outs() << "Block post-order visits"
+               << "\n";
+  op->walk<WalkOrder::PostOrder>(blockPure);
+  llvm::outs() << "Region post-order visits"
+               << "\n";
+  op->walk<WalkOrder::PostOrder>(regionPure);
+}
+
+/// Tests erasure callbacks that skip the walk.
+static void testSkipErasureCallbacks(Operation *op) {
+  auto skipOpErasure = [](Operation *op) {
+    // Do not erase module and function op. Otherwise there wouldn't be too
+    // much to test in pre-order.
+    if (isa<ModuleOp>(op) || isa<FuncOp>(op))
+      return WalkResult::advance();
+
+    llvm::outs() << "Erasing ";
+    printOperation(op);
+    llvm::outs() << "\n";
+    op->dropAllUses();
+    op->erase();
+    return WalkResult::skip();
+  };
+  auto skipBlockErasure = [](Block *block) {
+    // Do not erase module and function blocks. Otherwise there wouldn't be
+    // too much to test in pre-order.
+    Operation *parentOp = block->getParentOp();
+    if (isa<ModuleOp>(parentOp) || isa<FuncOp>(parentOp))
+      return WalkResult::advance();
+
+    llvm::outs() << "Erasing ";
+    printBlock(block);
+    llvm::outs() << "\n";
+    block->erase();
+    return WalkResult::skip();
+  };
+
+  llvm::outs() << "Op pre-order erasures (skip)"
+               << "\n";
+  Operation *cloned = op->clone();
+  cloned->walk<WalkOrder::PreOrder>(skipOpErasure);
+  cloned->erase();
+
+  llvm::outs() << "Block pre-order erasures (skip)"
+               << "\n";
+  cloned = op->clone();
+  cloned->walk<WalkOrder::PreOrder>(skipBlockErasure);
+  cloned->erase();
+
+  llvm::outs() << "Op post-order erasures (skip)"
+               << "\n";
+  cloned = op->clone();
+  cloned->walk<WalkOrder::PostOrder>(skipOpErasure);
+  cloned->erase();
+
+  llvm::outs() << "Block post-order erasures (skip)"
+               << "\n";
+  cloned = op->clone();
+  cloned->walk<WalkOrder::PostOrder>(skipBlockErasure);
+  cloned->erase();
+}
+
+/// Tests callbacks that erase the op or block but don't return 'Skip'. This
+/// callbacks are only valid in post-order.
+static void testNoSkipErasureCallbacks(Operation *op) {
+  auto noSkipOpErasure = [](Operation *op) {
+    llvm::outs() << "Erasing ";
+    printOperation(op);
+    llvm::outs() << "\n";
+    op->dropAllUses();
+    op->erase();
+  };
+  auto noSkipBlockErasure = [](Block *block) {
+    llvm::outs() << "Erasing ";
+    printBlock(block);
+    llvm::outs() << "\n";
+    block->erase();
+  };
+
+  llvm::outs() << "Op post-order erasures (no skip)"
+               << "\n";
+  Operation *cloned = op->clone();
+  cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure);
+
+  llvm::outs() << "Block post-order erasures (no skip)"
+               << "\n";
+  cloned = op->clone();
+  cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure);
+  cloned->erase();
+}
+
+namespace {
+/// This pass exercises the 
diff erent configurations of the IR visitors.
+struct TestIRVisitorsPass
+    : public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    testPureCallbacks(op);
+    testSkipErasureCallbacks(op);
+    testNoSkipErasureCallbacks(op);
+  }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestIRVisitorsPass() {
+  PassRegistration<TestIRVisitorsPass>("test-ir-visitors",
+                                       "Test various visitors.");
+}
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e03e7e8f8907..82db6ea256d7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
 void registerTestExpandTanhPass();
 void registerTestGpuParallelLoopMappingPass();
+void registerTestIRVisitorsPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
 void registerTestLinalgFusionTransforms();
@@ -146,6 +147,7 @@ void registerTestPasses() {
   test::registerTestDynamicPipelinePass();
   test::registerTestExpandTanhPass();
   test::registerTestGpuParallelLoopMappingPass();
+  test::registerTestIRVisitorsPass();
   test::registerTestInterfaces();
   test::registerTestLinalgCodegenStrategy();
   test::registerTestLinalgFusionTransforms();


        


More information about the Mlir-commits mailing list