[Mlir-commits] [mlir] cadd566 - [mlir][NFC] GreedyPatternRewriteDriver: Remove OpPatternRewriteDriver
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 01:47:41 PST 2023
Author: Matthias Springer
Date: 2023-01-27T10:43:18+01:00
New Revision: cadd5666a6d669fa0507bd4f86e9570af4250869
URL: https://github.com/llvm/llvm-project/commit/cadd5666a6d669fa0507bd4f86e9570af4250869
DIFF: https://github.com/llvm/llvm-project/commit/cadd5666a6d669fa0507bd4f86e9570af4250869.diff
LOG: [mlir][NFC] GreedyPatternRewriteDriver: Remove OpPatternRewriteDriver
The `MultiOpPatternRewriteDriver` can be reused. This gives us better debug messages and more code reuse. Debug messages such as `** Replace: (op name)` were previously not printed when using the `applyOpPatternsAndFold(Operation *, ...)` overload.
Differential Revision: https://reviews.llvm.org/D142613
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index dce47834547e..5a043775a01d 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -86,18 +86,6 @@ inline LogicalResult applyPatternsAndFoldGreedily(
return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config);
}
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead. Note: This does not
-/// apply any patterns recursively to the regions of `op`.
-///
-/// Returns success if the iterative process converged and no more patterns can
-/// be matched.
-LogicalResult applyOpPatternsAndFold(Operation *op,
- const FrozenRewritePatternSet &patterns,
- bool *erased = nullptr);
-
/// Applies the specified rewrite patterns on `ops` while also trying to fold
/// these ops.
///
@@ -132,6 +120,21 @@ LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
bool *allErased = nullptr,
Region *scope = nullptr);
+/// Applies the specified patterns on `op` alone while also trying to fold it,
+/// by selecting the highest benefits patterns in a greedy manner. Returns
+/// success if no more patterns can be matched. `erased` is set to true if `op`
+/// was folded away or erased as a result of becoming dead.
+///
+/// Returns success if the iterative process converged and no more patterns can
+/// be matched.
+inline LogicalResult
+applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
+ bool *erased = nullptr) {
+ return applyOpPatternsAndFold(ArrayRef(op), patterns,
+ GreedyRewriteStrictness::ExistingOps,
+ /*changed=*/nullptr, erased);
+}
+
} // namespace mlir
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index ead229dacae8..aba7a7fd08b9 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -459,109 +459,6 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
return success(converged);
}
-//===----------------------------------------------------------------------===//
-// OpPatternRewriteDriver
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This is a simple driver for the PatternMatcher to apply patterns and perform
-/// folding on a single op. It repeatedly applies locally optimal patterns.
-class OpPatternRewriteDriver : public PatternRewriter {
-public:
- explicit OpPatternRewriteDriver(MLIRContext *ctx,
- const FrozenRewritePatternSet &patterns)
- : PatternRewriter(ctx), matcher(patterns), folder(ctx) {
- // Apply a simple cost model based solely on pattern benefit.
- matcher.applyDefaultCostModel();
- }
-
- LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites,
- bool &erased);
-
- // These are hooks implemented for PatternRewriter.
-protected:
- /// If an operation is about to be removed, mark it so that we can let clients
- /// know.
- void notifyOperationRemoved(Operation *op) override {
- if (this->op == op)
- opErasedViaPatternRewrites = true;
- }
-
- // When a root is going to be replaced, its removal will be notified as well.
- // So there is nothing to do here.
- void notifyRootReplaced(Operation *op, ValueRange replacement) override {}
-
-private:
- /// The low-level pattern applicator.
- PatternApplicator matcher;
-
- /// Non-pattern based folder for operations.
- OperationFolder folder;
-
- /// Op that is being processed.
- Operation *op = nullptr;
-
- /// Set to true if the operation has been erased via pattern rewrites.
- bool opErasedViaPatternRewrites = false;
-};
-
-} // namespace
-
-/// Performs the rewrites and folding only on `op`. The simplification
-/// converges if the op is erased as a result of being folded, replaced, or
-/// becoming dead, or no more changes happen in an iteration. Returns success if
-/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op`
-/// gets erased.
-LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
- int64_t maxNumRewrites,
- bool &erased) {
- this->op = op;
- bool changed = false;
- erased = false;
- opErasedViaPatternRewrites = false;
- int64_t numRewrites = 0;
- // Iterate until convergence or until maxNumRewrites. Deletion of the op as
- // a result of being dead or folded is convergence.
- do {
- if (numRewrites >= maxNumRewrites &&
- maxNumRewrites != GreedyRewriteConfig::kNoLimit)
- break;
-
- changed = false;
-
- // If the operation is trivially dead - remove it.
- if (isOpTriviallyDead(op)) {
- op->erase();
- erased = true;
- return success();
- }
-
- // Try to fold this op.
- bool inPlaceUpdate;
- if (succeeded(folder.tryToFold(op, /*processGeneratedConstants=*/nullptr,
- /*preReplaceAction=*/nullptr,
- &inPlaceUpdate))) {
- changed = true;
- if (!inPlaceUpdate) {
- erased = true;
- return success();
- }
- }
-
- // Try to match one of the patterns. The rewriter is automatically
- // notified of any necessary changes, so there is nothing else to do here.
- if (succeeded(matcher.matchAndRewrite(op, *this))) {
- changed = true;
- ++numRewrites;
- }
- if ((erased = opErasedViaPatternRewrites))
- return success();
- } while (changed);
-
- // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
- return failure(changed);
-}
-
//===----------------------------------------------------------------------===//
// MultiOpPatternRewriteDriver
//===----------------------------------------------------------------------===//
@@ -734,23 +631,6 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
return success(worklist.empty());
}
-LogicalResult mlir::applyOpPatternsAndFold(
- Operation *op, const FrozenRewritePatternSet &patterns, bool *erased) {
- // Start the pattern driver.
- GreedyRewriteConfig config;
- OpPatternRewriteDriver driver(op->getContext(), patterns);
- bool opErased;
- LogicalResult converged =
- driver.simplifyLocally(op, config.maxNumRewrites, opErased);
- if (erased)
- *erased = opErased;
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after "
- << config.maxNumRewrites << " rewrites";
- });
- return converged;
-}
-
/// Find the region that is the closest common ancestor of all given ops.
static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
assert(!ops.empty() && "expected at least one op");
@@ -811,5 +691,9 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope);
if (allErased)
*allErased = surviving.empty();
+ LLVM_DEBUG(if (failed(converged)) {
+ llvm::dbgs() << "The pattern rewrite did not converge after "
+ << GreedyRewriteConfig().maxNumRewrites << " rewrites";
+ });
return converged;
}
More information about the Mlir-commits
mailing list