[Mlir-commits] [mlir] 71a8624 - [mlir] Extend Operation visitor with pre-order traversal

Diego Caballero llvmlistbot at llvm.org
Fri Mar 5 14:06:12 PST 2021


Author: Diego Caballero
Date: 2021-03-06T00:02:20+02:00
New Revision: 71a86245ca620d26ed63a86dda3b65a533f5df6b

URL: https://github.com/llvm/llvm-project/commit/71a86245ca620d26ed63a86dda3b65a533f5df6b
DIFF: https://github.com/llvm/llvm-project/commit/71a86245ca620d26ed63a86dda3b65a533f5df6b.diff

LOG: [mlir] Extend Operation visitor with pre-order traversal

This patch extends the Region, Block and Operation visitors to also support pre-order walks.
We introduce a new template argument that dictates the walk order (only pre-order and
post-order are supported for now). The default order for Regions, Blocks and Operations is
post-order. Mixed orders (e.g., Region/Block pre-order + Operation post-order) could easily
be implemented, as shown in NumberOfExecutions.cpp.

Reviewed By: rriddle, frgossen, bondhugula

Differential Revision: https://reviews.llvm.org/D97217

Added: 
    

Modified: 
    mlir/include/mlir/IR/Block.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Operation.h
    mlir/include/mlir/IR/Region.h
    mlir/include/mlir/IR/Visitors.h
    mlir/lib/Analysis/Liveness.cpp
    mlir/lib/Analysis/NumberOfExecutions.cpp
    mlir/lib/IR/Visitors.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index a04063814423..9f265b3b56f5 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -249,34 +249,40 @@ class Block : public IRObjectWithUseList<BlockOperand>,
   // Operation Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operations in this block in postorder, calling the callback for
-  /// each operation.
+  /// Walk the operations in this block, calling the callback for each
+  /// operation. The walk order for regions, blocks and operations is specified
+  /// by 'Order' (post-order by default).
   /// See Operation::walk for more details.
-  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+            typename RetT = detail::walkResultType<FnT>>
   RetT walk(FnT &&callback) {
-    return walk(begin(), end(), std::forward<FnT>(callback));
+    return walk<Order>(begin(), end(), std::forward<FnT>(callback));
   }
 
-  /// Walk the operations in the specified [begin, end) range of this block in
-  /// postorder, calling the callback for each operation. This method is invoked
-  /// for void return callbacks.
+  /// Walk the operations in the specified [begin, end) range of this block,
+  /// calling the callback for each operation. The walk order for regions,
+  /// blocks and operations is specified by 'Order' (post-order by default).
+  /// This method is invoked for void return callbacks.
   /// See Operation::walk for more details.
-  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+            typename RetT = detail::walkResultType<FnT>>
   typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
   walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
     for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      detail::walk(&op, callback);
+      detail::walk<Order>(&op, callback);
   }
 
-  /// Walk the operations in the specified [begin, end) range of this block in
-  /// postorder, calling the callback for each operation. This method is invoked
-  /// for interruptible callbacks.
+  /// Walk the operations in the specified [begin, end) range of this block,
+  /// calling the callback for each operation. The walk order for regions,
+  /// blocks and operations is specified by 'Order' (post-order by default).
+  /// This method is invoked for interruptible callbacks.
   /// See Operation::walk for more details.
-  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+            typename RetT = detail::walkResultType<FnT>>
   typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
   walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
     for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
-      if (detail::walk(&op, callback).wasInterrupted())
+      if (detail::walk<Order>(&op, callback).wasInterrupted())
         return WalkResult::interrupt();
     return WalkResult::advance();
   }

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index c65e653f2a50..4c9399f2a6fc 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -165,12 +165,14 @@ class OpState {
   /// handlers that may be listening.
   InFlightDiagnostic emitRemark(const Twine &message = {});
 
-  /// Walk the operation in postorder, calling the callback for each nested
-  /// operation(including this one).
+  /// Walk the operation by calling the callback for each nested
+  /// operation(including this one). The walk order for regions, blocks and
+  /// operations is specified by 'Order' (post-order by default).
   /// See Operation::walk for more details.
-  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+            typename RetT = detail::walkResultType<FnT>>
   RetT walk(FnT &&callback) {
-    return state->walk(std::forward<FnT>(callback));
+    return state->walk<Order>(std::forward<FnT>(callback));
   }
 
   // These are default implementations of customization hooks.

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 027c878929f4..679071457e52 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -484,9 +484,10 @@ class alignas(8) Operation final
   // Operation Walkers
   //===--------------------------------------------------------------------===//
 
-  /// Walk the operation in postorder, calling the callback for each nested
-  /// operation(including this one). The callback method can take any of the
-  /// following forms:
+  /// Walk the operation by calling the callback for each nested operation
+  /// (including this one). The walk order for regions, blocks and operations is
+  /// specified by 'Order' (post-order by default). The callback method can take
+  /// any of the following forms:
   ///   void(Operation*) : Walk all operations opaquely.
   ///     * op->walk([](Operation *nestedOp) { ...});
   ///   void(OpT) : Walk all operations of the given derived type.
@@ -499,9 +500,10 @@ class alignas(8) Operation final
   ///           return WalkResult::interrupt();
   ///         return WalkResult::advance();
   ///       });
-  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+            typename RetT = detail::walkResultType<FnT>>
   RetT walk(FnT &&callback) {
-    return detail::walk(this, std::forward<FnT>(callback));
+    return detail::walk<Order>(this, std::forward<FnT>(callback));
   }
 
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h
index 9b6f54fec936..349be0f5aedc 100644
--- a/mlir/include/mlir/IR/Region.h
+++ b/mlir/include/mlir/IR/Region.h
@@ -243,23 +243,29 @@ class Region {
   //===--------------------------------------------------------------------===//
 
   /// Walk the operations in this region in postorder, calling the callback for
-  /// each operation. This method is invoked for void-returning callbacks.
+  /// each operation. The walk order for regions, blocks and operations is
+  /// specified by 'Order' (post-order by default). This method is invoked for
+  /// void-returning callbacks.
   /// See Operation::walk for more details.
-  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+            typename RetT = detail::walkResultType<FnT>>
   typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
   walk(FnT &&callback) {
     for (auto &block : *this)
-      block.walk(callback);
+      block.walk<Order>(callback);
   }
 
   /// Walk the operations in this region in postorder, calling the callback for
-  /// each operation. This method is invoked for interruptible callbacks.
+  /// each operation. The walk order for regions, blocks and operations is
+  /// specified by 'Order' (post-order by default). This method is invoked for
+  /// interruptible callbacks.
   /// See Operation::walk for more details.
-  template <typename FnT, typename RetT = detail::walkResultType<FnT>>
+  template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
+            typename RetT = detail::walkResultType<FnT>>
   typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
   walk(FnT &&callback) {
     for (auto &block : *this)
-      if (block.walk(callback).wasInterrupted())
+      if (block.walk<Order>(callback).wasInterrupted())
         return WalkResult::interrupt();
     return WalkResult::advance();
   }

diff  --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index abb7f0f50e15..b4571e19f8fd 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -49,6 +49,9 @@ class WalkResult {
   bool wasInterrupted() const { return result == Interrupt; }
 };
 
+/// Traversal order for region, block and operation walk utilities.
+enum class WalkOrder { PreOrder, PostOrder };
+
 namespace detail {
 /// Helper templates to deduce the first argument of a callback parameter.
 template <typename Ret, typename Arg> Arg first_argument_type(Ret (*)(Arg));
@@ -64,17 +67,21 @@ template <typename T>
 using first_argument = decltype(first_argument_type(std::declval<T>()));
 
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation.
-void walk(Operation *op, function_ref<void(Region *)> callback);
-void walk(Operation *op, function_ref<void(Block *)> callback);
-void walk(Operation *op, function_ref<void(Operation *)> callback);
-
+/// the given operation. The walk order is specified by 'order'.
+void walk(Operation *op, function_ref<void(Region *)> callback,
+          WalkOrder order);
+void walk(Operation *op, function_ref<void(Block *)> callback, WalkOrder order);
+void walk(Operation *op, function_ref<void(Operation *)> callback,
+          WalkOrder order);
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. These functions walk until an interrupt result is
-/// returned by the callback.
-WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
-WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
-WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
+/// the given operation. The walk order is specified by 'order'. These functions
+/// walk until an interrupt result is returned by the callback.
+WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback,
+                WalkOrder order);
+WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback,
+                WalkOrder order);
+WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback,
+                WalkOrder order);
 
 // Below are a set of functions to walk nested operations. Users should favor
 // the direct `walk` methods on the IR classes(Operation/Block/etc) over these
@@ -82,7 +89,8 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
 // upon the type of the callback function.
 
 /// Walk all of the regions, blocks, or operations nested under (and including)
-/// the given operation. This method is selected for callbacks that operate on
+/// the given operation. The walk order is specified by 'Order' (post-order
+/// by default). This method is selected for callbacks that operate on
 /// Region*, Block*, and Operation*.
 ///
 /// Example:
@@ -90,22 +98,25 @@ WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
 ///   op->walk([](Block *b) { ... });
 ///   op->walk([](Operation *op) { ... });
 template <
-    typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+    WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
+    typename ArgT = detail::first_argument<FuncTy>,
     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
 typename std::enable_if<
     llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, RetT>::type
 walk(Operation *op, FuncTy &&callback) {
-  return walk(op, function_ref<RetT(ArgT)>(callback));
+  return detail::walk(op, function_ref<RetT(ArgT)>(callback), Order);
 }
 
 /// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. This method is selected for void returning callbacks that
-/// operate on a specific derived operation type.
+/// given operation. The walk order for regions, blocks and operations is
+/// specified by 'Order' (post-order by default). This method is selected for
+/// void returning callbacks that operate on a specific derived operation type.
 ///
 /// Example:
 ///   op->walk([](ReturnOp op) { ... });
 template <
-    typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+    WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
+    typename ArgT = detail::first_argument<FuncTy>,
     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
 typename std::enable_if<
     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
@@ -116,12 +127,14 @@ walk(Operation *op, FuncTy &&callback) {
     if (auto derivedOp = dyn_cast<ArgT>(op))
       callback(derivedOp);
   };
-  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
+  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
 }
 
 /// Walk all of the operations of type 'ArgT' nested under and including the
-/// given operation. This method is selected for WalkReturn returning
-/// interruptible callbacks that operate on a specific derived operation type.
+/// given operation. The walk order for regions, blocks and operations is
+/// specified by 'Order' (post-order by default). This method is selected for
+/// WalkReturn returning interruptible callbacks that operate on a specific
+/// derived operation type.
 ///
 /// Example:
 ///   op->walk([](ReturnOp op) {
@@ -130,7 +143,8 @@ walk(Operation *op, FuncTy &&callback) {
 ///     return WalkResult::advance();
 ///   });
 template <
-    typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
+    WalkOrder Order = WalkOrder::PostOrder, typename FuncTy,
+    typename ArgT = detail::first_argument<FuncTy>,
     typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
 typename std::enable_if<
     !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
@@ -142,7 +156,7 @@ walk(Operation *op, FuncTy &&callback) {
       return callback(derivedOp);
     return WalkResult::advance();
   };
-  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
+  return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn), Order);
 }
 
 /// Utility to provide the return type of a templated walk method.

diff  --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp
index 4dae386e94b2..c9558a7d502c 100644
--- a/mlir/lib/Analysis/Liveness.cpp
+++ b/mlir/lib/Analysis/Liveness.cpp
@@ -130,7 +130,7 @@ static void buildBlockMapping(Operation *operation,
                               DenseMap<Block *, BlockInfoBuilder> &builders) {
   llvm::SetVector<Block *> toProcess;
 
-  operation->walk([&](Block *block) {
+  operation->walk<WalkOrder::PreOrder>([&](Block *block) {
     BlockInfoBuilder &builder =
         builders.try_emplace(block, block).first->second;
 
@@ -270,7 +270,7 @@ void Liveness::print(raw_ostream &os) const {
   DenseMap<Block *, size_t> blockIds;
   DenseMap<Operation *, size_t> operationIds;
   DenseMap<Value, size_t> valueIds;
-  operation->walk([&](Block *block) {
+  operation->walk<WalkOrder::PreOrder>([&](Block *block) {
     blockIds.insert({block, blockIds.size()});
     for (BlockArgument argument : block->getArguments())
       valueIds.insert({argument, valueIds.size()});
@@ -304,7 +304,7 @@ void Liveness::print(raw_ostream &os) const {
   };
 
   // Dump information about in and out values.
-  operation->walk([&](Block *block) {
+  operation->walk<WalkOrder::PreOrder>([&](Block *block) {
     os << "// - Block: " << blockIds[block] << "\n";
     const auto *liveness = getLiveness(block);
     os << "// --- LiveIn: ";

diff  --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp
index 425936b5eaaf..26f31c0913d1 100644
--- a/mlir/lib/Analysis/NumberOfExecutions.cpp
+++ b/mlir/lib/Analysis/NumberOfExecutions.cpp
@@ -115,7 +115,7 @@ static void computeRegionBlockNumberOfExecutions(
 /// Creates a new NumberOfExecutions analysis that computes how many times a
 /// block within a region is executed for all associated regions.
 NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) {
-  operation->walk([&](Region *region) {
+  operation->walk<WalkOrder::PreOrder>([&](Region *region) {
     computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution);
   });
 }
@@ -191,7 +191,7 @@ void NumberOfExecutions::printBlockExecutions(
     raw_ostream &os, Region *perEntryOfThisRegion) const {
   unsigned blockId = 0;
 
-  operation->walk([&](Block *block) {
+  operation->walk<WalkOrder::PreOrder>([&](Block *block) {
     llvm::errs() << "Block: " << blockId++ << "\n";
     llvm::errs() << "Number of executions: ";
     if (auto n = getNumberOfExecutions(block, perEntryOfThisRegion))
@@ -203,7 +203,7 @@ void NumberOfExecutions::printBlockExecutions(
 
 void NumberOfExecutions::printOperationExecutions(
     raw_ostream &os, Region *perEntryOfThisRegion) const {
-  operation->walk([&](Block *block) {
+  operation->walk<WalkOrder::PreOrder>([&](Block *block) {
     block->walk([&](Operation *operation) {
       // Skip the operation that was used to build the analysis.
       if (operation == this->operation)

diff  --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index d03bdb508d37..be995a2a4fb2 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -12,79 +12,112 @@
 using namespace mlir;
 
 /// Walk all of the regions/blocks/operations nested under and including the
-/// given operation.
-void detail::walk(Operation *op, function_ref<void(Region *)> callback) {
+/// given operation. The walk order is specified by 'Order'.
+
+void detail::walk(Operation *op, function_ref<void(Region *)> callback,
+                  WalkOrder order) {
   for (auto &region : op->getRegions()) {
-    callback(&region);
+    if (order == WalkOrder::PreOrder)
+      callback(&region);
     for (auto &block : region) {
       for (auto &nestedOp : block)
-        walk(&nestedOp, callback);
+        walk(&nestedOp, callback, order);
     }
+    if (order == WalkOrder::PostOrder)
+      callback(&region);
   }
 }
 
-void detail::walk(Operation *op, function_ref<void(Block *)> callback) {
+void detail::walk(Operation *op, function_ref<void(Block *)> callback,
+                  WalkOrder order) {
   for (auto &region : op->getRegions()) {
     for (auto &block : region) {
-      callback(&block);
+      if (order == WalkOrder::PreOrder)
+        callback(&block);
       for (auto &nestedOp : block)
-        walk(&nestedOp, callback);
+        walk(&nestedOp, callback, order);
+      if (order == WalkOrder::PostOrder)
+        callback(&block);
     }
   }
 }
 
-void detail::walk(Operation *op, function_ref<void(Operation *op)> callback) {
+void detail::walk(Operation *op, function_ref<void(Operation *)> callback,
+                  WalkOrder order) {
+  if (order == WalkOrder::PreOrder)
+    callback(op);
+
   // TODO: This walk should be iterative over the operations.
   for (auto &region : op->getRegions()) {
     for (auto &block : region) {
       // Early increment here in the case where the operation is erased.
       for (auto &nestedOp : llvm::make_early_inc_range(block))
-        walk(&nestedOp, callback);
+        walk(&nestedOp, callback, order);
     }
   }
-  callback(op);
+
+  if (order == WalkOrder::PostOrder)
+    callback(op);
 }
 
 /// Walk all of the regions/blocks/operations nested under and including the
-/// given operation. These functions walk operations until an interrupt result
-/// is returned by the callback.
+/// given operation. The walk order is specified by 'order'. These functions
+/// walk operations until an interrupt result is returned by the callback.
 WalkResult detail::walk(Operation *op,
-                        function_ref<WalkResult(Region *op)> callback) {
+                        function_ref<WalkResult(Region *)> callback,
+                        WalkOrder order) {
   for (auto &region : op->getRegions()) {
-    if (callback(&region).wasInterrupted())
-      return WalkResult::interrupt();
+    if (order == WalkOrder::PreOrder)
+      if (callback(&region).wasInterrupted())
+        return WalkResult::interrupt();
     for (auto &block : region) {
       for (auto &nestedOp : block)
-        walk(&nestedOp, callback);
+        walk(&nestedOp, callback, order);
     }
+    if (order == WalkOrder::PostOrder)
+      if (callback(&region).wasInterrupted())
+        return WalkResult::interrupt();
   }
   return WalkResult::advance();
 }
 
 WalkResult detail::walk(Operation *op,
-                        function_ref<WalkResult(Block *op)> callback) {
+                        function_ref<WalkResult(Block *)> callback,
+                        WalkOrder order) {
   for (auto &region : op->getRegions()) {
     for (auto &block : region) {
-      if (callback(&block).wasInterrupted())
-        return WalkResult::interrupt();
+      if (order == WalkOrder::PreOrder)
+        if (callback(&block).wasInterrupted())
+          return WalkResult::interrupt();
       for (auto &nestedOp : block)
-        walk(&nestedOp, callback);
+        walk(&nestedOp, callback, order);
+      if (order == WalkOrder::PostOrder)
+        if (callback(&block).wasInterrupted())
+          return WalkResult::interrupt();
     }
   }
   return WalkResult::advance();
 }
 
 WalkResult detail::walk(Operation *op,
-                        function_ref<WalkResult(Operation *op)> callback) {
+                        function_ref<WalkResult(Operation *)> callback,
+                        WalkOrder order) {
+  if (order == WalkOrder::PreOrder)
+    if (callback(op).wasInterrupted())
+      return WalkResult::interrupt();
+
   // TODO: This walk should be iterative over the operations.
   for (auto &region : op->getRegions()) {
     for (auto &block : region) {
       // Early increment here in the case where the operation is erased.
       for (auto &nestedOp : llvm::make_early_inc_range(block)) {
-        if (walk(&nestedOp, callback).wasInterrupted())
+        if (walk(&nestedOp, callback, order).wasInterrupted())
           return WalkResult::interrupt();
       }
     }
   }
-  return callback(op);
+
+  if (order == WalkOrder::PostOrder)
+    return callback(op);
+  return WalkResult::advance();
 }


        


More information about the Mlir-commits mailing list