[Mlir-commits] [mlir] 977cddb - [mlir] GreedyPatternRewriteDriver: All entry points take a config
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 05:34:11 PST 2023
Author: Matthias Springer
Date: 2023-01-27T14:33:54+01:00
New Revision: 977cddb95eac67a6dc6680a7d0fadee81114de11
URL: https://github.com/llvm/llvm-project/commit/977cddb95eac67a6dc6680a7d0fadee81114de11
DIFF: https://github.com/llvm/llvm-project/commit/977cddb95eac67a6dc6680a7d0fadee81114de11.diff
LOG: [mlir] GreedyPatternRewriteDriver: All entry points take a config
The multi-op entry point now also takes a GreedyPatternRewriteConfig and respects config.maxNumRewrites. The scope is also a part of the config now.
Differential Revision: https://reviews.llvm.org/D142614
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 6ee565ffe5ec4..e4251103d6d58 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -37,15 +37,21 @@ class GreedyRewriteConfig {
/// generally more efficient in compile time. When set to false, its initial
/// traversal of the region tree is bottom up on each block, which may match
/// larger patterns when given an ambiguous pattern set.
+ ///
+ /// Note: Only applicable when simplifying entire regions.
bool useTopDownTraversal = false;
- // Perform control flow optimizations to the region tree after applying all
- // patterns.
+ /// Perform control flow optimizations to the region tree after applying all
+ /// patterns.
+ ///
+ /// Note: Only applicable when simplifying entire regions.
bool enableRegionSimplification = true;
/// This specifies the maximum number of times the rewriter will iterate
/// between applying patterns and simplifying regions. Use `kNoLimit` to
/// disable this iteration limit.
+ ///
+ /// Note: Only applicable when simplifying entire regions.
int64_t maxIterations = 10;
/// This specifies the maximum number of rewrites within an iteration. Use
@@ -53,6 +59,10 @@ class GreedyRewriteConfig {
int64_t maxNumRewrites = kNoLimit;
static constexpr int64_t kNoLimit = -1;
+
+ /// Only ops within the scope are added to the worklist. If no scope is
+ /// specified, the closest enclosing region is used as a scope.
+ Region *scope = nullptr;
};
//===----------------------------------------------------------------------===//
@@ -117,12 +127,12 @@ inline LogicalResult applyPatternsAndFoldGreedily(
/// Returns success if the iterative process converged and no more patterns can
/// be matched. `changed` is set to true if the IR was modified at all.
/// `allOpsErased` is set to true if all ops in `ops` were erased.
-LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
- const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode,
- bool *changed = nullptr,
- bool *allErased = nullptr,
- Region *scope = nullptr);
+LogicalResult
+applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+ const FrozenRewritePatternSet &patterns,
+ GreedyRewriteStrictness strictMode,
+ GreedyRewriteConfig config = GreedyRewriteConfig(),
+ bool *changed = nullptr, bool *allErased = nullptr);
/// Applies the specified patterns on `op` alone while also trying to fold it,
/// by selecting the highest benefits patterns in a greedy manner. Returns
@@ -133,9 +143,10 @@ LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
/// be matched.
inline LogicalResult
applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
+ GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *erased = nullptr) {
return applyOpPatternsAndFold(ArrayRef(op), patterns,
- GreedyRewriteStrictness::ExistingOps,
+ GreedyRewriteStrictness::ExistingOps, config,
/*changed=*/nullptr, erased);
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index ec85e566e778b..a146d5adfa637 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -322,8 +322,8 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
bool erased;
- (void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
-
+ (void)applyOpPatternsAndFold(res, std::move(patterns),
+ GreedyRewriteConfig(), &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 180fef853e202..286c38ec18ee5 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -415,7 +415,8 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
bool erased;
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- (void)applyOpPatternsAndFold(ifOp, frozenPatterns, &erased);
+ (void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(),
+ &erased);
if (erased) {
if (folded)
*folded = true;
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 36317e039ef2f..4a37730129691 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,8 +39,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config,
- const Region &scope);
+ const GreedyRewriteConfig &config);
/// Simplify the ops within the given region.
bool simplify(Region ®ion) &&;
@@ -103,9 +102,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
- /// Only ops within this scope are simplified.
- const Region &scope;
-
private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
@@ -116,9 +112,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config, const Region &scope)
- : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
- scope(scope) {
+ const GreedyRewriteConfig &config)
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
+ assert(config.scope && "scope is not specified");
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
@@ -313,7 +309,7 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
SmallVector<Operation *, 8> ancestors;
ancestors.push_back(op);
while (Region *region = op->getParentRegion()) {
- if (&scope == region) {
+ if (config.scope == region) {
// All gathered ops are in fact ancestors.
for (Operation *op : ancestors)
addSingleOpToWorklist(op);
@@ -434,9 +430,12 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"patterns can only be applied to operations IsolatedFromAbove");
+ // Set scope if not specified.
+ if (!config.scope)
+ config.scope = ®ion;
+
// Start the pattern driver.
- GreedyPatternRewriteDriver driver(region.getContext(), patterns, config,
- region);
+ GreedyPatternRewriteDriver driver(region.getContext(), patterns, config);
bool converged = std::move(driver).simplify(region);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -460,9 +459,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const Region &scope, GreedyRewriteStrictness strictMode,
+ GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
- : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
+ : GreedyPatternRewriteDriver(ctx, patterns, config),
strictMode(strictMode), survivingOps(survivingOps) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
@@ -636,11 +635,10 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
return region;
}
-LogicalResult
-mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
- const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, bool *changed,
- bool *allErased, Region *scope) {
+LogicalResult mlir::applyOpPatternsAndFold(
+ ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
+ GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
+ bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
@@ -649,14 +647,15 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
return success();
}
- if (!scope) {
+ // Determine scope of rewrite.
+ if (!config.scope) {
// Compute scope if none was provided.
- scope = findCommonAncestor(ops);
+ config.scope = findCommonAncestor(ops);
} else {
// If a scope was provided, make sure that all ops are in scope.
#ifndef NDEBUG
bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
- return static_cast<bool>(scope->findAncestorOpInRegion(*op));
+ return static_cast<bool>(config.scope->findAncestorOpInRegion(*op));
});
assert(allOpsInScope && "ops must be within the specified scope");
#endif // NDEBUG
@@ -665,14 +664,14 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- *scope, strictMode,
+ strictMode, config,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after "
- << GreedyRewriteConfig().maxNumRewrites << " rewrites";
+ << config.maxNumRewrites << " rewrites";
});
return converged;
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index c47c8f139e406..da018708c702e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -283,7 +283,7 @@ struct TestStrictPatternDriver
bool changed = false;
bool allErased = false;
(void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
- &changed, &allErased);
+ GreedyRewriteConfig(), &changed, &allErased);
Builder b(ctx);
getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
getOperation()->setAttr("pattern_driver_all_erased",
More information about the Mlir-commits
mailing list