[Mlir-commits] [mlir] 8067ced - [MLIR] Introduce generic visitors.
Rahul Joshi
llvmlistbot at llvm.org
Fri Jan 14 09:16:08 PST 2022
Author: Rahul Joshi
Date: 2022-01-14T09:15:27-08:00
New Revision: 8067ced144a213574b6c8cb4c15aba276d5cf906
URL: https://github.com/llvm/llvm-project/commit/8067ced144a213574b6c8cb4c15aba276d5cf906
DIFF: https://github.com/llvm/llvm-project/commit/8067ced144a213574b6c8cb4c15aba276d5cf906.diff
LOG: [MLIR] Introduce generic visitors.
- Generic visitors invoke operation callbacks before/in-between/after visiting the regions
attached to an operation and use a `WalkStage` to indicate which regions have been
visited.
- This can be useful for cases where we need to visit the operation in between visiting
regions attached to the operation.
Differential Revision: https://reviews.llvm.org/D116230
Added:
mlir/test/IR/generic-visitors-interrupt.mlir
mlir/test/IR/generic-visitors.mlir
mlir/test/lib/IR/TestVisitorsGeneric.cpp
Modified:
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/IR/Visitors.cpp
mlir/test/lib/IR/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 0f0f16603ce43..b65026187d1df 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -510,10 +510,40 @@ class alignas(8) Operation final
/// });
template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
typename RetT = detail::walkResultType<FnT>>
- RetT walk(FnT &&callback) {
+ typename std::enable_if<
+ llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type
+ walk(FnT &&callback) {
return detail::walk<Order>(this, std::forward<FnT>(callback));
}
+ /// Generic walker with a stage aware callback. Walk the operation by calling
+ /// the callback for each nested operation (including this one) N+1 times,
+ /// where N is the number of regions attached to that operation.
+ ///
+ /// The callback method can take any of the following forms:
+ /// void(Operation *, const WalkStage &) : Walk all operation opaquely
+ /// * op->walk([](Operation *nestedOp, const WalkStage &stage) { ...});
+ /// void(OpT, const WalkStage &) : Walk all operations of the given derived
+ /// type.
+ /// * op->walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
+ /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
+ /// but allow for interruption/skipping.
+ /// * op->walk([](... op, const WalkStage &stage) {
+ /// // Skip the walk of this op based on some invariant.
+ /// if (some_invariant)
+ /// return WalkResult::skip();
+ /// // Interrupt, i.e cancel, the walk based on some invariant.
+ /// if (another_invariant)
+ /// return WalkResult::interrupt();
+ /// return WalkResult::advance();
+ /// });
+ template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+ typename std::enable_if<
+ llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type
+ walk(FnT &&callback) {
+ return detail::walk(this, std::forward<FnT>(callback));
+ }
+
//===--------------------------------------------------------------------===//
// Uses
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 7827a9518cfbb..661541a555233 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -61,13 +61,49 @@ class WalkResult {
/// Traversal order for region, block and operation walk utilities.
enum class WalkOrder { PreOrder, PostOrder };
+/// A utility class to encode the current walk stage for "generic" walkers.
+/// When walking an operation, we can either choose a Pre/Post order walker
+/// which invokes the callback on an operation before/after all its attached
+/// regions have been visited, or choose a "generic" walker where the callback
+/// is invoked on the operation N+1 times where N is the number of regions
+/// attached to that operation. The `WalkStage` class below encodes the current
+/// stage of the walk, i.e., which regions have already been visited, and the
+/// callback accepts an additional argument for the current stage. Such
+/// generic walkers that accept stage-aware callbacks are only applicable when
+/// the callback operates on an operation (i.e., not applicable for callbacks
+/// on Blocks or Regions).
+class WalkStage {
+public:
+ explicit WalkStage(Operation *op);
+
+ /// Return true if parent operation is being visited before all regions.
+ bool isBeforeAllRegions() const { return nextRegion == 0; }
+ /// Returns true if parent operation is being visited just before visiting
+ /// region number `region`.
+ bool isBeforeRegion(int region) const { return nextRegion == region; }
+ /// Returns true if parent operation is being visited just after visiting
+ /// region number `region`.
+ bool isAfterRegion(int region) const { return nextRegion == region + 1; }
+ /// Return true if parent operation is being visited after all regions.
+ bool isAfterAllRegions() const { return nextRegion == numRegions; }
+ /// Advance the walk stage.
+ void advance() { nextRegion++; }
+ /// Returns the next region that will be visited.
+ int getNextRegion() const { return nextRegion; }
+
+private:
+ const int numRegions;
+ int nextRegion;
+};
+
namespace detail {
/// Helper templates to deduce the first argument of a callback parameter.
-template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
-template <typename Ret, typename F, typename Arg>
-Arg first_argument_type(Ret (F::*)(Arg));
-template <typename Ret, typename F, typename Arg>
-Arg first_argument_type(Ret (F::*)(Arg) const);
+template <typename Ret, typename Arg, typename... Rest>
+Arg first_argument_type(Ret (*)(Arg, Rest...));
+template <typename Ret, typename F, typename Arg, typename... Rest>
+Arg first_argument_type(Ret (F::*)(Arg, Rest...));
+template <typename Ret, typename F, typename Arg, typename... Rest>
+Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
template <typename F>
decltype(first_argument_type(&F::operator())) first_argument_type(F);
@@ -197,6 +233,87 @@ walk(Operation *op, FuncTy &&callback) {
return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
}
+/// Generic walkers with stage aware callbacks.
+
+/// Walk all the operations nested under (and including) the given operation,
+/// with the callback being invoked on each operation N+1 times, where N is the
+/// number of regions attached to the operation. The `stage` input to the
+/// callback indicates the current walk stage. This method is invoked for void
+/// returning callbacks.
+void walk(Operation *op,
+ function_ref<void(Operation *, const WalkStage &stage)> callback);
+
+/// Walk all the operations nested under (and including) the given operation,
+/// with the callback being invoked on each operation N+1 times, where N is the
+/// number of regions attached to the operation. The `stage` input to the
+/// callback indicates the current walk stage. This method is invoked for
+/// skippable or interruptible callbacks.
+WalkResult
+walk(Operation *op,
+ function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
+
+/// Walk all of the operations nested under and including the given operation.
+/// This method is selected for stage-aware callbacks that operate on
+/// Operation*.
+///
+/// Example:
+/// op->walk([](Operation *op, const WalkStage &stage) { ... });
+template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+ typename RetT = decltype(std::declval<FuncTy>()(
+ std::declval<ArgT>(), std::declval<const WalkStage &>()))>
+typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
+walk(Operation *op, FuncTy &&callback) {
+ return detail::walk(op,
+ function_ref<RetT(ArgT, const WalkStage &)>(callback));
+}
+
+/// Walk all of the operations of type 'ArgT' nested under and including the
+/// given operation. This method is selected for void returning callbacks that
+/// operate on a specific derived operation type.
+///
+/// Example:
+/// op->walk([](ReturnOp op, const WalkStage &stage) { ... });
+template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+ typename RetT = decltype(std::declval<FuncTy>()(
+ std::declval<ArgT>(), std::declval<const WalkStage &>()))>
+typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
+ std::is_same<RetT, void>::value,
+ RetT>::type
+walk(Operation *op, FuncTy &&callback) {
+ auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
+ if (auto derivedOp = dyn_cast<ArgT>(op))
+ callback(derivedOp, stage);
+ };
+ return detail::walk(
+ op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
+}
+
+/// Walk all of the operations of type 'ArgT' nested under and including the
+/// given operation. This method is selected for WalkReturn returning
+/// interruptible callbacks that operate on a specific derived operation type.
+///
+/// Example:
+/// op->walk(op, [](ReturnOp op, const WalkStage &stage) {
+/// if (some_invariant)
+/// return WalkResult::interrupt();
+/// return WalkResult::advance();
+/// });
+template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+ typename RetT = decltype(std::declval<FuncTy>()(
+ std::declval<ArgT>(), std::declval<const WalkStage &>()))>
+typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
+ std::is_same<RetT, WalkResult>::value,
+ RetT>::type
+walk(Operation *op, FuncTy &&callback) {
+ auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
+ if (auto derivedOp = dyn_cast<ArgT>(op))
+ return callback(derivedOp, stage);
+ return WalkResult::advance();
+ };
+ return detail::walk(
+ op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
+}
+
/// Utility to provide the return type of a templated walk method.
template <typename FnT>
using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index efe7da4032913..822d881400ee0 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -11,6 +11,9 @@
using namespace mlir;
+WalkStage::WalkStage(Operation *op)
+ : numRegions(op->getNumRegions()), nextRegion(0) {}
+
/// Walk all of the regions/blocks/operations nested under and including the
/// given operation. Regions, blocks and operations at the same nesting level
/// are visited in lexicographical order. The walk order for enclosing regions,
@@ -67,6 +70,25 @@ void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
callback(op);
}
+void detail::walk(Operation *op,
+ function_ref<void(Operation *, const WalkStage &)> callback) {
+ WalkStage stage(op);
+
+ for (Region ®ion : op->getRegions()) {
+ // Invoke callback on the parent op before visiting each child region.
+ callback(op, stage);
+ stage.advance();
+
+ for (Block &block : region) {
+ for (Operation &nestedOp : block)
+ walk(&nestedOp, callback);
+ }
+ }
+
+ // Invoke callback after all regions have been visited.
+ callback(op, stage);
+}
+
/// Walk all of the regions/blocks/operations nested under and including the
/// given operation. These functions walk operations until an interrupt result
/// is returned by the callback. Walks on regions, blocks and operations may
@@ -157,3 +179,29 @@ WalkResult detail::walk(Operation *op,
return callback(op);
return WalkResult::advance();
}
+
+WalkResult detail::walk(
+ Operation *op,
+ function_ref<WalkResult(Operation *, const WalkStage &)> callback) {
+ WalkStage stage(op);
+
+ for (Region ®ion : op->getRegions()) {
+ // Invoke callback on the parent op before visiting each child region.
+ WalkResult result = callback(op, stage);
+
+ if (result.wasSkipped())
+ return WalkResult::advance();
+ if (result.wasInterrupted())
+ return WalkResult::interrupt();
+
+ stage.advance();
+
+ for (Block &block : region) {
+ // Early increment here in the case where the operation is erased.
+ for (Operation &nestedOp : llvm::make_early_inc_range(block))
+ if (walk(&nestedOp, callback).wasInterrupted())
+ return WalkResult::interrupt();
+ }
+ }
+ return callback(op, stage);
+}
diff --git a/mlir/test/IR/generic-visitors-interrupt.mlir b/mlir/test/IR/generic-visitors-interrupt.mlir
new file mode 100644
index 0000000000000..3b4cb0496b71c
--- /dev/null
+++ b/mlir/test/IR/generic-visitors-interrupt.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// Walk is interrupted before visiting "foo"
+func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() {interrupt_before_all = true} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 walk was interrupted
+
+// -----
+
+// Walk is interrupted after visiting "foo" (which has a single empty region)
+func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() ({ "bar"() : ()-> () }) {interrupt_after_all = true} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar' before all regions
+// CHECK: step 4 walk was interrupted
+
+// -----
+
+// Walk is interrupted after visiting "foo"'s 1st region.
+func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() ({
+ "bar0"() : () -> ()
+ }, {
+ "bar1"() : () -> ()
+ }) {interrupt_after_region = 0} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar0' before all regions
+// CHECK: step 4 walk was interrupted
+
+
+// -----
+
+// Test static filtering.
+func @main() {
+ "foo"() : () -> ()
+ "test.two_region_op"()(
+ {"work"() : () -> ()},
+ {"work"() : () -> ()}
+ ) {interrupt_after_all = true} : () -> ()
+ return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'test.two_region_op' before all regions
+// CHECK: step 4 op 'work' before all regions
+// CHECK: step 5 op 'test.two_region_op' before region #1
+// CHECK: step 6 op 'work' before all regions
+// CHECK: step 7 walk was interrupted
+// CHECK: step 8 op 'test.two_region_op' before all regions
+// CHECK: step 9 op 'test.two_region_op' before region #1
+// CHECK: step 10 walk was interrupted
+
+// -----
+
+// Test static filtering.
+func @main() {
+ "foo"() : () -> ()
+ "test.two_region_op"()(
+ {"work"() : () -> ()},
+ {"work"() : () -> ()}
+ ) {interrupt_after_region = 0} : () -> ()
+ return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'test.two_region_op' before all regions
+// CHECK: step 4 op 'work' before all regions
+// CHECK: step 5 walk was interrupted
+// CHECK: step 6 op 'test.two_region_op' before all regions
+// CHECK: step 7 walk was interrupted
+
+// -----
+// Test skipping.
+
+// Walk is skipped before visiting "foo".
+func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() ({
+ "bar0"() : () -> ()
+ }, {
+ "bar1"() : () -> ()
+ }) {skip_before_all = true} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'arith.addf' before all regions
+// CHECK: step 3 op 'std.return' before all regions
+// CHECK: step 4 op 'builtin.func' after all regions
+// CHECK: step 5 op 'builtin.module' after all regions
+
+// -----
+// Walk is skipped after visiting all regions of "foo".
+func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() ({
+ "bar0"() : () -> ()
+ }, {
+ "bar1"() : () -> ()
+ }) {skip_after_all = true} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar0' before all regions
+// CHECK: step 4 op 'foo' before region #1
+// CHECK: step 5 op 'bar1' before all regions
+// CHECK: step 6 op 'arith.addf' before all regions
+// CHECK: step 7 op 'std.return' before all regions
+// CHECK: step 8 op 'builtin.func' after all regions
+// CHECK: step 9 op 'builtin.module' after all regions
+
+// -----
+// Walk is skipped after visiting first region of "foo".
+func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() ({
+ "bar0"() : () -> ()
+ }, {
+ "bar1"() : () -> ()
+ }) {skip_after_region = 0} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar0' before all regions
+// CHECK: step 4 op 'arith.addf' before all regions
+// CHECK: step 5 op 'std.return' before all regions
+// CHECK: step 6 op 'builtin.func' after all regions
+// CHECK: step 7 op 'builtin.module' after all regions
diff --git a/mlir/test/IR/generic-visitors.mlir b/mlir/test/IR/generic-visitors.mlir
new file mode 100644
index 0000000000000..c87bd559be9b3
--- /dev/null
+++ b/mlir/test/IR/generic-visitors.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt -test-generic-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// Verify the
diff erent configurations of generic IR visitors.
+
+func @structured_cfg() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ scf.for %i = %c1 to %c10 step %c1 {
+ %cond = "use0"(%i) : (index) -> (i1)
+ scf.if %cond {
+ "use1"(%i) : (index) -> ()
+ } else {
+ "use2"(%i) : (index) -> ()
+ }
+ "use3"(%i) : (index) -> ()
+ }
+ return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'arith.constant' before all regions
+// CHECK: step 3 op 'arith.constant' before all regions
+// CHECK: step 4 op 'arith.constant' before all regions
+// CHECK: step 5 op 'scf.for' before all regions
+// CHECK: step 6 op 'use0' before all regions
+// CHECK: step 7 op 'scf.if' before all regions
+// CHECK: step 8 op 'use1' before all regions
+// CHECK: step 9 op 'scf.yield' before all regions
+// CHECK: step 10 op 'scf.if' before region #1
+// CHECK: step 11 op 'use2' before all regions
+// CHECK: step 12 op 'scf.yield' before all regions
+// CHECK: step 13 op 'scf.if' after all regions
+// CHECK: step 14 op 'use3' before all regions
+// CHECK: step 15 op 'scf.yield' before all regions
+// CHECK: step 16 op 'scf.for' after all regions
+// CHECK: step 17 op 'std.return' before all regions
+// CHECK: step 18 op 'builtin.func' after all regions
+// CHECK: step 19 op 'builtin.module' after all regions
+
+// -----
+// Test the specific operation type visitor.
+
+func @correct_number_of_regions() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ scf.for %i = %c1 to %c10 step %c1 {
+ "test.two_region_op"()(
+ {"work"() : () -> ()},
+ {"work"() : () -> ()}
+ ) : () -> ()
+ }
+ return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 15 op 'builtin.module' after all regions
+// CHECK: step 16 op 'test.two_region_op' before all regions
+// CHECK: step 17 op 'test.two_region_op' before region #1
+// CHECK: step 18 op 'test.two_region_op' after all regions
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 0b5833979ac94..f656a4e6934ef 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_library(MLIRTestIR
TestSymbolUses.cpp
TestTypes.cpp
TestVisitors.cpp
+ TestVisitorsGeneric.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
new file mode 100644
index 0000000000000..0a73c71061562
--- /dev/null
+++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
@@ -0,0 +1,123 @@
+//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static std::string getStageDescription(const WalkStage &stage) {
+ if (stage.isBeforeAllRegions())
+ return "before all regions";
+ if (stage.isAfterAllRegions())
+ return "after all regions";
+ return "before region #" + std::to_string(stage.getNextRegion());
+}
+
+namespace {
+/// This pass exercises generic visitor with void callbacks and prints the order
+/// and stage in which operations are visited.
+class TestGenericIRVisitorPass
+ : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
+public:
+ StringRef getArgument() const final { return "test-generic-ir-visitors"; }
+ StringRef getDescription() const final { return "Test generic IR visitors."; }
+ void runOnOperation() override {
+ Operation *outerOp = getOperation();
+ int stepNo = 0;
+ outerOp->walk([&](Operation *op, const WalkStage &stage) {
+ llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
+ << getStageDescription(stage) << "\n";
+ });
+
+ // Exercise static inference of operation type.
+ outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
+ llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
+ << getStageDescription(stage) << "\n";
+ });
+ }
+};
+
+/// This pass exercises the generic visitor with non-void callbacks and prints
+/// the order and stage in which operations are visited. It will interrupt the
+/// walk based on attributes peesent in the IR.
+class TestGenericIRVisitorInterruptPass
+ : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
+public:
+ StringRef getArgument() const final {
+ return "test-generic-ir-visitors-interrupt";
+ }
+ StringRef getDescription() const final {
+ return "Test generic IR visitors with interrupts.";
+ }
+ void runOnOperation() override {
+ Operation *outerOp = getOperation();
+ int stepNo = 0;
+
+ auto walker = [&](Operation *op, const WalkStage &stage) {
+ if (auto interruptBeforeAall =
+ op->getAttrOfType<BoolAttr>("interrupt_before_all"))
+ if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
+ return WalkResult::interrupt();
+
+ if (auto interruptAfterAll =
+ op->getAttrOfType<BoolAttr>("interrupt_after_all"))
+ if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
+ return WalkResult::interrupt();
+
+ if (auto interruptAfterRegion =
+ op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
+ if (stage.isAfterRegion(
+ static_cast<int>(interruptAfterRegion.getInt())))
+ return WalkResult::interrupt();
+
+ if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
+ if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
+ return WalkResult::skip();
+
+ if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
+ if (skipAfterAll.getValue() && stage.isAfterAllRegions())
+ return WalkResult::skip();
+
+ if (auto skipAfterRegion =
+ op->getAttrOfType<IntegerAttr>("skip_after_region"))
+ if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
+ return WalkResult::skip();
+
+ llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
+ << getStageDescription(stage) << "\n";
+ return WalkResult::advance();
+ };
+
+ // Interrupt the walk based on attributes on the operation.
+ auto result = outerOp->walk(walker);
+
+ if (result.wasInterrupted())
+ llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+
+ // Exercise static inference of operation type.
+ result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
+ return walker(op, stage);
+ });
+
+ if (result.wasInterrupted())
+ llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+ }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestGenericIRVisitorsPass() {
+ PassRegistration<TestGenericIRVisitorPass>();
+ PassRegistration<TestGenericIRVisitorInterruptPass>();
+}
+
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3a2d83ebbe77b..120585c806a94 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -78,6 +78,8 @@ void registerTestExpandTanhPass();
void registerTestComposeSubView();
void registerTestGpuParallelLoopMappingPass();
void registerTestIRVisitorsPass();
+void registerTestGenericIRVisitorsPass();
+void registerTestGenericIRVisitorsInterruptPass();
void registerTestInterfaces();
void registerTestLinalgCodegenStrategy();
void registerTestLinalgControlFuseByExpansion();
@@ -171,6 +173,7 @@ void registerTestPasses() {
mlir::test::registerTestComposeSubView();
mlir::test::registerTestGpuParallelLoopMappingPass();
mlir::test::registerTestIRVisitorsPass();
+ mlir::test::registerTestGenericIRVisitorsPass();
mlir::test::registerTestInterfaces();
mlir::test::registerTestLinalgCodegenStrategy();
mlir::test::registerTestLinalgControlFuseByExpansion();
More information about the Mlir-commits
mailing list