[Mlir-commits] [mlir] 2de6dbd - [mlir] Add 'Skip' result to Operation visitor
Diego Caballero
llvmlistbot at llvm.org
Fri Mar 5 14:06:14 PST 2021
Author: Diego Caballero
Date: 2021-03-06T00:02:20+02:00
New Revision: 2de6dbda66b3ff23f1e0cb52862d90224852ae59
URL: https://github.com/llvm/llvm-project/commit/2de6dbda66b3ff23f1e0cb52862d90224852ae59
DIFF: https://github.com/llvm/llvm-project/commit/2de6dbda66b3ff23f1e0cb52862d90224852ae59.diff
LOG: [mlir] Add 'Skip' result to Operation visitor
This patch is a follow-up on D97217. It adds a new 'Skip' result to the Operation visitor
so that a callback can stop the ongoing visit of an operation/block/region and
continue visiting the next one without fully interrupting the walk. Skipping is
needed to be able to erase an operation/block in pre-order and do not continue
visiting the internals of that operation/block.
Related to the skipping mechanism, the patch also introduces the following changes:
* Added new TestIRVisitors pass with basic testing for the IR visitors.
* Fixed missing early increment ranges in visitor implementation.
* Updated documentation of walk methods to include erasure information and walk
order information.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D97820
Added:
mlir/test/IR/visitors.mlir
mlir/test/lib/IR/TestVisitors.cpp
Modified:
mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Region.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/IR/Visitors.cpp
mlir/test/lib/IR/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 9f265b3b56f5..9f26155b5265 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -249,9 +249,15 @@ class Block : public IRObjectWithUseList<BlockOperand>,
// Operation Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this block, calling the callback for each
- /// operation. The walk order for regions, blocks and operations is specified
- /// by 'Order' (post-order by default).
+ /// Walk the operations in this block. The callback method is called for each
+ /// nested region, block or operation, depending on the callback provided.
+ /// Regions, blocks and operations at the same nesting level are visited in
+ /// lexicographical order. The walk order for enclosing regions, blocks and
+ /// operations with respect to their nested ones is specified by 'Order'
+ /// (post-order by default). A callback on a block or operation is allowed to
+ /// erase that block or operation if either:
+ /// * the walk is in post-order, or
+ /// * the walk is in pre-order and the walk is skipped after the erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
@@ -259,10 +265,15 @@ class Block : public IRObjectWithUseList<BlockOperand>,
return walk<Order>(begin(), end(), std::forward<FnT>(callback));
}
- /// Walk the operations in the specified [begin, end) range of this block,
- /// calling the callback for each operation. The walk order for regions,
- /// blocks and operations is specified by 'Order' (post-order by default).
- /// This method is invoked for void return callbacks.
+ /// Walk the operations in the specified [begin, end) range of this block. The
+ /// callback method is called for each nested region, block or operation,
+ /// depending on the callback provided. Regions, blocks and operations at the
+ /// same nesting level are visited in lexicographical order. The walk order
+ /// for enclosing regions, blocks and operations with respect to their nested
+ /// ones is specified by 'Order' (post-order by default). This method is
+ /// invoked for void-returning callbacks. A callback on a block or operation
+ /// is allowed to erase that block or operation only if the walk is in
+ /// post-order. See non-void method for pre-order erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
@@ -272,10 +283,16 @@ class Block : public IRObjectWithUseList<BlockOperand>,
detail::walk<Order>(&op, callback);
}
- /// Walk the operations in the specified [begin, end) range of this block,
- /// calling the callback for each operation. The walk order for regions,
- /// blocks and operations is specified by 'Order' (post-order by default).
- /// This method is invoked for interruptible callbacks.
+ /// Walk the operations in the specified [begin, end) range of this block. The
+ /// callback method is called for each nested region, block or operation,
+ /// depending on the callback provided. Regions, blocks and operations at the
+ /// same nesting level are visited in lexicographical order. The walk order
+ /// for enclosing regions, blocks and operations with respect to their nested
+ /// ones is specified by 'Order' (post-order by default). This method is
+ /// invoked for skippable or interruptible callbacks. A callback on a block or
+ /// operation is allowed to erase that block or operation if either:
+ /// * the walk is in post-order, or
+ /// * the walk is in pre-order and the walk is skipped after the erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 4c9399f2a6fc..44bd7fb3283b 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -165,9 +165,15 @@ class OpState {
/// handlers that may be listening.
InFlightDiagnostic emitRemark(const Twine &message = {});
- /// Walk the operation by calling the callback for each nested
- /// operation(including this one). The walk order for regions, blocks and
- /// operations is specified by 'Order' (post-order by default).
+ /// Walk the operation by calling the callback for each nested operation
+ /// (including this one), block or region, depending on the callback provided.
+ /// Regions, blocks and operations at the same nesting level are visited in
+ /// lexicographical order. The walk order for enclosing regions, blocks and
+ /// operations with respect to their nested ones is specified by 'Order'
+ /// (post-order by default). A callback on a block or operation is allowed to
+ /// erase that block or operation if either:
+ /// * the walk is in post-order, or
+ /// * the walk is in pre-order and the walk is skipped after the erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 679071457e52..f76ea2657811 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -485,18 +485,28 @@ class alignas(8) Operation final
//===--------------------------------------------------------------------===//
/// Walk the operation by calling the callback for each nested operation
- /// (including this one). The walk order for regions, blocks and operations is
- /// specified by 'Order' (post-order by default). The callback method can take
- /// any of the following forms:
+ /// (including this one), block or region, depending on the callback provided.
+ /// Regions, blocks and operations at the same nesting level are visited in
+ /// lexicographical order. The walk order for enclosing regions, blocks and
+ /// operations with respect to their nested ones is specified by 'Order'
+ /// (post-order by default). A callback on a block or operation is allowed to
+ /// erase that block or operation if either:
+ /// * the walk is in post-order, or
+ /// * the walk is in pre-order and the walk is skipped after the erasure.
+ ///
+ /// The callback method can take any of the following forms:
/// void(Operation*) : Walk all operations opaquely.
/// * op->walk([](Operation *nestedOp) { ...});
/// void(OpT) : Walk all operations of the given derived type.
/// * op->walk([](ReturnOp returnOp) { ...});
/// WalkResult(Operation*|OpT) : Walk operations, but allow for
- /// interruption/cancellation.
+ /// interruption/skipping.
/// * op->walk([](... op) {
- /// // Interrupt, i.e cancel, the walk based on some invariant.
+ /// // Skip the walk of this op based on some invariant.
/// if (some_invariant)
+ /// return WalkResult::skip();
+ /// // Interrupt, i.e cancel, the walk based on some invariant.
+ /// if (another_invariant)
/// return WalkResult::interrupt();
/// return WalkResult::advance();
/// });
diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 349be0f5aedc..8888862dd10c 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -242,11 +242,15 @@ class Region {
// Operation Walkers
//===--------------------------------------------------------------------===//
- /// Walk the operations in this region in postorder, calling the callback for
- /// each operation. The walk order for regions, blocks and operations is
- /// specified by 'Order' (post-order by default). This method is invoked for
- /// void-returning callbacks.
- /// See Operation::walk for more details.
+ /// Walk the operations in this region. The callback method is called for each
+ /// nested region, block or operation, depending on the callback provided.
+ /// Regions, blocks and operations at the same nesting level are visited in
+ /// lexicographical order. The walk order for enclosing regions, blocks and
+ /// operations with respect to their nested ones is specified by 'Order'
+ /// (post-order by default). This method is invoked for void-returning
+ /// callbacks. A callback on a block or operation is allowed to erase that
+ /// block or operation only if the walk is in post-order. See non-void method
+ /// for pre-order erasure. See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
@@ -255,10 +259,16 @@ class Region {
block.walk<Order>(callback);
}
- /// Walk the operations in this region in postorder, calling the callback for
- /// each operation. The walk order for regions, blocks and operations is
- /// specified by 'Order' (post-order by default). This method is invoked for
- /// interruptible callbacks.
+ /// Walk the operations in this region. The callback method is called for each
+ /// nested region, block or operation, depending on the callback provided.
+ /// Regions, blocks and operations at the same nesting level are visited in
+ /// lexicographical order. The walk order for enclosing regions, blocks and
+ /// operations with respect to their nested ones is specified by 'Order'
+ /// (post-order by default). This method is invoked for skippable or
+ /// interruptible callbacks. A callback on a block or operation is allowed to
+ /// erase that block or operation if either:
+ /// * the walk is in post-order,
+ /// * or the walk is in pre-order and the walk is skipped after the erasure.
/// See Operation::walk for more details.
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index b4571e19f8fd..af2d05379662 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -24,10 +24,15 @@ class Operation;
class Block;
class Region;
-/// A utility result that is used to signal if a walk method should be
-/// interrupted or advance.
+/// A utility result that is used to signal how to proceed with an ongoing walk:
+/// * Interrupt: the walk will be interrupted and no more operations, regions
+/// or blocks will be visited.
+/// * Advance: the walk will continue.
+/// * Skip: the walk of the current operation, region or block and their
+/// nested elements that haven't been visited already will be skipped and will
+/// continue with the next operation, region or block.
class WalkResult {
- enum ResultEnum { Interrupt, Advance } result;
+ enum ResultEnum { Interrupt, Advance, Skip } result;
public:
WalkResult(ResultEnum result) : result(result) {}
@@ -44,9 +49,13 @@ class WalkResult {
static WalkResult interrupt() { return {Interrupt}; }
static WalkResult advance() { return {Advance}; }
+ static WalkResult skip() { return {Skip}; }
/// Returns true if the walk was interrupted.
bool wasInterrupted() const { return result == Interrupt; }
+
+ /// Returns true if the walk was skipped.
+ bool wasSkipped() const { return result == Skip; }
};
/// Traversal order for region, block and operation walk utilities.
@@ -67,15 +76,27 @@ template <typename T>
using first_argument = decltype(first_argument_type(std::declval<T>()));
/// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. The walk order is specified by 'order'.
+/// the given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'order'. These methods are invoked for void-returning
+/// callbacks. A callback on a block or operation is allowed to erase that block
+/// or operation only if the walk is in post-order. See non-void method for
+/// pre-order erasure.
void walk(Operation *op, function_ref<void(Region *)> callback,
WalkOrder order);
void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
void walk(Operation *op, function_ref<void(Operation *)> callback,
WalkOrder order);
/// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. The walk order is specified by 'order'. These functions
-/// walk until an interrupt result is returned by the callback.
+/// the given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'order'. This method is invoked for skippable or interruptible
+/// callbacks. A callback on a block or operation is allowed to erase that block
+/// or operation if either:
+/// * the walk is in post-order, or
+/// * the walk is in pre-order and the walk is skipped after the erasure.
WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
WalkOrder order);
WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
@@ -89,9 +110,15 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
// upon the type of the callback function.
/// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. The walk order is specified by 'Order' (post-order
-/// by default). This method is selected for callbacks that operate on
-/// Region*, Block*, and Operation*.
+/// the given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'Order' (post-order by default). A callback on a block or
+/// operation is allowed to erase that block or operation if either:
+/// * the walk is in post-order, or
+/// * the walk is in pre-order and the walk is skipped after the erasure.
+/// This method is selected for callbacks that operate on Region*, Block*, and
+/// Operation*.
///
/// Example:
/// op->walk([](Region *r) { ... });
@@ -108,9 +135,13 @@ walk(Operation *op, FuncTy &&callback) {
}
/// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. The walk order for regions, blocks and operations is
-/// specified by 'Order' (post-order by default). This method is selected for
-/// void returning callbacks that operate on a specific derived operation type.
+/// given operation. Regions, blocks and operations at the same nesting
+/// level are visited in lexicographical order. The walk order for enclosing
+/// regions, blocks and operations with respect to their nested ones is
+/// specified by 'order' (post-order by default). This method is selected for
+/// void-returning callbacks that operate on a specific derived operation type.
+/// A callback on an operation is allowed to erase that operation only if the
+/// walk is in post-order. See non-void method for pre-order erasure.
///
/// Example:
/// op->walk([](ReturnOp op) { ... });
@@ -131,14 +162,21 @@ walk(Operation *op, FuncTy &&callback) {
}
/// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. The walk order for regions, blocks and operations is
-/// specified by 'Order' (post-order by default). This method is selected for
-/// WalkReturn returning interruptible callbacks that operate on a specific
-/// derived operation type.
+/// given operation. Regions, blocks and operations at the same nesting level
+/// are visited in lexicographical order. The walk order for enclosing regions,
+/// blocks and operations with respect to their nested ones is specified by
+/// 'Order' (post-order by default). This method is selected for WalkReturn
+/// returning skippable or interruptible callbacks that operate on a specific
+/// derived operation type. A callback on an operation is allowed to erase that
+/// operation if either:
+/// * the walk is in post-order, or
+/// * the walk is in pre-order and the walk is skipped after the erasure.
///
/// Example:
/// op->walk([](ReturnOp op) {
/// if (some_invariant)
+/// return WalkResult::skip();
+/// if (another_invariant)
/// return WalkResult::interrupt();
/// return WalkResult::advance();
/// });
diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index be995a2a4fb2..efe7da403291 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -12,10 +12,16 @@
using namespace mlir;
/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. The walk order is specified by 'Order'.
-
+/// given operation. Regions, blocks and operations at the same nesting level
+/// are visited in lexicographical order. The walk order for enclosing regions,
+/// blocks and operations with respect to their nested ones is specified by
+/// 'order'. These methods are invoked for void-returning callbacks. A callback
+/// on a block or operation is allowed to erase that block or operation only if
+/// the walk is in post-order. See non-void method for pre-order erasure.
void detail::walk(Operation *op, function_ref<void(Region *)> callback,
WalkOrder order) {
+ // We don't use early increment for regions because they can't be erased from
+ // a callback.
for (auto ®ion : op->getRegions()) {
if (order == WalkOrder::PreOrder)
callback(®ion);
@@ -31,7 +37,8 @@ void detail::walk(Operation *op, function_ref<void(Region *)> callback,
void detail::walk(Operation *op, function_ref<void(Block *)> callback,
WalkOrder order) {
for (auto ®ion : op->getRegions()) {
- for (auto &block : region) {
+ // Early increment here in the case where the block is erased.
+ for (auto &block : llvm::make_early_inc_range(region)) {
if (order == WalkOrder::PreOrder)
callback(&block);
for (auto &nestedOp : block)
@@ -61,22 +68,38 @@ void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
}
/// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. The walk order is specified by 'order'. These functions
-/// walk operations until an interrupt result is returned by the callback.
+/// given operation. These functions walk operations until an interrupt result
+/// is returned by the callback. Walks on regions, blocks and operations may
+/// also be skipped if the callback returns a skip result. Regions, blocks and
+/// operations at the same nesting level are visited in lexicographical order.
+/// The walk order for enclosing regions, blocks and operations with respect to
+/// their nested ones is specified by 'order'. A callback on a block or
+/// operation is allowed to erase that block or operation if either:
+/// * the walk is in post-order, or
+/// * the walk is in pre-order and the walk is skipped after the erasure.
WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Region *)> callback,
WalkOrder order) {
+ // We don't use early increment for regions because they can't be erased from
+ // a callback.
for (auto ®ion : op->getRegions()) {
- if (order == WalkOrder::PreOrder)
- if (callback(®ion).wasInterrupted())
+ if (order == WalkOrder::PreOrder) {
+ WalkResult result = callback(®ion);
+ if (result.wasSkipped())
+ continue;
+ if (result.wasInterrupted())
return WalkResult::interrupt();
+ }
for (auto &block : region) {
for (auto &nestedOp : block)
walk(&nestedOp, callback, order);
}
- if (order == WalkOrder::PostOrder)
+ if (order == WalkOrder::PostOrder) {
if (callback(®ion).wasInterrupted())
return WalkResult::interrupt();
+ // We don't check if this region was skipped because its walk already
+ // finished and the walk will continue with the next region.
+ }
}
return WalkResult::advance();
}
@@ -85,15 +108,23 @@ WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Block *)> callback,
WalkOrder order) {
for (auto ®ion : op->getRegions()) {
- for (auto &block : region) {
- if (order == WalkOrder::PreOrder)
- if (callback(&block).wasInterrupted())
+ // Early increment here in the case where the block is erased.
+ for (auto &block : llvm::make_early_inc_range(region)) {
+ if (order == WalkOrder::PreOrder) {
+ WalkResult result = callback(&block);
+ if (result.wasSkipped())
+ continue;
+ if (result.wasInterrupted())
return WalkResult::interrupt();
+ }
for (auto &nestedOp : block)
walk(&nestedOp, callback, order);
- if (order == WalkOrder::PostOrder)
+ if (order == WalkOrder::PostOrder) {
if (callback(&block).wasInterrupted())
return WalkResult::interrupt();
+ // We don't check if this block was skipped because its walk already
+ // finished and the walk will continue with the next block.
+ }
}
}
return WalkResult::advance();
@@ -102,9 +133,14 @@ WalkResult detail::walk(Operation *op,
WalkResult detail::walk(Operation *op,
function_ref<WalkResult(Operation *)> callback,
WalkOrder order) {
- if (order == WalkOrder::PreOrder)
- if (callback(op).wasInterrupted())
+ if (order == WalkOrder::PreOrder) {
+ WalkResult result = callback(op);
+ // If skipped, caller will continue the walk on the next operation.
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
return WalkResult::interrupt();
+ }
// TODO: This walk should be iterative over the operations.
for (auto ®ion : op->getRegions()) {
diff --git a/mlir/test/IR/visitors.mlir b/mlir/test/IR/visitors.mlir
new file mode 100644
index 000000000000..789ae8f07570
--- /dev/null
+++ b/mlir/test/IR/visitors.mlir
@@ -0,0 +1,212 @@
+// RUN: mlir-opt -test-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// Verify the
diff erent configurations of IR visitors.
+// Constant, yield and other terminator ops are not matched for simplicity.
+// Module and function op and their immediately nested blocks are not erased in
+// callbacks with return so that the output includes more cases in pre-order.
+
+func @structured_cfg() {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c10 = constant 10 : index
+ scf.for %i = %c1 to %c10 step %c1 {
+ %cond = "use0"(%i) : (index) -> (i1)
+ scf.if %cond {
+ "use1"(%i) : (index) -> ()
+ } else {
+ "use2"(%i) : (index) -> ()
+ }
+ "use3"(%i) : (index) -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: Op pre-order visit
+// CHECK: Visiting op 'module'
+// CHECK: Visiting op 'func'
+// CHECK: Visiting op 'scf.for'
+// CHECK: Visiting op 'use0'
+// CHECK: Visiting op 'scf.if'
+// CHECK: Visiting op 'use1'
+// CHECK: Visiting op 'use2'
+// CHECK: Visiting op 'use3'
+// CHECK: Visiting op 'std.return'
+
+// CHECK-LABEL: Block pre-order visits
+// CHECK: Visiting block ^bb0 from region 0 from operation 'module'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Region pre-order visits
+// CHECK: Visiting region 0 from operation 'module'
+// CHECK: Visiting region 0 from operation 'func'
+// CHECK: Visiting region 0 from operation 'scf.for'
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+
+// CHECK-LABEL: Op post-order visits
+// CHECK: Visiting op 'use0'
+// CHECK: Visiting op 'use1'
+// CHECK: Visiting op 'use2'
+// CHECK: Visiting op 'scf.if'
+// CHECK: Visiting op 'use3'
+// CHECK: Visiting op 'scf.for'
+// CHECK: Visiting op 'std.return'
+// CHECK: Visiting op 'func'
+// CHECK: Visiting op 'module'
+
+// CHECK-LABEL: Block post-order visits
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 1 from operation 'scf.if'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'scf.for'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'module'
+
+// CHECK-LABEL: Region post-order visits
+// CHECK: Visiting region 0 from operation 'scf.if'
+// CHECK: Visiting region 1 from operation 'scf.if'
+// CHECK: Visiting region 0 from operation 'scf.for'
+// CHECK: Visiting region 0 from operation 'func'
+// CHECK: Visiting region 0 from operation 'module'
+
+// CHECK-LABEL: Op pre-order erasures
+// CHECK: Erasing op 'scf.for'
+// CHECK: Erasing op 'std.return'
+
+// CHECK-LABEL: Block pre-order erasures
+// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Op post-order erasures (skip)
+// CHECK: Erasing op 'use0'
+// CHECK: Erasing op 'use1'
+// CHECK: Erasing op 'use2'
+// CHECK: Erasing op 'scf.if'
+// CHECK: Erasing op 'use3'
+// CHECK: Erasing op 'scf.for'
+// CHECK: Erasing op 'std.return'
+
+// CHECK-LABEL: Block post-order erasures (skip)
+// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for'
+
+// CHECK-LABEL: Op post-order erasures (no skip)
+// CHECK: Erasing op 'use0'
+// CHECK: Erasing op 'use1'
+// CHECK: Erasing op 'use2'
+// CHECK: Erasing op 'scf.if'
+// CHECK: Erasing op 'use3'
+// CHECK: Erasing op 'scf.for'
+// CHECK: Erasing op 'std.return'
+// CHECK: Erasing op 'func'
+// CHECK: Erasing op 'module'
+
+// CHECK-LABEL: Block post-order erasures (no skip)
+// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.if'
+// CHECK: Erasing block ^bb0 from region 1 from operation 'scf.if'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'scf.for'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'func'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'module'
+
+// -----
+
+func @unstructured_cfg() {
+ "regionOp0"() ({
+ ^bb0:
+ "op0"() : () -> ()
+ br ^bb2
+ ^bb1:
+ "op1"() : () -> ()
+ br ^bb2
+ ^bb2:
+ "op2"() : () -> ()
+ }) : () -> ()
+ return
+}
+
+// CHECK-LABEL: Op pre-order visits
+// CHECK: Visiting op 'module'
+// CHECK: Visiting op 'func'
+// CHECK: Visiting op 'regionOp0'
+// CHECK: Visiting op 'op0'
+// CHECK: Visiting op 'std.br'
+// CHECK: Visiting op 'op1'
+// CHECK: Visiting op 'std.br'
+// CHECK: Visiting op 'op2'
+// CHECK: Visiting op 'std.return'
+
+// CHECK-LABEL: Block pre-order visits
+// CHECK: Visiting block ^bb0 from region 0 from operation 'module'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Region pre-order visits
+// CHECK: Visiting region 0 from operation 'module'
+// CHECK: Visiting region 0 from operation 'func'
+// CHECK: Visiting region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Op post-order visits
+// CHECK: Visiting op 'op0'
+// CHECK: Visiting op 'std.br'
+// CHECK: Visiting op 'op1'
+// CHECK: Visiting op 'std.br'
+// CHECK: Visiting op 'op2'
+// CHECK: Visiting op 'regionOp0'
+// CHECK: Visiting op 'std.return'
+// CHECK: Visiting op 'func'
+// CHECK: Visiting op 'module'
+
+// CHECK-LABEL: Block post-order visits
+// CHECK: Visiting block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb1 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb2 from region 0 from operation 'regionOp0'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'func'
+// CHECK: Visiting block ^bb0 from region 0 from operation 'module'
+
+// CHECK-LABEL: Region post-order visits
+// CHECK: Visiting region 0 from operation 'regionOp0'
+// CHECK: Visiting region 0 from operation 'func'
+// CHECK: Visiting region 0 from operation 'module'
+
+// CHECK-LABEL: Op pre-order erasures (skip)
+// CHECK: Erasing op 'regionOp0'
+// CHECK: Erasing op 'std.return'
+
+// CHECK-LABEL: Block pre-order erasures (skip)
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Op post-order erasures (skip)
+// CHECK: Erasing op 'op0'
+// CHECK: Erasing op 'std.br'
+// CHECK: Erasing op 'op1'
+// CHECK: Erasing op 'std.br'
+// CHECK: Erasing op 'op2'
+// CHECK: Erasing op 'regionOp0'
+// CHECK: Erasing op 'std.return'
+
+// CHECK-LABEL: Block post-order erasures (skip)
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+
+// CHECK-LABEL: Op post-order erasures (no skip)
+// CHECK: Erasing op 'op0'
+// CHECK: Erasing op 'std.br'
+// CHECK: Erasing op 'op1'
+// CHECK: Erasing op 'std.br'
+// CHECK: Erasing op 'op2'
+// CHECK: Erasing op 'regionOp0'
+// CHECK: Erasing op 'std.return'
+
+// CHECK-LABEL: Block post-order erasures (no skip)
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'regionOp0'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'func'
+// CHECK: Erasing block ^bb0 from region 0 from operation 'module'
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index a42f90bb9268..337029e0dac8 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -9,6 +9,7 @@ add_mlir_library(MLIRTestIR
TestSlicing.cpp
TestSymbolUses.cpp
TestTypes.cpp
+ TestVisitors.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/IR/TestVisitors.cpp b/mlir/test/lib/IR/TestVisitors.cpp
new file mode 100644
index 000000000000..7ce3422904cd
--- /dev/null
+++ b/mlir/test/lib/IR/TestVisitors.cpp
@@ -0,0 +1,171 @@
+//===- TestIRVisitors.cpp - Pass to test the IR visitors ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static void printRegion(Region *region) {
+ llvm::outs() << "region " << region->getRegionNumber() << " from operation '"
+ << region->getParentOp()->getName() << "'";
+}
+
+static void printBlock(Block *block) {
+ llvm::outs() << "block ";
+ block->printAsOperand(llvm::outs(), /*printType=*/false);
+ llvm::outs() << " from ";
+ printRegion(block->getParent());
+}
+
+static void printOperation(Operation *op) {
+ llvm::outs() << "op '" << op->getName() << "'";
+}
+
+/// Tests pure callbacks.
+static void testPureCallbacks(Operation *op) {
+ auto opPure = [](Operation *op) {
+ llvm::outs() << "Visiting ";
+ printOperation(op);
+ llvm::outs() << "\n";
+ };
+ auto blockPure = [](Block *block) {
+ llvm::outs() << "Visiting ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ };
+ auto regionPure = [](Region *region) {
+ llvm::outs() << "Visiting ";
+ printRegion(region);
+ llvm::outs() << "\n";
+ };
+
+ llvm::outs() << "Op pre-order visits"
+ << "\n";
+ op->walk<WalkOrder::PreOrder>(opPure);
+ llvm::outs() << "Block pre-order visits"
+ << "\n";
+ op->walk<WalkOrder::PreOrder>(blockPure);
+ llvm::outs() << "Region pre-order visits"
+ << "\n";
+ op->walk<WalkOrder::PreOrder>(regionPure);
+
+ llvm::outs() << "Op post-order visits"
+ << "\n";
+ op->walk<WalkOrder::PostOrder>(opPure);
+ llvm::outs() << "Block post-order visits"
+ << "\n";
+ op->walk<WalkOrder::PostOrder>(blockPure);
+ llvm::outs() << "Region post-order visits"
+ << "\n";
+ op->walk<WalkOrder::PostOrder>(regionPure);
+}
+
+/// Tests erasure callbacks that skip the walk.
+static void testSkipErasureCallbacks(Operation *op) {
+ auto skipOpErasure = [](Operation *op) {
+ // Do not erase module and function op. Otherwise there wouldn't be too
+ // much to test in pre-order.
+ if (isa<ModuleOp>(op) || isa<FuncOp>(op))
+ return WalkResult::advance();
+
+ llvm::outs() << "Erasing ";
+ printOperation(op);
+ llvm::outs() << "\n";
+ op->dropAllUses();
+ op->erase();
+ return WalkResult::skip();
+ };
+ auto skipBlockErasure = [](Block *block) {
+ // Do not erase module and function blocks. Otherwise there wouldn't be
+ // too much to test in pre-order.
+ Operation *parentOp = block->getParentOp();
+ if (isa<ModuleOp>(parentOp) || isa<FuncOp>(parentOp))
+ return WalkResult::advance();
+
+ llvm::outs() << "Erasing ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ block->erase();
+ return WalkResult::skip();
+ };
+
+ llvm::outs() << "Op pre-order erasures (skip)"
+ << "\n";
+ Operation *cloned = op->clone();
+ cloned->walk<WalkOrder::PreOrder>(skipOpErasure);
+ cloned->erase();
+
+ llvm::outs() << "Block pre-order erasures (skip)"
+ << "\n";
+ cloned = op->clone();
+ cloned->walk<WalkOrder::PreOrder>(skipBlockErasure);
+ cloned->erase();
+
+ llvm::outs() << "Op post-order erasures (skip)"
+ << "\n";
+ cloned = op->clone();
+ cloned->walk<WalkOrder::PostOrder>(skipOpErasure);
+ cloned->erase();
+
+ llvm::outs() << "Block post-order erasures (skip)"
+ << "\n";
+ cloned = op->clone();
+ cloned->walk<WalkOrder::PostOrder>(skipBlockErasure);
+ cloned->erase();
+}
+
+/// Tests callbacks that erase the op or block but don't return 'Skip'. This
+/// callbacks are only valid in post-order.
+static void testNoSkipErasureCallbacks(Operation *op) {
+ auto noSkipOpErasure = [](Operation *op) {
+ llvm::outs() << "Erasing ";
+ printOperation(op);
+ llvm::outs() << "\n";
+ op->dropAllUses();
+ op->erase();
+ };
+ auto noSkipBlockErasure = [](Block *block) {
+ llvm::outs() << "Erasing ";
+ printBlock(block);
+ llvm::outs() << "\n";
+ block->erase();
+ };
+
+ llvm::outs() << "Op post-order erasures (no skip)"
+ << "\n";
+ Operation *cloned = op->clone();
+ cloned->walk<WalkOrder::PostOrder>(noSkipOpErasure);
+
+ llvm::outs() << "Block post-order erasures (no skip)"
+ << "\n";
+ cloned = op->clone();
+ cloned->walk<WalkOrder::PostOrder>(noSkipBlockErasure);
+ cloned->erase();
+}
+
+namespace {
+/// This pass exercises the
diff erent configurations of the IR visitors.
+struct TestIRVisitorsPass
+ : public PassWrapper<TestIRVisitorsPass, OperationPass<>> {
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ testPureCallbacks(op);
+ testSkipErasureCallbacks(op);
+ testNoSkipErasureCallbacks(op);
+ }
+};
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestIRVisitorsPass() {
+ PassRegistration<TestIRVisitorsPass>("test-ir-visitors",
+ "Test various visitors.");
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index e03e7e8f8907..82db6ea256d7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -72,6 +72,7 @@ void registerTestDominancePass();
void registerTestDynamicPipelinePass();
void registerTestExpandTanhPass();
void registerTestGpuParallelLoopMappingPass();
+void registerTestIRVisitorsPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgFusionTransforms();
@@ -146,6 +147,7 @@ void registerTestPasses() {
test::registerTestDynamicPipelinePass();
test::registerTestExpandTanhPass();
test::registerTestGpuParallelLoopMappingPass();
+ test::registerTestIRVisitorsPass();
test::registerTestInterfaces();
test::registerTestLinalgCodegenStrategy();
test::registerTestLinalgFusionTransforms();
More information about the Mlir-commits
mailing list