[Mlir-commits] [mlir] a2b837a - [mlir] GreedyPatternRewriteDriver: Entry point takes single region
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 02:24:37 PST 2023
Author: Matthias Springer
Date: 2023-01-27T11:23:04+01:00
New Revision: a2b837ab0448869c74cc042155dd454833c60d62
URL: https://github.com/llvm/llvm-project/commit/a2b837ab0448869c74cc042155dd454833c60d62
DIFF: https://github.com/llvm/llvm-project/commit/a2b837ab0448869c74cc042155dd454833c60d62.diff
LOG: [mlir] GreedyPatternRewriteDriver: Entry point takes single region
The rewrite driver is typically applied to a single region or all regions of the same op. There is no longer an overload to apply the rewrite driver to a list of regions.
This simplifies the rewrite driver implementation because the scope is now a single region as opposed to a list of regions.
Note: This change is not NFC because `config.maxIterations` and `config.maxNumRewrites` is now counted for each region separately. Furthermore, worklist filtering (`scope`) is now applied to each region separately.
Differential Revision: https://reviews.llvm.org/D142611
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 5a043775a01d2..6ee565ffe5ec4 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -59,9 +59,9 @@ class GreedyRewriteConfig {
// applyPatternsGreedily
//===----------------------------------------------------------------------===//
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner.
+/// Rewrite ops in the given region, which must be isolated from above, by
+/// repeatedly applying the highest benefit patterns in a greedy work-list
+/// driven manner.
///
/// This variant may stop after a predefined number of iterations, see the
/// alternative below to provide a specific number of iterations before stopping
@@ -76,14 +76,18 @@ class GreedyRewriteConfig {
///
/// You may configure several aspects of this with GreedyRewriteConfig.
LogicalResult applyPatternsAndFoldGreedily(
- MutableArrayRef<Region> regions, const FrozenRewritePatternSet &patterns,
+ Region ®ion, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig());
-/// Rewrite the given regions, which must be isolated from above.
+/// Rewrite ops in all regions of the given op, which must be isolated from
+/// above.
inline LogicalResult applyPatternsAndFoldGreedily(
Operation *op, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig()) {
- return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config);
+ bool failed = false;
+ for (Region ®ion : op->getRegions())
+ failed |= applyPatternsAndFoldGreedily(region, patterns, config).failed();
+ return failure(failed);
}
/// Applies the specified rewrite patterns on `ops` while also trying to fold
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e9a8c8326fee2..2504a2ab0c9bd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1867,8 +1867,7 @@ struct LinalgElementwiseOpFusionPass
// Use TopDownTraversal for compile time reasons
GreedyRewriteConfig grc;
grc.useTopDownTraversal = true;
- (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns),
- grc);
+ (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index 71c99e9174ea6..8ef16d5eeaec4 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -781,8 +781,7 @@ struct ExpandStridedMetadataPass final
void ExpandStridedMetadataPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateExpandStridedMetadataPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
- std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 92f02c068d2b9..33e9ee71ee3b5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -605,8 +605,7 @@ struct FoldMemRefAliasOpsPass final
void FoldMemRefAliasOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateFoldMemRefAliasOpPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
- std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
index 7d3f1fbd5293d..650d71e732a7c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
@@ -149,8 +149,7 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
- std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
@@ -158,8 +157,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
- if (failed(applyPatternsAndFoldGreedily(getOperation()->getRegions(),
- std::move(patterns))))
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index b225662e58c5b..eaba09753f7f7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -120,8 +120,7 @@ struct LowerVectorMaskPass
RewritePatternSet loweringPatterns(context);
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
- if (failed(applyPatternsAndFoldGreedily(op->getRegions(),
- std::move(loweringPatterns))))
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index a5ddd9138873b..36317e039ef2f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -40,10 +40,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config,
- const DenseSet<Region *> &scope);
+ const Region &scope);
- /// Simplify the operations within the given regions.
- bool simplify(MutableArrayRef<Region> regions) &&;
+ /// Simplify the ops within the given region.
+ bool simplify(Region ®ion) &&;
/// Add the given operation and its ancestors to the worklist.
void addToWorklist(Operation *op);
@@ -104,7 +104,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
const GreedyRewriteConfig config;
/// Only ops within this scope are simplified.
- const DenseSet<Region *> scope;
+ const Region &scope;
private:
#ifndef NDEBUG
@@ -116,7 +116,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config, const DenseSet<Region *> &scope)
+ const GreedyRewriteConfig &config, const Region &scope)
: PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
scope(scope) {
worklist.reserve(64);
@@ -125,7 +125,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
matcher.applyDefaultCostModel();
}
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
+bool GreedyPatternRewriteDriver::simplify(Region ®ion) && {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -167,15 +167,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
- for (auto ®ion : regions) {
region.walk([&](Operation *op) {
if (!insertKnownConstant(op))
addToWorklist(op);
});
- }
} else {
// Add all nested operations to the worklist in preorder.
- for (auto ®ion : regions) {
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (!insertKnownConstant(op)) {
worklist.push_back(op);
@@ -183,7 +180,6 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
}
return WalkResult::skip();
});
- }
// Reverse the list so our pop-back loop processes them in-order.
std::reverse(worklist.begin(), worklist.end());
@@ -305,7 +301,7 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
// After applying patterns, make sure that the CFG of each of the regions
// is kept up to date.
if (config.enableRegionSimplification)
- changed |= succeeded(simplifyRegions(*this, regions));
+ changed |= succeeded(simplifyRegions(*this, region));
} while (changed);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
@@ -317,7 +313,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
SmallVector<Operation *, 8> ancestors;
ancestors.push_back(op);
while (Region *region = op->getParentRegion()) {
- if (scope.contains(region)) {
+ if (&scope == region) {
// All gathered ops are in fact ancestors.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
@@ -429,31 +425,19 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
/// top-level operation itself.
///
LogicalResult
-mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
+mlir::applyPatternsAndFoldGreedily(Region ®ion,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config) {
- if (regions.empty())
- return success();
-
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
- auto regionIsIsolated = [](Region ®ion) {
- return region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>();
- };
- (void)regionIsIsolated;
- assert(llvm::all_of(regions, regionIsIsolated) &&
+ assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"patterns can only be applied to operations IsolatedFromAbove");
- // Limit ops on the worklist to this scope.
- DenseSet<Region *> scope;
- for (Region &r : regions)
- scope.insert(&r);
-
// Start the pattern driver.
- GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config,
- scope);
- bool converged = std::move(driver).simplify(regions);
+ GreedyPatternRewriteDriver driver(region.getContext(), patterns, config,
+ region);
+ bool converged = std::move(driver).simplify(region);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
@@ -476,7 +460,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const DenseSet<Region *> &scope, GreedyRewriteStrictness strictMode,
+ const Region &scope, GreedyRewriteStrictness strictMode,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
: GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
strictMode(strictMode), survivingOps(survivingOps) {}
@@ -680,10 +664,8 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
- DenseSet<Region *> scopeSet;
- scopeSet.insert(scope);
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- scopeSet, strictMode,
+ *scope, strictMode,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 98896c736a3cb..c47c8f139e406 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1633,8 +1633,7 @@ struct TestSelectiveReplacementPatternDriver
MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.add<TestSelectiveOpReplacementPattern>(context);
- (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
- std::move(patterns));
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
} // namespace
More information about the Mlir-commits
mailing list