[Mlir-commits] [mlir] e195e6b - [mlir] GreedyPatternRewriteDriver: Enqueue ancestors in MultiOpPatternRewriteDriver
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 01:39:19 PST 2023
Author: Matthias Springer
Date: 2023-01-27T10:39:10+01:00
New Revision: e195e6bad6706230a4b5fd4b5cc13de1f16f25cc
URL: https://github.com/llvm/llvm-project/commit/e195e6bad6706230a4b5fd4b5cc13de1f16f25cc
DIFF: https://github.com/llvm/llvm-project/commit/e195e6bad6706230a4b5fd4b5cc13de1f16f25cc.diff
LOG: [mlir] GreedyPatternRewriteDriver: Enqueue ancestors in MultiOpPatternRewriteDriver
The `GreedyPatternRewriteDriver` was extended to enqueue ancestors in D140304. With this change, `MultiOpPatternRewriteDriver` behaves the same way.
Note: `MultiOpPatternRewriteDriver` now also has a scope that limits how far we go when checking ancestors. By default, this is the first common region of all given ops.
Differential Revision: https://reviews.llvm.org/D141945
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index f72dbb7ff2986..dce47834547e8 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -112,6 +112,12 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
/// simplified. All other ops are excluded.
///
+/// In addition to strictness, a region scope can be specified. Only ops within
+/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
+/// where only ops within the given regions are simplified. If no scope is
+/// specified, it is assumed to be the first common enclosing region of the
+/// given ops.
+///
/// Note that ops in `ops` could be erased as result of folding, becoming dead,
/// or via pattern rewrites. If more far reaching simplification is desired,
/// applyPatternsAndFoldGreedily should be used.
@@ -123,7 +129,8 @@ LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteStrictness strictMode,
bool *changed = nullptr,
- bool *allErased = nullptr);
+ bool *allErased = nullptr,
+ Region *scope = nullptr);
} // namespace mlir
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 2b3a796dee93f..ead229dacae80 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -16,6 +16,7 @@
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/CommandLine.h"
@@ -43,10 +44,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Simplify the operations within the given regions.
bool simplify(MutableArrayRef<Region> regions);
- /// Add the given operation to the worklist. Parent ops may or may not be
- /// added to the worklist, depending on the type of rewrite driver. By
- /// default, parent ops are added.
- virtual void addToWorklist(Operation *op);
+ /// Add the given operation and its ancestors to the worklist.
+ void addToWorklist(Operation *op);
/// Pop the next operation from the worklist.
Operation *popFromWorklist();
@@ -60,7 +59,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
protected:
/// Add the given operation to the worklist.
- void addSingleOpToWorklist(Operation *op);
+ virtual void addSingleOpToWorklist(Operation *op);
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
@@ -103,11 +102,12 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Configuration information for how to simplify.
GreedyRewriteConfig config;
-private:
/// Only ops within this scope are simplified. This is set at the beginning
- /// of `simplify()` to the current scope the rewriter operates on.
+ /// of `simplify()` and `simplifyLocally()` to the current scope the rewriter
+ /// operates on.
DenseSet<Region *> scope;
+private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -126,6 +126,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
}
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
+ scope.clear();
for (Region &r : regions)
scope.insert(&r);
@@ -581,7 +582,9 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
strictMode(strictMode) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
- /// ops. `strictMode` controls which other ops are simplified.
+ /// ops. `strictMode` controls which other ops are simplified. Only ops
+ /// within the given scope region are added to the worklist. If no scope is
+ /// specified, it assumed to be closest common region of all `ops`.
///
/// Note that ops in `ops` could be erased as a result of folding, becoming
/// dead, or via pattern rewrites. The return value indicates convergence.
@@ -589,9 +592,11 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
/// All `ops` that survived the rewrite are stored in `surviving`.
LogicalResult
simplifyLocally(ArrayRef<Operation *> ops, bool *changed = nullptr,
- llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr);
+ llvm::SmallDenseSet<Operation *, 4> *surviving = nullptr,
+ Region *scope = nullptr);
- void addToWorklist(Operation *op) override {
+protected:
+ void addSingleOpToWorklist(Operation *op) override {
if (strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op))
GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
@@ -632,7 +637,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
ArrayRef<Operation *> ops, bool *changed,
- llvm::SmallDenseSet<Operation *, 4> *surviving) {
+ llvm::SmallDenseSet<Operation *, 4> *surviving, Region *scope) {
auto cleanup = llvm::make_scope_exit([&]() { survivingOps = nullptr; });
if (surviving) {
survivingOps = surviving;
@@ -645,12 +650,16 @@ LogicalResult MultiOpPatternRewriteDriver::simplifyLocally(
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
+ assert(scope && "scope is mandatory");
+ this->scope.clear();
+ this->scope.insert(scope);
+
if (changed)
*changed = false;
worklist.clear();
worklistMap.clear();
for (Operation *op : ops)
- addToWorklist(op);
+ addSingleOpToWorklist(op);
// These are scratch vectors used in the folding loop below.
SmallVector<Value, 8> originalOperands, resultValues;
@@ -742,9 +751,37 @@ LogicalResult mlir::applyOpPatternsAndFold(
return converged;
}
-LogicalResult mlir::applyOpPatternsAndFold(
- ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, bool *changed, bool *allErased) {
+/// Find the region that is the closest common ancestor of all given ops.
+static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
+ assert(!ops.empty() && "expected at least one op");
+ // Fast path in case there is only one op.
+ if (ops.size() == 1)
+ return ops.front()->getParentRegion();
+
+ Region *region = ops.front()->getParentRegion();
+ ops = ops.drop_front();
+ int sz = ops.size();
+ llvm::BitVector remainingOps(sz, true);
+ do {
+ int pos = -1;
+ // Iterate over all remaining ops.
+ while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) {
+ // Is this op contained in `region`?
+ if (region->findAncestorOpInRegion(*ops[pos]))
+ remainingOps.reset(pos);
+ }
+ if (remainingOps.none())
+ break;
+ } while ((region = region->getParentRegion()));
+ assert(region && "could not find common parent region");
+ return region;
+}
+
+LogicalResult
+mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+ const FrozenRewritePatternSet &patterns,
+ GreedyRewriteStrictness strictMode, bool *changed,
+ bool *allErased, Region *scope) {
if (ops.empty()) {
if (changed)
*changed = false;
@@ -753,12 +790,25 @@ LogicalResult mlir::applyOpPatternsAndFold(
return success();
}
+ if (!scope) {
+ // Compute scope if none was provided.
+ scope = findCommonAncestor(ops);
+ } else {
+ // If a scope was provided, make sure that all ops are in scope.
+#ifndef NDEBUG
+ bool allOpsInScope = llvm::all_of(ops, [&](Operation *op) {
+ return static_cast<bool>(scope->findAncestorOpInRegion(*op));
+ });
+ assert(allOpsInScope && "ops must be within the specified scope");
+#endif // NDEBUG
+ }
+
// Start the pattern driver.
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
strictMode);
llvm::SmallDenseSet<Operation *, 4> surviving;
- LogicalResult converged =
- driver.simplifyLocally(ops, changed, allErased ? &surviving : nullptr);
+ LogicalResult converged = driver.simplifyLocally(
+ ops, changed, allErased ? &surviving : nullptr, /*scope=*/scope);
if (allErased)
*allErased = surviving.empty();
return converged;
More information about the Mlir-commits
mailing list