[Mlir-commits] [mlir] ca7167d - [mlir][Transforms][NFC] GreedyPatternRewriteDriver: Add worklist class

Matthias Springer llvmlistbot at llvm.org
Thu May 25 00:16:24 PDT 2023


Author: Matthias Springer
Date: 2023-05-25T09:16:13+02:00
New Revision: ca7167d5a07f703a15ec9c3aea8b8461bf6bac29

URL: https://github.com/llvm/llvm-project/commit/ca7167d5a07f703a15ec9c3aea8b8461bf6bac29
DIFF: https://github.com/llvm/llvm-project/commit/ca7167d5a07f703a15ec9c3aea8b8461bf6bac29.diff

LOG: [mlir][Transforms][NFC] GreedyPatternRewriteDriver: Add worklist class

Encapsulate all worklist-related functionality in a separate `Worklist` class. This makes the remaining code more readable and allows for custom worklist implementations (e.g., a randomized worklist for fuzzing pattern application: D142447).

Differential Revision: https://reviews.llvm.org/D151345

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4bc514b62c70f..050f18c8677b7 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -31,11 +31,12 @@ using namespace mlir;
 
 #define DEBUG_TYPE "greedy-rewriter"
 
+namespace {
+
 //===----------------------------------------------------------------------===//
 // Debugging Infrastructure
 //===----------------------------------------------------------------------===//
 
-namespace {
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 /// A helper struct that stores finger prints of ops in order to detect broken
 /// RewritePatterns. A rewrite pattern is broken if it modifies IR without
@@ -130,6 +131,100 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
 };
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
+//===----------------------------------------------------------------------===//
+// Worklist
+//===----------------------------------------------------------------------===//
+
+/// A LIFO worklist of operations with efficient removal and set semantics.
+///
+/// This class maintains a vector of operations and a mapping of operations to
+/// positions in the vector, so that operations can be removed efficiently at
+/// random. When an operation is removed, it is replaced with nullptr. Such
+/// nullptr are skipped when pop'ing elements.
+class Worklist {
+public:
+  Worklist();
+
+  /// Clear the worklist.
+  void clear();
+
+  /// Return whether the worklist is empty.
+  bool empty() const;
+
+  /// Push an operation to the end of the worklist, unless the operation is
+  /// already on the worklist.
+  void push(Operation *op);
+
+  /// Pop the an operation from the end of the worklist. Only allowed on
+  /// non-empty worklists.
+  Operation *pop();
+
+  /// Remove an operation from the worklist.
+  void remove(Operation *op);
+
+  /// Reverse the worklist.
+  void reverse();
+
+private:
+  /// The worklist of operations.
+  std::vector<Operation *> list;
+
+  /// A mapping of operations to positions in `list`.
+  DenseMap<Operation *, unsigned> map;
+};
+
+Worklist::Worklist() { list.reserve(64); }
+
+void Worklist::clear() {
+  list.clear();
+  map.clear();
+}
+
+bool Worklist::empty() const {
+  // Skip all nullptr.
+  return !llvm::any_of(list,
+                       [](Operation *op) { return static_cast<bool>(op); });
+}
+
+void Worklist::push(Operation *op) {
+  assert(op && "cannot push nullptr to worklist");
+  // Check to see if the worklist already contains this op.
+  if (map.count(op))
+    return;
+  map[op] = list.size();
+  list.push_back(op);
+}
+
+Operation *Worklist::pop() {
+  assert(!empty() && "cannot pop from empty worklist");
+  // Skip and remove all trailing nullptr.
+  while (!list.back())
+    list.pop_back();
+  Operation *op = list.back();
+  list.pop_back();
+  map.erase(op);
+  // Cleanup: Remove all trailing nullptr.
+  while (!list.empty() && !list.back())
+    list.pop_back();
+  return op;
+}
+
+void Worklist::remove(Operation *op) {
+  assert(op && "cannot remove nullptr from worklist");
+  auto it = map.find(op);
+  if (it != map.end()) {
+    assert(list[it->second] == op && "malformed worklist data structure");
+    list[it->second] = nullptr;
+    map.erase(it);
+  }
+}
+
+void Worklist::reverse() {
+  std::reverse(list.begin(), list.end());
+  for (size_t i = 0, e = list.size(); i != e; ++i)
+    map[list[i]] = i;
+}
+
 //===----------------------------------------------------------------------===//
 // GreedyPatternRewriteDriver
 //===----------------------------------------------------------------------===//
@@ -176,11 +271,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   bool processWorklist();
 
   /// The worklist for this transformation keeps track of the operations that
-  /// need to be revisited, plus their index in the worklist.  This allows us to
-  /// efficiently remove operations from the worklist when they are erased, even
-  /// if they aren't the root of a pattern.
-  std::vector<Operation *> worklist;
-  DenseMap<Operation *, unsigned> worklistMap;
+  /// need to be (re)visited.
+  Worklist worklist;
 
   /// Non-pattern based folder for operations.
   OperationFolder folder;
@@ -201,9 +293,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// simplifications.
   void addOperandsToWorklist(ValueRange operands);
 
-  /// Pop the next operation from the worklist.
-  Operation *popFromWorklist();
-
   /// Notify the driver that the given block was created.
   void notifyBlockCreated(Block *block) override;
 
@@ -212,9 +301,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   notifyMatchFailure(Location loc,
                      function_ref<void(Diagnostic &)> reasonCallback) override;
 
-  /// If the specified operation is in the worklist, remove it.
-  void removeFromWorklist(Operation *op);
-
 #ifndef NDEBUG
   /// A logger used to emit information during the application process.
   llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -239,8 +325,6 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 // clang-format on
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 {
-  worklist.reserve(64);
-
   // Apply a simple cost model based solely on pattern benefit.
   matcher.applyDefaultCostModel();
 
@@ -278,12 +362,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
   while (!worklist.empty() &&
          (numRewrites < config.maxNumRewrites ||
           config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
-    auto *op = popFromWorklist();
-
-    // Nulls get added to the worklist when operations are removed, ignore
-    // them.
-    if (op == nullptr)
-      continue;
+    auto *op = worklist.pop();
 
     LLVM_DEBUG({
       logger.getOStream() << "\n";
@@ -395,33 +474,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
 
 void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
   if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
-      strictModeFilteredOps.contains(op)) {
-    // Check to see if the worklist already contains this op.
-    if (worklistMap.count(op))
-      return;
-
-    worklistMap[op] = worklist.size();
-    worklist.push_back(op);
-  }
-}
-
-Operation *GreedyPatternRewriteDriver::popFromWorklist() {
-  auto *op = worklist.back();
-  worklist.pop_back();
-
-  // This operation is no longer in the worklist, keep worklistMap up to date.
-  if (op)
-    worklistMap.erase(op);
-  return op;
-}
-
-void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
-  auto it = worklistMap.find(op);
-  if (it != worklistMap.end()) {
-    assert(worklist[it->second] == op && "malformed worklist data structure");
-    worklist[it->second] = nullptr;
-    worklistMap.erase(it);
-  }
+      strictModeFilteredOps.contains(op))
+    worklist.push(op);
 }
 
 void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
@@ -475,7 +529,7 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
 
   addOperandsToWorklist(op->getOperands());
   op->walk([this](Operation *operation) {
-    removeFromWorklist(operation);
+    worklist.remove(operation);
     folder.notifyRemoval(operation);
   });
 
@@ -580,7 +634,6 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
       break;
 
     worklist.clear();
-    worklistMap.clear();
 
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
@@ -599,10 +652,7 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
       });
 
       // Reverse the list so our pop-back loop processes them in-order.
-      std::reverse(worklist.begin(), worklist.end());
-      // Remember the reverse index.
-      for (size_t i = 0, e = worklist.size(); i != e; ++i)
-        worklistMap[worklist[i]] = i;
+      worklist.reverse();
     }
 
     ctx->executeAction<GreedyPatternRewriteIteration>(


        


More information about the Mlir-commits mailing list