[Mlir-commits] [mlir] e24b91b - Add tracing for pattern application in a ApplyPatternAction

Mehdi Amini llvmlistbot at llvm.org
Mon Apr 10 17:42:59 PDT 2023


Author: Mehdi Amini
Date: 2023-04-10T18:42:45-06:00
New Revision: e24b91b063946a41afaba6c1ef0d38777fd8b601

URL: https://github.com/llvm/llvm-project/commit/e24b91b063946a41afaba6c1ef0d38777fd8b601
DIFF: https://github.com/llvm/llvm-project/commit/e24b91b063946a41afaba6c1ef0d38777fd8b601.diff

LOG: Add tracing for pattern application in a ApplyPatternAction

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D144816

Added: 
    

Modified: 
    mlir/include/mlir/Rewrite/PatternApplicator.h
    mlir/lib/Rewrite/PatternApplicator.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h
index a2e2286ebda34..41ec95afcd9bd 100644
--- a/mlir/include/mlir/Rewrite/PatternApplicator.h
+++ b/mlir/include/mlir/Rewrite/PatternApplicator.h
@@ -16,6 +16,8 @@
 
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 
+#include "mlir/IR/Action.h"
+
 namespace mlir {
 class PatternRewriter;
 
@@ -23,6 +25,26 @@ namespace detail {
 class PDLByteCodeMutableState;
 } // namespace detail
 
+/// This is the type of Action that is dispatched when a pattern is applied.
+/// It captures the pattern to apply on top of the usual context.
+class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
+public:
+  using Base = tracing::ActionImpl<ApplyPatternAction>;
+  ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern)
+      : Base(irUnits), pattern(pattern) {}
+  static constexpr StringLiteral tag = "apply-pattern-action";
+  static constexpr StringLiteral desc =
+      "Encapsulate the application of rewrite patterns";
+
+  void print(raw_ostream &os) const override {
+    os << "`" << tag << "`\n"
+       << " pattern: " << pattern.getDebugName() << '\n';
+  }
+
+private:
+  const Pattern &pattern;
+};
+
 /// This class manages the application of a group of rewrite patterns, with a
 /// user-provided cost model.
 class PatternApplicator {

diff  --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 499a8506bc606..08d6ee618ac69 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -185,35 +185,47 @@ LogicalResult PatternApplicator::matchAndRewrite(
     // Try to match and rewrite this pattern. The patterns are sorted by
     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
     // match has already been performed, we just need to rewrite.
-    rewriter.setInsertionPoint(op);
+    bool matched = false;
+    op->getContext()->executeAction<ApplyPatternAction>(
+        [&]() {
+          rewriter.setInsertionPoint(op);
 #ifndef NDEBUG
-    // Operation `op` may be invalidated after applying the rewrite pattern.
-    Operation *dumpRootOp = getDumpRootOp(op);
+          // Operation `op` may be invalidated after applying the rewrite
+          // pattern.
+          Operation *dumpRootOp = getDumpRootOp(op);
 #endif
-    if (pdlMatch) {
-      result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
-    } else {
-      LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
-                              << bestPattern->getDebugName() << "\"\n");
-
-      const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
-      result = pattern->matchAndRewrite(op, rewriter);
-
-      LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
-                              << "\" result " << succeeded(result) << "\n");
-    }
-
-    // Process the result of the pattern application.
-    if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
-      result = failure();
-    if (succeeded(result)) {
-      LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+          if (pdlMatch) {
+            result =
+                bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+          } else {
+            LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
+                                    << bestPattern->getDebugName() << "\"\n");
+
+            const auto *pattern =
+                static_cast<const RewritePattern *>(bestPattern);
+            result = pattern->matchAndRewrite(op, rewriter);
+
+            LLVM_DEBUG(llvm::dbgs()
+                       << "\"" << bestPattern->getDebugName() << "\" result "
+                       << succeeded(result) << "\n");
+          }
+
+          // Process the result of the pattern application.
+          if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
+            result = failure();
+          if (succeeded(result)) {
+            LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+            matched = true;
+            return;
+          }
+
+          // Perform any necessary cleanups.
+          if (onFailure)
+            onFailure(*bestPattern);
+        },
+        {op}, *bestPattern);
+    if (matched)
       break;
-    }
-
-    // Perform any necessary cleanups.
-    if (onFailure)
-      onFailure(*bestPattern);
   } while (true);
 
   if (mutableByteCodeState)


        


More information about the Mlir-commits mailing list