[Mlir-commits] [mlir] a02a0e8 - [mlir][Transforms] `GreedyPatternRewriteDriver`: Better expensive checks encapsulation (#78175)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 15 23:55:29 PST 2024
Author: Matthias Springer
Date: 2024-01-16T08:55:25+01:00
New Revision: a02a0e806fab01f4cf4307443cdaed76a2488752
URL: https://github.com/llvm/llvm-project/commit/a02a0e806fab01f4cf4307443cdaed76a2488752
DIFF: https://github.com/llvm/llvm-project/commit/a02a0e806fab01f4cf4307443cdaed76a2488752.diff
LOG: [mlir][Transforms] `GreedyPatternRewriteDriver`: Better expensive checks encapsulation (#78175)
This change moves most IR verification logic (which is part of the
expensive checks) into `DebugFingerPrints` and renames the struct to
`ExpensiveChecks`. This isolates the debugging logic better from the
remaining code.
This commit also removes a redundant check: the IR is no longer verified
after a failed pattern application. We already assert that the IR did
not change. (We know that the IR was valid before the attempted pattern
application.)
Added:
Modified:
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index d31408d240ebd5..36d63d62bf10fc 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -43,12 +43,18 @@ namespace {
//===----------------------------------------------------------------------===//
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-/// A helper struct that stores finger prints of ops in order to detect broken
-/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
-/// using the rewriter API or if it returns an inconsistent return value.
-struct DebugFingerPrints : public RewriterBase::ForwardingListener {
- DebugFingerPrints(RewriterBase::Listener *driver)
- : RewriterBase::ForwardingListener(driver) {}
+/// A helper struct that performs various "expensive checks" to detect broken
+/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
+/// broken if:
+/// * IR does not verify after pattern application / folding.
+/// * Pattern returns "failure" but the IR has changed.
+/// * Pattern returns "success" but the IR has not changed.
+///
+/// This struct stores finger prints of ops to determine whether the IR has
+/// changed or not.
+struct ExpensiveChecks : public RewriterBase::ForwardingListener {
+ ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
+ : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
/// Compute finger prints of the given op and its nested ops.
void computeFingerPrints(Operation *topLevel) {
@@ -65,6 +71,13 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}
void notifyRewriteSuccess() {
+ if (!topLevel)
+ return;
+
+ // Make sure that the IR still verifies.
+ if (failed(verify(topLevel)))
+ llvm::report_fatal_error("IR failed to verify after pattern application");
+
// Pattern application success => IR must have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint == afterFingerPrint) {
@@ -90,6 +103,9 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}
void notifyRewriteFailure() {
+ if (!topLevel)
+ return;
+
// Pattern application failure => IR must not have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint != afterFingerPrint) {
@@ -98,6 +114,15 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}
}
+ void notifyFoldingSuccess() {
+ if (!topLevel)
+ return;
+
+ // Make sure that the IR still verifies.
+ if (failed(verify(topLevel)))
+ llvm::report_fatal_error("IR failed to verify after folding");
+ }
+
protected:
/// Invalidate the finger print of the given op, i.e., remove it from the map.
void invalidateFingerPrint(Operation *op) {
@@ -362,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
PatternApplicator matcher;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- DebugFingerPrints debugFingerPrints;
+ ExpensiveChecks expensiveChecks;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
};
} // namespace
@@ -373,7 +398,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
: PatternRewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
- , debugFingerPrints(this)
+ , expensiveChecks(
+ /*driver=*/this,
+ /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
@@ -384,7 +411,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
- setListener(&debugFingerPrints);
+ setListener(&expensiveChecks);
#else
setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -458,8 +485,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after folding");
+ expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
@@ -513,8 +539,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after folding");
+ expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
@@ -551,33 +576,26 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
- debugFingerPrints.computeFingerPrints(config.scope->getParentOp());
+ expensiveChecks.computeFingerPrints(config.scope->getParentOp());
}
auto clearFingerprints =
- llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
+ llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after pattern application");
-#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope)
- debugFingerPrints.notifyRewriteSuccess();
+ expensiveChecks.notifyRewriteSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
changed = true;
++numRewrites;
} else {
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope)
- debugFingerPrints.notifyRewriteFailure();
+ expensiveChecks.notifyRewriteFailure();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
}
More information about the Mlir-commits
mailing list