[Mlir-commits] [mlir] 4abff4d - [mlir][Transforms] Improve `replaceOpWithMultiple` API (#132608)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 28 06:18:58 PDT 2025
Author: Matthias Springer
Date: 2025-03-28T14:18:54+01:00
New Revision: 4abff4d7b2b49f343da68f32ffdae2914ba8ae7f
URL: https://github.com/llvm/llvm-project/commit/4abff4d7b2b49f343da68f32ffdae2914ba8ae7f
DIFF: https://github.com/llvm/llvm-project/commit/4abff4d7b2b49f343da68f32ffdae2914ba8ae7f.diff
LOG: [mlir][Transforms] Improve `replaceOpWithMultiple` API (#132608)
This commit adds an additional overload to `replaceOpWithMultiple` that
accepts additional container types. This has been brought up by users of
the new `replaceOpWithMultiple` API.
In particular, one missing container type was
`SmallVector<SmallVector<Value>>`. The "default" `ArrayRef<ValueRange>`
container type can lead to use-after-scope errors in cases such as:
```c++
// Compute the replacement value ranges. Some replacements are single
// values, some are value ranges.
SmallVector<ValueRange> repl;
repl.push_back(someValueRange); // OK
for (...) {
// push_back(Value) triggers an implicit conversion to ValueRange,
// which does not own the range.
repl.push_back(someValue); // triggers use-after-scope later
}
rewriter.replaceOpWithMultiple(op, repl);
```
In this example, users should use `SmallVector<SmallVector<Value>>
repl;`.
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..6a9316cbc690f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -897,7 +897,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. The given operation is erased.
- void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+ void replaceOpWithMultiple(Operation *op,
+ SmallVector<SmallVector<Value>> &&newValues);
+ template <typename RangeT = ValueRange>
+ void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
+ replaceOpWithMultiple(op,
+ llvm::to_vector_of<SmallVector<Value>>(newValues));
+ }
+ template <typename RangeT>
+ void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
+ replaceOpWithMultiple(op,
+ ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+ }
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..e5f9717c3fbaa 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
}
assert(packedResultVals.size() == op.getNumResults());
- rewriter.replaceOpWithMultiple(
- op, llvm::to_vector_of<ValueRange>(packedResultVals));
+ rewriter.replaceOpWithMultiple(op, std::move(packedResultVals));
return success();
}
};
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bca31f86683fa..b9475a7cc95a8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -173,6 +173,10 @@ struct ConversionValueMapping {
}
}
+ void map(Value oldVal, SmallVector<Value> &&newVal) {
+ map(ValueVector{oldVal}, ValueVector(std::move(newVal)));
+ }
+
/// Drop the last mapping for the given values.
void erase(const ValueVector &value) { mapping.erase(value); }
@@ -946,7 +950,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
+ void notifyOpReplaced(Operation *op,
+ SmallVector<SmallVector<Value>> &&newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -1519,7 +1524,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
}
void ConversionPatternRewriterImpl::notifyOpReplaced(
- Operation *op, ArrayRef<ValueRange> newValues) {
+ Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");
@@ -1561,7 +1566,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Remap result to replacement value.
if (repl.empty())
continue;
- mapping.map(result, repl);
+ mapping.map(static_cast<Value>(result), std::move(repl));
}
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1639,26 +1644,22 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ValueRange> newVals;
- for (size_t i = 0; i < newValues.size(); ++i) {
- if (newValues[i]) {
- newVals.push_back(newValues.slice(i, 1));
- } else {
- newVals.push_back(ValueRange());
- }
- }
- impl->notifyOpReplaced(op, newVals);
+ SmallVector<SmallVector<Value>> newVals =
+ llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
+ return v ? SmallVector<Value>{v} : SmallVector<Value>();
+ });
+ impl->notifyOpReplaced(op, std::move(newVals));
}
void ConversionPatternRewriter::replaceOpWithMultiple(
- Operation *op, ArrayRef<ValueRange> newValues) {
+ Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
LLVM_DEBUG({
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- impl->notifyOpReplaced(op, newValues);
+ impl->notifyOpReplaced(op, std::move(newValues));
}
void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1666,8 +1667,8 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
- impl->notifyOpReplaced(op, nullRepls);
+ SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
+ impl->notifyOpReplaced(op, std::move(nullRepls));
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..bfdcaf431eeff 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,29 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
}
};
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+ ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+ SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+ SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+ SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value>> &&r7,
+ Value v, ValueRange vr, ArrayRef<Value> ar) {
+ rewriter.replaceOpWithMultiple(op, r1);
+ rewriter.replaceOpWithMultiple(op, r2);
+ rewriter.replaceOpWithMultiple(op, r3);
+ rewriter.replaceOpWithMultiple(op, r4);
+ rewriter.replaceOpWithMultiple(op, r5);
+ rewriter.replaceOpWithMultiple(op, r6);
+ rewriter.replaceOpWithMultiple(op, std::move(r7));
+ rewriter.replaceOpWithMultiple(op, {vr});
+ rewriter.replaceOpWithMultiple(op, {ar});
+ rewriter.replaceOpWithMultiple(op, {{v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+ rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
} // namespace
namespace {
More information about the Mlir-commits
mailing list