[Mlir-commits] [mlir] 9d5c63f - [mlir][NFC] GreedyPatternRewriteDriver: Merge region-based and multi-op-based drivers

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 08:32:09 PST 2023


Author: Matthias Springer
Date: 2023-01-27T17:32:00+01:00
New Revision: 9d5c63f641c8318808e8e62df0a9290d1072ae41

URL: https://github.com/llvm/llvm-project/commit/9d5c63f641c8318808e8e62df0a9290d1072ae41
DIFF: https://github.com/llvm/llvm-project/commit/9d5c63f641c8318808e8e62df0a9290d1072ae41.diff

LOG: [mlir][NFC] GreedyPatternRewriteDriver: Merge region-based and multi-op-based drivers

Deduplicate large parts of the worklist processing (`GreedyPatternRewriteDriver::processWorklist`).

The new class hierarchy is as follows:
```
          GreedyPatternRewriteDriver (abstract)
                       ^
                       |
      -----------------------------------
      |                                 |
RegionPatternRewriteDriver         MultiOpPatternRewriteDriver
```

Also update the Markdown documentation.

Differential Revision: https://reviews.llvm.org/D141396

Added: 
    

Modified: 
    mlir/docs/PatternRewriter.md
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 3ae6547308b98..8428d4ba991ef 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -322,22 +322,48 @@ driver can be found [here](DialectConversion.md).
 
 ### Greedy Pattern Rewrite Driver
 
-This driver walks the provided operations and greedily applies the patterns that
-locally have the most benefit. The benefit of
-a pattern is decided solely by the benefit specified on the pattern, and the
-relative order of the pattern within the pattern list (when two patterns have
-the same local benefit). Patterns are iteratively applied to operations until a
-fixed point is reached, at which point the driver finishes. This driver may be
-used via the following: `applyPatternsAndFoldGreedily` and
-`applyOpPatternsAndFold`. The latter of which only applies patterns to the
-provided operation, and will not traverse the IR.
-
-The driver is configurable and supports two modes: 1) you may opt-in to a
-"top-down" traversal, which seeds the worklist with each operation top down and
-in a pre-order over the region tree.  This is generally more efficient in
-compile time.  2) the default is a "bottom up" traversal, which builds the
-initial worklist with a postorder traversal of the region tree.  This may
-match larger patterns with ambiguous pattern sets.
+This driver processes ops in a worklist-driven fashion and greedily applies the
+patterns that locally have the most benefit. The benefit of a pattern is decided
+solely by the benefit specified on the pattern, and the relative order of the
+pattern within the pattern list (when two patterns have the same local benefit).
+Patterns are iteratively applied to operations until a fixed point is reached or
+until the configurable maximum number of iterations exhausted, at which point
+the driver finishes.
+
+This driver comes in two fashions:
+
+*   `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to
+    all ops in a given region or a given container op (but not the container op
+    itself). I.e., the worklist is initialized with all containing ops.
+*   `applyOpPatternsAndFold` ("op-based driver") applies patterns to the
+    provided list of operations. I.e., the worklist is initialized with the
+    specified list of ops.
+
+The driver is configurable via `GreedyRewriteConfig`. The region-based driver
+supports two modes for populating the initial worklist:
+
+*   Top-down traversal: Traverse the container op/region top down and in
+    pre-order. This is generally more efficient in compile time.
+*   Bottom-up traversal: This is the default setting. It builds the initial
+    worklist with a postorder traversal and then reverses the worklist. This may
+    match larger patterns with ambiguous pattern sets.
+
+By default, ops that were modified in-place and newly created are added back to
+the worklist. Ops that are outside of the configurable "scope" of the driver are
+not added to the worklist. Furthermore, "strict mode" can exclude certain ops
+from being added to the worklist throughout the rewrite process:
+
+*   `GreedyRewriteStrictness::AnyOp`: No ops are excluded (apart from the ones
+    that are out of scope).
+*   `GreedyRewriteStrictness::ExistingAndNewOps`: Only pre-existing ops (with
+    which the worklist was initialized) and newly created ops are added to the
+    worklist.
+*   `GreedyRewriteStrictness::ExistingOps`: Only pre-existing ops (with which
+    the worklist was initialized) are added to the worklist.
+
+Note: This driver listens for IR changes via the callbacks provided by
+`RewriterBase`. It is important that patterns announce all IR changes to the
+rewriter and do not bypass the rewriter API by modifying ops directly.
 
 Note: This driver is the one used by the [canonicalization](Canonicalization.md)
 [pass](Passes.md/#-canonicalize-canonicalize-operations) in MLIR.

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2ef46b1fdcfc1..4c5868aead3f1 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -34,59 +34,44 @@ using namespace mlir;
 
 namespace {
 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
-/// applies the locally optimal patterns in a roughly "bottom up" way.
+/// applies the locally optimal patterns.
+///
+/// This abstract class manages the worklist and contains helper methods for
+/// rewriting ops on the worklist. Derived classes specify how ops are added
+/// to the worklist in the beginning.
 class GreedyPatternRewriteDriver : public PatternRewriter {
-public:
+protected:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
                                       const GreedyRewriteConfig &config);
 
-  /// Simplify the ops within the given region.
-  bool simplify(Region &region) &&;
+  /// Add the given operation to the worklist.
+  void addSingleOpToWorklist(Operation *op);
 
   /// Add the given operation and its ancestors to the worklist.
   void addToWorklist(Operation *op);
 
-  /// Pop the next operation from the worklist.
-  Operation *popFromWorklist();
-
-  /// If the specified operation is in the worklist, remove it.
-  void removeFromWorklist(Operation *op);
-
-  /// Notifies the driver that the specified operation may have been modified
-  /// in-place.
+  /// Notify the driver that the specified operation may have been modified
+  /// in-place. The operation is added to the worklist.
   void finalizeRootUpdate(Operation *op) override;
 
-protected:
-  /// Add the given operation to the worklist.
-  void addSingleOpToWorklist(Operation *op);
-
-  // Implement the hook for inserting operations, and make sure that newly
-  // inserted ops are added to the worklist for processing.
+  /// Notify the driver that the specified operation was inserted. Update the
+  /// worklist as needed: The operation is enqueued depending on scope and
+  /// strict mode.
   void notifyOperationInserted(Operation *op) override;
 
-  // Look over the provided operands for any defining operations that should
-  // be re-added to the worklist. This function should be called when an
-  // operation is modified or removed, as it may trigger further
-  // simplifications.
-  void addOperandsToWorklist(ValueRange operands);
-
-  // If an operation is about to be removed, make sure it is not in our
-  // worklist anymore because we'd get dangling references to it.
+  /// Notify the driver that the specified operation was removed. Update the
+  /// worklist as needed: The operation and its children are removed from the
+  /// worklist.
   void notifyOperationRemoved(Operation *op) override;
 
-  // When the root of a pattern is about to be replaced, it can trigger
-  // simplifications to its users - make sure to add them to the worklist
-  // before the root is changed.
+  /// Notify the driver that the specified operation was replaced. Update the
+  /// worklist as needed: New users are added enqueued.
   void notifyRootReplaced(Operation *op, ValueRange replacement) override;
 
-  /// PatternRewriter hook for notifying match failure reasons.
-  LogicalResult
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) override;
-
-  /// The low-level pattern applicator.
-  PatternApplicator matcher;
+  /// Process ops until the worklist is empty or `config.maxNumRewrites` is
+  /// reached. Return `true` if any IR was changed.
+  bool processWorklist();
 
   /// The worklist for this transformation keeps track of the operations that
   /// need to be revisited, plus their index in the worklist.  This allows us to
@@ -98,7 +83,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   /// Non-pattern based folder for operations.
   OperationFolder folder;
 
-protected:
   /// Configuration information for how to simplify.
   const GreedyRewriteConfig config;
 
@@ -109,17 +93,37 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
 
 private:
+  /// Look over the provided operands for any defining operations that should
+  /// be re-added to the worklist. This function should be called when an
+  /// operation is modified or removed, as it may trigger further
+  /// simplifications.
+  void addOperandsToWorklist(ValueRange operands);
+
+  /// Pop the next operation from the worklist.
+  Operation *popFromWorklist();
+
+  /// For debugging only: Notify the driver of a pattern match failure.
+  LogicalResult
+  notifyMatchFailure(Location loc,
+                     function_ref<void(Diagnostic &)> reasonCallback) override;
+
+  /// If the specified operation is in the worklist, remove it.
+  void removeFromWorklist(Operation *op);
+
 #ifndef NDEBUG
   /// A logger used to emit information during the application process.
   llvm::ScopedPrinter logger{llvm::dbgs()};
 #endif
+
+  /// The low-level pattern applicator.
+  PatternApplicator matcher;
 };
 } // namespace
 
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
-    : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
+    : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) {
   assert(config.scope && "scope is not specified");
   worklist.reserve(64);
 
@@ -127,7 +131,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
   matcher.applyDefaultCostModel();
 }
 
-bool GreedyPatternRewriteDriver::simplify(Region &region) && {
+bool GreedyPatternRewriteDriver::processWorklist() {
 #ifndef NDEBUG
   const char *logLineComment =
       "//===-------------------------------------------===//\n";
@@ -146,130 +150,80 @@ bool GreedyPatternRewriteDriver::simplify(Region &region) && {
   };
 #endif
 
-  auto insertKnownConstant = [&](Operation *op) {
-    // Check for existing constants when populating the worklist. This avoids
-    // accidentally reversing the constant order during processing.
-    Attribute constValue;
-    if (matchPattern(op, m_Constant(&constValue)))
-      if (!folder.insertKnownConstant(op, constValue))
-        return true;
-    return false;
-  };
-
-  // Populate strict mode ops.
-  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
-    strictModeFilteredOps.clear();
-    region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
-  }
+  // These are scratch vectors used in the folding loop below.
+  SmallVector<Value, 8> originalOperands;
 
   bool changed = false;
-  int64_t iteration = 0;
-  do {
-    // Check if the iteration limit was reached.
-    if (iteration++ >= config.maxIterations &&
-        config.maxIterations != GreedyRewriteConfig::kNoLimit)
-      break;
+  int64_t numRewrites = 0;
+  while (!worklist.empty() &&
+         (numRewrites < config.maxNumRewrites ||
+          config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
+    auto *op = popFromWorklist();
 
-    worklist.clear();
-    worklistMap.clear();
+    // Nulls get added to the worklist when operations are removed, ignore
+    // them.
+    if (op == nullptr)
+      continue;
 
-    if (!config.useTopDownTraversal) {
-      // Add operations to the worklist in postorder.
-        region.walk([&](Operation *op) {
-          if (!insertKnownConstant(op))
-            addToWorklist(op);
-        });
-    } else {
-      // Add all nested operations to the worklist in preorder.
-        region.walk<WalkOrder::PreOrder>([&](Operation *op) {
-          if (!insertKnownConstant(op)) {
-            worklist.push_back(op);
-            return WalkResult::advance();
-          }
-          return WalkResult::skip();
-        });
+    LLVM_DEBUG({
+      logger.getOStream() << "\n";
+      logger.startLine() << logLineComment;
+      logger.startLine() << "Processing operation : '" << op->getName() << "'("
+                         << op << ") {\n";
+      logger.indent();
+
+      // If the operation has no regions, just print it here.
+      if (op->getNumRegions() == 0) {
+        op->print(
+            logger.startLine(),
+            OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
+        logger.getOStream() << "\n\n";
+      }
+    });
 
-      // Reverse the list so our pop-back loop processes them in-order.
-      std::reverse(worklist.begin(), worklist.end());
-      // Remember the reverse index.
-      for (size_t i = 0, e = worklist.size(); i != e; ++i)
-        worklistMap[worklist[i]] = i;
+    // If the operation is trivially dead - remove it.
+    if (isOpTriviallyDead(op)) {
+      notifyOperationRemoved(op);
+      op->erase();
+      changed = true;
+
+      LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
+      continue;
     }
 
-    // These are scratch vectors used in the folding loop below.
-    SmallVector<Value, 8> originalOperands, resultValues;
+    // Collects all the operands and result uses of the given `op` into work
+    // list. Also remove `op` and nested ops from worklist.
+    originalOperands.assign(op->operand_begin(), op->operand_end());
+    auto preReplaceAction = [&](Operation *op) {
+      // Add the operands to the worklist for visitation.
+      addOperandsToWorklist(originalOperands);
 
-    changed = false;
-    int64_t numRewrites = 0;
-    while (!worklist.empty() &&
-           (numRewrites < config.maxNumRewrites ||
-            config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
-      auto *op = popFromWorklist();
+      // Add all the users of the result to the worklist so we make sure
+      // to revisit them.
+      for (auto result : op->getResults())
+        for (auto *userOp : result.getUsers())
+          addToWorklist(userOp);
 
-      // Nulls get added to the worklist when operations are removed, ignore
-      // them.
-      if (op == nullptr)
-        continue;
+      notifyOperationRemoved(op);
+    };
 
-      LLVM_DEBUG({
-        logger.getOStream() << "\n";
-        logger.startLine() << logLineComment;
-        logger.startLine() << "Processing operation : '" << op->getName()
-                           << "'(" << op << ") {\n";
-        logger.indent();
-
-        // If the operation has no regions, just print it here.
-        if (op->getNumRegions() == 0) {
-          op->print(
-              logger.startLine(),
-              OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
-          logger.getOStream() << "\n\n";
-        }
-      });
+    // Add the given operation to the worklist.
+    auto collectOps = [this](Operation *op) { addToWorklist(op); };
 
-      // If the operation is trivially dead - remove it.
-      if (isOpTriviallyDead(op)) {
-        notifyOperationRemoved(op);
-        op->erase();
-        changed = true;
+    // Try to fold this op.
+    bool inPlaceUpdate;
+    if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
+                                    &inPlaceUpdate)))) {
+      LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
 
-        LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
+      changed = true;
+      if (!inPlaceUpdate)
         continue;
-      }
-
-      // Collects all the operands and result uses of the given `op` into work
-      // list. Also remove `op` and nested ops from worklist.
-      originalOperands.assign(op->operand_begin(), op->operand_end());
-      auto preReplaceAction = [&](Operation *op) {
-        // Add the operands to the worklist for visitation.
-        addOperandsToWorklist(originalOperands);
-
-        // Add all the users of the result to the worklist so we make sure
-        // to revisit them.
-        for (auto result : op->getResults())
-          for (auto *userOp : result.getUsers())
-            addToWorklist(userOp);
-
-        notifyOperationRemoved(op);
-      };
-
-      // Add the given operation to the worklist.
-      auto collectOps = [this](Operation *op) { addToWorklist(op); };
-
-      // Try to fold this op.
-      bool inPlaceUpdate;
-      if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
-                                      &inPlaceUpdate)))) {
-        LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-
-        changed = true;
-        if (!inPlaceUpdate)
-          continue;
-      }
+    }
 
-      // Try to match one of the patterns. The rewriter is automatically
-      // notified of any necessary changes, so there is nothing else to do
-      // here.
+    // Try to match one of the patterns. The rewriter is automatically
+    // notified of any necessary changes, so there is nothing else to do
+    // here.
 #ifndef NDEBUG
       auto canApply = [&](const Pattern &pattern) {
         LLVM_DEBUG({
@@ -304,16 +258,9 @@ bool GreedyPatternRewriteDriver::simplify(Region &region) && {
         changed = true;
         ++numRewrites;
       }
-    }
-
-    // After applying patterns, make sure that the CFG of each of the regions
-    // is kept up to date.
-    if (config.enableRegionSimplification)
-      changed |= succeeded(simplifyRegions(*this, region));
-  } while (changed);
+  }
 
-  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
-  return !changed;
+  return changed;
 }
 
 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
@@ -321,12 +268,12 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
   SmallVector<Operation *, 8> ancestors;
   ancestors.push_back(op);
   while (Region *region = op->getParentRegion()) {
-    if (config.scope == region) {
-      // All gathered ops are in fact ancestors.
-      for (Operation *op : ancestors)
-        addSingleOpToWorklist(op);
-      break;
-    }
+      if (config.scope == region) {
+        // All gathered ops are in fact ancestors.
+        for (Operation *op : ancestors)
+          addSingleOpToWorklist(op);
+        break;
+      }
     op = region->getParentOp();
     if (!op)
       break;
@@ -434,12 +381,96 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
   return failure();
 }
 
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return success if no more patterns can be matched
-/// in the result operation regions. Note: This does not apply patterns to the
-/// top-level operation itself.
-///
+//===----------------------------------------------------------------------===//
+// RegionPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This driver simplfies all ops in a region.
+class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+  explicit RegionPatternRewriteDriver(MLIRContext *ctx,
+                                      const FrozenRewritePatternSet &patterns,
+                                      const GreedyRewriteConfig &config,
+                                      Region &regions);
+
+  /// Simplify ops inside `region` and simplify the region itself. Return
+  /// success if the transformation converged.
+  LogicalResult simplify() &&;
+
+private:
+  /// The region that is simplified.
+  Region ®ion;
+};
+} // namespace
+
+RegionPatternRewriteDriver::RegionPatternRewriteDriver(
+    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    const GreedyRewriteConfig &config, Region &region)
+    : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
+  // Populate strict mode ops.
+  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
+    region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
+  }
+}
+
+LogicalResult RegionPatternRewriteDriver::simplify() && {
+  auto insertKnownConstant = [&](Operation *op) {
+    // Check for existing constants when populating the worklist. This avoids
+    // accidentally reversing the constant order during processing.
+    Attribute constValue;
+    if (matchPattern(op, m_Constant(&constValue)))
+      if (!folder.insertKnownConstant(op, constValue))
+        return true;
+    return false;
+  };
+
+  bool changed = false;
+  int64_t iteration = 0;
+  do {
+    // Check if the iteration limit was reached.
+    if (iteration++ >= config.maxIterations &&
+        config.maxIterations != GreedyRewriteConfig::kNoLimit)
+      break;
+
+    worklist.clear();
+    worklistMap.clear();
+
+    if (!config.useTopDownTraversal) {
+      // Add operations to the worklist in postorder.
+      region.walk([&](Operation *op) {
+        if (!insertKnownConstant(op))
+          addToWorklist(op);
+      });
+    } else {
+      // Add all nested operations to the worklist in preorder.
+      region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+        if (!insertKnownConstant(op)) {
+          worklist.push_back(op);
+          return WalkResult::advance();
+        }
+        return WalkResult::skip();
+      });
+
+      // Reverse the list so our pop-back loop processes them in-order.
+      std::reverse(worklist.begin(), worklist.end());
+      // Remember the reverse index.
+      for (size_t i = 0, e = worklist.size(); i != e; ++i)
+        worklistMap[worklist[i]] = i;
+    }
+
+    changed = processWorklist();
+
+    // After applying patterns, make sure that the CFG of each of the regions
+    // is kept up to date.
+    if (config.enableRegionSimplification)
+      changed |= succeeded(simplifyRegions(*this, region));
+  } while (changed);
+
+  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
+  return success(!changed);
+}
+
 LogicalResult
 mlir::applyPatternsAndFoldGreedily(Region &region,
                                    const FrozenRewritePatternSet &patterns,
@@ -455,13 +486,14 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
     config.scope = ®ion;
 
   // Start the pattern driver.
-  GreedyPatternRewriteDriver driver(region.getContext(), patterns, config);
-  bool converged = std::move(driver).simplify(region);
-  LLVM_DEBUG(if (!converged) {
+  RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
+                                    region);
+  LogicalResult converged = std::move(driver).simplify();
+  LLVM_DEBUG(if (failed(converged)) {
     llvm::dbgs() << "The pattern rewrite did not converge after scanning "
                  << config.maxIterations << " times\n";
   });
-  return success(converged);
+  return converged;
 }
 
 //===----------------------------------------------------------------------===//
@@ -469,32 +501,16 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
 //===----------------------------------------------------------------------===//
 
 namespace {
-
-/// This is a specialized GreedyPatternRewriteDriver to apply patterns and
-/// perform folding for a supplied set of ops. It repeatedly simplifies while
-/// restricting the rewrites to only the provided set of ops or optionally
-/// to those directly affected by it (result users or operand providers). Parent
-/// ops are not considered.
+/// This driver simplfies a list of ops.
 class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
 public:
   explicit MultiOpPatternRewriteDriver(
       MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
-      const GreedyRewriteConfig &config,
-      llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
-      : GreedyPatternRewriteDriver(ctx, patterns, config),
-        survivingOps(survivingOps) {}
-
-  /// Performs the specified rewrites on `ops` while also trying to fold these
-  /// ops. `strictMode` controls which other ops are simplified. Only ops
-  /// within the given scope region are added to the worklist. If no scope is
-  /// specified, it assumed to be closest common region of all `ops`.
-  ///
-  /// Note that ops in `ops` could be erased as a result of folding, becoming
-  /// dead, or via pattern rewrites. The return value indicates convergence.
-  ///
-  /// All erased ops are stored in `erased`.
-  LogicalResult simplifyLocally(ArrayRef<Operation *> op,
-                                bool *changed = nullptr) &&;
+      const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
+      llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
+
+  /// Simplify `ops`. Return `success` if the transformation converged.
+  LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
 
 private:
   void notifyOperationRemoved(Operation *op) override {
@@ -508,98 +524,33 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
   /// of ops.
   llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
 };
-
 } // namespace
 
-LogicalResult
-MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
-                                             bool *changed) && {
+MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
+    MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+    const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
+    llvm::SmallDenseSet<Operation *, 4> *survivingOps)
+    : GreedyPatternRewriteDriver(ctx, patterns, config),
+      survivingOps(survivingOps) {
+  if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+    strictModeFilteredOps.insert(ops.begin(), ops.end());
+
   if (survivingOps) {
     survivingOps->clear();
     survivingOps->insert(ops.begin(), ops.end());
   }
+}
 
-  if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
-    strictModeFilteredOps.clear();
-    strictModeFilteredOps.insert(ops.begin(), ops.end());
-  }
-
-  if (changed)
-    *changed = false;
-  worklist.clear();
-  worklistMap.clear();
+LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
+                                                    bool *changed) && {
+  // Populate the initial worklist.
   for (Operation *op : ops)
     addSingleOpToWorklist(op);
 
-  // These are scratch vectors used in the folding loop below.
-  SmallVector<Value, 8> originalOperands, resultValues;
-  int64_t numRewrites = 0;
-  while (!worklist.empty() &&
-         (numRewrites < config.maxNumRewrites ||
-          config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
-    Operation *op = popFromWorklist();
-
-    // Nulls get added to the worklist when operations are removed, ignore
-    // them.
-    if (op == nullptr)
-      continue;
-
-    assert((config.strictMode == GreedyRewriteStrictness::AnyOp ||
-            strictModeFilteredOps.contains(op)) &&
-           "unexpected op was inserted under strict mode");
-
-    // If the operation is trivially dead - remove it.
-    if (isOpTriviallyDead(op)) {
-      notifyOperationRemoved(op);
-      op->erase();
-      if (changed)
-        *changed = true;
-      continue;
-    }
-
-    // Collects all the operands and result uses of the given `op` into work
-    // list. Also remove `op` and nested ops from worklist.
-    originalOperands.assign(op->operand_begin(), op->operand_end());
-    auto preReplaceAction = [&](Operation *op) {
-      // Add the operands to the worklist for visitation.
-      addOperandsToWorklist(originalOperands);
-
-      // Add all the users of the result to the worklist so we make sure
-      // to revisit them.
-      for (Value result : op->getResults()) {
-        for (Operation *userOp : result.getUsers())
-          addToWorklist(userOp);
-      }
-
-      notifyOperationRemoved(op);
-    };
-
-    // Add the given operation generated by the folder to the worklist.
-    auto processGeneratedConstants = [this](Operation *op) {
-      notifyOperationInserted(op);
-    };
-
-    // Try to fold this op.
-    bool inPlaceUpdate;
-    if (succeeded(folder.tryToFold(op, processGeneratedConstants,
-                                   preReplaceAction, &inPlaceUpdate))) {
-      if (changed)
-        *changed = true;
-      if (!inPlaceUpdate) {
-        // Op has been erased.
-        continue;
-      }
-    }
-
-    // Try to match one of the patterns. The rewriter is automatically
-    // notified of any necessary changes, so there is nothing else to do
-    // here.
-    if (succeeded(matcher.matchAndRewrite(op, *this))) {
-      if (changed)
-        *changed = true;
-      ++numRewrites;
-    }
-  }
+  // Process ops on the worklist.
+  bool result = processWorklist();
+  if (changed)
+    *changed = result;
 
   return success(worklist.empty());
 }
@@ -658,8 +609,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
   // Start the pattern driver.
   llvm::SmallDenseSet<Operation *, 4> surviving;
   MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
-                                     config, allErased ? &surviving : nullptr);
-  LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
+                                     config, ops,
+                                     allErased ? &surviving : nullptr);
+  LogicalResult converged = std::move(driver).simplify(ops, changed);
   if (allErased)
     *allErased = surviving.empty();
   LLVM_DEBUG(if (failed(converged)) {


        


More information about the Mlir-commits mailing list