[llvm-branch-commits] [mlir] [mlir][IR] Add listener notifications for pattern begin/end (PR #84131)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Mar 6 00:08:13 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/84131

This commit adds two new notifications to `RewriterBase::Listener`:
* `notifyPatternBegin`: Called when a pattern application begins during a greedy pattern rewrite or dialect conversion.
* `notifyPatternEnd`: Called when a pattern application finishes during a greedy pattern rewrite or dialect conversion.

The listener infrastructure already provides a `notifyMatchFailure` callback that notifies about the reason for a pattern match failure. The two new notifications provide additional information about pattern applications.

This change is in preparation of improving the handle update mechanism in the `apply_conversion_patterns` transform op.


>From 3ecec094d43a4eb9405ad057c32e4579f1fbe680 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 6 Mar 2024 08:00:35 +0000
Subject: [PATCH] [mlir][IR] Add listener notifications for pattern begin/end

---
 mlir/include/mlir/IR/PatternMatch.h           | 30 +++++++++--
 .../Transforms/Utils/DialectConversion.cpp    | 29 +++++++---
 .../Utils/GreedyPatternRewriteDriver.cpp      | 53 +++++++++++--------
 3 files changed, 77 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..838b4947648f5e 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -432,11 +432,22 @@ class RewriterBase : public OpBuilder {
     /// Note: This notification is not triggered when unlinking an operation.
     virtual void notifyOperationErased(Operation *op) {}
 
-    /// Notify the listener that the pattern failed to match the given
-    /// operation, and provide a callback to populate a diagnostic with the
-    /// reason why the failure occurred. This method allows for derived
-    /// listeners to optionally hook into the reason why a rewrite failed, and
-    /// display it to users.
+    /// Notify the listener that the specified pattern is about to be applied
+    /// at the specified root operation.
+    virtual void notifyPatternBegin(const Pattern &pattern, Operation *op) {}
+
+    /// Notify the listener that a pattern application finished with the
+    /// specified status. "success" indicates that the pattern was applied
+    /// successfully. "failure" indicates that the pattern could not be
+    /// applied. The pattern may have communicated the reason for the failure
+    /// with `notifyMatchFailure`.
+    virtual void notifyPatternEnd(const Pattern &pattern,
+                                  LogicalResult status) {}
+
+    /// Notify the listener that the pattern failed to match, and provide a
+    /// callback to populate a diagnostic with the reason why the failure
+    /// occurred. This method allows for derived listeners to optionally hook
+    /// into the reason why a rewrite failed, and display it to users.
     virtual void
     notifyMatchFailure(Location loc,
                        function_ref<void(Diagnostic &)> reasonCallback) {}
@@ -478,6 +489,15 @@ class RewriterBase : public OpBuilder {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationErased(op);
     }
+    void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyPatternBegin(pattern, op);
+    }
+    void notifyPatternEnd(const Pattern &pattern,
+                          LogicalResult status) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyPatternEnd(pattern, status);
+    }
     void notifyMatchFailure(
         Location loc,
         function_ref<void(Diagnostic &)> reasonCallback) override {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a5145246bc30c4..587fbe209b58af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1863,7 +1863,8 @@ class OperationLegalizer {
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
   OperationLegalizer(const ConversionTarget &targetInfo,
-                     const FrozenRewritePatternSet &patterns);
+                     const FrozenRewritePatternSet &patterns,
+                     const ConversionConfig &config);
 
   /// Returns true if the given operation is known to be illegal on the target.
   bool isIllegal(Operation *op) const;
@@ -1955,12 +1956,16 @@ class OperationLegalizer {
 
   /// The pattern applicator to use for conversions.
   PatternApplicator applicator;
+
+  /// Dialect conversion configuration.
+  const ConversionConfig &config;
 };
 } // namespace
 
 OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
-                                       const FrozenRewritePatternSet &patterns)
-    : target(targetInfo), applicator(patterns) {
+                                       const FrozenRewritePatternSet &patterns,
+                                       const ConversionConfig &config)
+    : target(targetInfo), applicator(patterns), config(config) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2105,7 +2110,10 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern &pattern) {
-    return canApplyPattern(op, pattern, rewriter);
+    bool canApply = canApplyPattern(op, pattern, rewriter);
+    if (canApply && config.listener)
+      config.listener->notifyPatternBegin(pattern, op);
+    return canApply;
   };
 
   // Functor that cleans up the rewriter state after a pattern failed to match.
@@ -2122,6 +2130,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
         rewriterImpl.config.notifyCallback(diag);
       }
     });
+    if (config.listener)
+      config.listener->notifyPatternEnd(pattern, failure());
     rewriterImpl.resetState(curState);
     appliedPatterns.erase(&pattern);
   };
@@ -2134,6 +2144,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     appliedPatterns.erase(&pattern);
     if (failed(result))
       rewriterImpl.resetState(curState);
+    if (config.listener)
+      config.listener->notifyPatternEnd(pattern, result);
     return result;
   };
 
@@ -2509,7 +2521,8 @@ struct OperationConverter {
                               const FrozenRewritePatternSet &patterns,
                               const ConversionConfig &config,
                               OpConversionMode mode)
-      : opLegalizer(target, patterns), config(config), mode(mode) {}
+      : config(config), opLegalizer(target, patterns, this->config),
+        mode(mode) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops);
@@ -2546,12 +2559,12 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       const DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
-  /// The legalizer to use when converting operations.
-  OperationLegalizer opLegalizer;
-
   /// Dialect conversion configuration.
   ConversionConfig config;
 
+  /// The legalizer to use when converting operations.
+  OperationLegalizer opLegalizer;
+
   /// The conversion mode to use when legalizing operations.
   OpConversionMode mode;
 };
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 51d2f5e01b7235..5fda6f87196f94 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -562,30 +562,39 @@ bool GreedyPatternRewriteDriver::processWorklist() {
     // Try to match one of the patterns. The rewriter is automatically
     // notified of any necessary changes, so there is nothing else to do
     // here.
-#ifndef NDEBUG
-    auto canApply = [&](const Pattern &pattern) {
-      LLVM_DEBUG({
-        logger.getOStream() << "\n";
-        logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
-                           << op->getName() << " -> (";
-        llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
-        logger.getOStream() << ")' {\n";
-        logger.indent();
-      });
-      return true;
-    };
-    auto onFailure = [&](const Pattern &pattern) {
-      LLVM_DEBUG(logResult("failure", "pattern failed to match"));
-    };
-    auto onSuccess = [&](const Pattern &pattern) {
-      LLVM_DEBUG(logResult("success", "pattern applied successfully"));
-      return success();
-    };
-#else
     function_ref<bool(const Pattern &)> canApply = {};
     function_ref<void(const Pattern &)> onFailure = {};
     function_ref<LogicalResult(const Pattern &)> onSuccess = {};
-#endif
+    bool debugBuild = false;
+#ifdef NDEBUG
+    debugBuild = true;
+#endif // NDEBUG
+    if (debugBuild || config.listener) {
+      canApply = [&](const Pattern &pattern) {
+        LLVM_DEBUG({
+          logger.getOStream() << "\n";
+          logger.startLine() << "* Pattern " << pattern.getDebugName() << " : '"
+                             << op->getName() << " -> (";
+          llvm::interleaveComma(pattern.getGeneratedOps(), logger.getOStream());
+          logger.getOStream() << ")' {\n";
+          logger.indent();
+        });
+        if (config.listener)
+          config.listener->notifyPatternBegin(pattern, op);
+        return true;
+      };
+      onFailure = [&](const Pattern &pattern) {
+        LLVM_DEBUG(logResult("failure", "pattern failed to match"));
+        if (config.listener)
+          config.listener->notifyPatternEnd(pattern, failure());
+      };
+      onSuccess = [&](const Pattern &pattern) {
+        LLVM_DEBUG(logResult("success", "pattern applied successfully"));
+        if (config.listener)
+          config.listener->notifyPatternEnd(pattern, success());
+        return success();
+      };
+    }
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
     if (config.scope) {
@@ -731,7 +740,7 @@ void GreedyPatternRewriteDriver::notifyMatchFailure(
   LLVM_DEBUG({
     Diagnostic diag(loc, DiagnosticSeverity::Remark);
     reasonCallback(diag);
-    logger.startLine() << "** Failure : " << diag.str() << "\n";
+    logger.startLine() << "** Match Failure : " << diag.str() << "\n";
   });
   if (config.listener)
     config.listener->notifyMatchFailure(loc, reasonCallback);



More information about the llvm-branch-commits mailing list