[Mlir-commits] [mlir] [mlir] Add fast walk-based pattern rewrite driver (PR #113825)
Mehdi Amini
llvmlistbot at llvm.org
Sun Oct 27 12:35:03 PDT 2024
================
@@ -0,0 +1,86 @@
+//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements mlir::walkAndApplyPatterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "walk-rewriter"
+
+namespace mlir {
+
+namespace {
+// Forwarding listener to guard against unsupported erasures. Because we use
+// walk-based pattern application, erasing the op from the *next* iteration
+// (e.g., a user of the visited op) is not valid.
+struct ErasedOpsListener final : RewriterBase::ForwardingListener {
+ using RewriterBase::ForwardingListener::ForwardingListener;
+
+ void notifyOperationErased(Operation *op) override {
+ if (op != visitedOp)
+ llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
+ "erasure is only supported for matched ops");
+
+ ForwardingListener::notifyOperationErased(op);
+ }
+
+ Operation *visitedOp = nullptr;
+};
+} // namespace
+
+void walkAndApplyPatterns(Operation *op,
+ const FrozenRewritePatternSet &patterns,
+ RewriterBase::Listener *listener) {
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ if (failed(verify(op)))
+ llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+ PatternRewriter rewriter(op->getContext());
+ ErasedOpsListener erasedListener(listener);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ rewriter.setListener(&erasedListener);
+#else
+ (void)erasedListener;
+ rewriter.setListener(listener);
+#endif
+
+ PatternApplicator applicator(patterns);
+ applicator.applyDefaultCostModel();
+
+ op->walk([&](Operation *visitedOp) {
+ if (visitedOp == op)
+ return;
+
+ LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
+ visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
+ llvm::dbgs() << "\n";);
+ erasedListener.visitedOp = visitedOp;
+ if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
----------------
joker-eph wrote:
Can we wrap this in an action like the greedy driver?
https://github.com/llvm/llvm-project/pull/113825
More information about the Mlir-commits
mailing list