[Mlir-commits] [mlir] e24b91b - Add tracing for pattern application in a ApplyPatternAction
Mehdi Amini
llvmlistbot at llvm.org
Mon Apr 10 17:42:59 PDT 2023
Author: Mehdi Amini
Date: 2023-04-10T18:42:45-06:00
New Revision: e24b91b063946a41afaba6c1ef0d38777fd8b601
URL: https://github.com/llvm/llvm-project/commit/e24b91b063946a41afaba6c1ef0d38777fd8b601
DIFF: https://github.com/llvm/llvm-project/commit/e24b91b063946a41afaba6c1ef0d38777fd8b601.diff
LOG: Add tracing for pattern application in a ApplyPatternAction
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D144816
Added:
Modified:
mlir/include/mlir/Rewrite/PatternApplicator.h
mlir/lib/Rewrite/PatternApplicator.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h
index a2e2286ebda34..41ec95afcd9bd 100644
--- a/mlir/include/mlir/Rewrite/PatternApplicator.h
+++ b/mlir/include/mlir/Rewrite/PatternApplicator.h
@@ -16,6 +16,8 @@
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/IR/Action.h"
+
namespace mlir {
class PatternRewriter;
@@ -23,6 +25,26 @@ namespace detail {
class PDLByteCodeMutableState;
} // namespace detail
+/// This is the type of Action that is dispatched when a pattern is applied.
+/// It captures the pattern to apply on top of the usual context.
+class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
+public:
+ using Base = tracing::ActionImpl<ApplyPatternAction>;
+ ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern)
+ : Base(irUnits), pattern(pattern) {}
+ static constexpr StringLiteral tag = "apply-pattern-action";
+ static constexpr StringLiteral desc =
+ "Encapsulate the application of rewrite patterns";
+
+ void print(raw_ostream &os) const override {
+ os << "`" << tag << "`\n"
+ << " pattern: " << pattern.getDebugName() << '\n';
+ }
+
+private:
+ const Pattern &pattern;
+};
+
/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index 499a8506bc606..08d6ee618ac69 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -185,35 +185,47 @@ LogicalResult PatternApplicator::matchAndRewrite(
// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite. For PDL patterns, the
// match has already been performed, we just need to rewrite.
- rewriter.setInsertionPoint(op);
+ bool matched = false;
+ op->getContext()->executeAction<ApplyPatternAction>(
+ [&]() {
+ rewriter.setInsertionPoint(op);
#ifndef NDEBUG
- // Operation `op` may be invalidated after applying the rewrite pattern.
- Operation *dumpRootOp = getDumpRootOp(op);
+ // Operation `op` may be invalidated after applying the rewrite
+ // pattern.
+ Operation *dumpRootOp = getDumpRootOp(op);
#endif
- if (pdlMatch) {
- result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
- } else {
- LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
- << bestPattern->getDebugName() << "\"\n");
-
- const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
- result = pattern->matchAndRewrite(op, rewriter);
-
- LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
- << "\" result " << succeeded(result) << "\n");
- }
-
- // Process the result of the pattern application.
- if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
- result = failure();
- if (succeeded(result)) {
- LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+ if (pdlMatch) {
+ result =
+ bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
+ } else {
+ LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
+ << bestPattern->getDebugName() << "\"\n");
+
+ const auto *pattern =
+ static_cast<const RewritePattern *>(bestPattern);
+ result = pattern->matchAndRewrite(op, rewriter);
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "\"" << bestPattern->getDebugName() << "\" result "
+ << succeeded(result) << "\n");
+ }
+
+ // Process the result of the pattern application.
+ if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
+ result = failure();
+ if (succeeded(result)) {
+ LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
+ matched = true;
+ return;
+ }
+
+ // Perform any necessary cleanups.
+ if (onFailure)
+ onFailure(*bestPattern);
+ },
+ {op}, *bestPattern);
+ if (matched)
break;
- }
-
- // Perform any necessary cleanups.
- if (onFailure)
- onFailure(*bestPattern);
} while (true);
if (mutableByteCodeState)
More information about the Mlir-commits
mailing list