[Mlir-commits] [mlir] a02a0e8 - [mlir][Transforms] `GreedyPatternRewriteDriver`: Better expensive checks encapsulation (#78175)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 15 23:55:29 PST 2024


Author: Matthias Springer
Date: 2024-01-16T08:55:25+01:00
New Revision: a02a0e806fab01f4cf4307443cdaed76a2488752

URL: https://github.com/llvm/llvm-project/commit/a02a0e806fab01f4cf4307443cdaed76a2488752
DIFF: https://github.com/llvm/llvm-project/commit/a02a0e806fab01f4cf4307443cdaed76a2488752.diff

LOG: [mlir][Transforms] `GreedyPatternRewriteDriver`: Better expensive checks encapsulation (#78175)

This change moves most IR verification logic (which is part of the
expensive checks) into `DebugFingerPrints` and renames the struct to
`ExpensiveChecks`. This isolates the debugging logic better from the
remaining code.

This commit also removes a redundant check: the IR is no longer verified
after a failed pattern application. We already assert that the IR did
not change. (We know that the IR was valid before the attempted pattern
application.)

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index d31408d240ebd5..36d63d62bf10fc 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -43,12 +43,18 @@ namespace {
 //===----------------------------------------------------------------------===//
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-/// A helper struct that stores finger prints of ops in order to detect broken
-/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
-/// using the rewriter API or if it returns an inconsistent return value.
-struct DebugFingerPrints : public RewriterBase::ForwardingListener {
-  DebugFingerPrints(RewriterBase::Listener *driver)
-      : RewriterBase::ForwardingListener(driver) {}
+/// A helper struct that performs various "expensive checks" to detect broken
+/// rewrite patterns use the rewriter API incorrectly. A rewrite pattern is
+/// broken if:
+/// * IR does not verify after pattern application / folding.
+/// * Pattern returns "failure" but the IR has changed.
+/// * Pattern returns "success" but the IR has not changed.
+///
+/// This struct stores finger prints of ops to determine whether the IR has
+/// changed or not.
+struct ExpensiveChecks : public RewriterBase::ForwardingListener {
+  ExpensiveChecks(RewriterBase::Listener *driver, Operation *topLevel)
+      : RewriterBase::ForwardingListener(driver), topLevel(topLevel) {}
 
   /// Compute finger prints of the given op and its nested ops.
   void computeFingerPrints(Operation *topLevel) {
@@ -65,6 +71,13 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
   }
 
   void notifyRewriteSuccess() {
+    if (!topLevel)
+      return;
+
+    // Make sure that the IR still verifies.
+    if (failed(verify(topLevel)))
+      llvm::report_fatal_error("IR failed to verify after pattern application");
+
     // Pattern application success => IR must have changed.
     OperationFingerPrint afterFingerPrint(topLevel);
     if (*topLevelFingerPrint == afterFingerPrint) {
@@ -90,6 +103,9 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
   }
 
   void notifyRewriteFailure() {
+    if (!topLevel)
+      return;
+
     // Pattern application failure => IR must not have changed.
     OperationFingerPrint afterFingerPrint(topLevel);
     if (*topLevelFingerPrint != afterFingerPrint) {
@@ -98,6 +114,15 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
     }
   }
 
+  void notifyFoldingSuccess() {
+    if (!topLevel)
+      return;
+
+    // Make sure that the IR still verifies.
+    if (failed(verify(topLevel)))
+      llvm::report_fatal_error("IR failed to verify after folding");
+  }
+
 protected:
   /// Invalidate the finger print of the given op, i.e., remove it from the map.
   void invalidateFingerPrint(Operation *op) {
@@ -362,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   PatternApplicator matcher;
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-  DebugFingerPrints debugFingerPrints;
+  ExpensiveChecks expensiveChecks;
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 };
 } // namespace
@@ -373,7 +398,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     : PatternRewriter(ctx), config(config), matcher(patterns)
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
       // clang-format off
-      , debugFingerPrints(this)
+      , expensiveChecks(
+          /*driver=*/this,
+          /*topLevel=*/config.scope ? config.scope->getParentOp() : nullptr)
 // clang-format on
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 {
@@ -384,7 +411,7 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // Send IR notifications to the debug handler. This handler will then forward
   // all notifications to this GreedyPatternRewriteDriver.
-  setListener(&debugFingerPrints);
+  setListener(&expensiveChecks);
 #else
   setListener(this);
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -458,8 +485,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           changed = true;
           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-          if (config.scope && failed(verify(config.scope->getParentOp())))
-            llvm::report_fatal_error("IR failed to verify after folding");
+          expensiveChecks.notifyFoldingSuccess();
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
           continue;
         }
@@ -513,8 +539,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           changed = true;
           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-          if (config.scope && failed(verify(config.scope->getParentOp())))
-            llvm::report_fatal_error("IR failed to verify after folding");
+          expensiveChecks.notifyFoldingSuccess();
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
           continue;
         }
@@ -551,33 +576,26 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
     if (config.scope) {
-      debugFingerPrints.computeFingerPrints(config.scope->getParentOp());
+      expensiveChecks.computeFingerPrints(config.scope->getParentOp());
     }
     auto clearFingerprints =
-        llvm::make_scope_exit([&]() { debugFingerPrints.clear(); });
+        llvm::make_scope_exit([&]() { expensiveChecks.clear(); });
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
     LogicalResult matchResult =
         matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
 
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-    if (config.scope && failed(verify(config.scope->getParentOp())))
-      llvm::report_fatal_error("IR failed to verify after pattern application");
-#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-
     if (succeeded(matchResult)) {
       LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-      if (config.scope)
-        debugFingerPrints.notifyRewriteSuccess();
+      expensiveChecks.notifyRewriteSuccess();
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
       changed = true;
       ++numRewrites;
     } else {
       LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match"));
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-      if (config.scope)
-        debugFingerPrints.notifyRewriteFailure();
+      expensiveChecks.notifyRewriteFailure();
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
     }
   }


        


More information about the Mlir-commits mailing list