[Mlir-commits] [mlir] fefe655 - [mlir][NFC] GreedyPatternRewriteDriver: Consistent return values
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 16 07:34:40 PST 2023
Author: Matthias Springer
Date: 2023-01-16T16:30:12+01:00
New Revision: fefe655baafb9aa11ae3e2a34b19aef1f47e2b8d
URL: https://github.com/llvm/llvm-project/commit/fefe655baafb9aa11ae3e2a34b19aef1f47e2b8d
DIFF: https://github.com/llvm/llvm-project/commit/fefe655baafb9aa11ae3e2a34b19aef1f47e2b8d.diff
LOG: [mlir][NFC] GreedyPatternRewriteDriver: Consistent return values
All `apply...` functions now return a LogicalResult indicating whether the iterative process converged or not.
Differential Revision: https://reviews.llvm.org/D141845
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 5478587dcc43..aaafebf252fa 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -80,6 +80,9 @@ inline LogicalResult applyPatternsAndFoldGreedily(
/// success if no more patterns can be matched. `erased` is set to true if `op`
/// was folded away or erased as a result of becoming dead. Note: This does not
/// apply any patterns recursively to the regions of `op`.
+///
+/// Returns success if the iterative process converged and no more patterns can
+/// be matched.
LogicalResult applyOpPatternsAndFold(Operation *op,
const FrozenRewritePatternSet &patterns,
bool *erased = nullptr);
@@ -93,10 +96,13 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a
/// result of folding, becoming dead, or via pattern rewrites. If more far
/// reaching simplification is desired, applyPatternsAndFoldGreedily should be
-/// used. Returns true if at all any IR was rewritten.
-bool applyOpPatternsAndFold(ArrayRef<Operation *> ops,
- const FrozenRewritePatternSet &patterns,
- bool strict);
+/// used.
+///
+/// Returns success if the iterative process converged and no more patterns can
+/// be matched. `changed` is set to true if the IR was modified at all.
+LogicalResult applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+ const FrozenRewritePatternSet &patterns,
+ bool strict, bool *changed = nullptr);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
index 24ed10eb593d..282a35b45c53 100644
--- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
+++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp
@@ -131,7 +131,12 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
// Apply the simplification pattern to a fixpoint.
- (void)applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true);
+ if (failed(
+ applyOpPatternsAndFold(targets, frozenPatterns, /*strict=*/true))) {
+ auto diag = emitDefiniteFailure()
+ << "affine.min/max simplification did not converge";
+ return diag;
+ }
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index f0794f8dec3f..6bd3994c4313 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -574,7 +574,8 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
: GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
strictMode(strict) {}
- bool simplifyLocally(ArrayRef<Operation *> op);
+ LogicalResult simplifyLocally(ArrayRef<Operation *> op,
+ bool *changed = nullptr);
void addToWorklist(Operation *op) override {
if (!strictMode || strictModeFilteredOps.contains(op))
@@ -625,13 +626,16 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
// there is no strong rationale to re-add all operations into the worklist and
// rerun until an iteration changes nothing. If more widereaching simplification
// is desired, GreedyPatternRewriteDriver should be used.
-bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
+LogicalResult
+MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
+ bool *changed) {
if (strictMode) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
- bool changed = false;
+ if (changed)
+ *changed = false;
worklist.clear();
worklistMap.clear();
for (Operation *op : ops)
@@ -657,7 +661,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
if (isOpTriviallyDead(op)) {
notifyOperationRemoved(op);
op->erase();
- changed = true;
+ if (changed)
+ *changed = true;
continue;
}
@@ -687,7 +692,8 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
bool inPlaceUpdate;
if (succeeded(folder.tryToFold(op, processGeneratedConstants,
preReplaceAction, &inPlaceUpdate))) {
- changed = true;
+ if (changed)
+ *changed = true;
if (!inPlaceUpdate) {
// Op has been erased.
continue;
@@ -698,12 +704,13 @@ bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
// notified of any necessary changes, so there is nothing else to do
// here.
if (succeeded(matcher.matchAndRewrite(op, *this))) {
- changed = true;
+ if (changed)
+ *changed = true;
++numRewrites;
}
}
- return changed;
+ return success(worklist.empty());
}
/// Rewrites only `op` using the supplied canonicalization patterns and
@@ -726,14 +733,18 @@ LogicalResult mlir::applyOpPatternsAndFold(
return converged;
}
-bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
- const FrozenRewritePatternSet &patterns,
- bool strict) {
- if (ops.empty())
- return false;
+LogicalResult
+mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+ const FrozenRewritePatternSet &patterns,
+ bool strict, bool *changed) {
+ if (ops.empty()) {
+ if (changed)
+ *changed = false;
+ return success();
+ }
// Start the pattern driver.
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
strict);
- return driver.simplifyLocally(ops);
+ return driver.simplifyLocally(ops, changed);
}
More information about the Mlir-commits
mailing list