[Mlir-commits] [mlir] 8498c9e - [mlir][GreedyPatternRewriter] Add out param to detect changes in IR in `applyPatternsAndFoldGreedily`
Matthias Springer
llvmlistbot at llvm.org
Thu Jun 29 03:55:41 PDT 2023
Author: Joel Wee
Date: 2023-06-29T12:48:00+02:00
New Revision: 8498c9e9489f57c6eb59b464d2409b4974a22a02
URL: https://github.com/llvm/llvm-project/commit/8498c9e9489f57c6eb59b464d2409b4974a22a02
DIFF: https://github.com/llvm/llvm-project/commit/8498c9e9489f57c6eb59b464d2409b4974a22a02.diff
LOG: [mlir][GreedyPatternRewriter] Add out param to detect changes in IR in `applyPatternsAndFoldGreedily`
This allows users of `applyPatternsAndFoldGreedily` to detect if any MLIR changes have occurred. An example use-case is where we expect the `applyPatternsAndFoldGreedily` to change the IR and want to validate that it indeed does change it.
Differential Revision: https://reviews.llvm.org/D153986
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 2c5c63e0ac4dc4..aaf6f0a8951257 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -94,25 +94,36 @@ class GreedyRewriteConfig {
/// in absence of convergence.
///
/// Return success if the iterative process converged and no more patterns can
-/// be matched in the result operation regions.
+/// be matched in the result operation regions. `changed` is set to true if the
+/// IR was modified at all.
///
/// Note: This does not apply patterns to the top-level operation itself.
/// These methods also perform folding and simple dead-code elimination
/// before attempting to match any of the provided patterns.
///
/// You may configure several aspects of this with GreedyRewriteConfig.
-LogicalResult applyPatternsAndFoldGreedily(
- Region ®ion, const FrozenRewritePatternSet &patterns,
- GreedyRewriteConfig config = GreedyRewriteConfig());
+LogicalResult
+applyPatternsAndFoldGreedily(Region ®ion,
+ const FrozenRewritePatternSet &patterns,
+ GreedyRewriteConfig config = GreedyRewriteConfig(),
+ bool *changed = nullptr);
/// Rewrite ops in all regions of the given op, which must be isolated from
/// above.
-inline LogicalResult applyPatternsAndFoldGreedily(
- Operation *op, const FrozenRewritePatternSet &patterns,
- GreedyRewriteConfig config = GreedyRewriteConfig()) {
+inline LogicalResult
+applyPatternsAndFoldGreedily(Operation *op,
+ const FrozenRewritePatternSet &patterns,
+ GreedyRewriteConfig config = GreedyRewriteConfig(),
+ bool *changed = nullptr) {
bool failed = false;
- for (Region ®ion : op->getRegions())
- failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed();
+ for (Region ®ion : op->getRegions()) {
+ bool regionChanged;
+ failed |=
+ applyPatternsAndFoldGreedily(region, patterns, config, ®ionChanged)
+ .failed();
+ if (changed)
+ *changed |= regionChanged;
+ }
return failure(failed);
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2a39cccfc580d4..fba4944f130c23 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -616,7 +616,7 @@ class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
/// Simplify ops inside `region` and simplify the region itself. Return
/// success if the transformation converged.
- LogicalResult simplify() &&;
+ LogicalResult simplify(bool *changed) &&;
private:
/// The region that is simplified.
@@ -652,7 +652,7 @@ class GreedyPatternRewriteIteration
};
} // namespace
-LogicalResult RegionPatternRewriteDriver::simplify() && {
+LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
auto insertKnownConstant = [&](Operation *op) {
// Check for existing constants when populating the worklist. This avoids
// accidentally reversing the constant order during processing.
@@ -663,12 +663,12 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
return false;
};
- bool changed = false;
+ bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
do {
// Check if the iteration limit was reached.
- if (iteration++ >= config.maxIterations &&
+ if (++iteration > config.maxIterations &&
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
@@ -696,24 +696,27 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
- changed = processWorklist();
+ continueRewrites = processWorklist();
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
if (config.enableRegionSimplification)
- changed |= succeeded(simplifyRegions(*this, region));
+ continueRewrites |= succeeded(simplifyRegions(*this, region));
},
{®ion}, iteration);
- } while (changed);
+ } while (continueRewrites);
+
+ if (changed)
+ *changed = iteration > 1;
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
- return success(!changed);
+ return success(!continueRewrites);
}
LogicalResult
mlir::applyPatternsAndFoldGreedily(Region ®ion,
const FrozenRewritePatternSet &patterns,
- GreedyRewriteConfig config) {
+ GreedyRewriteConfig config, bool *changed) {
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
@@ -727,7 +730,7 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
// Start the pattern driver.
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
region);
- LogicalResult converged = std::move(driver).simplify();
+ LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
More information about the Mlir-commits
mailing list