[Mlir-commits] [mlir] 6bdecbc - [mlir] GreedyPatternRewriteDriver: Move strict mode to GreedyPatternRewriteDriver
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 27 06:58:16 PST 2023
Author: Matthias Springer
Date: 2023-01-27T15:52:01+01:00
New Revision: 6bdecbcb99bc6b8fa25a2841bf2087bdbb91b4aa
URL: https://github.com/llvm/llvm-project/commit/6bdecbcb99bc6b8fa25a2841bf2087bdbb91b4aa
DIFF: https://github.com/llvm/llvm-project/commit/6bdecbcb99bc6b8fa25a2841bf2087bdbb91b4aa.diff
LOG: [mlir] GreedyPatternRewriteDriver: Move strict mode to GreedyPatternRewriteDriver
`strictMode` is moved to GreedyRewriteConfig to simplify the API and state of rewriter classes. The region-based GreedyPatternRewriteDriver now also supports strict mode.
MultiOpPatternRewriteDriver becomes simpler: fewer method must be overridden.
Differential Revision: https://reviews.llvm.org/D142623
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/Affine/Utils/Utils.cpp
mlir/lib/Reducer/ReductionTreePass.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index e4251103d6d58..217953493c41d 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -63,6 +63,18 @@ class GreedyRewriteConfig {
/// Only ops within the scope are added to the worklist. If no scope is
/// specified, the closest enclosing region is used as a scope.
Region *scope = nullptr;
+
+ /// Strict mode can restrict the ops that are added to the worklist during
+ /// the rewrite.
+ ///
+ /// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
+ /// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing ops (that
+ /// were on the worklist at the very beginning) and newly created ops are
+ /// enqueued. All other ops are excluded.
+ /// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
+ /// were on the worklist at the very beginning) enqueued. All other ops are
+ /// excluded.
+ GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
};
//===----------------------------------------------------------------------===//
@@ -105,14 +117,8 @@ inline LogicalResult applyPatternsAndFoldGreedily(
///
/// Newly created ops and other pre-existing ops that use results of rewritten
/// ops or supply operands to such ops are simplified, unless such ops are
-/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless
-/// of `strictMode`).
-///
-/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
-/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly
-/// created ops are simplified. All other ops are excluded.
-/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
-/// simplified. All other ops are excluded.
+/// excluded via `config.strictMode`. Any other ops remain unmodified (i.e.,
+/// regardless of `strictMode`).
///
/// In addition to strictness, a region scope can be specified. Only ops within
/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
@@ -130,23 +136,17 @@ inline LogicalResult applyPatternsAndFoldGreedily(
LogicalResult
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
-/// Applies the specified patterns on `op` alone while also trying to fold it,
-/// by selecting the highest benefits patterns in a greedy manner. Returns
-/// success if no more patterns can be matched. `erased` is set to true if `op`
-/// was folded away or erased as a result of becoming dead.
-///
-/// Returns success if the iterative process converged and no more patterns can
-/// be matched.
+/// Applies the specified patterns on `op` while also trying to fold it.
+/// This function is a shortcut for the ArrayRef<Operation *> overload and
+/// behaves the same way.
inline LogicalResult
applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *erased = nullptr) {
- return applyOpPatternsAndFold(ArrayRef(op), patterns,
- GreedyRewriteStrictness::ExistingOps, config,
+ return applyOpPatternsAndFold(ArrayRef(op), patterns, config,
/*changed=*/nullptr, erased);
}
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index d516de8ee6d26..19ee04d10b19a 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -130,10 +130,10 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
- if (failed(
- applyOpPatternsAndFold(targets, frozenPatterns,
- GreedyRewriteStrictness::ExistingAndNewOps))) {
+ if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index a9d6f940200b0..716f95314eac9 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -239,6 +239,7 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- (void)applyOpPatternsAndFold(copyOps, frozenPatterns,
- GreedyRewriteStrictness::ExistingAndNewOps);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ (void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 6cb0a30dce39f..8039484ce8d62 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -105,6 +105,7 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
- (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns,
- GreedyRewriteStrictness::ExistingAndNewOps);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config);
}
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index a146d5adfa637..54d2bd469db66 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -321,9 +321,10 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
// Simplify/canonicalize the affine.for.
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
- (void)applyOpPatternsAndFold(res, std::move(patterns),
- GreedyRewriteConfig(), &erased);
+ (void)applyOpPatternsAndFold(res, std::move(patterns), config, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 286c38ec18ee5..24acc60cd1832 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -413,10 +413,11 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
// in which case we return with `folded` being set.
RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
- bool erased;
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- (void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(),
- &erased);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ bool erased;
+ (void)applyOpPatternsAndFold(ifOp, frozenPatterns, config, &erased);
if (erased) {
if (folded)
*folded = true;
diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp
index a30b6d19413e0..b00045a3a41b7 100644
--- a/mlir/lib/Reducer/ReductionTreePass.cpp
+++ b/mlir/lib/Reducer/ReductionTreePass.cpp
@@ -60,10 +60,13 @@ static void applyPatterns(Region ®ion,
// matching in above iteration. Besides, erase op not-in-range may end up in
// invalid module, so `applyOpPatternsAndFold` should come before that
// transform.
- for (Operation *op : opsInRange)
+ for (Operation *op : opsInRange) {
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
// because we don't have expectation this reduction will be success or not.
- (void)applyOpPatternsAndFold(op, patterns);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
+ (void)applyOpPatternsAndFold(op, patterns, config);
+ }
if (eraseOpNotInRange)
for (Operation *op : opsNotInRange) {
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 4a37730129691..2ef46b1fdcfc1 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -59,7 +59,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
protected:
/// Add the given operation to the worklist.
- virtual void addSingleOpToWorklist(Operation *op);
+ void addSingleOpToWorklist(Operation *op);
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
@@ -102,6 +102,12 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
+ /// The list of ops we are restricting our rewrites to. These include the
+ /// supplied set of ops as well as new ops created while rewriting those ops
+ /// depending on `strictMode`. This set is not maintained when
+ /// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
+ llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+
private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
@@ -150,6 +156,12 @@ bool GreedyPatternRewriteDriver::simplify(Region ®ion) && {
return false;
};
+ // Populate strict mode ops.
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
+ strictModeFilteredOps.clear();
+ region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
+ }
+
bool changed = false;
int64_t iteration = 0;
do {
@@ -323,12 +335,15 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
- // Check to see if the worklist already contains this op.
- if (worklistMap.count(op))
- return;
-
- worklistMap[op] = worklist.size();
- worklist.push_back(op);
+ if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
+ strictModeFilteredOps.contains(op)) {
+ // Check to see if the worklist already contains this op.
+ if (worklistMap.count(op))
+ return;
+
+ worklistMap[op] = worklist.size();
+ worklist.push_back(op);
+ }
}
Operation *GreedyPatternRewriteDriver::popFromWorklist() {
@@ -355,6 +370,8 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
+ if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
+ strictModeFilteredOps.insert(op);
addToWorklist(op);
}
@@ -391,6 +408,9 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
removeFromWorklist(operation);
folder.notifyRemoval(operation);
});
+
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ strictModeFilteredOps.erase(op);
}
void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
@@ -459,10 +479,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
+ const GreedyRewriteConfig &config,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
: GreedyPatternRewriteDriver(ctx, patterns, config),
- strictMode(strictMode), survivingOps(survivingOps) {}
+ survivingOps(survivingOps) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
/// ops. `strictMode` controls which other ops are simplified. Only ops
@@ -476,38 +496,13 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
LogicalResult simplifyLocally(ArrayRef<Operation *> op,
bool *changed = nullptr) &&;
-protected:
- void addSingleOpToWorklist(Operation *op) override {
- if (strictMode == GreedyRewriteStrictness::AnyOp ||
- strictModeFilteredOps.contains(op))
- GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
- }
-
private:
- void notifyOperationInserted(Operation *op) override {
- if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
- strictModeFilteredOps.insert(op);
- GreedyPatternRewriteDriver::notifyOperationInserted(op);
- }
-
void notifyOperationRemoved(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationRemoved(op);
if (survivingOps)
survivingOps->erase(op);
- if (strictMode != GreedyRewriteStrictness::AnyOp)
- strictModeFilteredOps.erase(op);
}
- /// `strictMode` control which ops are added to the worklist during
- /// simplification.
- const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
-
- /// The list of ops we are restricting our rewrites to. These include the
- /// supplied set of ops as well as new ops created while rewriting those ops
- /// depending on `strictMode`. This set is not maintained when `strictMode`
- /// is GreedyRewriteStrictness::AnyOp.
- llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
-
/// An optional set of ops that survived the rewrite. This set is populated
/// at the beginning of `simplifyLocally` with the inititally provided list
/// of ops.
@@ -524,7 +519,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
survivingOps->insert(ops.begin(), ops.end());
}
- if (strictMode != GreedyRewriteStrictness::AnyOp) {
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
@@ -549,7 +544,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
if (op == nullptr)
continue;
- assert((strictMode == GreedyRewriteStrictness::AnyOp ||
+ assert((config.strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op)) &&
"unexpected op was inserted under strict mode");
@@ -637,8 +632,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
LogicalResult mlir::applyOpPatternsAndFold(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
- GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
- bool *changed, bool *allErased) {
+ GreedyRewriteConfig config, bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
@@ -664,8 +658,7 @@ LogicalResult mlir::applyOpPatternsAndFold(
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- strictMode, config,
- allErased ? &surviving : nullptr);
+ config, allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)
*allErased = surviving.empty();
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 7dc478c8b9cf1..62671dde34b0e 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -132,8 +132,9 @@ void TestAffineDataCopy::runOnOperation() {
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
}
}
- (void)applyOpPatternsAndFold(copyOps, std::move(patterns),
- GreedyRewriteStrictness::ExistingAndNewOps);
+ GreedyRewriteConfig config;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
+ (void)applyOpPatternsAndFold(copyOps, std::move(patterns), config);
}
namespace mlir {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index da018708c702e..de4fa39218844 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -266,13 +266,13 @@ struct TestStrictPatternDriver
}
});
- GreedyRewriteStrictness mode;
+ GreedyRewriteConfig config;
if (strictMode == "AnyOp") {
- mode = GreedyRewriteStrictness::AnyOp;
+ config.strictMode = GreedyRewriteStrictness::AnyOp;
} else if (strictMode == "ExistingAndNewOps") {
- mode = GreedyRewriteStrictness::ExistingAndNewOps;
+ config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
} else if (strictMode == "ExistingOps") {
- mode = GreedyRewriteStrictness::ExistingOps;
+ config.strictMode = GreedyRewriteStrictness::ExistingOps;
} else {
llvm_unreachable("invalid strictness option");
}
@@ -282,8 +282,8 @@ struct TestStrictPatternDriver
// operation will trigger the assertion while processing.
bool changed = false;
bool allErased = false;
- (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
- GreedyRewriteConfig(), &changed, &allErased);
+ (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
+ &changed, &allErased);
Builder b(ctx);
getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
getOperation()->setAttr("pattern_driver_all_erased",
More information about the Mlir-commits
mailing list