[Mlir-commits] [mlir] [mlir][Transforms] Improve `replaceOpWithMultiple` API (PR #132608)
Matthias Springer
llvmlistbot at llvm.org
Tue Mar 25 06:33:43 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/132608
>From a833e1a01cab16f943dab1ab17f895ba9b70e12c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 23 Mar 2025 12:59:17 +0100
Subject: [PATCH 1/4] [mlir][Transforms] Improve `replaceOpWithMultiple` API
---
.../mlir/Transforms/DialectConversion.h | 5 +++++
.../Transforms/SparseTensorCodegen.cpp | 3 +--
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 22 +++++++++++++++++++
3 files changed, 28 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..b537b790687c8 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,6 +898,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+ template <typename RangeT>
+ void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+ replaceOpWithMultiple(op,
+ ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+ }
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..6291f3ea37230 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
}
assert(packedResultVals.size() == op.getNumResults());
- rewriter.replaceOpWithMultiple(
- op, llvm::to_vector_of<ValueRange>(packedResultVals));
+ rewriter.replaceOpWithMultiple(op, packedResultVals);
return success();
}
};
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..e325003f5384c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
}
};
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+ ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+ SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+ SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+ SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+ ArrayRef<Value> ar) {
+ rewriter.replaceOpWithMultiple(op, r1);
+ rewriter.replaceOpWithMultiple(op, r2);
+ rewriter.replaceOpWithMultiple(op, r3);
+ rewriter.replaceOpWithMultiple(op, r4);
+ rewriter.replaceOpWithMultiple(op, r5);
+ rewriter.replaceOpWithMultiple(op, r6);
+ rewriter.replaceOpWithMultiple(op, {vr});
+ rewriter.replaceOpWithMultiple(op, {ar});
+ rewriter.replaceOpWithMultiple(op, {{v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+ rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
} // namespace
namespace {
>From 0948f752a931dabdac753d084fc998f0067007d0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 24 Mar 2025 08:53:04 +0100
Subject: [PATCH 2/4] Update mlir/include/mlir/Transforms/DialectConversion.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
mlir/include/mlir/Transforms/DialectConversion.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b537b790687c8..00a5389b4101e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -899,7 +899,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
template <typename RangeT>
- void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+ void replaceOpWithMultiple(Operation *op, RangeT&& newValues) {
replaceOpWithMultiple(op,
ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
}
>From 5aadd7fae2d95b1d45134169366c33950a65ef79 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 24 Mar 2025 09:42:14 +0100
Subject: [PATCH 3/4] clang-format
---
mlir/include/mlir/Transforms/DialectConversion.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 00a5389b4101e..66a5e6a889905 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -899,7 +899,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
template <typename RangeT>
- void replaceOpWithMultiple(Operation *op, RangeT&& newValues) {
+ void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
replaceOpWithMultiple(op,
ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
}
>From b6a28ebbbceffb8c104c8b1af429461662154f07 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 25 Mar 2025 14:33:15 +0100
Subject: [PATCH 4/4] Use SmallVector<Value,1> by default
---
.../mlir/Transforms/DialectConversion.h | 32 ++++++++++++++++---
.../ArmSME/Transforms/VectorLegalization.cpp | 6 ++--
.../Transforms/SparseTensorCodegen.cpp | 8 ++---
.../Transforms/Utils/DialectConversion.cpp | 20 +++++-------
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 5 ++-
5 files changed, 46 insertions(+), 25 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 66a5e6a889905..e8ee9fecf711d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -897,11 +897,33 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. The given operation is erased.
- void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
- template <typename RangeT>
- void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
- replaceOpWithMultiple(op,
- ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+ void replaceOpWithMultiple(Operation *op,
+ ArrayRef<SmallVector<Value, 1>> newValues);
+ // Note: This overload matches SmallVector<ValueRange>,
+ // SmallVector<SmallVector<Value>>, etc.
+ template <typename RangeRangeT>
+ void replaceOpWithMultiple(Operation *op, RangeRangeT &&newValues) {
+ // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
+ // does not copy the replacements vector.
+ auto vals = llvm::map_to_vector(newValues, [](const auto &r) {
+ // Note: Create intermediate ValueRange because SmallVector<Value, 1>
+ // is not constructible from SmallVector<Value>.
+ return SmallVector<Value, 1>(ValueRange(r));
+ });
+ replaceOpWithMultiple(op, ArrayRef(vals));
+ }
+ // Note: This overload matches initializer list of ValueRange,
+ // SmallVector<Value>, etc.
+ template <typename RangeT = ValueRange>
+ void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
+ // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
+ // does not copy the replacements vector.
+ auto vals = llvm::map_to_vector(newValues, [](const RangeT &r) {
+ // Note: Create intermediate ValueRange because SmallVector<Value, 1>
+ // is not constructible from SmallVector<Value>.
+ return SmallVector<Value, 1>(ValueRange(r));
+ });
+ replaceOpWithMultiple(op, ArrayRef(vals));
}
/// PatternRewriter hook for erasing a dead operation. The uses of this
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index dec3dca988ae9..d6128785927b4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -192,7 +192,7 @@ struct LegalizeArithConstantOpsByDecomposition
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
auto tileSplat = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
- SmallVector<Value> repl(tileCount, tileSplat);
+ SmallVector<Value, 1> repl(tileCount, tileSplat);
rewriter.replaceOpWithMultiple(constantOp, {repl});
return success();
@@ -232,7 +232,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
- SmallVector<Value> resultSMETiles;
+ SmallVector<Value, 1> resultSMETiles;
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
@@ -310,7 +310,7 @@ struct LegalizeTransferReadOpsByDecomposition
auto loc = readOp.getLoc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
- SmallVector<Value> resultSMETiles;
+ SmallVector<Value, 1> resultSMETiles;
for (SMESubTile smeTile :
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6291f3ea37230..80969cf30cf88 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -585,7 +585,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
auto newCall = rewriter.create<func::CallOp>(
loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
// (2) Gather sparse tensor returns.
- SmallVector<SmallVector<Value>> packedResultVals;
+ SmallVector<SmallVector<Value, 1>> packedResultVals;
// Tracks the offset of current return value (of the original call)
// relative to the new call (after sparse tensor flattening);
unsigned retOffset = 0;
@@ -752,7 +752,7 @@ class SparseTensorAllocConverter
if (op.getCopy()) {
auto desc = getDescriptorFromTensorTuple(
adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
- SmallVector<Value> fields;
+ SmallVector<Value, 1> fields;
fields.reserve(desc.getNumFields());
// Memcpy on memref fields.
for (auto field : desc.getMemRefFields()) {
@@ -823,7 +823,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
/*dimSizesValues=*/lvlSizesValues);
// Construct allocation for each field.
Value sizeHint; // none
- SmallVector<Value> fields;
+ SmallVector<Value, 1> fields;
createAllocFields(rewriter, loc, resType, enableBufferInitialization,
sizeHint, lvlSizesValues, fields);
@@ -1176,7 +1176,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
Location loc = op.getLoc();
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
op.getSource().getType());
- SmallVector<Value> fields;
+ SmallVector<Value, 1> fields;
foreachFieldAndTypeInSparseTensor(
SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
[&rewriter, &fields, srcDesc,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9779436c947cf..c692aafd34aa3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -947,7 +947,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
+ void notifyOpReplaced(Operation *op, ArrayRef<ValueVector> newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -1520,7 +1520,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
}
void ConversionPatternRewriterImpl::notifyOpReplaced(
- Operation *op, ArrayRef<ValueRange> newValues) {
+ Operation *op, ArrayRef<ValueVector> newValues) {
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");
@@ -1640,19 +1640,15 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ValueRange> newVals;
- for (size_t i = 0; i < newValues.size(); ++i) {
- if (newValues[i]) {
- newVals.push_back(newValues.slice(i, 1));
- } else {
- newVals.push_back(ValueRange());
- }
- }
+ SmallVector<ValueVector> newVals =
+ llvm::map_to_vector(newValues, [](Value v) -> ValueVector {
+ return v ? ValueVector{v} : ValueVector();
+ });
impl->notifyOpReplaced(op, newVals);
}
void ConversionPatternRewriter::replaceOpWithMultiple(
- Operation *op, ArrayRef<ValueRange> newValues) {
+ Operation *op, ArrayRef<SmallVector<Value, 1>> newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
LLVM_DEBUG({
@@ -1667,7 +1663,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
- SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
+ SmallVector<ValueVector> nullRepls(op->getNumResults(), ValueVector());
impl->notifyOpReplaced(op, nullRepls);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e325003f5384c..764d8bd8575f5 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1284,7 +1284,8 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
- SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+ SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value, 1>> r7,
+ ArrayRef<SmallVector<Value, 1>> r8, Value v, ValueRange vr,
ArrayRef<Value> ar) {
rewriter.replaceOpWithMultiple(op, r1);
rewriter.replaceOpWithMultiple(op, r2);
@@ -1292,6 +1293,8 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
rewriter.replaceOpWithMultiple(op, r4);
rewriter.replaceOpWithMultiple(op, r5);
rewriter.replaceOpWithMultiple(op, r6);
+ rewriter.replaceOpWithMultiple(op, r7);
+ rewriter.replaceOpWithMultiple(op, r8);
rewriter.replaceOpWithMultiple(op, {vr});
rewriter.replaceOpWithMultiple(op, {ar});
rewriter.replaceOpWithMultiple(op, {{v}});
More information about the Mlir-commits
mailing list