[Mlir-commits] [mlir] [mlir][Transforms] Improve `replaceOpWithMultiple` API (PR #132608)

Matthias Springer llvmlistbot at llvm.org
Sun Mar 23 05:09:39 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/132608

This commit adds an additional overload to `replaceOpWithMultiple` that accepts additional container types. This has been brought up by users of the new `replaceOpWithMultiple` API.

In particular, one missing container type was `SmallVector<SmallVector<Value>>`. The "default" `ArrayRef<ValueRange>` container type can lead to use-after-scope errors in cases such as:
```c++
// Compute the replacement value ranges. Some replacements are single
// values, some are value ranges.
SmallVector<ValueRange> repl;
repl.push_back(someValueRange);  // OK
for (...) {
  // push_back(Value) triggers an implicit conversion to ValueRange,
  // which does not own the Value.
  repl.push_back(someValue);  // triggers use-after-scope later
}
```

In this example, users should use `SmallVector<SmallVector<Value>> repl;`.


>From c4e8397dd18f1b686b26788619b8dab81f5640cd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 23 Mar 2025 12:59:17 +0100
Subject: [PATCH] [mlir][Transforms] Improve `replaceOpWithMultiple` API

---
 .../mlir/Transforms/DialectConversion.h       |  4 ++++
 .../Transforms/SparseTensorCodegen.cpp        |  3 +--
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 22 +++++++++++++++++++
 3 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..cbf60b784af94 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,6 +898,10 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+  template <typename RangeT>
+  void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+    replaceOpWithMultiple(op, llvm::to_vector_of<ValueRange>(newValues));
+  }
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..6291f3ea37230 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     }
 
     assert(packedResultVals.size() == op.getNumResults());
-    rewriter.replaceOpWithMultiple(
-        op, llvm::to_vector_of<ValueRange>(packedResultVals));
+    rewriter.replaceOpWithMultiple(op, packedResultVals);
     return success();
   }
 };
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..e325003f5384c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   }
 };
 
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+    ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+    SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+    SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+    SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+    ArrayRef<Value> ar) {
+  rewriter.replaceOpWithMultiple(op, r1);
+  rewriter.replaceOpWithMultiple(op, r2);
+  rewriter.replaceOpWithMultiple(op, r3);
+  rewriter.replaceOpWithMultiple(op, r4);
+  rewriter.replaceOpWithMultiple(op, r5);
+  rewriter.replaceOpWithMultiple(op, r6);
+  rewriter.replaceOpWithMultiple(op, {vr});
+  rewriter.replaceOpWithMultiple(op, {ar});
+  rewriter.replaceOpWithMultiple(op, {{v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+  rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
 } // namespace
 
 namespace {



More information about the Mlir-commits mailing list