[Mlir-commits] [mlir] 1664462 - [MLIR] Support walks over regions and blocks
Frederik Gossen
llvmlistbot at llvm.org
Wed Nov 4 04:50:26 PST 2020
Author: Frederik Gossen
Date: 2020-11-04T12:50:05Z
New Revision: 1664462d70cc399c6a95b88a9e8e73cb44d8a151
URL: https://github.com/llvm/llvm-project/commit/1664462d70cc399c6a95b88a9e8e73cb44d8a151
DIFF: https://github.com/llvm/llvm-project/commit/1664462d70cc399c6a95b88a9e8e73cb44d8a151.diff
LOG: [MLIR] Support walks over regions and blocks
Relands
- [MLIR] Support walks over regions and blocks
(dbae3d50f114a8ec0a7c3211e3b1b9fb6ef22dbd)
- [MLIR] Use llvm::is_one_of in walk templates
(56299b1e58bf3720dff2fe60163739ee1554a371)
Differential Revision: https://reviews.llvm.org/D90753
Added:
Modified:
mlir/include/mlir/Analysis/Liveness.h
mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/Visitors.h
mlir/lib/Analysis/Liveness.cpp
mlir/lib/IR/Visitors.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Liveness.h b/mlir/include/mlir/Analysis/Liveness.h
index be9cb7166b8f..3bd298a0fbe7 100644
--- a/mlir/include/mlir/Analysis/Liveness.h
+++ b/mlir/include/mlir/Analysis/Liveness.h
@@ -86,7 +86,7 @@ class Liveness {
private:
/// Initializes the internal mappings.
- void build(MutableArrayRef<Region> regions);
+ void build();
private:
/// The operation this analysis was constructed from.
diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h
index 3e867976cc32..a04063814423 100644
--- a/mlir/include/mlir/IR/Block.h
+++ b/mlir/include/mlir/IR/Block.h
@@ -265,7 +265,7 @@ class Block : public IRObjectWithUseList<BlockOperand>,
typename std::enable_if<std::is_same<RetT, void>::value, RetT>::type
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- detail::walkOperations(&op, callback);
+ detail::walk(&op, callback);
}
/// Walk the operations in the specified [begin, end) range of this block in
@@ -276,7 +276,7 @@ class Block : public IRObjectWithUseList<BlockOperand>,
typename std::enable_if<std::is_same<RetT, WalkResult>::value, RetT>::type
walk(Block::iterator begin, Block::iterator end, FnT &&callback) {
for (auto &op : llvm::make_early_inc_range(llvm::make_range(begin, end)))
- if (detail::walkOperations(&op, callback).wasInterrupted())
+ if (detail::walk(&op, callback).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
}
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index d3dce868ca64..fa54cb608cf5 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -520,7 +520,7 @@ class Operation final
/// });
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
- return detail::walkOperations(this, std::forward<FnT>(callback));
+ return detail::walk(this, std::forward<FnT>(callback));
}
//===--------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h
index 490ba92662a2..abb7f0f50e15 100644
--- a/mlir/include/mlir/IR/Visitors.h
+++ b/mlir/include/mlir/IR/Visitors.h
@@ -21,6 +21,8 @@ namespace mlir {
class Diagnostic;
class InFlightDiagnostic;
class Operation;
+class Block;
+class Region;
/// A utility result that is used to signal if a walk method should be
/// interrupted or advance.
@@ -61,31 +63,39 @@ decltype(first_argument_type(&F::operator())) first_argument_type(F);
template <typename T>
using first_argument = decltype(first_argument_type(std::declval<T>()));
-/// Walk all of the operations nested under and including the given operation.
-void walkOperations(Operation *op, function_ref<void(Operation *op)> callback);
+/// Walk all of the regions, blocks, or operations nested under (and including)
+/// the given operation.
+void walk(Operation *op, function_ref<void(Region *)> callback);
+void walk(Operation *op, function_ref<void(Block *)> callback);
+void walk(Operation *op, function_ref<void(Operation *)> callback);
-/// Walk all of the operations nested under and including the given operation.
-/// This methods walks operations until an interrupt result is returned by the
-/// callback.
-WalkResult walkOperations(Operation *op,
- function_ref<WalkResult(Operation *op)> callback);
+/// Walk all of the regions, blocks, or operations nested under (and including)
+/// the given operation. These functions walk until an interrupt result is
+/// returned by the callback.
+WalkResult walk(Operation *op, function_ref<WalkResult(Region *)> callback);
+WalkResult walk(Operation *op, function_ref<WalkResult(Block *)> callback);
+WalkResult walk(Operation *op, function_ref<WalkResult(Operation *)> callback);
// Below are a set of functions to walk nested operations. Users should favor
// the direct `walk` methods on the IR classes(Operation/Block/etc) over these
// methods. They are also templated to allow for statically dispatching based
// upon the type of the callback function.
-/// Walk all of the operations nested under and including the given operation.
-/// This method is selected for callbacks that operate on Operation*.
+/// Walk all of the regions, blocks, or operations nested under (and including)
+/// the given operation. This method is selected for callbacks that operate on
+/// Region*, Block*, and Operation*.
///
/// Example:
+/// op->walk([](Region *r) { ... });
+/// op->walk([](Block *b) { ... });
/// op->walk([](Operation *op) { ... });
template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
-typename std::enable_if<std::is_same<ArgT, Operation *>::value, RetT>::type
-walkOperations(Operation *op, FuncTy &&callback) {
- return detail::walkOperations(op, function_ref<RetT(ArgT)>(callback));
+typename std::enable_if<
+ llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value, RetT>::type
+walk(Operation *op, FuncTy &&callback) {
+ return walk(op, function_ref<RetT(ArgT)>(callback));
}
/// Walk all of the operations of type 'ArgT' nested under and including the
@@ -97,15 +107,16 @@ walkOperations(Operation *op, FuncTy &&callback) {
template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
-typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
- std::is_same<RetT, void>::value,
- RetT>::type
-walkOperations(Operation *op, FuncTy &&callback) {
+typename std::enable_if<
+ !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
+ std::is_same<RetT, void>::value,
+ RetT>::type
+walk(Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](Operation *op) {
if (auto derivedOp = dyn_cast<ArgT>(op))
callback(derivedOp);
};
- return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
+ return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
}
/// Walk all of the operations of type 'ArgT' nested under and including the
@@ -121,21 +132,22 @@ walkOperations(Operation *op, FuncTy &&callback) {
template <
typename FuncTy, typename ArgT = detail::first_argument<FuncTy>,
typename RetT = decltype(std::declval<FuncTy>()(std::declval<ArgT>()))>
-typename std::enable_if<!std::is_same<ArgT, Operation *>::value &&
- std::is_same<RetT, WalkResult>::value,
- RetT>::type
-walkOperations(Operation *op, FuncTy &&callback) {
+typename std::enable_if<
+ !llvm::is_one_of<ArgT, Operation *, Region *, Block *>::value &&
+ std::is_same<RetT, WalkResult>::value,
+ RetT>::type
+walk(Operation *op, FuncTy &&callback) {
auto wrapperFn = [&](Operation *op) {
if (auto derivedOp = dyn_cast<ArgT>(op))
return callback(derivedOp);
return WalkResult::advance();
};
- return detail::walkOperations(op, function_ref<RetT(Operation *)>(wrapperFn));
+ return detail::walk(op, function_ref<RetT(Operation *)>(wrapperFn));
}
/// Utility to provide the return type of a templated walk method.
template <typename FnT>
-using walkResultType = decltype(walkOperations(nullptr, std::declval<FnT>()));
+using walkResultType = decltype(walk(nullptr, std::declval<FnT>()));
} // end namespace detail
} // namespace mlir
diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp
index 38fb386f8000..4dae386e94b2 100644
--- a/mlir/lib/Analysis/Liveness.cpp
+++ b/mlir/lib/Analysis/Liveness.cpp
@@ -125,31 +125,17 @@ struct BlockInfoBuilder {
};
} // namespace
-/// Walks all regions (including nested regions recursively) and invokes the
-/// given function for every block.
-template <typename FuncT>
-static void walkRegions(MutableArrayRef<Region> regions, const FuncT &func) {
- for (Region ®ion : regions)
- for (Block &block : region) {
- func(block);
-
- // Traverse all nested regions.
- for (Operation &operation : block)
- walkRegions(operation.getRegions(), func);
- }
-}
-
/// Builds the internal liveness block mapping.
-static void buildBlockMapping(MutableArrayRef<Region> regions,
+static void buildBlockMapping(Operation *operation,
DenseMap<Block *, BlockInfoBuilder> &builders) {
llvm::SetVector<Block *> toProcess;
- walkRegions(regions, [&](Block &block) {
+ operation->walk([&](Block *block) {
BlockInfoBuilder &builder =
- builders.try_emplace(&block, &block).first->second;
+ builders.try_emplace(block, block).first->second;
if (builder.updateLiveIn())
- toProcess.insert(block.pred_begin(), block.pred_end());
+ toProcess.insert(block->pred_begin(), block->pred_end());
});
// Propagate the in and out-value sets (fixpoint iteration)
@@ -172,14 +158,14 @@ static void buildBlockMapping(MutableArrayRef<Region> regions,
/// Creates a new Liveness analysis that computes liveness information for all
/// associated regions.
-Liveness::Liveness(Operation *op) : operation(op) { build(op->getRegions()); }
+Liveness::Liveness(Operation *op) : operation(op) { build(); }
/// Initializes the internal mappings.
-void Liveness::build(MutableArrayRef<Region> regions) {
+void Liveness::build() {
// Build internal block mapping.
DenseMap<Block *, BlockInfoBuilder> builders;
- buildBlockMapping(regions, builders);
+ buildBlockMapping(operation, builders);
// Store internal block data.
for (auto &entry : builders) {
@@ -284,11 +270,11 @@ void Liveness::print(raw_ostream &os) const {
DenseMap<Block *, size_t> blockIds;
DenseMap<Operation *, size_t> operationIds;
DenseMap<Value, size_t> valueIds;
- walkRegions(operation->getRegions(), [&](Block &block) {
- blockIds.insert({&block, blockIds.size()});
- for (BlockArgument argument : block.getArguments())
+ operation->walk([&](Block *block) {
+ blockIds.insert({block, blockIds.size()});
+ for (BlockArgument argument : block->getArguments())
valueIds.insert({argument, valueIds.size()});
- for (Operation &operation : block) {
+ for (Operation &operation : *block) {
operationIds.insert({&operation, operationIds.size()});
for (Value result : operation.getResults())
valueIds.insert({result, valueIds.size()});
@@ -318,9 +304,9 @@ void Liveness::print(raw_ostream &os) const {
};
// Dump information about in and out values.
- walkRegions(operation->getRegions(), [&](Block &block) {
- os << "// - Block: " << blockIds[&block] << "\n";
- auto liveness = getLiveness(&block);
+ operation->walk([&](Block *block) {
+ os << "// - Block: " << blockIds[block] << "\n";
+ const auto *liveness = getLiveness(block);
os << "// --- LiveIn: ";
printValueRefs(liveness->inValues);
os << "\n// --- LiveOut: ";
@@ -329,7 +315,7 @@ void Liveness::print(raw_ostream &os) const {
// Print liveness intervals.
os << "// --- BeginLiveness";
- for (Operation &op : block) {
+ for (Operation &op : *block) {
if (op.getNumResults() < 1)
continue;
os << "\n";
diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index bbccdcbf7592..d03bdb508d37 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -11,31 +11,79 @@
using namespace mlir;
-/// Walk all of the operations nested under and including the given operations.
-void detail::walkOperations(Operation *op,
- function_ref<void(Operation *op)> callback) {
+/// Walk all of the regions/blocks/operations nested under and including the
+/// given operation.
+void detail::walk(Operation *op, function_ref<void(Region *)> callback) {
+ for (auto ®ion : op->getRegions()) {
+ callback(®ion);
+ for (auto &block : region) {
+ for (auto &nestedOp : block)
+ walk(&nestedOp, callback);
+ }
+ }
+}
+
+void detail::walk(Operation *op, function_ref<void(Block *)> callback) {
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region) {
+ callback(&block);
+ for (auto &nestedOp : block)
+ walk(&nestedOp, callback);
+ }
+ }
+}
+
+void detail::walk(Operation *op, function_ref<void(Operation *op)> callback) {
// TODO: This walk should be iterative over the operations.
- for (auto ®ion : op->getRegions())
- for (auto &block : region)
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region) {
// Early increment here in the case where the operation is erased.
for (auto &nestedOp : llvm::make_early_inc_range(block))
- walkOperations(&nestedOp, callback);
-
+ walk(&nestedOp, callback);
+ }
+ }
callback(op);
}
-/// Walk all of the operations nested under and including the given operations.
-/// This methods walks operations until an interrupt signal is received.
-WalkResult
-detail::walkOperations(Operation *op,
- function_ref<WalkResult(Operation *op)> callback) {
+/// Walk all of the regions/blocks/operations nested under and including the
+/// given operation. These functions walk operations until an interrupt result
+/// is returned by the callback.
+WalkResult detail::walk(Operation *op,
+ function_ref<WalkResult(Region *op)> callback) {
+ for (auto ®ion : op->getRegions()) {
+ if (callback(®ion).wasInterrupted())
+ return WalkResult::interrupt();
+ for (auto &block : region) {
+ for (auto &nestedOp : block)
+ walk(&nestedOp, callback);
+ }
+ }
+ return WalkResult::advance();
+}
+
+WalkResult detail::walk(Operation *op,
+ function_ref<WalkResult(Block *op)> callback) {
+ for (auto ®ion : op->getRegions()) {
+ for (auto &block : region) {
+ if (callback(&block).wasInterrupted())
+ return WalkResult::interrupt();
+ for (auto &nestedOp : block)
+ walk(&nestedOp, callback);
+ }
+ }
+ return WalkResult::advance();
+}
+
+WalkResult detail::walk(Operation *op,
+ function_ref<WalkResult(Operation *op)> callback) {
// TODO: This walk should be iterative over the operations.
for (auto ®ion : op->getRegions()) {
for (auto &block : region) {
// Early increment here in the case where the operation is erased.
- for (auto &nestedOp : llvm::make_early_inc_range(block))
- if (walkOperations(&nestedOp, callback).wasInterrupted())
+ for (auto &nestedOp : llvm::make_early_inc_range(block)) {
+ if (walk(&nestedOp, callback).wasInterrupted())
return WalkResult::interrupt();
+ }
}
}
return callback(op);
More information about the Mlir-commits
mailing list