[Mlir-commits] [mlir] 8067ced - [MLIR] Introduce generic visitors.

Rahul Joshi llvmlistbot at llvm.org
Fri Jan 14 09:16:08 PST 2022


Author: Rahul Joshi
Date: 2022-01-14T09:15:27-08:00
New Revision: 8067ced144a213574b6c8cb4c15aba276d5cf906

URL: https://github.com/llvm/llvm-project/commit/8067ced144a213574b6c8cb4c15aba276d5cf906
DIFF: https://github.com/llvm/llvm-project/commit/8067ced144a213574b6c8cb4c15aba276d5cf906.diff

LOG: [MLIR] Introduce generic visitors.

- Generic visitors invoke operation callbacks before/in-between/after visiting the regions
  attached to an operation and use a `WalkStage` to indicate which regions have been
  visited.
- This can be useful for cases where we need to visit the operation in between visiting
  regions attached to the operation.

Differential Revision: https://reviews.llvm.org/D116230

Added: 
    mlir/test/IR/generic-visitors-interrupt.mlir
    mlir/test/IR/generic-visitors.mlir
    mlir/test/lib/IR/TestVisitorsGeneric.cpp

Modified: 
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/Visitors.h
    mlir/lib/IR/Visitors.cpp
    mlir/test/lib/IR/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 0f0f16603ce43..b65026187d1df 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -510,10 +510,40 @@ class alignas(8) Operation final
   ///       });
   template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
             typename RetT = detail::walkResultType<FnT>>
-  RetT walk(FnT &&callback) {
+  typename std::enable_if<
+      llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type
+  walk(FnT &&callback) {
     return detail::walk<Order>(this, std::forward<FnT>(callback));
   }
 
+  /// Generic walker with a stage aware callback. Walk the operation by calling
+  /// the callback for each nested operation (including this one) N+1 times,
+  /// where N is the number of regions attached to that operation.
+  ///
+  /// The callback method can take any of the following forms:
+  ///   void(Operation *, const WalkStage &) : Walk all operation opaquely
+  ///     * op->walk([](Operation *nestedOp, const WalkStage &stage) { ...});
+  ///   void(OpT, const WalkStage &) : Walk all operations of the given derived
+  ///                                  type.
+  ///     * op->walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
+  ///   WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
+  ///          but allow for interruption/skipping.
+  ///     * op->walk([](... op, const WalkStage &stage) {
+  ///         // Skip the walk of this op based on some invariant.
+  ///         if (some_invariant)
+  ///           return WalkResult::skip();
+  ///         // Interrupt, i.e cancel, the walk based on some invariant.
+  ///         if (another_invariant)
+  ///           return WalkResult::interrupt();
+  ///         return WalkResult::advance();
+  ///       });
+  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  typename std::enable_if<
+      llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type
+  walk(FnT &&callback) {
+    return detail::walk(this, std::forward<FnT>(callback));
+  }
+
   //===--------------------------------------------------------------------===//
   // Uses
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 7827a9518cfbb..661541a555233 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -61,13 +61,49 @@ class WalkResult {
 /// Traversal order for region, block and operation walk utilities.
 enum class WalkOrder { PreOrder, PostOrder };
 
+/// A utility class to encode the current walk stage for "generic" walkers.
+/// When walking an operation, we can either choose a Pre/Post order walker
+/// which invokes the callback on an operation before/after all its attached
+/// regions have been visited, or choose a "generic" walker where the callback
+/// is invoked on the operation N+1 times where N is the number of regions
+/// attached to that operation. The `WalkStage` class below encodes the current
+/// stage of the walk, i.e., which regions have already been visited, and the
+/// callback accepts an additional argument for the current stage. Such
+/// generic walkers that accept stage-aware callbacks are only applicable when
+/// the callback operates on an operation (i.e., not applicable for callbacks
+/// on Blocks or Regions).
+class WalkStage {
+public:
+  explicit WalkStage(Operation *op);
+
+  /// Return true if parent operation is being visited before all regions.
+  bool isBeforeAllRegions() const { return nextRegion == 0; }
+  /// Returns true if parent operation is being visited just before visiting
+  /// region number `region`.
+  bool isBeforeRegion(int region) const { return nextRegion == region; }
+  /// Returns true if parent operation is being visited just after visiting
+  /// region number `region`.
+  bool isAfterRegion(int region) const { return nextRegion == region + 1; }
+  /// Return true if parent operation is being visited after all regions.
+  bool isAfterAllRegions() const { return nextRegion == numRegions; }
+  /// Advance the walk stage.
+  void advance() { nextRegion++; }
+  /// Returns the next region that will be visited.
+  int getNextRegion() const { return nextRegion; }
+
+private:
+  const int numRegions;
+  int nextRegion;
+};
+
 namespace detail {
 /// Helper templates to deduce the first argument of a callback parameter.
-template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
-template <typename Ret, typename F, typename Arg>
-Arg first_argument_type(Ret (F::*)(Arg));
-template <typename Ret, typename F, typename Arg>
-Arg first_argument_type(Ret (F::*)(Arg) const);
+template <typename Ret, typename Arg, typename... Rest>
+Arg first_argument_type(Ret (*)(Arg, Rest...));
+template <typename Ret, typename F, typename Arg, typename... Rest>
+Arg first_argument_type(Ret (F::*)(Arg, Rest...));
+template <typename Ret, typename F, typename Arg, typename... Rest>
+Arg first_argument_type(Ret (F::*)(Arg, Rest...) const);
 template <typename F>
 decltype(first_argument_type(&F::operator())) first_argument_type(F);
 
@@ -197,6 +233,87 @@ walk(Operation *op, FuncTy &&callback) {
   return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
 }
 
+/// Generic walkers with stage aware callbacks.
+
+/// Walk all the operations nested under (and including) the given operation,
+/// with the callback being invoked on each operation N+1 times, where N is the
+/// number of regions attached to the operation. The `stage` input to the
+/// callback indicates the current walk stage. This method is invoked for void
+/// returning callbacks.
+void walk(Operation *op,
+          function_ref<void(Operation *, const WalkStage &stage)> callback);
+
+/// Walk all the operations nested under (and including) the given operation,
+/// with the callback being invoked on each operation N+1 times, where N is the
+/// number of regions attached to the operation. The `stage` input to the
+/// callback indicates the current walk stage. This method is invoked for
+/// skippable or interruptible callbacks.
+WalkResult
+walk(Operation *op,
+     function_ref<WalkResult(Operation *, const WalkStage &stage)> callback);
+
+/// Walk all of the operations nested under and including the given operation.
+/// This method is selected for stage-aware callbacks that operate on
+/// Operation*.
+///
+/// Example:
+///   op->walk([](Operation *op, const WalkStage &stage) { ... });
+template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+          typename RetT = decltype(std::declval<FuncTy>()(
+              std::declval<ArgT>(), std::declval<const WalkStage &>()))>
+typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
+walk(Operation *op, FuncTy &&callback) {
+  return detail::walk(op,
+                      function_ref<RetT(ArgT, const WalkStage &)>(callback));
+}
+
+/// Walk all of the operations of type 'ArgT' nested under and including the
+/// given operation. This method is selected for void returning callbacks that
+/// operate on a specific derived operation type.
+///
+/// Example:
+///   op->walk([](ReturnOp op, const WalkStage &stage) { ... });
+template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+          typename RetT = decltype(std::declval<FuncTy>()(
+              std::declval<ArgT>(), std::declval<const WalkStage &>()))>
+typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
+                            std::is_same<RetT, void>::value,
+                        RetT>::type
+walk(Operation *op, FuncTy &&callback) {
+  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
+    if (auto derivedOp = dyn_cast<ArgT>(op))
+      callback(derivedOp, stage);
+  };
+  return detail::walk(
+      op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
+}
+
+/// Walk all of the operations of type 'ArgT' nested under and including the
+/// given operation. This method is selected for WalkReturn returning
+/// interruptible callbacks that operate on a specific derived operation type.
+///
+/// Example:
+///   op->walk(op, [](ReturnOp op, const WalkStage &stage) {
+///     if (some_invariant)
+///       return WalkResult::interrupt();
+///     return WalkResult::advance();
+///   });
+template <typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+          typename RetT = decltype(std::declval<FuncTy>()(
+              std::declval<ArgT>(), std::declval<const WalkStage &>()))>
+typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
+                            std::is_same<RetT, WalkResult>::value,
+                        RetT>::type
+walk(Operation *op, FuncTy &&callback) {
+  auto wrapperFn = [&](Operation *op, const WalkStage &stage) {
+    if (auto derivedOp = dyn_cast<ArgT>(op))
+      return callback(derivedOp, stage);
+    return WalkResult::advance();
+  };
+  return detail::walk(
+      op, function_ref<RetT(Operation *, const WalkStage &)>(wrapperFn));
+}
+
 /// Utility to provide the return type of a templated walk method.
 template <typename FnT>
 using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));

diff  --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index efe7da4032913..822d881400ee0 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -11,6 +11,9 @@
 
 using namespace mlir;
 
+WalkStage::WalkStage(Operation *op)
+    : numRegions(op->getNumRegions()), nextRegion(0) {}
+
 /// Walk all of the regions/blocks/operations nested under and including the
 /// given operation. Regions, blocks and operations at the same nesting level
 /// are visited in lexicographical order. The walk order for enclosing regions,
@@ -67,6 +70,25 @@ void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
     callback(op);
 }
 
+void detail::walk(Operation *op,
+                  function_ref<void(Operation *, const WalkStage &)> callback) {
+  WalkStage stage(op);
+
+  for (Region &region : op->getRegions()) {
+    // Invoke callback on the parent op before visiting each child region.
+    callback(op, stage);
+    stage.advance();
+
+    for (Block &block : region) {
+      for (Operation &nestedOp : block)
+        walk(&nestedOp, callback);
+    }
+  }
+
+  // Invoke callback after all regions have been visited.
+  callback(op, stage);
+}
+
 /// Walk all of the regions/blocks/operations nested under and including the
 /// given operation. These functions walk operations until an interrupt result
 /// is returned by the callback. Walks on regions, blocks and operations may
@@ -157,3 +179,29 @@ WalkResult detail::walk(Operation *op,
     return callback(op);
   return WalkResult::advance();
 }
+
+WalkResult detail::walk(
+    Operation *op,
+    function_ref<WalkResult(Operation *, const WalkStage &)> callback) {
+  WalkStage stage(op);
+
+  for (Region &region : op->getRegions()) {
+    // Invoke callback on the parent op before visiting each child region.
+    WalkResult result = callback(op, stage);
+
+    if (result.wasSkipped())
+      return WalkResult::advance();
+    if (result.wasInterrupted())
+      return WalkResult::interrupt();
+
+    stage.advance();
+
+    for (Block &block : region) {
+      // Early increment here in the case where the operation is erased.
+      for (Operation &nestedOp : llvm::make_early_inc_range(block))
+        if (walk(&nestedOp, callback).wasInterrupted())
+          return WalkResult::interrupt();
+    }
+  }
+  return callback(op, stage);
+}

diff  --git a/mlir/test/IR/generic-visitors-interrupt.mlir b/mlir/test/IR/generic-visitors-interrupt.mlir
new file mode 100644
index 0000000000000..3b4cb0496b71c
--- /dev/null
+++ b/mlir/test/IR/generic-visitors-interrupt.mlir
@@ -0,0 +1,157 @@
+// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// Walk is interrupted before visiting "foo"
+func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() {interrupt_before_all = true} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 walk was interrupted
+
+// -----
+
+// Walk is interrupted after visiting "foo" (which has a single empty region)
+func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() ({ "bar"() : ()-> () }) {interrupt_after_all = true} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar' before all regions
+// CHECK: step 4 walk was interrupted
+
+// -----
+
+// Walk is interrupted after visiting "foo"'s 1st region.
+func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() ({
+    "bar0"() : () -> ()
+  }, {
+    "bar1"() : () -> ()
+  }) {interrupt_after_region = 0} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar0' before all regions
+// CHECK: step 4 walk was interrupted
+
+
+// -----
+
+// Test static filtering.
+func @main() {
+  "foo"() : () -> ()
+  "test.two_region_op"()(
+    {"work"() : () -> ()},
+    {"work"() : () -> ()}
+  ) {interrupt_after_all = true} : () -> ()
+  return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'test.two_region_op' before all regions
+// CHECK: step 4 op 'work' before all regions
+// CHECK: step 5 op 'test.two_region_op' before region #1
+// CHECK: step 6 op 'work' before all regions
+// CHECK: step 7 walk was interrupted
+// CHECK: step 8 op 'test.two_region_op' before all regions
+// CHECK: step 9 op 'test.two_region_op' before region #1
+// CHECK: step 10 walk was interrupted
+
+// -----
+
+// Test static filtering.
+func @main() {
+  "foo"() : () -> ()
+  "test.two_region_op"()(
+    {"work"() : () -> ()},
+    {"work"() : () -> ()}
+  ) {interrupt_after_region = 0} : () -> ()
+  return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'test.two_region_op' before all regions
+// CHECK: step 4 op 'work' before all regions
+// CHECK: step 5 walk was interrupted
+// CHECK: step 6 op 'test.two_region_op' before all regions
+// CHECK: step 7 walk was interrupted
+
+// -----
+// Test skipping.
+
+// Walk is skipped before visiting "foo".
+func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() ({
+    "bar0"() : () -> ()
+  }, {
+    "bar1"() : () -> ()
+  }) {skip_before_all = true} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'arith.addf' before all regions
+// CHECK: step 3 op 'std.return' before all regions
+// CHECK: step 4 op 'builtin.func' after all regions
+// CHECK: step 5 op 'builtin.module' after all regions
+
+// -----
+// Walk is skipped after visiting all regions of "foo".
+func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() ({
+    "bar0"() : () -> ()
+  }, {
+    "bar1"() : () -> ()
+  }) {skip_after_all = true} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar0' before all regions
+// CHECK: step 4 op 'foo' before region #1
+// CHECK: step 5 op 'bar1' before all regions
+// CHECK: step 6 op 'arith.addf' before all regions
+// CHECK: step 7 op 'std.return' before all regions
+// CHECK: step 8 op 'builtin.func' after all regions
+// CHECK: step 9 op 'builtin.module' after all regions
+
+// -----
+// Walk is skipped after visiting first region of "foo".
+func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() ({
+    "bar0"() : () -> ()
+  }, {
+    "bar1"() : () -> ()
+  }) {skip_after_region = 0} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'foo' before all regions
+// CHECK: step 3 op 'bar0' before all regions
+// CHECK: step 4 op 'arith.addf' before all regions
+// CHECK: step 5 op 'std.return' before all regions
+// CHECK: step 6 op 'builtin.func' after all regions
+// CHECK: step 7 op 'builtin.module' after all regions

diff  --git a/mlir/test/IR/generic-visitors.mlir b/mlir/test/IR/generic-visitors.mlir
new file mode 100644
index 0000000000000..c87bd559be9b3
--- /dev/null
+++ b/mlir/test/IR/generic-visitors.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt -test-generic-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+// Verify the 
diff erent configurations of generic IR visitors.
+
+func @structured_cfg() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  scf.for %i = %c1 to %c10 step %c1 {
+    %cond = "use0"(%i) : (index) -> (i1)
+    scf.if %cond {
+      "use1"(%i) : (index) -> ()
+    } else {
+      "use2"(%i) : (index) -> ()
+    }
+    "use3"(%i) : (index) -> ()
+  }
+  return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 1 op 'builtin.func' before all regions
+// CHECK: step 2 op 'arith.constant' before all regions
+// CHECK: step 3 op 'arith.constant' before all regions
+// CHECK: step 4 op 'arith.constant' before all regions
+// CHECK: step 5 op 'scf.for' before all regions
+// CHECK: step 6 op 'use0' before all regions
+// CHECK: step 7 op 'scf.if' before all regions
+// CHECK: step 8 op 'use1' before all regions
+// CHECK: step 9 op 'scf.yield' before all regions
+// CHECK: step 10 op 'scf.if' before region #1
+// CHECK: step 11 op 'use2' before all regions
+// CHECK: step 12 op 'scf.yield' before all regions
+// CHECK: step 13 op 'scf.if' after all regions
+// CHECK: step 14 op 'use3' before all regions
+// CHECK: step 15 op 'scf.yield' before all regions
+// CHECK: step 16 op 'scf.for' after all regions
+// CHECK: step 17 op 'std.return' before all regions
+// CHECK: step 18 op 'builtin.func' after all regions
+// CHECK: step 19 op 'builtin.module' after all regions
+
+// -----
+// Test the specific operation type visitor.
+
+func @correct_number_of_regions() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  scf.for %i = %c1 to %c10 step %c1 {
+    "test.two_region_op"()(
+      {"work"() : () -> ()},
+      {"work"() : () -> ()}
+    ) : () -> ()
+  }
+  return
+}
+
+// CHECK: step 0 op 'builtin.module' before all regions
+// CHECK: step 15 op 'builtin.module' after all regions
+// CHECK: step 16 op 'test.two_region_op' before all regions
+// CHECK: step 17 op 'test.two_region_op' before region #1
+// CHECK: step 18 op 'test.two_region_op' after all regions

diff  --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
index 0b5833979ac94..f656a4e6934ef 100644
--- a/mlir/test/lib/IR/CMakeLists.txt
+++ b/mlir/test/lib/IR/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_library(MLIRTestIR
   TestSymbolUses.cpp
   TestTypes.cpp
   TestVisitors.cpp
+  TestVisitorsGeneric.cpp
 
   EXCLUDE_FROM_LIBMLIR
 

diff  --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
new file mode 100644
index 0000000000000..0a73c71061562
--- /dev/null
+++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
@@ -0,0 +1,123 @@
+//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestDialect.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+static std::string getStageDescription(const WalkStage &stage) {
+  if (stage.isBeforeAllRegions())
+    return "before all regions";
+  if (stage.isAfterAllRegions())
+    return "after all regions";
+  return "before region #" + std::to_string(stage.getNextRegion());
+}
+
+namespace {
+/// This pass exercises generic visitor with void callbacks and prints the order
+/// and stage in which operations are visited.
+class TestGenericIRVisitorPass
+    : public PassWrapper<TestGenericIRVisitorPass, OperationPass<>> {
+public:
+  StringRef getArgument() const final { return "test-generic-ir-visitors"; }
+  StringRef getDescription() const final { return "Test generic IR visitors."; }
+  void runOnOperation() override {
+    Operation *outerOp = getOperation();
+    int stepNo = 0;
+    outerOp->walk([&](Operation *op, const WalkStage &stage) {
+      llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
+                   << getStageDescription(stage) << "\n";
+    });
+
+    // Exercise static inference of operation type.
+    outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
+      llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
+                   << getStageDescription(stage) << "\n";
+    });
+  }
+};
+
+/// This pass exercises the generic visitor with non-void callbacks and prints
+/// the order and stage in which operations are visited. It will interrupt the
+/// walk based on attributes peesent in the IR.
+class TestGenericIRVisitorInterruptPass
+    : public PassWrapper<TestGenericIRVisitorInterruptPass, OperationPass<>> {
+public:
+  StringRef getArgument() const final {
+    return "test-generic-ir-visitors-interrupt";
+  }
+  StringRef getDescription() const final {
+    return "Test generic IR visitors with interrupts.";
+  }
+  void runOnOperation() override {
+    Operation *outerOp = getOperation();
+    int stepNo = 0;
+
+    auto walker = [&](Operation *op, const WalkStage &stage) {
+      if (auto interruptBeforeAall =
+              op->getAttrOfType<BoolAttr>("interrupt_before_all"))
+        if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions())
+          return WalkResult::interrupt();
+
+      if (auto interruptAfterAll =
+              op->getAttrOfType<BoolAttr>("interrupt_after_all"))
+        if (interruptAfterAll.getValue() && stage.isAfterAllRegions())
+          return WalkResult::interrupt();
+
+      if (auto interruptAfterRegion =
+              op->getAttrOfType<IntegerAttr>("interrupt_after_region"))
+        if (stage.isAfterRegion(
+                static_cast<int>(interruptAfterRegion.getInt())))
+          return WalkResult::interrupt();
+
+      if (auto skipBeforeAall = op->getAttrOfType<BoolAttr>("skip_before_all"))
+        if (skipBeforeAall.getValue() && stage.isBeforeAllRegions())
+          return WalkResult::skip();
+
+      if (auto skipAfterAll = op->getAttrOfType<BoolAttr>("skip_after_all"))
+        if (skipAfterAll.getValue() && stage.isAfterAllRegions())
+          return WalkResult::skip();
+
+      if (auto skipAfterRegion =
+              op->getAttrOfType<IntegerAttr>("skip_after_region"))
+        if (stage.isAfterRegion(static_cast<int>(skipAfterRegion.getInt())))
+          return WalkResult::skip();
+
+      llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' "
+                   << getStageDescription(stage) << "\n";
+      return WalkResult::advance();
+    };
+
+    // Interrupt the walk based on attributes on the operation.
+    auto result = outerOp->walk(walker);
+
+    if (result.wasInterrupted())
+      llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+
+    // Exercise static inference of operation type.
+    result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) {
+      return walker(op, stage);
+    });
+
+    if (result.wasInterrupted())
+      llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+  }
+};
+
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestGenericIRVisitorsPass() {
+  PassRegistration<TestGenericIRVisitorPass>();
+  PassRegistration<TestGenericIRVisitorInterruptPass>();
+}
+
+} // namespace test
+} // namespace mlir

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 3a2d83ebbe77b..120585c806a94 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -78,6 +78,8 @@ void registerTestExpandTanhPass();
 void registerTestComposeSubView();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestIRVisitorsPass();
+void registerTestGenericIRVisitorsPass();
+void registerTestGenericIRVisitorsInterruptPass();
 void registerTestInterfaces();
 void registerTestLinalgCodegenStrategy();
 void registerTestLinalgControlFuseByExpansion();
@@ -171,6 +173,7 @@ void registerTestPasses() {
   mlir::test::registerTestComposeSubView();
   mlir::test::registerTestGpuParallelLoopMappingPass();
   mlir::test::registerTestIRVisitorsPass();
+  mlir::test::registerTestGenericIRVisitorsPass();
   mlir::test::registerTestInterfaces();
   mlir::test::registerTestLinalgCodegenStrategy();
   mlir::test::registerTestLinalgControlFuseByExpansion();


        


More information about the Mlir-commits mailing list