[Mlir-commits] [mlir] e6d90a0 - [mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns
Matthias Springer
llvmlistbot at llvm.org
Wed May 24 07:24:19 PDT 2023
Author: Matthias Springer
Date: 2023-05-24T16:22:08+02:00
New Revision: e6d90a0d5e202166a9846f1845196086aa02f35e
URL: https://github.com/llvm/llvm-project/commit/e6d90a0d5e202166a9846f1845196086aa02f35e
DIFF: https://github.com/llvm/llvm-project/commit/e6d90a0d5e202166a9846f1845196086aa02f35e.diff
LOG: [mlir][Transforms] GreedyPatternRewriteDriver debugging: Detect faulty patterns
Compute operation finger prints to detect incorrect API usage in RewritePatterns. Does not work for dialect conversion patterns.
Detect patterns that:
* Returned `failure` but changed the IR.
* Returned `success` but did not change the IR.
* Inserted/removed/modified ops, bypassing the rewriter. Not all cases are detected.
These new checks are quite expensive, so they are only enabled with `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON`. Failures manifest as fatal errors (`llvm::report_fatal_error`) or crashes (accessing deallocated memory). To get better debugging information, run `mlir-opt -debug` (to see which pattern is broken) with ASAN (to see where memory was deallocated).
Differential Revision: https://reviews.llvm.org/D144552
Added:
mlir/include/mlir/Config/mlir-config.h.cmake
Modified:
mlir/CMakeLists.txt
mlir/include/mlir/IR/PatternMatch.h
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index c9b0d53bc3e94..cd38836e21ec9 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -141,6 +141,10 @@ set(MLIR_INSTALL_AGGREGATE_OBJECTS 1 CACHE BOOL
set(MLIR_BUILD_MLIR_C_DYLIB 0 CACHE BOOL "Builds libMLIR-C shared library.")
+configure_file(
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Config/mlir-config.h.cmake
+ ${MLIR_INCLUDE_DIR}/mlir/Config/mlir-config.h)
+
#-------------------------------------------------------------------------------
# Python Bindings Configuration
# Requires:
diff --git a/mlir/include/mlir/Config/mlir-config.h.cmake b/mlir/include/mlir/Config/mlir-config.h.cmake
new file mode 100644
index 0000000000000..2bcc9bf9f6b09
--- /dev/null
+++ b/mlir/include/mlir/Config/mlir-config.h.cmake
@@ -0,0 +1,22 @@
+//===- mlir-config.h - MLIR configuration ------------------------*- C -*-===*//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+/* This file enumerates variables from the MLIR configuration so that they
+ can be in exported headers and won't override package specific directives.
+ This is a C header that can be included in the mlir-c headers. */
+
+#ifndef MLIR_CONFIG_H
+#define MLIR_CONFIG_H
+
+/* Enable expensive checks to detect invalid pattern API usage. Failed checks
+ manifest as fatal errors or invalid memory accesses (e.g., accessing
+ deallocated memory) that cause a crash. Running with ASAN is recommended for
+ easier debugging. */
+#cmakedefine01 MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+#endif
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 600ace4882734..4614649caae12 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -429,6 +429,38 @@ class RewriterBase : public OpBuilder {
static bool classof(const OpBuilder::Listener *base);
};
+ /// A listener that forwards all notifications to another listener. This
+ /// struct can be used as a base to create listener chains, so that multiple
+ /// listeners can be notified of IR changes.
+ struct ForwardingListener : public RewriterBase::Listener {
+ ForwardingListener(Listener *listener) : listener(listener) {}
+
+ void notifyOperationInserted(Operation *op) override {
+ listener->notifyOperationInserted(op);
+ }
+ void notifyBlockCreated(Block *block) override {
+ listener->notifyBlockCreated(block);
+ }
+ void notifyOperationModified(Operation *op) override {
+ listener->notifyOperationModified(op);
+ }
+ void notifyOperationReplaced(Operation *op,
+ ValueRange replacement) override {
+ listener->notifyOperationReplaced(op, replacement);
+ }
+ void notifyOperationRemoved(Operation *op) override {
+ listener->notifyOperationRemoved(op);
+ }
+ LogicalResult notifyMatchFailure(
+ Location loc,
+ function_ref<void(Diagnostic &)> reasonCallback) override {
+ return listener->notifyMatchFailure(loc, reasonCallback);
+ }
+
+ private:
+ Listener *listener;
+ };
+
/// Move the blocks that belong to "region" before the given position in
/// another region "parent". The two regions must be
diff erent. The caller
/// is responsible for creating or updating the operation transferring flow
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index f6e7fa18a7789..c05b6398496c2 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -30,10 +32,108 @@ using namespace mlir;
#define DEBUG_TYPE "greedy-rewriter"
//===----------------------------------------------------------------------===//
-// GreedyPatternRewriteDriver
+// Debugging Infrastructure
//===----------------------------------------------------------------------===//
namespace {
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+/// A helper struct that stores finger prints of ops in order to detect broken
+/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
+/// using the rewriter API or if it returns an inconsistent return value.
+struct DebugFingerPrints : public RewriterBase::ForwardingListener {
+ DebugFingerPrints(RewriterBase::Listener *driver)
+ : RewriterBase::ForwardingListener(driver) {}
+
+ /// Compute finger prints of the given op and its nested ops.
+ void computeFingerPrints(Operation *topLevel) {
+ this->topLevel = topLevel;
+ this->topLevelFingerPrint.emplace(topLevel);
+ topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
+ }
+
+ /// Clear all finger prints.
+ void clear() {
+ topLevel = nullptr;
+ topLevelFingerPrint.reset();
+ fingerprints.clear();
+ }
+
+ void notifyRewriteSuccess() {
+ // Pattern application success => IR must have changed.
+ OperationFingerPrint afterFingerPrint(topLevel);
+ if (*topLevelFingerPrint == afterFingerPrint) {
+ // Note: Run "mlir-opt -debug" to see which pattern is broken.
+ llvm::report_fatal_error(
+ "pattern returned success but IR did not change");
+ }
+ for (const auto &it : fingerprints) {
+ // Skip top-level op, its finger print is never invalidated.
+ if (it.first == topLevel)
+ continue;
+ // Note: Finger print computation may crash when an op was erased
+ // without notifying the rewriter. (Run with ASAN to see where the op was
+ // erased; the op was probably erased directly, bypassing the rewriter
+ // API.) Finger print computation does may not crash if a new op was
+ // created at the same memory location. (But then the finger print should
+ // have changed.)
+ if (it.second != OperationFingerPrint(it.first)) {
+ // Note: Run "mlir-opt -debug" to see which pattern is broken.
+ llvm::report_fatal_error("operation finger print changed");
+ }
+ }
+ }
+
+ void notifyRewriteFailure() {
+ // Pattern application failure => IR must not have changed.
+ OperationFingerPrint afterFingerPrint(topLevel);
+ if (*topLevelFingerPrint != afterFingerPrint) {
+ // Note: Run "mlir-opt -debug" to see which pattern is broken.
+ llvm::report_fatal_error("pattern returned failure but IR did change");
+ }
+ }
+
+protected:
+ /// Invalidate the finger print of the given op, i.e., remove it from the map.
+ void invalidateFingerPrint(Operation *op) {
+ // Invalidate all finger prints until the top level.
+ while (op && op != topLevel) {
+ fingerprints.erase(op);
+ op = op->getParentOp();
+ }
+ }
+
+ void notifyOperationInserted(Operation *op) override {
+ RewriterBase::ForwardingListener::notifyOperationInserted(op);
+ invalidateFingerPrint(op->getParentOp());
+ }
+
+ void notifyOperationModified(Operation *op) override {
+ RewriterBase::ForwardingListener::notifyOperationModified(op);
+ invalidateFingerPrint(op);
+ }
+
+ void notifyOperationRemoved(Operation *op) override {
+ RewriterBase::ForwardingListener::notifyOperationRemoved(op);
+ op->walk([this](Operation *op) { invalidateFingerPrint(op); });
+ }
+
+ /// Operation finger prints to detect invalid pattern API usage. IR is checked
+ /// against these finger prints after pattern application to detect cases
+ /// where IR was modified directly, bypassing the rewriter API.
+ DenseMap<Operation *, OperationFingerPrint> fingerprints;
+
+ /// Top-level operation of the current greedy rewrite.
+ Operation *topLevel = nullptr;
+
+ /// Finger print of the top-level operation.
+ std::optional<OperationFingerPrint> topLevelFingerPrint;
+};
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
+//===----------------------------------------------------------------------===//
+// GreedyPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
/// This is a worklist-driven driver for the PatternMatcher, which repeatedly
/// applies the locally optimal patterns.
///
@@ -122,21 +222,36 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
/// The low-level pattern applicator.
PatternApplicator matcher;
+
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ DebugFingerPrints debugFingerPrints;
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
};
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : PatternRewriter(ctx), folder(ctx, this), config(config),
- matcher(patterns) {
+ : PatternRewriter(ctx), folder(ctx, this), config(config), matcher(patterns)
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // clang-format off
+ , debugFingerPrints(this)
+// clang-format on
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+{
worklist.reserve(64);
// Apply a simple cost model based solely on pattern benefit.
matcher.applyDefaultCostModel();
// Set up listener.
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // Send IR notifications to the debug handler. This handler will then forward
+ // all notifications to this GreedyPatternRewriteDriver.
+ setListener(&debugFingerPrints);
+#else
setListener(this);
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
bool GreedyPatternRewriteDriver::processWorklist() {
@@ -231,15 +346,28 @@ bool GreedyPatternRewriteDriver::processWorklist() {
function_ref<LogicalResult(const Pattern &)> onSuccess = {};
#endif
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ debugFingerPrints.computeFingerPrints(
+ /*topLevel=*/config.scope ? config.scope->getParentOp() : op);
+ auto clearFingerprints =
+ llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+
LogicalResult matchResult =
matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
if (succeeded(matchResult)) {
LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ debugFingerPrints.notifyRewriteSuccess();
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
changed = true;
++numRewrites;
} else {
LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
+#ifdef MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ debugFingerPrints.notifyRewriteFailure();
+#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
}
@@ -247,6 +375,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
}
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
+ assert(op && "expected valid op");
// Gather potential ancestors while looking for a "scope" parent region.
SmallVector<Operation *, 8> ancestors;
Region *region = nullptr;
More information about the Mlir-commits
mailing list