[Mlir-commits] [mlir] [mlir][Transforms] Improve `replaceOpWithMultiple` API (PR #132608)
Matthias Springer
llvmlistbot at llvm.org
Sun Mar 23 05:23:26 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/132608
>From 8d9755ace68ef6ecb5872c91335c198a006fba94 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 23 Mar 2025 12:59:17 +0100
Subject: [PATCH] [mlir][Transforms] Improve `replaceOpWithMultiple` API
---
.../mlir/Transforms/DialectConversion.h | 5 +++++
.../Transforms/SparseTensorCodegen.cpp | 3 +--
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 22 +++++++++++++++++++
3 files changed, 28 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..b537b790687c8 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,6 +898,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. The given operation is erased.
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+ template <typename RangeT>
+ void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+ replaceOpWithMultiple(op,
+ ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+ }
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..6291f3ea37230 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
}
assert(packedResultVals.size() == op.getNumResults());
- rewriter.replaceOpWithMultiple(
- op, llvm::to_vector_of<ValueRange>(packedResultVals));
+ rewriter.replaceOpWithMultiple(op, packedResultVals);
return success();
}
};
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..e325003f5384c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
}
};
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+ ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+ SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+ SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+ SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+ ArrayRef<Value> ar) {
+ rewriter.replaceOpWithMultiple(op, r1);
+ rewriter.replaceOpWithMultiple(op, r2);
+ rewriter.replaceOpWithMultiple(op, r3);
+ rewriter.replaceOpWithMultiple(op, r4);
+ rewriter.replaceOpWithMultiple(op, r5);
+ rewriter.replaceOpWithMultiple(op, r6);
+ rewriter.replaceOpWithMultiple(op, {vr});
+ rewriter.replaceOpWithMultiple(op, {ar});
+ rewriter.replaceOpWithMultiple(op, {{v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+ rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+ rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
} // namespace
namespace {
More information about the Mlir-commits
mailing list