[Mlir-commits] [mlir] [mlir][Transforms] Improve `replaceOpWithMultiple` API (PR #132608)

Matthias Springer llvmlistbot at llvm.org
Mon Mar 24 01:42:38 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/132608

>From 8d9755ace68ef6ecb5872c91335c198a006fba94 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 23 Mar 2025 12:59:17 +0100
Subject: [PATCH 1/3] [mlir][Transforms] Improve `replaceOpWithMultiple` API

---
 .../mlir/Transforms/DialectConversion.h       |  5 +++++
 .../Transforms/SparseTensorCodegen.cpp        |  3 +--
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 22 +++++++++++++++++++
 3 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..b537b790687c8 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,6 +898,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+  template <typename RangeT>
+  void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+    replaceOpWithMultiple(op,
+                          ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+  }
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..6291f3ea37230 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     }
 
     assert(packedResultVals.size() == op.getNumResults());
-    rewriter.replaceOpWithMultiple(
-        op, llvm::to_vector_of<ValueRange>(packedResultVals));
+    rewriter.replaceOpWithMultiple(op, packedResultVals);
     return success();
   }
 };
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..e325003f5384c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   }
 };
 
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+    ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+    SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+    SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+    SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+    ArrayRef<Value> ar) {
+  rewriter.replaceOpWithMultiple(op, r1);
+  rewriter.replaceOpWithMultiple(op, r2);
+  rewriter.replaceOpWithMultiple(op, r3);
+  rewriter.replaceOpWithMultiple(op, r4);
+  rewriter.replaceOpWithMultiple(op, r5);
+  rewriter.replaceOpWithMultiple(op, r6);
+  rewriter.replaceOpWithMultiple(op, {vr});
+  rewriter.replaceOpWithMultiple(op, {ar});
+  rewriter.replaceOpWithMultiple(op, {{v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+  rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
 } // namespace
 
 namespace {

>From 09f16b2fc932f51f5a04f9eabf72e7db2b372214 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 24 Mar 2025 08:53:04 +0100
Subject: [PATCH 2/3] Update mlir/include/mlir/Transforms/DialectConversion.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
 mlir/include/mlir/Transforms/DialectConversion.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b537b790687c8..00a5389b4101e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -899,7 +899,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
   template <typename RangeT>
-  void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+  void replaceOpWithMultiple(Operation *op, RangeT&& newValues) {
     replaceOpWithMultiple(op,
                           ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
   }

>From 287bcbd7d2c6ac35427a253298177b2bcf82b1d6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 24 Mar 2025 09:42:14 +0100
Subject: [PATCH 3/3] clang-format

---
 mlir/include/mlir/Transforms/DialectConversion.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 00a5389b4101e..66a5e6a889905 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -899,7 +899,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
   template <typename RangeT>
-  void replaceOpWithMultiple(Operation *op, RangeT&& newValues) {
+  void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
     replaceOpWithMultiple(op,
                           ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
   }



More information about the Mlir-commits mailing list