[Mlir-commits] [mlir] 9d5c63f - [mlir][NFC] GreedyPatternRewriteDriver: Merge region-based and multi-op-based drivers
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 08:32:09 PST 2023
Author: Matthias Springer
Date: 2023-01-27T17:32:00+01:00
New Revision: 9d5c63f641c8318808e8e62df0a9290d1072ae41
URL: https://github.com/llvm/llvm-project/commit/9d5c63f641c8318808e8e62df0a9290d1072ae41
DIFF: https://github.com/llvm/llvm-project/commit/9d5c63f641c8318808e8e62df0a9290d1072ae41.diff
LOG: [mlir][NFC] GreedyPatternRewriteDriver: Merge region-based and multi-op-based drivers
Deduplicate large parts of the worklist processing (`GreedyPatternRewriteDriver::processWorklist`).
The new class hierarchy is as follows:
```
GreedyPatternRewriteDriver (abstract)
^
|
-----------------------------------
| |
RegionPatternRewriteDriver MultiOpPatternRewriteDriver
```
Also update the Markdown documentation.
Differential Revision: https://reviews.llvm.org/D141396
Added:
Modified:
mlir/docs/PatternRewriter.md
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md
index 3ae6547308b98..8428d4ba991ef 100644
--- a/mlir/docs/PatternRewriter.md
+++ b/mlir/docs/PatternRewriter.md
@@ -322,22 +322,48 @@ driver can be found [here](DialectConversion.md).
### Greedy Pattern Rewrite Driver
-This driver walks the provided operations and greedily applies the patterns that
-locally have the most benefit. The benefit of
-a pattern is decided solely by the benefit specified on the pattern, and the
-relative order of the pattern within the pattern list (when two patterns have
-the same local benefit). Patterns are iteratively applied to operations until a
-fixed point is reached, at which point the driver finishes. This driver may be
-used via the following: `applyPatternsAndFoldGreedily` and
-`applyOpPatternsAndFold`. The latter of which only applies patterns to the
-provided operation, and will not traverse the IR.
-
-The driver is configurable and supports two modes: 1) you may opt-in to a
-"top-down" traversal, which seeds the worklist with each operation top down and
-in a pre-order over the region tree. This is generally more efficient in
-compile time. 2) the default is a "bottom up" traversal, which builds the
-initial worklist with a postorder traversal of the region tree. This may
-match larger patterns with ambiguous pattern sets.
+This driver processes ops in a worklist-driven fashion and greedily applies the
+patterns that locally have the most benefit. The benefit of a pattern is decided
+solely by the benefit specified on the pattern, and the relative order of the
+pattern within the pattern list (when two patterns have the same local benefit).
+Patterns are iteratively applied to operations until a fixed point is reached or
+until the configurable maximum number of iterations exhausted, at which point
+the driver finishes.
+
+This driver comes in two fashions:
+
+* `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to
+ all ops in a given region or a given container op (but not the container op
+ itself). I.e., the worklist is initialized with all containing ops.
+* `applyOpPatternsAndFold` ("op-based driver") applies patterns to the
+ provided list of operations. I.e., the worklist is initialized with the
+ specified list of ops.
+
+The driver is configurable via `GreedyRewriteConfig`. The region-based driver
+supports two modes for populating the initial worklist:
+
+* Top-down traversal: Traverse the container op/region top down and in
+ pre-order. This is generally more efficient in compile time.
+* Bottom-up traversal: This is the default setting. It builds the initial
+ worklist with a postorder traversal and then reverses the worklist. This may
+ match larger patterns with ambiguous pattern sets.
+
+By default, ops that were modified in-place and newly created are added back to
+the worklist. Ops that are outside of the configurable "scope" of the driver are
+not added to the worklist. Furthermore, "strict mode" can exclude certain ops
+from being added to the worklist throughout the rewrite process:
+
+* `GreedyRewriteStrictness::AnyOp`: No ops are excluded (apart from the ones
+ that are out of scope).
+* `GreedyRewriteStrictness::ExistingAndNewOps`: Only pre-existing ops (with
+ which the worklist was initialized) and newly created ops are added to the
+ worklist.
+* `GreedyRewriteStrictness::ExistingOps`: Only pre-existing ops (with which
+ the worklist was initialized) are added to the worklist.
+
+Note: This driver listens for IR changes via the callbacks provided by
+`RewriterBase`. It is important that patterns announce all IR changes to the
+rewriter and do not bypass the rewriter API by modifying ops directly.
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize-canonicalize-operations) in MLIR.
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2ef46b1fdcfc1..4c5868aead3f1 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -34,59 +34,44 @@ using namespace mlir;
namespace {
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
-/// applies the locally optimal patterns in a roughly "bottom up" way.
+/// applies the locally optimal patterns.
+///
+/// This abstract class manages the worklist and contains helper methods for
+/// rewriting ops on the worklist. Derived classes specify how ops are added
+/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public PatternRewriter {
-public:
+protected:
explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config);
- /// Simplify the ops within the given region.
- bool simplify(Region ®ion) &&;
+ /// Add the given operation to the worklist.
+ void addSingleOpToWorklist(Operation *op);
/// Add the given operation and its ancestors to the worklist.
void addToWorklist(Operation *op);
- /// Pop the next operation from the worklist.
- Operation *popFromWorklist();
-
- /// If the specified operation is in the worklist, remove it.
- void removeFromWorklist(Operation *op);
-
- /// Notifies the driver that the specified operation may have been modified
- /// in-place.
+ /// Notify the driver that the specified operation may have been modified
+ /// in-place. The operation is added to the worklist.
void finalizeRootUpdate(Operation *op) override;
-protected:
- /// Add the given operation to the worklist.
- void addSingleOpToWorklist(Operation *op);
-
- // Implement the hook for inserting operations, and make sure that newly
- // inserted ops are added to the worklist for processing.
+ /// Notify the driver that the specified operation was inserted. Update the
+ /// worklist as needed: The operation is enqueued depending on scope and
+ /// strict mode.
void notifyOperationInserted(Operation *op) override;
- // Look over the provided operands for any defining operations that should
- // be re-added to the worklist. This function should be called when an
- // operation is modified or removed, as it may trigger further
- // simplifications.
- void addOperandsToWorklist(ValueRange operands);
-
- // If an operation is about to be removed, make sure it is not in our
- // worklist anymore because we'd get dangling references to it.
+ /// Notify the driver that the specified operation was removed. Update the
+ /// worklist as needed: The operation and its children are removed from the
+ /// worklist.
void notifyOperationRemoved(Operation *op) override;
- // When the root of a pattern is about to be replaced, it can trigger
- // simplifications to its users - make sure to add them to the worklist
- // before the root is changed.
+ /// Notify the driver that the specified operation was replaced. Update the
+ /// worklist as needed: New users are added enqueued.
void notifyRootReplaced(Operation *op, ValueRange replacement) override;
- /// PatternRewriter hook for notifying match failure reasons.
- LogicalResult
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) override;
-
- /// The low-level pattern applicator.
- PatternApplicator matcher;
+ /// Process ops until the worklist is empty or `config.maxNumRewrites` is
+ /// reached. Return `true` if any IR was changed.
+ bool processWorklist();
/// The worklist for this transformation keeps track of the operations that
/// need to be revisited, plus their index in the worklist. This allows us to
@@ -98,7 +83,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Non-pattern based folder for operations.
OperationFolder folder;
-protected:
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
@@ -109,17 +93,37 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
private:
+ /// Look over the provided operands for any defining operations that should
+ /// be re-added to the worklist. This function should be called when an
+ /// operation is modified or removed, as it may trigger further
+ /// simplifications.
+ void addOperandsToWorklist(ValueRange operands);
+
+ /// Pop the next operation from the worklist.
+ Operation *popFromWorklist();
+
+ /// For debugging only: Notify the driver of a pattern match failure.
+ LogicalResult
+ notifyMatchFailure(Location loc,
+ function_ref<void(Diagnostic &)> reasonCallback) override;
+
+ /// If the specified operation is in the worklist, remove it.
+ void removeFromWorklist(Operation *op);
+
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
#endif
+
+ /// The low-level pattern applicator.
+ PatternApplicator matcher;
};
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) {
+ : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) {
assert(config.scope && "scope is not specified");
worklist.reserve(64);
@@ -127,7 +131,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
matcher.applyDefaultCostModel();
}
-bool GreedyPatternRewriteDriver::simplify(Region ®ion) && {
+bool GreedyPatternRewriteDriver::processWorklist() {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -146,130 +150,80 @@ bool GreedyPatternRewriteDriver::simplify(Region ®ion) && {
};
#endif
- auto insertKnownConstant = [&](Operation *op) {
- // Check for existing constants when populating the worklist. This avoids
- // accidentally reversing the constant order during processing.
- Attribute constValue;
- if (matchPattern(op, m_Constant(&constValue)))
- if (!folder.insertKnownConstant(op, constValue))
- return true;
- return false;
- };
-
- // Populate strict mode ops.
- if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
- strictModeFilteredOps.clear();
- region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
- }
+ // These are scratch vectors used in the folding loop below.
+ SmallVector<Value, 8> originalOperands;
bool changed = false;
- int64_t iteration = 0;
- do {
- // Check if the iteration limit was reached.
- if (iteration++ >= config.maxIterations &&
- config.maxIterations != GreedyRewriteConfig::kNoLimit)
- break;
+ int64_t numRewrites = 0;
+ while (!worklist.empty() &&
+ (numRewrites < config.maxNumRewrites ||
+ config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
+ auto *op = popFromWorklist();
- worklist.clear();
- worklistMap.clear();
+ // Nulls get added to the worklist when operations are removed, ignore
+ // them.
+ if (op == nullptr)
+ continue;
- if (!config.useTopDownTraversal) {
- // Add operations to the worklist in postorder.
- region.walk([&](Operation *op) {
- if (!insertKnownConstant(op))
- addToWorklist(op);
- });
- } else {
- // Add all nested operations to the worklist in preorder.
- region.walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (!insertKnownConstant(op)) {
- worklist.push_back(op);
- return WalkResult::advance();
- }
- return WalkResult::skip();
- });
+ LLVM_DEBUG({
+ logger.getOStream() << "\n";
+ logger.startLine() << logLineComment;
+ logger.startLine() << "Processing operation : '" << op->getName() << "'("
+ << op << ") {\n";
+ logger.indent();
+
+ // If the operation has no regions, just print it here.
+ if (op->getNumRegions() == 0) {
+ op->print(
+ logger.startLine(),
+ OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
+ logger.getOStream() << "\n\n";
+ }
+ });
- // Reverse the list so our pop-back loop processes them in-order.
- std::reverse(worklist.begin(), worklist.end());
- // Remember the reverse index.
- for (size_t i = 0, e = worklist.size(); i != e; ++i)
- worklistMap[worklist[i]] = i;
+ // If the operation is trivially dead - remove it.
+ if (isOpTriviallyDead(op)) {
+ notifyOperationRemoved(op);
+ op->erase();
+ changed = true;
+
+ LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
+ continue;
}
- // These are scratch vectors used in the folding loop below.
- SmallVector<Value, 8> originalOperands, resultValues;
+ // Collects all the operands and result uses of the given `op` into work
+ // list. Also remove `op` and nested ops from worklist.
+ originalOperands.assign(op->operand_begin(), op->operand_end());
+ auto preReplaceAction = [&](Operation *op) {
+ // Add the operands to the worklist for visitation.
+ addOperandsToWorklist(originalOperands);
- changed = false;
- int64_t numRewrites = 0;
- while (!worklist.empty() &&
- (numRewrites < config.maxNumRewrites ||
- config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
- auto *op = popFromWorklist();
+ // Add all the users of the result to the worklist so we make sure
+ // to revisit them.
+ for (auto result : op->getResults())
+ for (auto *userOp : result.getUsers())
+ addToWorklist(userOp);
- // Nulls get added to the worklist when operations are removed, ignore
- // them.
- if (op == nullptr)
- continue;
+ notifyOperationRemoved(op);
+ };
- LLVM_DEBUG({
- logger.getOStream() << "\n";
- logger.startLine() << logLineComment;
- logger.startLine() << "Processing operation : '" << op->getName()
- << "'(" << op << ") {\n";
- logger.indent();
-
- // If the operation has no regions, just print it here.
- if (op->getNumRegions() == 0) {
- op->print(
- logger.startLine(),
- OpPrintingFlags().printGenericOpForm().elideLargeElementsAttrs());
- logger.getOStream() << "\n\n";
- }
- });
+ // Add the given operation to the worklist.
+ auto collectOps = [this](Operation *op) { addToWorklist(op); };
- // If the operation is trivially dead - remove it.
- if (isOpTriviallyDead(op)) {
- notifyOperationRemoved(op);
- op->erase();
- changed = true;
+ // Try to fold this op.
+ bool inPlaceUpdate;
+ if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
+ &inPlaceUpdate)))) {
+ LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
+ changed = true;
+ if (!inPlaceUpdate)
continue;
- }
-
- // Collects all the operands and result uses of the given `op` into work
- // list. Also remove `op` and nested ops from worklist.
- originalOperands.assign(op->operand_begin(), op->operand_end());
- auto preReplaceAction = [&](Operation *op) {
- // Add the operands to the worklist for visitation.
- addOperandsToWorklist(originalOperands);
-
- // Add all the users of the result to the worklist so we make sure
- // to revisit them.
- for (auto result : op->getResults())
- for (auto *userOp : result.getUsers())
- addToWorklist(userOp);
-
- notifyOperationRemoved(op);
- };
-
- // Add the given operation to the worklist.
- auto collectOps = [this](Operation *op) { addToWorklist(op); };
-
- // Try to fold this op.
- bool inPlaceUpdate;
- if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
- &inPlaceUpdate)))) {
- LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-
- changed = true;
- if (!inPlaceUpdate)
- continue;
- }
+ }
- // Try to match one of the patterns. The rewriter is automatically
- // notified of any necessary changes, so there is nothing else to do
- // here.
+ // Try to match one of the patterns. The rewriter is automatically
+ // notified of any necessary changes, so there is nothing else to do
+ // here.
#ifndef NDEBUG
auto canApply = [&](const Pattern &pattern) {
LLVM_DEBUG({
@@ -304,16 +258,9 @@ bool GreedyPatternRewriteDriver::simplify(Region ®ion) && {
changed = true;
++numRewrites;
}
- }
-
- // After applying patterns, make sure that the CFG of each of the regions
- // is kept up to date.
- if (config.enableRegionSimplification)
- changed |= succeeded(simplifyRegions(*this, region));
- } while (changed);
+ }
- // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
- return !changed;
+ return changed;
}
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
@@ -321,12 +268,12 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
SmallVector<Operation *, 8> ancestors;
ancestors.push_back(op);
while (Region *region = op->getParentRegion()) {
- if (config.scope == region) {
- // All gathered ops are in fact ancestors.
- for (Operation *op : ancestors)
- addSingleOpToWorklist(op);
- break;
- }
+ if (config.scope == region) {
+ // All gathered ops are in fact ancestors.
+ for (Operation *op : ancestors)
+ addSingleOpToWorklist(op);
+ break;
+ }
op = region->getParentOp();
if (!op)
break;
@@ -434,12 +381,96 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
return failure();
}
-/// Rewrite the regions of the specified operation, which must be isolated from
-/// above, by repeatedly applying the highest benefit patterns in a greedy
-/// work-list driven manner. Return success if no more patterns can be matched
-/// in the result operation regions. Note: This does not apply patterns to the
-/// top-level operation itself.
-///
+//===----------------------------------------------------------------------===//
+// RegionPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This driver simplfies all ops in a region.
+class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+ explicit RegionPatternRewriteDriver(MLIRContext *ctx,
+ const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config,
+ Region ®ions);
+
+ /// Simplify ops inside `region` and simplify the region itself. Return
+ /// success if the transformation converged.
+ LogicalResult simplify() &&;
+
+private:
+ /// The region that is simplified.
+ Region ®ion;
+};
+} // namespace
+
+RegionPatternRewriteDriver::RegionPatternRewriteDriver(
+ MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config, Region ®ion)
+ : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
+ // Populate strict mode ops.
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
+ region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
+ }
+}
+
+LogicalResult RegionPatternRewriteDriver::simplify() && {
+ auto insertKnownConstant = [&](Operation *op) {
+ // Check for existing constants when populating the worklist. This avoids
+ // accidentally reversing the constant order during processing.
+ Attribute constValue;
+ if (matchPattern(op, m_Constant(&constValue)))
+ if (!folder.insertKnownConstant(op, constValue))
+ return true;
+ return false;
+ };
+
+ bool changed = false;
+ int64_t iteration = 0;
+ do {
+ // Check if the iteration limit was reached.
+ if (iteration++ >= config.maxIterations &&
+ config.maxIterations != GreedyRewriteConfig::kNoLimit)
+ break;
+
+ worklist.clear();
+ worklistMap.clear();
+
+ if (!config.useTopDownTraversal) {
+ // Add operations to the worklist in postorder.
+ region.walk([&](Operation *op) {
+ if (!insertKnownConstant(op))
+ addToWorklist(op);
+ });
+ } else {
+ // Add all nested operations to the worklist in preorder.
+ region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (!insertKnownConstant(op)) {
+ worklist.push_back(op);
+ return WalkResult::advance();
+ }
+ return WalkResult::skip();
+ });
+
+ // Reverse the list so our pop-back loop processes them in-order.
+ std::reverse(worklist.begin(), worklist.end());
+ // Remember the reverse index.
+ for (size_t i = 0, e = worklist.size(); i != e; ++i)
+ worklistMap[worklist[i]] = i;
+ }
+
+ changed = processWorklist();
+
+ // After applying patterns, make sure that the CFG of each of the regions
+ // is kept up to date.
+ if (config.enableRegionSimplification)
+ changed |= succeeded(simplifyRegions(*this, region));
+ } while (changed);
+
+ // Whether the rewrite converges, i.e. wasn't changed in the last iteration.
+ return success(!changed);
+}
+
LogicalResult
mlir::applyPatternsAndFoldGreedily(Region ®ion,
const FrozenRewritePatternSet &patterns,
@@ -455,13 +486,14 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
config.scope = ®ion;
// Start the pattern driver.
- GreedyPatternRewriteDriver driver(region.getContext(), patterns, config);
- bool converged = std::move(driver).simplify(region);
- LLVM_DEBUG(if (!converged) {
+ RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
+ region);
+ LogicalResult converged = std::move(driver).simplify();
+ LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
<< config.maxIterations << " times\n";
});
- return success(converged);
+ return converged;
}
//===----------------------------------------------------------------------===//
@@ -469,32 +501,16 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion,
//===----------------------------------------------------------------------===//
namespace {
-
-/// This is a specialized GreedyPatternRewriteDriver to apply patterns and
-/// perform folding for a supplied set of ops. It repeatedly simplifies while
-/// restricting the rewrites to only the provided set of ops or optionally
-/// to those directly affected by it (result users or operand providers). Parent
-/// ops are not considered.
+/// This driver simplfies a list of ops.
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- const GreedyRewriteConfig &config,
- llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
- : GreedyPatternRewriteDriver(ctx, patterns, config),
- survivingOps(survivingOps) {}
-
- /// Performs the specified rewrites on `ops` while also trying to fold these
- /// ops. `strictMode` controls which other ops are simplified. Only ops
- /// within the given scope region are added to the worklist. If no scope is
- /// specified, it assumed to be closest common region of all `ops`.
- ///
- /// Note that ops in `ops` could be erased as a result of folding, becoming
- /// dead, or via pattern rewrites. The return value indicates convergence.
- ///
- /// All erased ops are stored in `erased`.
- LogicalResult simplifyLocally(ArrayRef<Operation *> op,
- bool *changed = nullptr) &&;
+ const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
+ llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
+
+ /// Simplify `ops`. Return `success` if the transformation converged.
+ LogicalResult simplify(ArrayRef<Operation *> ops, bool *changed = nullptr) &&;
private:
void notifyOperationRemoved(Operation *op) override {
@@ -508,98 +524,33 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
/// of ops.
llvm::SmallDenseSet<Operation *, 4> *const survivingOps = nullptr;
};
-
} // namespace
-LogicalResult
-MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
- bool *changed) && {
+MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
+ MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
+ llvm::SmallDenseSet<Operation *, 4> *survivingOps)
+ : GreedyPatternRewriteDriver(ctx, patterns, config),
+ survivingOps(survivingOps) {
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ strictModeFilteredOps.insert(ops.begin(), ops.end());
+
if (survivingOps) {
survivingOps->clear();
survivingOps->insert(ops.begin(), ops.end());
}
+}
- if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
- strictModeFilteredOps.clear();
- strictModeFilteredOps.insert(ops.begin(), ops.end());
- }
-
- if (changed)
- *changed = false;
- worklist.clear();
- worklistMap.clear();
+LogicalResult MultiOpPatternRewriteDriver::simplify(ArrayRef<Operation *> ops,
+ bool *changed) && {
+ // Populate the initial worklist.
for (Operation *op : ops)
addSingleOpToWorklist(op);
- // These are scratch vectors used in the folding loop below.
- SmallVector<Value, 8> originalOperands, resultValues;
- int64_t numRewrites = 0;
- while (!worklist.empty() &&
- (numRewrites < config.maxNumRewrites ||
- config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
- Operation *op = popFromWorklist();
-
- // Nulls get added to the worklist when operations are removed, ignore
- // them.
- if (op == nullptr)
- continue;
-
- assert((config.strictMode == GreedyRewriteStrictness::AnyOp ||
- strictModeFilteredOps.contains(op)) &&
- "unexpected op was inserted under strict mode");
-
- // If the operation is trivially dead - remove it.
- if (isOpTriviallyDead(op)) {
- notifyOperationRemoved(op);
- op->erase();
- if (changed)
- *changed = true;
- continue;
- }
-
- // Collects all the operands and result uses of the given `op` into work
- // list. Also remove `op` and nested ops from worklist.
- originalOperands.assign(op->operand_begin(), op->operand_end());
- auto preReplaceAction = [&](Operation *op) {
- // Add the operands to the worklist for visitation.
- addOperandsToWorklist(originalOperands);
-
- // Add all the users of the result to the worklist so we make sure
- // to revisit them.
- for (Value result : op->getResults()) {
- for (Operation *userOp : result.getUsers())
- addToWorklist(userOp);
- }
-
- notifyOperationRemoved(op);
- };
-
- // Add the given operation generated by the folder to the worklist.
- auto processGeneratedConstants = [this](Operation *op) {
- notifyOperationInserted(op);
- };
-
- // Try to fold this op.
- bool inPlaceUpdate;
- if (succeeded(folder.tryToFold(op, processGeneratedConstants,
- preReplaceAction, &inPlaceUpdate))) {
- if (changed)
- *changed = true;
- if (!inPlaceUpdate) {
- // Op has been erased.
- continue;
- }
- }
-
- // Try to match one of the patterns. The rewriter is automatically
- // notified of any necessary changes, so there is nothing else to do
- // here.
- if (succeeded(matcher.matchAndRewrite(op, *this))) {
- if (changed)
- *changed = true;
- ++numRewrites;
- }
- }
+ // Process ops on the worklist.
+ bool result = processWorklist();
+ if (changed)
+ *changed = result;
return success(worklist.empty());
}
@@ -658,8 +609,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- config, allErased ? &surviving : nullptr);
- LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
+ config, ops,
+ allErased ? &surviving : nullptr);
+ LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
*allErased = surviving.empty();
LLVM_DEBUG(if (failed(converged)) {
More information about the Mlir-commits
mailing list