[Mlir-commits] [mlir] [mlir] Add fast walk-based pattern rewrite driver (PR #113825)

Mehdi Amini llvmlistbot at llvm.org
Sun Oct 27 12:35:03 PDT 2024


================
@@ -0,0 +1,86 @@
+//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements mlir::walkAndApplyPatterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#define DEBUG_TYPE "walk-rewriter"
+
+namespace mlir {
+
+namespace {
+// Forwarding listener to guard against unsupported erasures. Because we use
+// walk-based pattern application, erasing the op from the *next* iteration
+// (e.g., a user of the visited op) is not valid.
+struct ErasedOpsListener final : RewriterBase::ForwardingListener {
+  using RewriterBase::ForwardingListener::ForwardingListener;
+
+  void notifyOperationErased(Operation *op) override {
+    if (op != visitedOp)
+      llvm::report_fatal_error("unsupported op erased in WalkPatternRewriter; "
+                               "erasure is only supported for matched ops");
+
+    ForwardingListener::notifyOperationErased(op);
+  }
+
+  Operation *visitedOp = nullptr;
+};
+} // namespace
+
+void walkAndApplyPatterns(Operation *op,
+                          const FrozenRewritePatternSet &patterns,
+                          RewriterBase::Listener *listener) {
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  if (failed(verify(op)))
+    llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+  PatternRewriter rewriter(op->getContext());
+  ErasedOpsListener erasedListener(listener);
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  rewriter.setListener(&erasedListener);
+#else
+  (void)erasedListener;
+  rewriter.setListener(listener);
+#endif
+
+  PatternApplicator applicator(patterns);
+  applicator.applyDefaultCostModel();
+
+  op->walk([&](Operation *visitedOp) {
+    if (visitedOp == op)
+      return;
+
+    LLVM_DEBUG(llvm::dbgs() << "Visiting op: ";
+               visitedOp->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
+               llvm::dbgs() << "\n";);
+    erasedListener.visitedOp = visitedOp;
+    if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
----------------
joker-eph wrote:

Can we wrap this in an action like the greedy driver?

https://github.com/llvm/llvm-project/pull/113825


More information about the Mlir-commits mailing list