[Mlir-commits] [mlir] [mlir][Transforms] `GreedyPatternRewriteDriver`: Hash ops separately (PR #78312)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 16 08:48:25 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The greedy pattern rewrite driver has multiple "expensive checks" to detect invalid rewrite pattern API usage. As part of these checks, it computes fingerprints for every op that is in scope, and compares the fingerprints before and after an attempted pattern application.

Until now, each computed fingerprint took into account all nested operations. That is quite expensive because it walks the entire IR subtree. It is also redudant in the expensive checks because we already compute a fingerprint for every op.

This commit significantly improves the running time of the "expensive checks" in the greedy pattern rewrite driver.

Depends on #<!-- -->78306. Review only the top commit.

---
Full diff: https://github.com/llvm/llvm-project/pull/78312.diff


6 Files Affected:

- (modified) mlir/include/mlir/IR/OperationSupport.h (+2-2) 
- (modified) mlir/include/mlir/IR/PatternMatch.h (+8) 
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+1-1) 
- (modified) mlir/lib/IR/OperationSupport.cpp (+12-4) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+9-2) 
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+24-9) 


``````````diff
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 478504b8f76e7ae..f2aa6cee840308e 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -1316,10 +1316,10 @@ LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
 //===----------------------------------------------------------------------===//
 
 /// A unique fingerprint for a specific operation, and all of it's internal
-/// operations.
+/// operations (if `includeNested` is set).
 class OperationFingerPrint {
 public:
-  OperationFingerPrint(Operation *topOp);
+  OperationFingerPrint(Operation *topOp, bool includeNested = true);
   OperationFingerPrint(const OperationFingerPrint &) = default;
   OperationFingerPrint &operator=(const OperationFingerPrint &) = default;
 
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 9b4fa65bff49e12..6c5d317a79b7ab5 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -402,6 +402,10 @@ class RewriterBase : public OpBuilder {
     Listener()
         : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {}
 
+    /// Notify the listener that the specified block is about to be erased.
+    /// At this point, the block has zero uses.
+    virtual void notifyBlockRemoved(Block *block) {}
+
     /// Notify the listener that the specified operation was modified in-place.
     virtual void notifyOperationModified(Operation *op) {}
 
@@ -452,6 +456,10 @@ class RewriterBase : public OpBuilder {
     void notifyBlockCreated(Block *block) override {
       listener->notifyBlockCreated(block);
     }
+    void notifyBlockRemoved(Block *block) override {
+      if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
+        rewriteListener->notifyBlockRemoved(block);
+    }
     void notifyOperationModified(Operation *op) override {
       if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
         rewriteListener->notifyOperationModified(op);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index abb65bc3c38f221..b63baf330c86457 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1114,7 +1114,7 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
   scf::ForOp newLoop = rewriter.create<scf::ForOp>(
       loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
       operands);
-  newLoop.getBody()->erase();
+  rewriter.eraseBlock(newLoop.getBody());
 
   newLoop.getRegion().getBlocks().splice(
       newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index e10cd748e03ba58..0ab3538de259cbe 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -911,11 +911,12 @@ static void addDataToHash(llvm::SHA1 &hasher, const T &data) {
       ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
 }
 
-OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
+OperationFingerPrint::OperationFingerPrint(Operation *topOp,
+                                           bool includeNested) {
   llvm::SHA1 hasher;
 
-  // Hash each of the operations based upon their mutable bits:
-  topOp->walk([&](Operation *op) {
+  // Helper function that hashes an operation based on its mutable bits:
+  auto addOperationToHash = [&](Operation *op) {
     //   - Operation pointer
     addDataToHash(hasher, op);
     //   - Parent operation pointer (to take into account the nesting structure)
@@ -944,6 +945,13 @@ OperationFingerPrint::OperationFingerPrint(Operation *topOp) {
     //   - Result types
     for (Type t : op->getResultTypes())
       addDataToHash(hasher, t);
-  });
+  };
+
+  if (includeNested) {
+    topOp->walk(addOperationToHash);
+  } else {
+    addOperationToHash(topOp);
+  }
+
   hash = hasher.result();
 }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5e788cdb4897d3d..de226f40cb3e9ad 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -244,7 +244,7 @@ void RewriterBase::eraseOp(Operation *op) {
           for (BlockArgument bbArg : b->getArguments())
             bbArg.dropAllUses();
           b->dropAllUses();
-          b->erase();
+          eraseBlock(b);
         }
       }
     }
@@ -256,10 +256,17 @@ void RewriterBase::eraseOp(Operation *op) {
 }
 
 void RewriterBase::eraseBlock(Block *block) {
+  assert(block->use_empty() && "expected 'block' to have no uses");
+
   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
     assert(op.use_empty() && "expected 'op' to have no uses");
     eraseOp(&op);
   }
+
+  // Notify the listener that the block is about to be removed.
+  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
+    rewriteListener->notifyBlockRemoved(block);
+
   block->erase();
 }
 
@@ -311,7 +318,7 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
   // Move operations from the source block to the dest block and erase the
   // source block.
   dest->getOperations().splice(before, source->getOperations());
-  source->erase();
+  eraseBlock(source);
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Operation *op,
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 36d63d62bf10fc2..090c833b115b39f 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -60,7 +60,9 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
   void computeFingerPrints(Operation *topLevel) {
     this->topLevel = topLevel;
     this->topLevelFingerPrint.emplace(topLevel);
-    topLevel->walk([&](Operation *op) { fingerprints.try_emplace(op, op); });
+    topLevel->walk([&](Operation *op) {
+      fingerprints.try_emplace(op, op, /*includeNested=*/false);
+    });
   }
 
   /// Clear all finger prints.
@@ -95,7 +97,8 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
       // API.) Finger print computation does may not crash if a new op was
       // created at the same memory location. (But then the finger print should
       // have changed.)
-      if (it.second != OperationFingerPrint(it.first)) {
+      if (it.second !=
+          OperationFingerPrint(it.first, /*includeNested=*/false)) {
         // Note: Run "mlir-opt -debug" to see which pattern is broken.
         llvm::report_fatal_error("operation finger print changed");
       }
@@ -125,14 +128,18 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
 
 protected:
   /// Invalidate the finger print of the given op, i.e., remove it from the map.
-  void invalidateFingerPrint(Operation *op) {
-    // Invalidate all finger prints until the top level.
-    while (op && op != topLevel) {
-      fingerprints.erase(op);
-      op = op->getParentOp();
-    }
-  }
+  void invalidateFingerPrint(Operation *op) { fingerprints.erase(op); }
 
+  void notifyBlockRemoved(Block *block) override {
+    RewriterBase::ForwardingListener::notifyBlockRemoved(block);
+
+    // The block structure (number of blocks, types of block arguments, etc.)
+    // is part of the fingerprint of the parent op.
+    // TODO: The parent op fingerprint should also be invalidated when modifying
+    // the block arguments of a block, but we do not have a
+    // `notifyBlockModified` callback yet.
+    invalidateFingerPrint(block->getParentOp());
+  }
   void notifyOperationInserted(Operation *op) override {
     RewriterBase::ForwardingListener::notifyOperationInserted(op);
     invalidateFingerPrint(op->getParentOp());
@@ -373,6 +380,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the given block was created.
   void notifyBlockCreated(Block *block) override;
 
+  /// Notify the driver that the given block is about to be removed.
+  void notifyBlockRemoved(Block *block) override;
+
   /// For debugging only: Notify the driver of a pattern match failure.
   LogicalResult
   notifyMatchFailure(Location loc,
@@ -633,6 +643,11 @@ void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
     config.listener->notifyBlockCreated(block);
 }
 
+void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
+  if (config.listener)
+    config.listener->notifyBlockRemoved(block);
+}
+
 void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op

``````````

</details>


https://github.com/llvm/llvm-project/pull/78312


More information about the Mlir-commits mailing list