[Mlir-commits] [mlir] 4abff4d - [mlir][Transforms] Improve `replaceOpWithMultiple` API (#132608)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 28 06:18:58 PDT 2025


Author: Matthias Springer
Date: 2025-03-28T14:18:54+01:00
New Revision: 4abff4d7b2b49f343da68f32ffdae2914ba8ae7f

URL: https://github.com/llvm/llvm-project/commit/4abff4d7b2b49f343da68f32ffdae2914ba8ae7f
DIFF: https://github.com/llvm/llvm-project/commit/4abff4d7b2b49f343da68f32ffdae2914ba8ae7f.diff

LOG: [mlir][Transforms] Improve `replaceOpWithMultiple` API (#132608)

This commit adds an additional overload to `replaceOpWithMultiple` that
accepts additional container types. This has been brought up by users of
the new `replaceOpWithMultiple` API.

In particular, one missing container type was
`SmallVector<SmallVector<Value>>`. The "default" `ArrayRef<ValueRange>`
container type can lead to use-after-scope errors in cases such as:
```c++
// Compute the replacement value ranges. Some replacements are single
// values, some are value ranges.
SmallVector<ValueRange> repl;
repl.push_back(someValueRange);  // OK
for (...) {
  // push_back(Value) triggers an implicit conversion to ValueRange,
  // which does not own the range.
  repl.push_back(someValue);  // triggers use-after-scope later
}
rewriter.replaceOpWithMultiple(op, repl);
```

In this example, users should use `SmallVector<SmallVector<Value>>
repl;`.

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..6a9316cbc690f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -897,7 +897,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
-  void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+  void replaceOpWithMultiple(Operation *op,
+                             SmallVector<SmallVector<Value>> &&newValues);
+  template <typename RangeT = ValueRange>
+  void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
+    replaceOpWithMultiple(op,
+                          llvm::to_vector_of<SmallVector<Value>>(newValues));
+  }
+  template <typename RangeT>
+  void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
+    replaceOpWithMultiple(op,
+                          ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+  }
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..e5f9717c3fbaa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     }
 
     assert(packedResultVals.size() == op.getNumResults());
-    rewriter.replaceOpWithMultiple(
-        op, llvm::to_vector_of<ValueRange>(packedResultVals));
+    rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
     return success();
   }
 };

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bca31f86683fa..b9475a7cc95a8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -173,6 +173,10 @@ struct ConversionValueMapping {
     }
   }
 
+  void map(Value oldVal, SmallVector<Value> &&newVal) {
+    map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
+  }
+
   /// Drop the last mapping for the given values.
   void erase(const ValueVector &value) { mapping.erase(value); }
 
@@ -946,7 +950,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                OpBuilder::InsertPoint previous) override;
 
   /// Notifies that an op is about to be replaced with the given values.
-  void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
+  void notifyOpReplaced(Operation *op,
+                        SmallVector<SmallVector<Value>> &&newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
@@ -1519,7 +1524,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
 }
 
 void ConversionPatternRewriterImpl::notifyOpReplaced(
-    Operation *op, ArrayRef<ValueRange> newValues) {
+    Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
   assert(newValues.size() == op->getNumResults());
   assert(!ignoredOps.contains(op) && "operation was already replaced");
 
@@ -1561,7 +1566,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
     // Remap result to replacement value.
     if (repl.empty())
       continue;
-    mapping.map(result, repl);
+    mapping.map(static_cast<Value>(result), std::move(repl));
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1639,26 +1644,22 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<ValueRange> newVals;
-  for (size_t i = 0; i < newValues.size(); ++i) {
-    if (newValues[i]) {
-      newVals.push_back(newValues.slice(i, 1));
-    } else {
-      newVals.push_back(ValueRange());
-    }
-  }
-  impl->notifyOpReplaced(op, newVals);
+  SmallVector<SmallVector<Value>> newVals =
+      llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
+        return v ? SmallVector<Value>{v} : SmallVector<Value>();
+      });
+  impl->notifyOpReplaced(op, std::move(newVals));
 }
 
 void ConversionPatternRewriter::replaceOpWithMultiple(
-    Operation *op, ArrayRef<ValueRange> newValues) {
+    Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
   LLVM_DEBUG({
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
-  impl->notifyOpReplaced(op, newValues);
+  impl->notifyOpReplaced(op, std::move(newValues));
 }
 
 void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1666,8 +1667,8 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
     impl->logger.startLine()
         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
-  impl->notifyOpReplaced(op, nullRepls);
+  SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
+  impl->notifyOpReplaced(op, std::move(nullRepls));
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..bfdcaf431eeff 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,29 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   }
 };
 
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+    ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+    SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+    SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+    SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value>> &&r7,
+    Value v, ValueRange vr, ArrayRef<Value> ar) {
+  rewriter.replaceOpWithMultiple(op, r1);
+  rewriter.replaceOpWithMultiple(op, r2);
+  rewriter.replaceOpWithMultiple(op, r3);
+  rewriter.replaceOpWithMultiple(op, r4);
+  rewriter.replaceOpWithMultiple(op, r5);
+  rewriter.replaceOpWithMultiple(op, r6);
+  rewriter.replaceOpWithMultiple(op, std::move(r7));
+  rewriter.replaceOpWithMultiple(op, {vr});
+  rewriter.replaceOpWithMultiple(op, {ar});
+  rewriter.replaceOpWithMultiple(op, {{v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+  rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
 } // namespace
 
 namespace {


        


More information about the Mlir-commits mailing list