[Mlir-commits] [mlir] [mlir][IR] Make `replaceOp` / `replaceAllUsesWith` API consistent (PR #82629)

Matthias Springer llvmlistbot at llvm.org
Wed Mar 6 17:08:38 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82629

>From 5074d55305ea160a145567234cda4e25a7644324 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 22 Feb 2024 15:03:30 +0000
Subject: [PATCH] [mlir][IR] Make `replaceOp` / `replaceAllUsesWith` API
 consistent

* `replaceOp` replaces all uses of the original op and erases the old op.
* `replaceAllUsesWith` replaces all uses of the original op/value/block. It does not erase any IR.

This commit renames `replaceOpWithIf` to `replaceUsesWithIf`. `replaceOpWithIf` was a misnomer because the function never erases the original op. Similarly, `replaceOpWithinBlock` is renamed to `replaceUsesWithinBlock`.

Also improve comments.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 mlir/include/mlir/IR/PatternMatch.h           | 87 +++++++++----------
 .../mlir/Transforms/DialectConversion.h       |  6 --
 .../Linalg/Transforms/DecomposeLinalgOps.cpp  |  4 +-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |  2 +-
 mlir/lib/IR/PatternMatch.cpp                  | 73 ++++++----------
 .../Transforms/Utils/DialectConversion.cpp    | 13 ---
 mlir/lib/Transforms/Utils/RegionUtils.cpp     |  2 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  2 +-
 8 files changed, 74 insertions(+), 115 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index f8d22cfb22afd0..e3500b3f9446d8 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -497,42 +497,19 @@ class RewriterBase : public OpBuilder {
                           Region::iterator before);
   void inlineRegionBefore(Region &region, Block *before);
 
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when the provided `functor` returns true for a specific use.
-  /// The number of values in `newValues` is required to match the number of
-  /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
-  /// the uses of `op` were replaced. Note that in some rewriters, the given
-  /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
-  /// As such, the function should not capture by reference and instead use
-  /// value capture as necessary.
-  virtual void
-  replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
-                  llvm::unique_function<bool(OpOperand &) const> functor);
-  void replaceOpWithIf(Operation *op, ValueRange newValues,
-                       llvm::unique_function<bool(OpOperand &) const> functor) {
-    replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
-                    std::move(functor));
-  }
-
-  /// This method replaces the uses of the results of `op` with the values in
-  /// `newValues` when a use is nested within the given `block`. The number of
-  /// values in `newValues` is required to match the number of results of `op`.
-  /// If all uses of this operation are replaced, the operation is erased.
-  void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
-                            bool *allUsesReplaced = nullptr);
-
-  /// This method replaces the results of the operation with the specified list
-  /// of values. The number of provided values must match the number of results
-  /// of the operation. The replaced op is erased.
+  /// Replace the results of the given (original) operation with the specified
+  /// list of values (replacements). The result types of the given op and the
+  /// replacements must match. The original op is erased.
   virtual void replaceOp(Operation *op, ValueRange newValues);
 
-  /// This method replaces the results of the operation with the specified
-  /// new op (replacement). The number of results of the two operations must
-  /// match. The replaced op is erased.
+  /// Replace the results of the given (original) operation with the specified
+  /// new op (replacement). The result types of the two ops must match. The
+  /// original op is erased.
   virtual void replaceOp(Operation *op, Operation *newOp);
 
-  /// Replaces the result op with a new op that is created without verification.
-  /// The result values of the two ops must be the same types.
+  /// Replace the results of the given (original) op with a new op that is
+  /// created without verification (replacement). The result values of the two
+  /// ops must match. The original op is erased.
   template <typename OpTy, typename... Args>
   OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
@@ -634,9 +611,8 @@ class RewriterBase : public OpBuilder {
     finalizeOpModification(root);
   }
 
-  /// Find uses of `from` and replace them with `to`. It also marks every
-  /// modified uses and notifies the rewriter that an in-place operation
-  /// modification is about to happen.
+  /// Find uses of `from` and replace them with `to`. Also notify the listener
+  /// about every in-place op modification (for every use that was replaced).
   void replaceAllUsesWith(Value from, Value to) {
     return replaceAllUsesWith(from.getImpl(), to);
   }
@@ -652,22 +628,43 @@ class RewriterBase : public OpBuilder {
     for (auto it : llvm::zip(from, to))
       replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
   }
+  void replaceAllUsesWith(Operation *from, ValueRange to) {
+    replaceAllUsesWith(from->getResults(), to);
+  }
 
   /// Find uses of `from` and replace them with `to` if the `functor` returns
-  /// true. It also marks every modified uses and notifies the rewriter that an
-  /// in-place operation modification is about to happen.
+  /// true. Also notify the listener about every in-place op modification (for
+  /// every use that was replaced). The optional `allUsesReplaced` flag is set
+  /// to "true" if all uses were replaced.
   void replaceUsesWithIf(Value from, Value to,
-                         function_ref<bool(OpOperand &)> functor);
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr);
   void replaceUsesWithIf(ValueRange from, ValueRange to,
-                         function_ref<bool(OpOperand &)> functor) {
-    assert(from.size() == to.size() && "incorrect number of replacements");
-    for (auto it : llvm::zip(from, to))
-      replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr);
+  void replaceUsesWithIf(Operation *from, ValueRange to,
+                         function_ref<bool(OpOperand &)> functor,
+                         bool *allUsesReplaced = nullptr) {
+    replaceUsesWithIf(from->getResults(), to, functor, allUsesReplaced);
+  }
+
+  /// Find uses of `from` within `block` and replace them with `to`. Also notify
+  /// the listener about every in-place op modification (for every use that was
+  /// replaced). The optional `allUsesReplaced` flag is set to "true" if all
+  /// uses were replaced.
+  void replaceUsesWithinBlock(Operation *op, ValueRange newValues, Block *block,
+                              bool *allUsesReplaced = nullptr) {
+    replaceUsesWithIf(
+        op, newValues,
+        [block](OpOperand &use) {
+          return block->getParentOp()->isProperAncestor(use.getOwner());
+        },
+        allUsesReplaced);
   }
 
   /// Find uses of `from` and replace them with `to` except if the user is
-  /// `exceptedUser`. It also marks every modified uses and notifies the
-  /// rewriter that an in-place operation modification is about to happen.
+  /// `exceptedUser`. Also notify the listener about every in-place op
+  /// modification (for every use that was replaced).
   void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) {
     return replaceUsesWithIf(from, to, [&](OpOperand &use) {
       Operation *user = use.getOwner();
@@ -675,7 +672,7 @@ class RewriterBase : public OpBuilder {
     });
   }
 
-  /// Used to notify the rewriter that the IR failed to be rewritten because of
+  /// Used to notify the listener that the IR failed to be rewritten because of
   /// a match failure, and provide a callback to populate a diagnostic with the
   /// reason why the failure occurred. This method allows for derived rewriters
   /// to optionally hook into the reason why a rewrite failed, and display it to
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..01fde101ef3cb6 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -720,12 +720,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// PatternRewriter hook for replacing an operation when the given functor
-  /// returns "true".
-  void replaceOpWithIf(
-      Operation *op, ValueRange newValues, bool *allUsesReplaced,
-      llvm::unique_function<bool(OpOperand &) const> functor) override;
-
   /// PatternRewriter hook for replacing an operation.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
index 5cd6d4597affaf..1658ea67a46077 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp
@@ -370,8 +370,8 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,
       scalarReplacements.push_back(
           residualGenericOpBody->getArgument(num + origNumInputs));
     bool allUsesReplaced = false;
-    rewriter.replaceOpWithinBlock(peeledScalarOperation, scalarReplacements,
-                                  residualGenericOpBody, &allUsesReplaced);
+    rewriter.replaceUsesWithinBlock(peeledScalarOperation, scalarReplacements,
+                                    residualGenericOpBody, &allUsesReplaced);
     assert(!allUsesReplaced &&
            "peeled scalar operation is erased when it wasnt expected to be");
   }
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index 43c408a97687ce..74f6d97aeea53c 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -870,7 +870,7 @@ void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
         {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
     Value materialized =
         getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
-    b.replaceOpWithIf(indexOp, materialized, [&](OpOperand &use) {
+    b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
       return use.getOwner() != materialized.getDefiningOp();
     });
   }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 8796289d725707..0a88e40f73ec6c 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -110,41 +110,6 @@ RewriterBase::~RewriterBase() {
   // Out of line to provide a vtable anchor for the class.
 }
 
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when the provided `functor` returns true for a specific use.
-/// The number of values in `newValues` is required to match the number of
-/// results of `op`.
-void RewriterBase::replaceOpWithIf(
-    Operation *op, ValueRange newValues, bool *allUsesReplaced,
-    llvm::unique_function<bool(OpOperand &) const> functor) {
-  assert(op->getNumResults() == newValues.size() &&
-         "incorrect number of values to replace operation");
-
-  // Notify the listener that we're about to replace this op.
-  if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
-    rewriteListener->notifyOperationReplaced(op, newValues);
-
-  // Replace each use of the results when the functor is true.
-  bool replacedAllUses = true;
-  for (auto it : llvm::zip(op->getResults(), newValues)) {
-    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor);
-    replacedAllUses &= std::get<0>(it).use_empty();
-  }
-  if (allUsesReplaced)
-    *allUsesReplaced = replacedAllUses;
-}
-
-/// This method replaces the uses of the results of `op` with the values in
-/// `newValues` when a use is nested within the given `block`. The number of
-/// values in `newValues` is required to match the number of results of `op`.
-/// If all uses of this operation are replaced, the operation is erased.
-void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
-                                        Block *block, bool *allUsesReplaced) {
-  replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
-    return block->getParentOp()->isProperAncestor(use.getOwner());
-  });
-}
-
 /// This method replaces the results of the operation with the specified list of
 /// values. The number of provided values must match the number of results of
 /// the operation. The replaced op is erased.
@@ -156,9 +121,8 @@ void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
     rewriteListener->notifyOperationReplaced(op, newValues);
 
-  // Replace results one-by-one. Also notifies the listener of modifications.
-  for (auto it : llvm::zip(op->getResults(), newValues))
-    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  // Replace all result uses. Also notifies the listener of modifications.
+  replaceAllUsesWith(op, newValues);
 
   // Erase op and notify listener.
   eraseOp(op);
@@ -176,9 +140,8 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) {
   if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
     rewriteListener->notifyOperationReplaced(op, newOp);
 
-  // Replace results one-by-one. Also notifies the listener of modifications.
-  for (auto it : llvm::zip(op->getResults(), newOp->getResults()))
-    replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
+  // Replace all result uses. Also notifies the listener of modifications.
+  replaceAllUsesWith(op, newOp->getResults());
 
   // Erase op and notify listener.
   eraseOp(op);
@@ -279,15 +242,33 @@ void RewriterBase::finalizeOpModification(Operation *op) {
     rewriteListener->notifyOperationModified(op);
 }
 
-/// Find uses of `from` and replace them with `to` if the `functor` returns
-/// true. It also marks every modified uses and notifies the rewriter that an
-/// in-place operation modification is about to happen.
 void RewriterBase::replaceUsesWithIf(Value from, Value to,
-                                     function_ref<bool(OpOperand &)> functor) {
+                                     function_ref<bool(OpOperand &)> functor,
+                                     bool *allUsesReplaced) {
+  bool allReplaced = true;
   for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
-    if (functor(operand))
+    bool replace = functor(operand);
+    if (replace)
       modifyOpInPlace(operand.getOwner(), [&]() { operand.set(to); });
+    allReplaced &= replace;
   }
+  if (allUsesReplaced)
+    *allUsesReplaced = allReplaced;
+}
+
+void RewriterBase::replaceUsesWithIf(ValueRange from, ValueRange to,
+                                     function_ref<bool(OpOperand &)> functor,
+                                     bool *allUsesReplaced) {
+  assert(from.size() == to.size() && "incorrect number of replacements");
+  bool allReplaced = true;
+  for (auto it : llvm::zip_equal(from, to)) {
+    bool r;
+    replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor,
+                      /*allUsesReplaced=*/&r);
+    allReplaced &= r;
+  }
+  if (allUsesReplaced)
+    *allUsesReplaced = allReplaced;
 }
 
 void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4741110bc60682..d7dc902a9a5ebd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1528,19 +1528,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
-void ConversionPatternRewriter::replaceOpWithIf(
-    Operation *op, ValueRange newValues, bool *allUsesReplaced,
-    llvm::unique_function<bool(OpOperand &) const> functor) {
-  // TODO: To support this we will need to rework a bit of how replacements are
-  // tracked, given that this isn't guranteed to replace all of the uses of an
-  // operation. The main change is that now an operation can be replaced
-  // multiple times, in parts. The current "set" based tracking is mainly useful
-  // for tracking if a replaced operation should be ignored, i.e. if all of the
-  // uses will be replaced.
-  llvm_unreachable(
-      "replaceOpWithIf is currently not supported by DialectConversion");
-}
-
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
   assert(op && newOp && "expected non-null op");
   replaceOp(op, newOp->getResults());
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index e8b07143fc60bd..eff8acdfb33d20 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -161,7 +161,7 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
   rewriter.setInsertionPointToStart(newEntryBlock);
   for (auto *clonedOp : clonedOperations) {
     Operation *newOp = rewriter.clone(*clonedOp, map);
-    rewriter.replaceOpWithIf(clonedOp, newOp->getResults(), replaceIfFn);
+    rewriter.replaceUsesWithIf(clonedOp, newOp->getResults(), replaceIfFn);
   }
   rewriter.mergeBlocks(
       entryBlock, newEntryBlock,
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index abc0e43c7b7f2d..27eae2ffd694b5 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1836,7 +1836,7 @@ struct TestSelectiveOpReplacementPattern : public OpRewritePattern<TestCastOp> {
     OperandRange operands = op.getOperands();
 
     // Replace non-terminator uses with the first operand.
-    rewriter.replaceOpWithIf(op, operands[0], [](OpOperand &operand) {
+    rewriter.replaceUsesWithIf(op, operands[0], [](OpOperand &operand) {
       return operand.getOwner()->hasTrait<OpTrait::IsTerminator>();
     });
     // Replace everything else with the second operand if the operation isn't



More information about the Mlir-commits mailing list