[Mlir-commits] [mlir] cadd566 - [mlir][NFC] GreedyPatternRewriteDriver: Remove OpPatternRewriteDriver

Matthias Springer llvmlistbot at llvm.org
Fri Jan 27 01:47:41 PST 2023


Author: Matthias Springer
Date: 2023-01-27T10:43:18+01:00
New Revision: cadd5666a6d669fa0507bd4f86e9570af4250869

URL: https://github.com/llvm/llvm-project/commit/cadd5666a6d669fa0507bd4f86e9570af4250869
DIFF: https://github.com/llvm/llvm-project/commit/cadd5666a6d669fa0507bd4f86e9570af4250869.diff

LOG: [mlir][NFC] GreedyPatternRewriteDriver: Remove OpPatternRewriteDriver

The `MultiOpPatternRewriteDriver` can be reused. This gives us better debug messages and more code reuse. Debug messages such as `** Replace: (op name)` were previously not printed when using the `applyOpPatternsAndFold(Operation *, ...)` overload.

Differential Revision: https://reviews.llvm.org/D142613

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index dce47834547e..5a043775a01d 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -86,18 +86,6 @@ inline LogicalResult applyPatternsAndFoldGreedily(
   return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config);
 }
 
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead. Note: This does not
-/// apply any patterns recursively to the regions of `op`.
-///
-/// Returns success if the iterative process converged and no more patterns can
-/// be matched.
-LogicalResult applyOpPatternsAndFold(Operation *op,
-                                     const FrozenRewritePatternSet &patterns,
-                                     bool *erased = nullptr);
-
 /// Applies the specified rewrite patterns on `ops` while also trying to fold
 /// these ops.
 ///
@@ -132,6 +120,21 @@ LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
                                      bool *allErased = nullptr,
                                      Region *scope = nullptr);
 
+/// Applies the specified patterns on `op` alone while also trying to fold it,
+/// by selecting the highest benefits patterns in a greedy manner. Returns
+/// success if no more patterns can be matched. `erased` is set to true if `op`
+/// was folded away or erased as a result of becoming dead.
+///
+/// Returns success if the iterative process converged and no more patterns can
+/// be matched.
+inline LogicalResult
+applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
+                       bool *erased = nullptr) {
+  return applyOpPatternsAndFold(ArrayRef(op), patterns,
+                                GreedyRewriteStrictness::ExistingOps,
+                                /*changed=*/nullptr, erased);
+}
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index ead229dacae8..aba7a7fd08b9 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -459,109 +459,6 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
   return success(converged);
 }
 
-//===----------------------------------------------------------------------===//
-// OpPatternRewriteDriver
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This is a simple driver for the PatternMatcher to apply patterns and perform
-/// folding on a single op. It repeatedly applies locally optimal patterns.
-class OpPatternRewriteDriver : public PatternRewriter {
-public:
-  explicit OpPatternRewriteDriver(MLIRContext *ctx,
-                                  const FrozenRewritePatternSet &patterns)
-      : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
-    // Apply a simple cost model based solely on pattern benefit.
-    matcher.applyDefaultCostModel();
-  }
-
-  LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites,
-                                bool &erased);
-
-  // These are hooks implemented for PatternRewriter.
-protected:
-  /// If an operation is about to be removed, mark it so that we can let clients
-  /// know.
-  void notifyOperationRemoved(Operation *op) override {
-    if (this->op == op)
-      opErasedViaPatternRewrites = true;
-  }
-
-  // When a root is going to be replaced, its removal will be notified as well.
-  // So there is nothing to do here.
-  void notifyRootReplaced(Operation *op, ValueRange replacement) override {}
-
-private:
-  /// The low-level pattern applicator.
-  PatternApplicator matcher;
-
-  /// Non-pattern based folder for operations.
-  OperationFolder folder;
-
-  /// Op that is being processed.
-  Operation *op = nullptr;
-
-  /// Set to true if the operation has been erased via pattern rewrites.
-  bool opErasedViaPatternRewrites = false;
-};
-
-} // namespace
-
-/// Performs the rewrites and folding only on `op`. The simplification
-/// converges if the op is erased as a result of being folded, replaced, or
-/// becoming dead, or no more changes happen in an iteration. Returns success if
-/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op`
-/// gets erased.
-LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
-                                                      int64_t maxNumRewrites,
-                                                      bool &erased) {
-  this->op = op;
-  bool changed = false;
-  erased = false;
-  opErasedViaPatternRewrites = false;
-  int64_t numRewrites = 0;
-  // Iterate until convergence or until maxNumRewrites. Deletion of the op as
-  // a result of being dead or folded is convergence.
-  do {
-    if (numRewrites >= maxNumRewrites &&
-        maxNumRewrites != GreedyRewriteConfig::kNoLimit)
-      break;
-
-    changed = false;
-
-    // If the operation is trivially dead - remove it.
-    if (isOpTriviallyDead(op)) {
-      op->erase();
-      erased = true;
-      return success();
-    }
-
-    // Try to fold this op.
-    bool inPlaceUpdate;
-    if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
-                                   /*preReplaceAction=*/nullptr,
-                                   &inPlaceUpdate))) {
-      changed = true;
-      if (!inPlaceUpdate) {
-        erased = true;
-        return success();
-      }
-    }
-
-    // Try to match one of the patterns. The rewriter is automatically
-    // notified of any necessary changes, so there is nothing else to do here.
-    if (succeeded(matcher.matchAndRewrite(op, *this))) {
-      changed = true;
-      ++numRewrites;
-    }
-    if ((erased = opErasedViaPatternRewrites))
-      return success();
-  } while (changed);
-
-  // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
-  return failure(changed);
-}
-
 //===----------------------------------------------------------------------===//
 // MultiOpPatternRewriteDriver
 //===----------------------------------------------------------------------===//
@@ -734,23 +631,6 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
   return success(worklist.empty());
 }
 
-LogicalResult mlir::applyOpPatternsAndFold(
-    Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
-  // Start the pattern driver.
-  GreedyRewriteConfig config;
-  OpPatternRewriteDriver driver(op->getContext(), patterns);
-  bool opErased;
-  LogicalResult converged =
-      driver.simplifyLocally(op, config.maxNumRewrites, opErased);
-  if (erased)
-    *erased = opErased;
-  LLVM_DEBUG(if (failed(converged)) {
-    llvm::dbgs() << "The pattern rewrite did not converge after "
-                 << config.maxNumRewrites << " rewrites";
-  });
-  return converged;
-}
-
 /// Find the region that is the closest common ancestor of all given ops.
 static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
   assert(!ops.empty() && "expected at least one op");
@@ -811,5 +691,9 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
       ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope);
   if (allErased)
     *allErased = surviving.empty();
+  LLVM_DEBUG(if (failed(converged)) {
+    llvm::dbgs() << "The pattern rewrite did not converge after "
+                 << GreedyRewriteConfig().maxNumRewrites << " rewrites";
+  });
   return converged;
 }


        


More information about the Mlir-commits mailing list