[Mlir-commits] [mlir] 71d50c8 - [mlir][IR] Improve listener notifications for ops without results

Matthias Springer llvmlistbot at llvm.org
Tue Jun 13 23:51:29 PDT 2023


Author: Matthias Springer
Date: 2023-06-14T08:51:14+02:00
New Revision: 71d50c890bad943ab23ee9b32638b2366351f8f8

URL: https://github.com/llvm/llvm-project/commit/71d50c890bad943ab23ee9b32638b2366351f8f8
DIFF: https://github.com/llvm/llvm-project/commit/71d50c890bad943ab23ee9b32638b2366351f8f8.diff

LOG: [mlir][IR] Improve listener notifications for ops without results

`RewriterBase::Listener::notifyOperationReplaced` notifies observers that an op is about to be replaced with a range of values. This notification is not very useful for ops without results, because it does not specify the replacement op (and it cannot be deduced from the replacement values). It provides no additional information over the `notifyOperationRemoved` notification.

This revision adds an additional notification when a rewriter replaces an op with another op. By default, this notification triggers the original "op replaced with values" notification, so there is no functional change for existing code.

This new API is useful for the transform dialect, which needs to track op replacements. (Updated in a subsequent revision.)

Also includes minor documentation improvements.

Differential Revision: https://reviews.llvm.org/D152814

Added: 
    

Modified: 
    mlir/include/mlir/IR/PatternMatch.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/IR/PatternMatch.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 4614649caae12..3843ff249ddf7 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -406,13 +406,24 @@ class RewriterBase : public OpBuilder {
     virtual void notifyOperationModified(Operation *op) {}
 
     /// Notify the listener that the specified operation is about to be replaced
-    /// with the set of values potentially produced by new operations. This is
-    /// called before the uses of the operation have been changed.
+    /// with another operation. This is called before the uses of the old
+    /// operation have been changed.
+    ///
+    /// By default, this function calls the "operation replaced with values"
+    /// notification.
+    virtual void notifyOperationReplaced(Operation *op,
+                                         Operation *replacement) {
+      notifyOperationReplaced(op, replacement->getResults());
+    }
+
+    /// Notify the listener that the specified operation is about to be replaced
+    /// with the a range of values, potentially produced by other operations.
+    /// This is called before the uses of the operation have been changed.
     virtual void notifyOperationReplaced(Operation *op,
                                          ValueRange replacement) {}
 
-    /// This is called on an operation that a rewrite is removing, right before
-    /// the operation is deleted. At this point, the operation has zero uses.
+    /// Notify the listener that the specified operation is about to be erased.
+    /// At this point, the operation has zero uses.
     virtual void notifyOperationRemoved(Operation *op) {}
 
     /// Notify the listener that the pattern failed to match the given
@@ -444,6 +455,9 @@ class RewriterBase : public OpBuilder {
     void notifyOperationModified(Operation *op) override {
       listener->notifyOperationModified(op);
     }
+    void notifyOperationReplaced(Operation *op, Operation *newOp) override {
+      listener->notifyOperationReplaced(op, newOp);
+    }
     void notifyOperationReplaced(Operation *op,
                                  ValueRange replacement) override {
       listener->notifyOperationReplaced(op, replacement);
@@ -505,15 +519,20 @@ class RewriterBase : public OpBuilder {
 
   /// This method replaces the results of the operation with the specified list
   /// of values. The number of provided values must match the number of results
-  /// of the operation.
+  /// of the operation. The replaced op is erased.
   virtual void replaceOp(Operation *op, ValueRange newValues);
 
+  /// This method replaces the results of the operation with the specified
+  /// new op (replacement). The number of results of the two operations must
+  /// match. The replaced op is erased.
+  virtual void replaceOp(Operation *op, Operation *newOp);
+
   /// Replaces the result op with a new op that is created without verification.
   /// The result values of the two ops must be the same types.
   template <typename OpTy, typename... Args>
   OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
-    replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
+    replaceOp(op, newOp.getOperation());
     return newOp;
   }
 
@@ -666,10 +685,6 @@ class RewriterBase : public OpBuilder {
 private:
   void operator=(const RewriterBase &) = delete;
   RewriterBase(const RewriterBase &) = delete;
-
-  /// 'op' and 'newOp' are known to have the same number of results, replace the
-  /// uses of op with uses of newOp.
-  void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f242eea767786..f5206b1a4da4f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -695,15 +695,17 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// PatternRewriter hook for replacing the results of an operation when the
-  /// given functor returns true.
+  /// PatternRewriter hook for replacing an operation when the given functor
+  /// returns "true".
   void replaceOpWithIf(
       Operation *op, ValueRange newValues, bool *allUsesReplaced,
       llvm::unique_function<bool(OpOperand &) const> functor) override;
 
-  /// PatternRewriter hook for replacing the results of an operation.
+  /// PatternRewriter hook for replacing an operation.
   void replaceOp(Operation *op, ValueRange newValues) override;
-  using PatternRewriter::replaceOp;
+
+  /// PatternRewriter hook for replacing an operation.
+  void replaceOp(Operation *op, Operation *newOp) override;
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,

diff  --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index d07d6518d57c0..9dfe07797ff4b 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -139,7 +139,7 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
       loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
   rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
                                     loopBlock, atomicRes);
-  rewriter.replaceOp(atomicOp, {});
+  rewriter.eraseOp(atomicOp);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 8ed790c421a4c..01ad2dd20e7ca 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -331,10 +331,7 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(),
         alloc.getAlignmentAttr());
     // Insert a cast so we have the same type as the old alloc.
-    auto resultCast =
-        rewriter.create<CastOp>(alloc.getLoc(), alloc.getType(), newAlloc);
-
-    rewriter.replaceOp(alloc, {resultCast});
+    rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
     return success();
   }
 };

diff  --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 052696d5cb13a..db920c14ea08d 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -262,12 +262,12 @@ void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
 
 /// This method replaces the results of the operation with the specified list of
 /// values. The number of provided values must match the number of results of
-/// the operation.
+/// the operation. The replaced op is erased.
 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
 
-  // Notify the listener that we're about to remove this op.
+  // Notify the listener that we're about to replace this op.
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
     rewriteListener->notifyOperationReplaced(op, newValues);
 
@@ -275,9 +275,28 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   for (auto it : llvm::zip(op->getResults(), newValues))
     replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
 
+  // Erase the op.
+  eraseOp(op);
+}
+
+/// This method replaces the results of the operation with the specified new op
+/// (replacement). The number of results of the two operations must match. The
+/// replaced op is erased.
+void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
+  assert(op && newOp && "expected non-null op");
+  assert(op->getNumResults() == newOp->getNumResults() &&
+         "ops have 
diff erent number of results");
+
+  // Notify the listener that we're about to replace this op.
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationRemoved(op);
-  op->erase();
+    rewriteListener->notifyOperationReplaced(op, newOp);
+
+  // Replace results one-by-one. Also notifies the listener of modifications.
+  for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
+    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+
+  // Erase the old op.
+  eraseOp(op);
 }
 
 /// This method erases an operation that is known to have no uses. The uses of
@@ -364,17 +383,6 @@ Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
   return block->splitBlock(before);
 }
 
-/// 'op' and 'newOp' are known to have the same number of results, replace the
-/// uses of op with uses of newOp
-void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
-                                                   Operation *newOp) {
-  assert(op->getNumResults() == newOp->getNumResults() &&
-         "replacement op doesn't match results of original op");
-  if (op->getNumResults() == 1)
-    return replaceOp(op, newOp->getResult(0));
-  return replaceOp(op, newOp->getResults());
-}
-
 /// Move the blocks that belong to "region" before the given position in
 /// another region.  The two regions must be 
diff erent.  The caller is in
 /// charge to update create the operation transferring the control flow to the

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 615c8e4a99ceb..411111358b1e1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1452,7 +1452,14 @@ void ConversionPatternRewriter::replaceOpWithIf(
       "replaceOpWithIf is currently not supported by DialectConversion");
 }
 
+void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
+  assert(op && newOp && "expected non-null op");
+  replaceOp(op, newOp->getResults());
+}
+
 void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
+  assert(op->getNumResults() == newValues.size() &&
+         "incorrect # of replacement values");
   LLVM_DEBUG({
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3a1faeabe84c1..8a0ca056cd8b3 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -601,7 +601,7 @@ struct TestCreateBlock : public RewritePattern {
     Location loc = op->getLoc();
     rewriter.createBlock(&region, region.end(), {i32Type, i32Type}, {loc, loc});
     rewriter.create<TerminatorOp>(loc);
-    rewriter.replaceOp(op, {});
+    rewriter.eraseOp(op);
     return success();
   }
 };
@@ -621,7 +621,7 @@ struct TestCreateIllegalBlock : public RewritePattern {
     // Create an illegal op to ensure the conversion fails.
     rewriter.create<ILLegalOpF>(loc, i32Type);
     rewriter.create<TerminatorOp>(loc);
-    rewriter.replaceOp(op, {});
+    rewriter.eraseOp(op);
     return success();
   }
 };
@@ -793,8 +793,8 @@ struct TestNonRootReplacement : public RewritePattern {
     auto illegalOp = rewriter.create<ILLegalOpF>(op->getLoc(), resultType);
     auto legalOp = rewriter.create<LegalOpB>(op->getLoc(), resultType);
 
-    rewriter.replaceOp(illegalOp, {legalOp});
-    rewriter.replaceOp(op, {illegalOp});
+    rewriter.replaceOp(illegalOp, legalOp);
+    rewriter.replaceOp(op, illegalOp);
     return success();
   }
 };


        


More information about the Mlir-commits mailing list