[Mlir-commits] [mlir] 6b3e000 - [mlir][Transforms][NFC] `GreedyPatternRewriteDriver`: Use composition instead of inheritance (#92785)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jun 8 01:26:21 PDT 2024


Author: Matthias Springer
Date: 2024-06-08T10:26:17+02:00
New Revision: 6b3e0002dfe0029487fc2f8f11f5d5fdc07a5e11

URL: https://github.com/llvm/llvm-project/commit/6b3e0002dfe0029487fc2f8f11f5d5fdc07a5e11
DIFF: https://github.com/llvm/llvm-project/commit/6b3e0002dfe0029487fc2f8f11f5d5fdc07a5e11.diff

LOG: [mlir][Transforms][NFC] `GreedyPatternRewriteDriver`: Use composition instead of inheritance (#92785)

This commit simplifies the design of the `GreedyPatternRewriterDriver`
class. This class used to inherit from both `PatternRewriter` and
`RewriterBase::Listener` and then attached itself as a listener.

In the new design, the class has a `PatternRewriter` field instead of
inheriting from `PatternRewriter`, which is generally perferred in
object-oriented programming.

---------

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2562301e499dd..ed7b9ece4a464 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -784,6 +784,7 @@ class IRRewriter : public RewriterBase {
 /// place.
 class PatternRewriter : public RewriterBase {
 public:
+  explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
   using RewriterBase::RewriterBase;
 
   /// A hook used to indicate if the pattern rewriter can recover from failure

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cfd4f9c03aaff..597cb29ce911b 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -319,8 +319,7 @@ class RandomizedWorklist : public Worklist {
 /// This abstract class manages the worklist and contains helper methods for
 /// rewriting ops on the worklist. Derived classes specify how ops are added
 /// to the worklist in the beginning.
-class GreedyPatternRewriteDriver : public PatternRewriter,
-                                   public RewriterBase::Listener {
+class GreedyPatternRewriteDriver : public RewriterBase::Listener {
 protected:
   explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
                                       const FrozenRewritePatternSet &patterns,
@@ -339,7 +338,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the specified operation was inserted. Update the
   /// worklist as needed: The operation is enqueued depending on scope and
   /// strict mode.
-  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
 
   /// Notify the driver that the specified operation was removed. Update the
   /// worklist as needed: The operation and its children are removed from the
@@ -354,6 +354,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// reached. Return `true` if any IR was changed.
   bool processWorklist();
 
+  /// The pattern rewriter that is used for making IR modifications and is
+  /// passed to rewrite patterns.
+  PatternRewriter rewriter;
+
   /// The worklist for this transformation keeps track of the operations that
   /// need to be (re)visited.
 #ifdef MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
@@ -407,7 +411,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
-    : PatternRewriter(ctx), config(config), matcher(patterns)
+    : rewriter(ctx), config(config), matcher(patterns)
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
       // clang-format off
       , expensiveChecks(
@@ -423,9 +427,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // Send IR notifications to the debug handler. This handler will then forward
   // all notifications to this GreedyPatternRewriteDriver.
-  setListener(&expensiveChecks);
+  rewriter.setListener(&expensiveChecks);
 #else
-  setListener(this);
+  rewriter.setListener(this);
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 }
 
@@ -473,7 +477,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 
     // If the operation is trivially dead - remove it.
     if (isOpTriviallyDead(op)) {
-      eraseOp(op);
+      rewriter.eraseOp(op);
       changed = true;
 
       LLVM_DEBUG(logResultWithLine("success", "operation is trivially dead"));
@@ -505,8 +509,8 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         // Op results can be replaced with `foldResults`.
         assert(foldResults.size() == op->getNumResults() &&
                "folder produced incorrect number of results");
-        OpBuilder::InsertionGuard g(*this);
-        setInsertionPoint(op);
+        OpBuilder::InsertionGuard g(rewriter);
+        rewriter.setInsertionPoint(op);
         SmallVector<Value> replacements;
         bool materializationSucceeded = true;
         for (auto [ofr, resultType] :
@@ -519,7 +523,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           }
           // Materialize Attributes as SSA values.
           Operation *constOp = op->getDialect()->materializeConstant(
-              *this, ofr.get<Attribute>(), resultType, op->getLoc());
+              rewriter, ofr.get<Attribute>(), resultType, op->getLoc());
 
           if (!constOp) {
             // If materialization fails, cleanup any operations generated for
@@ -532,7 +536,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
               replacementOps.insert(replacement.getDefiningOp());
             }
             for (Operation *op : replacementOps) {
-              eraseOp(op);
+              rewriter.eraseOp(op);
             }
 
             materializationSucceeded = false;
@@ -547,7 +551,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         }
 
         if (materializationSucceeded) {
-          replaceOp(op, replacements);
+          rewriter.replaceOp(op, replacements);
           changed = true;
           LLVM_DEBUG(logSuccessfulFolding(dumpRootOp));
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
@@ -608,7 +612,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
     LogicalResult matchResult =
-        matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess);
+        matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess);
 
     if (succeeded(matchResult)) {
       LLVM_DEBUG(logResultWithLine("success", "pattern matched"));
@@ -664,8 +668,8 @@ void GreedyPatternRewriteDriver::notifyBlockErased(Block *block) {
     config.listener->notifyBlockErased(block);
 }
 
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
-                                                         InsertPoint previous) {
+void GreedyPatternRewriteDriver::notifyOperationInserted(
+    Operation *op, OpBuilder::InsertPoint previous) {
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
@@ -822,7 +826,7 @@ class GreedyPatternRewriteIteration
 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
   bool continueRewrites = false;
   int64_t iteration = 0;
-  MLIRContext *ctx = getContext();
+  MLIRContext *ctx = rewriter.getContext();
   do {
     // Check if the iteration limit was reached.
     if (++iteration > config.maxIterations &&
@@ -834,7 +838,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
 
     // `OperationFolder` CSE's constant ops (and may move them into parents
     // regions to enable more aggressive CSE'ing).
-    OperationFolder folder(getContext(), this);
+    OperationFolder folder(ctx, this);
     auto insertKnownConstant = [&](Operation *op) {
       // Check for existing constants when populating the worklist. This avoids
       // accidentally reversing the constant order during processing.
@@ -872,7 +876,7 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
           // After applying patterns, make sure that the CFG of each of the
           // regions is kept up to date.
           if (config.enableRegionSimplification)
-            continueRewrites |= succeeded(simplifyRegions(*this, region));
+            continueRewrites |= succeeded(simplifyRegions(rewriter, region));
         },
         {&region}, iteration);
   } while (continueRewrites);


        


More information about the Mlir-commits mailing list