[Mlir-commits] [mlir] [mlir][Transforms] Improve `replaceOpWithMultiple` API (PR #132608)
Matthias Springer
llvmlistbot at llvm.org
Sun Mar 23 05:09:39 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/132608
This commit adds an additional overload to `replaceOpWithMultiple` that accepts additional container types. This has been brought up by users of the new `replaceOpWithMultiple` API.
In particular, one missing container type was `SmallVector<SmallVector<Value>>`. The "default" `ArrayRef<ValueRange>` container type can lead to use-after-scope errors in cases such as:
```c++
// Compute the replacement value ranges. Some replacements are single
// values, some are value ranges.
SmallVector<ValueRange> repl;
repl.push_back(someValueRange); // OK
for (...) {
// push_back(Value) triggers an implicit conversion to ValueRange,
// which does not own the Value.
repl.push_back(someValue); // triggers use-after-scope later
}
```
In this example, users should use `SmallVector<SmallVector<Value>> repl;`.
>From c4e8397dd18f1b686b26788619b8dab81f5640cd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 23 Mar 2025 12:59:17 +0100
Subject: [PATCH] [mlir][Transforms] Improve `replaceOpWithMultiple` API
---
.../mlir/Transforms/DialectConversion.h | 4 ++++
.../Transforms/SparseTensorCodegen.cpp | 3 +--
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 22 +++++++++++++++++++
3 files changed, 27 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..cbf60b784af94 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,6 +898,10 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+ template <typename RangeT>
+ void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+ replaceOpWithMultiple(op, llvm::to_vector_of<ValueRange>(newValues));
+ }
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..6291f3ea37230 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
}
assert(packedResultVals.size() == op.getNumResults());
- rewriter.replaceOpWithMultiple(
- op, llvm::to_vector_of<ValueRange>(packedResultVals));
+ rewriter.replaceOpWithMultiple(op, packedResultVals);
return success();
}
};
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..e325003f5384c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
}
};
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+ ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+ SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+ SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+ SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+ ArrayRef<Value> ar) {
+ rewriter.replaceOpWithMultiple(op, r1);
+ rewriter.replaceOpWithMultiple(op, r2);
+ rewriter.replaceOpWithMultiple(op, r3);
+ rewriter.replaceOpWithMultiple(op, r4);
+ rewriter.replaceOpWithMultiple(op, r5);
+ rewriter.replaceOpWithMultiple(op, r6);
+ rewriter.replaceOpWithMultiple(op, {vr});
+ rewriter.replaceOpWithMultiple(op, {ar});
+ rewriter.replaceOpWithMultiple(op, {{v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+ rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
} // namespace
namespace {
More information about the Mlir-commits
mailing list