[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Mar 6 00:08:42 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds two new notifications to `RewriterBase::Listener`:
* `notifyPatternBegin`: Called when a pattern application begins during a greedy pattern rewrite or dialect conversion.
* `notifyPatternEnd`: Called when a pattern application finishes during a greedy pattern rewrite or dialect conversion.
The listener infrastructure already provides a `notifyMatchFailure` callback that notifies about the reason for a pattern match failure. The two new notifications provide additional information about pattern applications.
This change is in preparation of improving the handle update mechanism in the `apply_conversion_patterns` transform op.
---
Full diff: https://github.com/llvm/llvm-project/pull/84131.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/PatternMatch.h (+25-5)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+21-8)
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+31-22)
``````````diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..838b4947648f5e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
/// Note: This notification is not triggered when unlinking an operation.
virtual void notifyOperationErased(Operation *op) {}
- /// Notify the listener that the pattern failed to match the given
- /// operation, and provide a callback to populate a diagnostic with the
- /// reason why the failure occurred. This method allows for derived
- /// listeners to optionally hook into the reason why a rewrite failed, and
- /// display it to users.
+ /// Notify the listener that the specified pattern is about to be applied
+ /// at the specified root operation.
+ virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
+
+ /// Notify the listener that a pattern application finished with the
+ /// specified status. "success" indicates that the pattern was applied
+ /// successfully. "failure" indicates that the pattern could not be
+ /// applied. The pattern may have communicated the reason for the failure
+ /// with `notifyMatchFailure`.
+ virtual void notifyPatternEnd(const Pattern &pattern,
+ LogicalResult status) {}
+
+ /// Notify the listener that the pattern failed to match, and provide a
+ /// callback to populate a diagnostic with the reason why the failure
+ /// occurred. This method allows for derived listeners to optionally hook
+ /// into the reason why a rewrite failed, and display it to users.
virtual void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyOperationErased(op);
}
+ void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyPatternBegin(pattern, op);
+ }
+ void notifyPatternEnd(const Pattern &pattern,
+ LogicalResult status) override {
+ if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+ rewriteListener->notifyPatternEnd(pattern, status);
+ }
void notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a5145246bc30c4..587fbe209b58af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1863,7 +1863,8 @@ class OperationLegalizer {
using LegalizationAction = ConversionTarget::LegalizationAction;
OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ const ConversionConfig &config);
/// Returns true if the given operation is known to be illegal on the target.
bool isIllegal(Operation *op) const;
@@ -1955,12 +1956,16 @@ class OperationLegalizer {
/// The pattern applicator to use for conversions.
PatternApplicator applicator;
+
+ /// Dialect conversion configuration.
+ const ConversionConfig &config;
};
} // namespace
OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
- const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ const FrozenRewritePatternSet &patterns,
+ const ConversionConfig &config)
+ : target(targetInfo), applicator(patterns), config(config) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2105,7 +2110,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- return canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern, rewriter);
+ if (canApply && config.listener)
+ config.listener->notifyPatternBegin(pattern, op);
+ return canApply;
};
// Functor that cleans up the rewriter state after a pattern failed to match.
@@ -2122,6 +2130,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
rewriterImpl.config.notifyCallback(diag);
}
});
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, failure());
rewriterImpl.resetState(curState);
appliedPatterns.erase(&pattern);
};
@@ -2134,6 +2144,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
appliedPatterns.erase(&pattern);
if (failed(result))
rewriterImpl.resetState(curState);
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, result);
return result;
};
@@ -2509,7 +2521,8 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
- : opLegalizer(target, patterns), config(config), mode(mode) {}
+ : config(config), opLegalizer(target, patterns, this->config),
+ mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
@@ -2546,12 +2559,12 @@ struct OperationConverter {
ConversionPatternRewriterImpl &rewriterImpl,
const DenseMap<Value, SmallVector<Value>> &inverseMapping);
- /// The legalizer to use when converting operations.
- OperationLegalizer opLegalizer;
-
/// Dialect conversion configuration.
ConversionConfig config;
+ /// The legalizer to use when converting operations.
+ OperationLegalizer opLegalizer;
+
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
};
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 51d2f5e01b7235..5fda6f87196f94 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Try to match one of the patterns. The rewriter is automatically
// notified of any necessary changes, so there is nothing else to do
// here.
-#ifndef NDEBUG
- auto canApply = [&](const Pattern &pattern) {
- LLVM_DEBUG({
- logger.getOStream() << "\n";
- logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
- << op->getName() << " -> (";
- llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
- logger.getOStream() << ")' {\n";
- logger.indent();
- });
- return true;
- };
- auto onFailure = [&](const Pattern &pattern) {
- LLVM_DEBUG(logResult("failure", "pattern failed to match"));
- };
- auto onSuccess = [&](const Pattern &pattern) {
- LLVM_DEBUG(logResult("success", "pattern applied successfully"));
- return success();
- };
-#else
function_ref<bool(const Pattern &)> canApply = {};
function_ref<void(const Pattern &)> onFailure = {};
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
-#endif
+ bool debugBuild = false;
+#ifdef NDEBUG
+ debugBuild = true;
+#endif // NDEBUG
+ if (debugBuild || config.listener) {
+ canApply = [&](const Pattern &pattern) {
+ LLVM_DEBUG({
+ logger.getOStream() << "\n";
+ logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
+ << op->getName() << " -> (";
+ llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
+ logger.getOStream() << ")' {\n";
+ logger.indent();
+ });
+ if (config.listener)
+ config.listener->notifyPatternBegin(pattern, op);
+ return true;
+ };
+ onFailure = [&](const Pattern &pattern) {
+ LLVM_DEBUG(logResult("failure", "pattern failed to match"));
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, failure());
+ };
+ onSuccess = [&](const Pattern &pattern) {
+ LLVM_DEBUG(logResult("success", "pattern applied successfully"));
+ if (config.listener)
+ config.listener->notifyPatternEnd(pattern, success());
+ return success();
+ };
+ }
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope) {
@@ -731,7 +740,7 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
- logger.startLine() << "** Failure : " << diag.str() << "\n";
+ logger.startLine() << "** Match Failure : " << diag.str() << "\n";
});
if (config.listener)
config.listener->notifyMatchFailure(loc, reasonCallback);
``````````
</details>
https://github.com/llvm/llvm-project/pull/84131
More information about the llvm-branch-commits
mailing list