[Mlir-commits] [mlir] [mlir][Transforms] `GreedyPatternRewriteDriver`: Better expensive checks encapsulation (PR #78175)
Matthias Springer
llvmlistbot at llvm.org
Mon Jan 15 07:45:21 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/78175
This change moves most IR verification logic (which is part of the expensive checks) into `DebugFingerPrints` and renames the struct to `ExpensiveChecks`. This isolates the debugging logic better from the remaining code.
This commit also removes a redundant check: the IR is no longer verified after a failed pattern application. We already assert that the IR did not change. (We know that the IR was valid before the attempted pattern application.)
>From bf248455dcc4b77d4546f7a5cc0b8f6e8ba2db44 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 15 Jan 2024 15:37:32 +0000
Subject: [PATCH] [mlir][Transforms] `GreedyPatternRewriteDriver`: Better
expensive checks encapsualtion
This change moves most IR verification logic (which is part of the expensive
checks) into `DebugFingerPrints` and renames the struct to
`ExpensiveChecks`. This isolates the debugging logic better from the
remaining code.
This commit also removes a redundant check: the IR is no longer
verified after a failed pattern application. We already assert that the
IR did not change. (We know that the IR was valid before the attempted pattern
application.)
---
.../Utils/GreedyPatternRewriteDriver.cpp | 66 ++++++++++++-------
1 file changed, 42 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index d31408d240ebd57..36d63d62bf10fc2 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -43,12 +43,18 @@ namespace {
//===----------------------------------------------------------------------===//
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-/// A helper struct that stores finger prints of ops in order to detect broken
-/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
-/// using the rewriter API or if it returns an inconsistent return value.
-struct DebugFingerPrints : public RewriterBase::ForwardingListener {
- DebugFingerPrints(RewriterBase::Listener *driver)
- : RewriterBase::ForwardingListener(driver) {}
+/// A helper struct that performs various "expensive checks" to detect broken
+/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
+/// broken if:
+/// * IR does not verify after pattern application / folding.
+/// * Pattern returns "failure" but the IR has changed.
+/// * Pattern returns "success" but the IR has not changed.
+///
+/// This struct stores finger prints of ops to determine whether the IR has
+/// changed or not.
+struct ExpensiveChecks : public RewriterBase::ForwardingListener {
+ ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
+ : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
/// Compute finger prints of the given op and its nested ops.
void computeFingerPrints(Operation *topLevel) {
@@ -65,6 +71,13 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}
void notifyRewriteSuccess() {
+ if (!topLevel)
+ return;
+
+ // Make sure that the IR still verifies.
+ if (failed(verify(topLevel)))
+ llvm::report_fatal_error("IR failed to verify after pattern application");
+
// Pattern application success => IR must have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint == afterFingerPrint) {
@@ -90,6 +103,9 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}
void notifyRewriteFailure() {
+ if (!topLevel)
+ return;
+
// Pattern application failure => IR must not have changed.
OperationFingerPrint afterFingerPrint(topLevel);
if (*topLevelFingerPrint != afterFingerPrint) {
@@ -98,6 +114,15 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
}
}
+ void notifyFoldingSuccess() {
+ if (!topLevel)
+ return;
+
+ // Make sure that the IR still verifies.
+ if (failed(verify(topLevel)))
+ llvm::report_fatal_error("IR failed to verify after folding");
+ }
+
protected:
/// Invalidate the finger print of the given op, i.e., remove it from the map.
void invalidateFingerPrint(Operation *op) {
@@ -362,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
PatternApplicator matcher;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- DebugFingerPrints debugFingerPrints;
+ ExpensiveChecks expensiveChecks;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
};
} // namespace
@@ -373,7 +398,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
: PatternRewriter(ctx), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
- , debugFingerPrints(this)
+ , expensiveChecks(
+ /*driver=*/this,
+ /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
// clang-format on
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
{
@@ -384,7 +411,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Send IR notifications to the debug handler. This handler will then forward
// all notifications to this GreedyPatternRewriteDriver.
- setListener(&debugFingerPrints);
+ setListener(&expensiveChecks);
#else
setListener(this);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -458,8 +485,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after folding");
+ expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
@@ -513,8 +539,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
changed = true;
LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after folding");
+ expensiveChecks.notifyFoldingSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
continue;
}
@@ -551,33 +576,26 @@ bool GreedyPatternRewriteDriver::processWorklist() {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
- debugFingerPrints.computeFingerPrints(config.scope->getParentOp());
+ expensiveChecks.computeFingerPrints(config.scope->getParentOp());
}
auto clearFingerprints =
- llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
+ llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after pattern application");
-#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope)
- debugFingerPrints.notifyRewriteSuccess();
+ expensiveChecks.notifyRewriteSuccess();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
changed = true;
++numRewrites;
} else {
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope)
- debugFingerPrints.notifyRewriteFailure();
+ expensiveChecks.notifyRewriteFailure();
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
}
More information about the Mlir-commits
mailing list