[Mlir-commits] [mlir] e6d90a0 - [mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns

Matthias Springer llvmlistbot at llvm.org
Wed May 24 07:24:19 PDT 2023


Author: Matthias Springer
Date: 2023-05-24T16:22:08+02:00
New Revision: e6d90a0d5e202166a9846f1845196086aa02f35e

URL: https://github.com/llvm/llvm-project/commit/e6d90a0d5e202166a9846f1845196086aa02f35e
DIFF: https://github.com/llvm/llvm-project/commit/e6d90a0d5e202166a9846f1845196086aa02f35e.diff

LOG: [mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns

Compute operation finger prints to detect incorrect API usage in RewritePatterns. Does not work for dialect conversion patterns.

Detect patterns that:
* Returned `failure` but changed the IR.
* Returned `success` but did not change the IR.
* Inserted/removed/modified ops, bypassing the rewriter. Not all cases are detected.

These new checks are quite expensive, so they are only enabled with `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON`. Failures manifest as fatal errors (`llvm::report_fatal_error`) or crashes (accessing deallocated memory). To get better debugging information, run `mlir-opt -debug` (to see which pattern is broken) with ASAN (to see where memory was deallocated).

Differential Revision: https://reviews.llvm.org/D144552

Added: 
    mlir/include/mlir/Config/mlir-config.h.cmake

Modified: 
    mlir/CMakeLists.txt
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index c9b0d53bc3e94..cd38836e21ec9 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -141,6 +141,10 @@ set(MLIR_INSTALL_AGGREGATE_OBJECTS 1 CACHE BOOL
 
 set(MLIR_BUILD_MLIR_C_DYLIB 0 CACHE BOOL "Builds libMLIR-C shared library.")
 
+configure_file(
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Config/mlir-config.h.cmake
+  ${MLIR_INCLUDE_DIR}/mlir/Config/mlir-config.h)
+
 #-------------------------------------------------------------------------------
 # Python Bindings Configuration
 # Requires:

diff  --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake
new file mode 100644
index 0000000000000..2bcc9bf9f6b09
--- /dev/null
+++ b/mlir/include/mlir/Config/mlir-config.h.cmake
@@ -0,0 +1,22 @@
+//===- mlir-config.h - MLIR configuration ------------------------*- C -*-===*//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+/* This file enumerates variables from the MLIR configuration so that they
+   can be in exported headers and won't override package specific directives.
+   This is a C header that can be included in the mlir-c headers. */
+
+#ifndef MLIR_CONFIG_H
+#define MLIR_CONFIG_H
+
+/* Enable expensive checks to detect invalid pattern API usage. Failed checks
+   manifest as fatal errors or invalid memory accesses (e.g., accessing
+   deallocated memory) that cause a crash. Running with ASAN is recommended for
+   easier debugging. */
+#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+#endif

diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 600ace4882734..4614649caae12 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -429,6 +429,38 @@ class RewriterBase : public OpBuilder {
     static bool classof(const OpBuilder::Listener *base);
   };
 
+  /// A listener that forwards all notifications to another listener. This
+  /// struct can be used as a base to create listener chains, so that multiple
+  /// listeners can be notified of IR changes.
+  struct ForwardingListener : public RewriterBase::Listener {
+    ForwardingListener(Listener *listener) : listener(listener) {}
+
+    void notifyOperationInserted(Operation *op) override {
+      listener->notifyOperationInserted(op);
+    }
+    void notifyBlockCreated(Block *block) override {
+      listener->notifyBlockCreated(block);
+    }
+    void notifyOperationModified(Operation *op) override {
+      listener->notifyOperationModified(op);
+    }
+    void notifyOperationReplaced(Operation *op,
+                                 ValueRange replacement) override {
+      listener->notifyOperationReplaced(op, replacement);
+    }
+    void notifyOperationRemoved(Operation *op) override {
+      listener->notifyOperationRemoved(op);
+    }
+    LogicalResult notifyMatchFailure(
+        Location loc,
+        function_ref<void(Diagnostic &)> reasonCallback) override {
+      return listener->notifyMatchFailure(loc, reasonCallback);
+    }
+
+  private:
+    Listener *listener;
+  };
+
   /// Move the blocks that belong to "region" before the given position in
   /// another region "parent". The two regions must be 
diff erent. The caller
   /// is responsible for creating or updating the operation transferring flow

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index f6e7fa18a7789..c05b6398496c2 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -11,6 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Config/mlir-config.h"
 #include "mlir/IR/Action.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -30,10 +32,108 @@ using namespace mlir;
 #define DEBUG_TYPE "greedy-rewriter"
 
 //===----------------------------------------------------------------------===//
-// GreedyPatternRewriteDriver
+// Debugging Infrastructure
 //===----------------------------------------------------------------------===//
 
 namespace {
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+/// A helper struct that stores finger prints of ops in order to detect broken
+/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
+/// using the rewriter API or if it returns an inconsistent return value.
+struct DebugFingerPrints : public RewriterBase::ForwardingListener {
+  DebugFingerPrints(RewriterBase::Listener *driver)
+      : RewriterBase::ForwardingListener(driver) {}
+
+  /// Compute finger prints of the given op and its nested ops.
+  void computeFingerPrints(Operation *topLevel) {
+    this->topLevel = topLevel;
+    this->topLevelFingerPrint.emplace(topLevel);
+    topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
+  }
+
+  /// Clear all finger prints.
+  void clear() {
+    topLevel = nullptr;
+    topLevelFingerPrint.reset();
+    fingerprints.clear();
+  }
+
+  void notifyRewriteSuccess() {
+    // Pattern application success => IR must have changed.
+    OperationFingerPrint afterFingerPrint(topLevel);
+    if (*topLevelFingerPrint == afterFingerPrint) {
+      // Note: Run "mlir-opt -debug" to see which pattern is broken.
+      llvm::report_fatal_error(
+          "pattern returned success but IR did not change");
+    }
+    for (const auto &it : fingerprints) {
+      // Skip top-level op, its finger print is never invalidated.
+      if (it.first == topLevel)
+        continue;
+      // Note: Finger print computation may crash when an op was erased
+      // without notifying the rewriter. (Run with ASAN to see where the op was
+      // erased; the op was probably erased directly, bypassing the rewriter
+      // API.) Finger print computation does may not crash if a new op was
+      // created at the same memory location. (But then the finger print should
+      // have changed.)
+      if (it.second != OperationFingerPrint(it.first)) {
+        // Note: Run "mlir-opt -debug" to see which pattern is broken.
+        llvm::report_fatal_error("operation finger print changed");
+      }
+    }
+  }
+
+  void notifyRewriteFailure() {
+    // Pattern application failure => IR must not have changed.
+    OperationFingerPrint afterFingerPrint(topLevel);
+    if (*topLevelFingerPrint != afterFingerPrint) {
+      // Note: Run "mlir-opt -debug" to see which pattern is broken.
+      llvm::report_fatal_error("pattern returned failure but IR did change");
+    }
+  }
+
+protected:
+  /// Invalidate the finger print of the given op, i.e., remove it from the map.
+  void invalidateFingerPrint(Operation *op) {
+    // Invalidate all finger prints until the top level.
+    while (op && op != topLevel) {
+      fingerprints.erase(op);
+      op = op->getParentOp();
+    }
+  }
+
+  void notifyOperationInserted(Operation *op) override {
+    RewriterBase::ForwardingListener::notifyOperationInserted(op);
+    invalidateFingerPrint(op->getParentOp());
+  }
+
+  void notifyOperationModified(Operation *op) override {
+    RewriterBase::ForwardingListener::notifyOperationModified(op);
+    invalidateFingerPrint(op);
+  }
+
+  void notifyOperationRemoved(Operation *op) override {
+    RewriterBase::ForwardingListener::notifyOperationRemoved(op);
+    op->walk([this](Operation *op) { invalidateFingerPrint(op); });
+  }
+
+  /// Operation finger prints to detect invalid pattern API usage. IR is checked
+  /// against these finger prints after pattern application to detect cases
+  /// where IR was modified directly, bypassing the rewriter API.
+  DenseMap<Operation *, OperationFingerPrint> fingerprints;
+
+  /// Top-level operation of the current greedy rewrite.
+  Operation *topLevel = nullptr;
+
+  /// Finger print of the top-level operation.
+  std::optional<OperationFingerPrint> topLevelFingerPrint;
+};
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+//===----------------------------------------------------------------------===//
+// GreedyPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
 /// This is a worklist-driven driver for the PatternMatcher, which repeatedly
 /// applies the locally optimal patterns.
 ///
@@ -122,21 +222,36 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
 
   /// The low-level pattern applicator.
   PatternApplicator matcher;
+
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  DebugFingerPrints debugFingerPrints;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 };
 } // namespace
 
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
-    : PatternRewriter(ctx), folder(ctx, this), config(config),
-      matcher(patterns) {
+    : PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+      // clang-format off
+      , debugFingerPrints(this)
+// clang-format on
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+{
   worklist.reserve(64);
 
   // Apply a simple cost model based solely on pattern benefit.
   matcher.applyDefaultCostModel();
 
   // Set up listener.
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  // Send IR notifications to the debug handler. This handler will then forward
+  // all notifications to this GreedyPatternRewriteDriver.
+  setListener(&debugFingerPrints);
+#else
   setListener(this);
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 }
 
 bool GreedyPatternRewriteDriver::processWorklist() {
@@ -231,15 +346,28 @@ bool GreedyPatternRewriteDriver::processWorklist() {
     function_ref<LogicalResult(const Pattern &)> onSuccess = {};
 #endif
 
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+    debugFingerPrints.computeFingerPrints(
+        /*topLevel=*/config.scope ? config.scope->getParentOp() : op);
+    auto clearFingerprints =
+        llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
     LogicalResult matchResult =
         matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
 
     if (succeeded(matchResult)) {
       LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+      debugFingerPrints.notifyRewriteSuccess();
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
       changed = true;
       ++numRewrites;
     } else {
       LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+      debugFingerPrints.notifyRewriteFailure();
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
     }
   }
 
@@ -247,6 +375,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 }
 
 void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
+  assert(op && "expected valid op");
   // Gather potential ancestors while looking for a "scope" parent region.
   SmallVector<Operation *, 8> ancestors;
   Region *region = nullptr;


        


More information about the Mlir-commits mailing list