[Mlir-commits] [mlir] 0ff3cf0 - [mlir] GreedyPatternRewriter: fix counting of iterations
Matthias Springer
llvmlistbot at llvm.org
Tue Jan 10 03:25:37 PST 2023
Author: Matthias Springer
Date: 2023-01-10T12:21:08+01:00
New Revision: 0ff3cf0c0cda58b139ddbf6954befeebdda7ed52
URL: https://github.com/llvm/llvm-project/commit/0ff3cf0c0cda58b139ddbf6954befeebdda7ed52
DIFF: https://github.com/llvm/llvm-project/commit/0ff3cf0c0cda58b139ddbf6954befeebdda7ed52.diff
LOG: [mlir] GreedyPatternRewriter: fix counting of iterations
The GreedyPatternRewriteDriver did previously not count the first iteration. I.e., when setting `config.maxIterations = 1`, two iterations were performed. In pratice, this number is not really important; we usually just need a limit in some reasonable order of magnitude. However, this fix allows us to write better convergence/worklist tests with carefully crafted test patterns to purposely trigger edge cases in the driver.
Similarly, the first rewrite was previously not counted towards `config.maxNumRewrites`.
For consistency, `OpPatternRewriteDriver` now uses `config.maxNumRewrites` instead of `config.maxIterations`; this driver does not have "iterations", it consists of a single loop (corresponding to the inner loop in the GreedyPatternRewriteDriver).
Differential Revision: https://reviews.llvm.org/D141365
Added:
Modified:
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 0d6fdaf3039cf..5005a08bc29bb 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -96,10 +96,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Non-pattern based folder for operations.
OperationFolder folder;
-private:
+protected:
/// Configuration information for how to simplify.
GreedyRewriteConfig config;
+private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -147,8 +148,13 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
};
bool changed = false;
- unsigned iteration = 0;
+ int64_t iteration = 0;
do {
+ // Check if the iteration limit was reached.
+ if (iteration++ >= config.maxIterations &&
+ config.maxIterations != GreedyRewriteConfig::kNoLimit)
+ break;
+
worklist.clear();
worklistMap.clear();
@@ -184,7 +190,9 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
changed = false;
int64_t numRewrites = 0;
- while (!worklist.empty()) {
+ while (!worklist.empty() &&
+ (numRewrites < config.maxNumRewrites ||
+ config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
auto *op = popFromWorklist();
// Nulls get added to the worklist when operations are removed, ignore
@@ -280,11 +288,10 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
#else
LogicalResult matchResult = matcher.matchAndRewrite(op, *this);
#endif
+
if (succeeded(matchResult)) {
changed = true;
- if (numRewrites++ >= config.maxNumRewrites &&
- config.maxNumRewrites != GreedyRewriteConfig::kNoLimit)
- break;
+ ++numRewrites;
}
}
@@ -292,8 +299,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
// is kept up to date.
if (config.enableRegionSimplification)
changed |= succeeded(simplifyRegions(*this, regions));
- } while (changed && (iteration++ < config.maxIterations ||
- config.maxIterations == GreedyRewriteConfig::kNoLimit));
+ } while (changed);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return !changed;
@@ -421,7 +427,7 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
bool converged = driver.simplify(regions);
LLVM_DEBUG(if (!converged) {
- llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
+ llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
});
return success(converged);
@@ -443,7 +449,8 @@ class OpPatternRewriteDriver : public PatternRewriter {
matcher.applyDefaultCostModel();
}
- LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
+ LogicalResult simplifyLocally(Operation *op, int64_t maxNumRewrites,
+ bool &erased);
// These are hooks implemented for PatternRewriter.
protected:
@@ -473,18 +480,22 @@ class OpPatternRewriteDriver : public PatternRewriter {
/// Performs the rewrites and folding only on `op`. The simplification
/// converges if the op is erased as a result of being folded, replaced, or
/// becoming dead, or no more changes happen in an iteration. Returns success if
-/// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
+/// the rewrite converges in `maxNumRewrites`. `erased` is set to true if `op`
/// gets erased.
LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
- int maxIterations,
+ int64_t maxNumRewrites,
bool &erased) {
bool changed = false;
erased = false;
opErasedViaPatternRewrites = false;
- int iterations = 0;
- // Iterate until convergence or until maxIterations. Deletion of the op as
+ int64_t numRewrites = 0;
+ // Iterate until convergence or until maxNumRewrites. Deletion of the op as
// a result of being dead or folded is convergence.
do {
+ if (numRewrites >= maxNumRewrites &&
+ maxNumRewrites != GreedyRewriteConfig::kNoLimit)
+ break;
+
changed = false;
// If the operation is trivially dead - remove it.
@@ -508,11 +519,13 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
// Try to match one of the patterns. The rewriter is automatically
// notified of any necessary changes, so there is nothing else to do here.
- changed |= succeeded(matcher.matchAndRewrite(op, *this));
+ if (succeeded(matcher.matchAndRewrite(op, *this))) {
+ changed = true;
+ ++numRewrites;
+ }
if ((erased = opErasedViaPatternRewrites))
return success();
- } while (changed && (++iterations < maxIterations ||
- maxIterations == GreedyRewriteConfig::kNoLimit));
+ } while (changed);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return failure(changed);
@@ -601,7 +614,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// These are scratch vectors used in the folding loop below.
SmallVector<Value, 8> originalOperands, resultValues;
- while (!worklist.empty()) {
+ int64_t numRewrites = 0;
+ while (!worklist.empty() &&
+ (numRewrites < config.maxNumRewrites ||
+ config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
Operation *op = popFromWorklist();
// Nulls get added to the worklist when operations are removed, ignore
@@ -656,7 +672,10 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// Try to match one of the patterns. The rewriter is automatically
// notified of any necessary changes, so there is nothing else to do
// here.
- changed |= succeeded(matcher.matchAndRewrite(op, *this));
+ if (succeeded(matcher.matchAndRewrite(op, *this))) {
+ changed = true;
+ ++numRewrites;
+ }
}
return changed;
@@ -672,12 +691,12 @@ LogicalResult mlir::applyOpPatternsAndFold(
OpPatternRewriteDriver driver(op->getContext(), patterns);
bool opErased;
LogicalResult converged =
- driver.simplifyLocally(op, config.maxIterations, opErased);
+ driver.simplifyLocally(op, config.maxNumRewrites, opErased);
if (erased)
*erased = opErased;
LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite doesn't converge after scanning "
- << config.maxIterations << " times";
+ llvm::dbgs() << "The pattern rewrite did not converge after "
+ << config.maxNumRewrites << " rewrites";
});
return converged;
}
More information about the Mlir-commits
mailing list