[Mlir-commits] [mlir] ca7167d - [mlir][Transforms][NFC] GreedyPatternRewriteDriver: Add worklist class
Matthias Springer
llvmlistbot at llvm.org
Thu May 25 00:16:24 PDT 2023
Author: Matthias Springer
Date: 2023-05-25T09:16:13+02:00
New Revision: ca7167d5a07f703a15ec9c3aea8b8461bf6bac29
URL: https://github.com/llvm/llvm-project/commit/ca7167d5a07f703a15ec9c3aea8b8461bf6bac29
DIFF: https://github.com/llvm/llvm-project/commit/ca7167d5a07f703a15ec9c3aea8b8461bf6bac29.diff
LOG: [mlir][Transforms][NFC] GreedyPatternRewriteDriver: Add worklist class
Encapsulate all worklist-related functionality in a separate `Worklist` class. This makes the remaining code more readable and allows for custom worklist implementations (e.g., a randomized worklist for fuzzing pattern application: D142447).
Differential Revision: https://reviews.llvm.org/D151345
Added:
Modified:
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4bc514b62c70f..050f18c8677b7 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -31,11 +31,12 @@ using namespace mlir;
#define DEBUG_TYPE "greedy-rewriter"
+namespace {
+
//===----------------------------------------------------------------------===//
// Debugging Infrastructure
//===----------------------------------------------------------------------===//
-namespace {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
/// A helper struct that stores finger prints of ops in order to detect broken
/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
@@ -130,6 +131,100 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+//===----------------------------------------------------------------------===//
+// Worklist
+//===----------------------------------------------------------------------===//
+
+/// A LIFO worklist of operations with efficient removal and set semantics.
+///
+/// This class maintains a vector of operations and a mapping of operations to
+/// positions in the vector, so that operations can be removed efficiently at
+/// random. When an operation is removed, it is replaced with nullptr. Such
+/// nullptr are skipped when pop'ing elements.
+class Worklist {
+public:
+ Worklist();
+
+ /// Clear the worklist.
+ void clear();
+
+ /// Return whether the worklist is empty.
+ bool empty() const;
+
+ /// Push an operation to the end of the worklist, unless the operation is
+ /// already on the worklist.
+ void push(Operation *op);
+
+ /// Pop the an operation from the end of the worklist. Only allowed on
+ /// non-empty worklists.
+ Operation *pop();
+
+ /// Remove an operation from the worklist.
+ void remove(Operation *op);
+
+ /// Reverse the worklist.
+ void reverse();
+
+private:
+ /// The worklist of operations.
+ std::vector<Operation *> list;
+
+ /// A mapping of operations to positions in `list`.
+ DenseMap<Operation *, unsigned> map;
+};
+
+Worklist::Worklist() { list.reserve(64); }
+
+void Worklist::clear() {
+ list.clear();
+ map.clear();
+}
+
+bool Worklist::empty() const {
+ // Skip all nullptr.
+ return !llvm::any_of(list,
+ [](Operation *op) { return static_cast<bool>(op); });
+}
+
+void Worklist::push(Operation *op) {
+ assert(op && "cannot push nullptr to worklist");
+ // Check to see if the worklist already contains this op.
+ if (map.count(op))
+ return;
+ map[op] = list.size();
+ list.push_back(op);
+}
+
+Operation *Worklist::pop() {
+ assert(!empty() && "cannot pop from empty worklist");
+ // Skip and remove all trailing nullptr.
+ while (!list.back())
+ list.pop_back();
+ Operation *op = list.back();
+ list.pop_back();
+ map.erase(op);
+ // Cleanup: Remove all trailing nullptr.
+ while (!list.empty() && !list.back())
+ list.pop_back();
+ return op;
+}
+
+void Worklist::remove(Operation *op) {
+ assert(op && "cannot remove nullptr from worklist");
+ auto it = map.find(op);
+ if (it != map.end()) {
+ assert(list[it->second] == op && "malformed worklist data structure");
+ list[it->second] = nullptr;
+ map.erase(it);
+ }
+}
+
+void Worklist::reverse() {
+ std::reverse(list.begin(), list.end());
+ for (size_t i = 0, e = list.size(); i != e; ++i)
+ map[list[i]] = i;
+}
+
//===----------------------------------------------------------------------===//
// GreedyPatternRewriteDriver
//===----------------------------------------------------------------------===//
@@ -176,11 +271,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
bool processWorklist();
/// The worklist for this transformation keeps track of the operations that
- /// need to be revisited, plus their index in the worklist. This allows us to
- /// efficiently remove operations from the worklist when they are erased, even
- /// if they aren't the root of a pattern.
- std::vector<Operation *> worklist;
- DenseMap<Operation *, unsigned> worklistMap;
+ /// need to be (re)visited.
+ Worklist worklist;
/// Non-pattern based folder for operations.
OperationFolder folder;
@@ -201,9 +293,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// simplifications.
void addOperandsToWorklist(ValueRange operands);
- /// Pop the next operation from the worklist.
- Operation *popFromWorklist();
-
/// Notify the driver that the given block was created.
void notifyBlockCreated(Block *block) override;
@@ -212,9 +301,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
- /// If the specified operation is in the worklist, remove it.
- void removeFromWorklist(Operation *op);
-
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -239,8 +325,6 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
- worklist.reserve(64);
-
// Apply a simple cost model based solely on pattern benefit.
matcher.applyDefaultCostModel();
@@ -278,12 +362,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
while (!worklist.empty() &&
(numRewrites < config.maxNumRewrites ||
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
- auto *op = popFromWorklist();
-
- // Nulls get added to the worklist when operations are removed, ignore
- // them.
- if (op == nullptr)
- continue;
+ auto *op = worklist.pop();
LLVM_DEBUG({
logger.getOStream() << "\n";
@@ -395,33 +474,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
- strictModeFilteredOps.contains(op)) {
- // Check to see if the worklist already contains this op.
- if (worklistMap.count(op))
- return;
-
- worklistMap[op] = worklist.size();
- worklist.push_back(op);
- }
-}
-
-Operation *GreedyPatternRewriteDriver::popFromWorklist() {
- auto *op = worklist.back();
- worklist.pop_back();
-
- // This operation is no longer in the worklist, keep worklistMap up to date.
- if (op)
- worklistMap.erase(op);
- return op;
-}
-
-void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
- auto it = worklistMap.find(op);
- if (it != worklistMap.end()) {
- assert(worklist[it->second] == op && "malformed worklist data structure");
- worklist[it->second] = nullptr;
- worklistMap.erase(it);
- }
+ strictModeFilteredOps.contains(op))
+ worklist.push(op);
}
void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
@@ -475,7 +529,7 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
addOperandsToWorklist(op->getOperands());
op->walk([this](Operation *operation) {
- removeFromWorklist(operation);
+ worklist.remove(operation);
folder.notifyRemoval(operation);
});
@@ -580,7 +634,6 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
break;
worklist.clear();
- worklistMap.clear();
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
@@ -599,10 +652,7 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
});
// Reverse the list so our pop-back loop processes them in-order.
- std::reverse(worklist.begin(), worklist.end());
- // Remember the reverse index.
- for (size_t i = 0, e = worklist.size(); i != e; ++i)
- worklistMap[worklist[i]] = i;
+ worklist.reverse();
}
ctx->executeAction<GreedyPatternRewriteIteration>(
More information about the Mlir-commits
mailing list