[Mlir-commits] [mlir] 67760d7 - [mlir] GreedyPatternRewriteDriver: Make classes single-use
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 02:02:53 PST 2023
Author: Matthias Springer
Date: 2023-01-27T10:55:16+01:00
New Revision: 67760d7e315ff90198bccfd4b0a3934f7a30e6ce
URL: https://github.com/llvm/llvm-project/commit/67760d7e315ff90198bccfd4b0a3934f7a30e6ce
DIFF: https://github.com/llvm/llvm-project/commit/67760d7e315ff90198bccfd4b0a3934f7a30e6ce.diff
LOG: [mlir] GreedyPatternRewriteDriver: Make classes single-use
Less mutable state, more `const`. This is to address a concern about complexity of state in D140304.
Differential Revision: https://reviews.llvm.org/D141949
Added:
Modified:
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index aba7a7fd08b9..a5ddd9138873 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -39,10 +39,11 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
public:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config);
+ const GreedyRewriteConfig &config,
+ const DenseSet<Region *> &scope);
/// Simplify the operations within the given regions.
- bool simplify(MutableArrayRef<Region> regions);
+ bool simplify(MutableArrayRef<Region> regions) &&;
/// Add the given operation and its ancestors to the worklist.
void addToWorklist(Operation *op);
@@ -100,12 +101,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
protected:
/// Configuration information for how to simplify.
- GreedyRewriteConfig config;
+ const GreedyRewriteConfig config;
- /// Only ops within this scope are simplified. This is set at the beginning
- /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter
- /// operates on.
- DenseSet<Region *> scope;
+ /// Only ops within this scope are simplified.
+ const DenseSet<Region *> scope;
private:
#ifndef NDEBUG
@@ -117,19 +116,16 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config)
- : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
+ const GreedyRewriteConfig &config, const DenseSet<Region *> &scope)
+ : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config),
+ scope(scope) {
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
matcher.applyDefaultCostModel();
}
-bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
- scope.clear();
- for (Region &r : regions)
- scope.insert(&r);
-
+bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) && {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -449,9 +445,15 @@ mlir::applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
assert(llvm::all_of(regions, regionIsIsolated) &&
"patterns can only be applied to operations IsolatedFromAbove");
+ // Limit ops on the worklist to this scope.
+ DenseSet<Region *> scope;
+ for (Region &r : regions)
+ scope.insert(&r);
+
// Start the pattern driver.
- GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config);
- bool converged = driver.simplify(regions);
+ GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config,
+ scope);
+ bool converged = std::move(driver).simplify(regions);
LLVM_DEBUG(if (!converged) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
@@ -472,11 +474,12 @@ namespace {
/// ops are not considered.
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
- explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
- const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode)
- : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
- strictMode(strictMode) {}
+ explicit MultiOpPatternRewriteDriver(
+ MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ const DenseSet<Region *> &scope, GreedyRewriteStrictness strictMode,
+ llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
+ : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), scope),
+ strictMode(strictMode), survivingOps(survivingOps) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
/// ops. `strictMode` controls which other ops are simplified. Only ops
@@ -486,11 +489,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
/// Note that ops in `ops` could be erased as a result of folding, becoming
/// dead, or via pattern rewrites. The return value indicates convergence.
///
- /// All `ops` that survived the rewrite are stored in `surviving`.
- LogicalResult
- simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
- llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr,
- Region *scope = nullptr);
+ /// All erased ops are stored in `erased`.
+ LogicalResult simplifyLocally(ArrayRef<Operation *> op,
+ bool *changed = nullptr) &&;
protected:
void addSingleOpToWorklist(Operation *op) override {
@@ -516,7 +517,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
/// `strictMode` control which ops are added to the worklist during
/// simplification.
- GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
+ const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
/// The list of ops we are restricting our rewrites to. These include the
/// supplied set of ops as well as new ops created while rewriting those ops
@@ -527,17 +528,15 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
/// An optional set of ops that survived the rewrite. This set is populated
/// at the beginning of `simplifyLocally` with the inititally provided list
/// of ops.
- llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr;
+ llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
};
} // namespace
-LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
- ArrayRef<Operation *> ops, bool *changed,
- llvm::SmallDenseSet<Operation *, 4> *surviving, Region *scope) {
- auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
- if (surviving) {
- survivingOps = surviving;
+LogicalResult
+MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
+ bool *changed) && {
+ if (survivingOps) {
survivingOps->clear();
survivingOps->insert(ops.begin(), ops.end());
}
@@ -547,10 +546,6 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
- assert(scope && "scope is mandatory");
- this->scope.clear();
- this->scope.insert(scope);
-
if (changed)
*changed = false;
worklist.clear();
@@ -684,11 +679,13 @@ mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
}
// Start the pattern driver.
- MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- strictMode);
llvm::SmallDenseSet<Operation *, 4> surviving;
- LogicalResult converged = driver.simplifyLocally(
- ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope);
+ DenseSet<Region *> scopeSet;
+ scopeSet.insert(scope);
+ MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
+ scopeSet, strictMode,
+ allErased ? &surviving : nullptr);
+ LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
More information about the Mlir-commits
mailing list