[Mlir-commits] [mlir] [mlir][Transforms] Improve `replaceOpWithMultiple` API (PR #132608)

Matthias Springer llvmlistbot at llvm.org
Tue Mar 25 06:41:11 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/132608

>From a833e1a01cab16f943dab1ab17f895ba9b70e12c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 23 Mar 2025 12:59:17 +0100
Subject: [PATCH 1/5] [mlir][Transforms] Improve `replaceOpWithMultiple` API

---
 .../mlir/Transforms/DialectConversion.h       |  5 +++++
 .../Transforms/SparseTensorCodegen.cpp        |  3 +--
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 22 +++++++++++++++++++
 3 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..b537b790687c8 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,6 +898,11 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+  template <typename RangeT>
+  void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+    replaceOpWithMultiple(op,
+                          ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+  }
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..6291f3ea37230 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     }
 
     assert(packedResultVals.size() == op.getNumResults());
-    rewriter.replaceOpWithMultiple(
-        op, llvm::to_vector_of<ValueRange>(packedResultVals));
+    rewriter.replaceOpWithMultiple(op, packedResultVals);
     return success();
   }
 };
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..e325003f5384c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   }
 };
 
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+    ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+    SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+    SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+    SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+    ArrayRef<Value> ar) {
+  rewriter.replaceOpWithMultiple(op, r1);
+  rewriter.replaceOpWithMultiple(op, r2);
+  rewriter.replaceOpWithMultiple(op, r3);
+  rewriter.replaceOpWithMultiple(op, r4);
+  rewriter.replaceOpWithMultiple(op, r5);
+  rewriter.replaceOpWithMultiple(op, r6);
+  rewriter.replaceOpWithMultiple(op, {vr});
+  rewriter.replaceOpWithMultiple(op, {ar});
+  rewriter.replaceOpWithMultiple(op, {{v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+  rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
 } // namespace
 
 namespace {

>From 0948f752a931dabdac753d084fc998f0067007d0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 24 Mar 2025 08:53:04 +0100
Subject: [PATCH 2/5] Update mlir/include/mlir/Transforms/DialectConversion.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
 mlir/include/mlir/Transforms/DialectConversion.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b537b790687c8..00a5389b4101e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -899,7 +899,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
   template <typename RangeT>
-  void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+  void replaceOpWithMultiple(Operation *op, RangeT&& newValues) {
     replaceOpWithMultiple(op,
                           ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
   }

>From 5aadd7fae2d95b1d45134169366c33950a65ef79 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 24 Mar 2025 09:42:14 +0100
Subject: [PATCH 3/5] clang-format

---
 mlir/include/mlir/Transforms/DialectConversion.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 00a5389b4101e..66a5e6a889905 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -899,7 +899,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
   template <typename RangeT>
-  void replaceOpWithMultiple(Operation *op, RangeT&& newValues) {
+  void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
     replaceOpWithMultiple(op,
                           ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
   }

>From b6a28ebbbceffb8c104c8b1af429461662154f07 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 25 Mar 2025 14:33:15 +0100
Subject: [PATCH 4/5] Use SmallVector<Value,1> by default

---
 .../mlir/Transforms/DialectConversion.h       | 32 ++++++++++++++++---
 .../ArmSME/Transforms/VectorLegalization.cpp  |  6 ++--
 .../Transforms/SparseTensorCodegen.cpp        |  8 ++---
 .../Transforms/Utils/DialectConversion.cpp    | 20 +++++-------
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  5 ++-
 5 files changed, 46 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 66a5e6a889905..e8ee9fecf711d 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -897,11 +897,33 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
-  void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
-  template <typename RangeT>
-  void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
-    replaceOpWithMultiple(op,
-                          ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
+  void replaceOpWithMultiple(Operation *op,
+                             ArrayRef<SmallVector<Value, 1>> newValues);
+  // Note: This overload matches SmallVector<ValueRange>,
+  // SmallVector<SmallVector<Value>>, etc.
+  template <typename RangeRangeT>
+  void replaceOpWithMultiple(Operation *op, RangeRangeT &&newValues) {
+    // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
+    // does not copy the replacements vector.
+    auto vals = llvm::map_to_vector(newValues, [](const auto &r) {
+      // Note: Create intermediate ValueRange because SmallVector<Value, 1>
+      // is not constructible from SmallVector<Value>.
+      return SmallVector<Value, 1>(ValueRange(r));
+    });
+    replaceOpWithMultiple(op, ArrayRef(vals));
+  }
+  // Note: This overload matches initializer list of ValueRange,
+  // SmallVector<Value>, etc.
+  template <typename RangeT = ValueRange>
+  void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
+    // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
+    // does not copy the replacements vector.
+    auto vals = llvm::map_to_vector(newValues, [](const RangeT &r) {
+      // Note: Create intermediate ValueRange because SmallVector<Value, 1>
+      // is not constructible from SmallVector<Value>.
+      return SmallVector<Value, 1>(ValueRange(r));
+    });
+    replaceOpWithMultiple(op, ArrayRef(vals));
   }
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index dec3dca988ae9..d6128785927b4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -192,7 +192,7 @@ struct LegalizeArithConstantOpsByDecomposition
     auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
     auto tileSplat = rewriter.create<arith::ConstantOp>(
         constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
-    SmallVector<Value> repl(tileCount, tileSplat);
+    SmallVector<Value, 1> repl(tileCount, tileSplat);
     rewriter.replaceOpWithMultiple(constantOp, {repl});
 
     return success();
@@ -232,7 +232,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
     VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
 
-    SmallVector<Value> resultSMETiles;
+    SmallVector<Value, 1> resultSMETiles;
     for (auto [index, smeTile] : llvm::enumerate(
              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
 
@@ -310,7 +310,7 @@ struct LegalizeTransferReadOpsByDecomposition
     auto loc = readOp.getLoc();
     auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
 
-    SmallVector<Value> resultSMETiles;
+    SmallVector<Value, 1> resultSMETiles;
     for (SMESubTile smeTile :
          decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
       auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6291f3ea37230..80969cf30cf88 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -585,7 +585,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     auto newCall = rewriter.create<func::CallOp>(
         loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
     // (2) Gather sparse tensor returns.
-    SmallVector<SmallVector<Value>> packedResultVals;
+    SmallVector<SmallVector<Value, 1>> packedResultVals;
     // Tracks the offset of current return value (of the original call)
     // relative to the new call (after sparse tensor flattening);
     unsigned retOffset = 0;
@@ -752,7 +752,7 @@ class SparseTensorAllocConverter
     if (op.getCopy()) {
       auto desc = getDescriptorFromTensorTuple(
           adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
-      SmallVector<Value> fields;
+      SmallVector<Value, 1> fields;
       fields.reserve(desc.getNumFields());
       // Memcpy on memref fields.
       for (auto field : desc.getMemRefFields()) {
@@ -823,7 +823,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
                    /*dimSizesValues=*/lvlSizesValues);
     // Construct allocation for each field.
     Value sizeHint; // none
-    SmallVector<Value> fields;
+    SmallVector<Value, 1> fields;
     createAllocFields(rewriter, loc, resType, enableBufferInitialization,
                       sizeHint, lvlSizesValues, fields);
 
@@ -1176,7 +1176,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
     Location loc = op.getLoc();
     auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
                                                 op.getSource().getType());
-    SmallVector<Value> fields;
+    SmallVector<Value, 1> fields;
     foreachFieldAndTypeInSparseTensor(
         SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
         [&rewriter, &fields, srcDesc,
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9779436c947cf..c692aafd34aa3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -947,7 +947,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                OpBuilder::InsertPoint previous) override;
 
   /// Notifies that an op is about to be replaced with the given values.
-  void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
+  void notifyOpReplaced(Operation *op, ArrayRef<ValueVector> newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
@@ -1520,7 +1520,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
 }
 
 void ConversionPatternRewriterImpl::notifyOpReplaced(
-    Operation *op, ArrayRef<ValueRange> newValues) {
+    Operation *op, ArrayRef<ValueVector> newValues) {
   assert(newValues.size() == op->getNumResults());
   assert(!ignoredOps.contains(op) && "operation was already replaced");
 
@@ -1640,19 +1640,15 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<ValueRange> newVals;
-  for (size_t i = 0; i < newValues.size(); ++i) {
-    if (newValues[i]) {
-      newVals.push_back(newValues.slice(i, 1));
-    } else {
-      newVals.push_back(ValueRange());
-    }
-  }
+  SmallVector<ValueVector> newVals =
+      llvm::map_to_vector(newValues, [](Value v) -> ValueVector {
+        return v ? ValueVector{v} : ValueVector();
+      });
   impl->notifyOpReplaced(op, newVals);
 }
 
 void ConversionPatternRewriter::replaceOpWithMultiple(
-    Operation *op, ArrayRef<ValueRange> newValues) {
+    Operation *op, ArrayRef<SmallVector<Value, 1>> newValues) {
   assert(op->getNumResults() == newValues.size() &&
          "incorrect # of replacement values");
   LLVM_DEBUG({
@@ -1667,7 +1663,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
     impl->logger.startLine()
         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
+  SmallVector<ValueVector> nullRepls(op->getNumResults(), ValueVector());
   impl->notifyOpReplaced(op, nullRepls);
 }
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index e325003f5384c..764d8bd8575f5 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1284,7 +1284,8 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
     ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
     SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
     SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
-    SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+    SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value, 1>> r7,
+    ArrayRef<SmallVector<Value, 1>> r8, Value v, ValueRange vr,
     ArrayRef<Value> ar) {
   rewriter.replaceOpWithMultiple(op, r1);
   rewriter.replaceOpWithMultiple(op, r2);
@@ -1292,6 +1293,8 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   rewriter.replaceOpWithMultiple(op, r4);
   rewriter.replaceOpWithMultiple(op, r5);
   rewriter.replaceOpWithMultiple(op, r6);
+  rewriter.replaceOpWithMultiple(op, r7);
+  rewriter.replaceOpWithMultiple(op, r8);
   rewriter.replaceOpWithMultiple(op, {vr});
   rewriter.replaceOpWithMultiple(op, {ar});
   rewriter.replaceOpWithMultiple(op, {{v}});

>From 1b6eddfcd96bb749c05a6ab927c3ccd666b6d984 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 25 Mar 2025 14:40:51 +0100
Subject: [PATCH 5/5] use begin/end

---
 mlir/include/mlir/Transforms/DialectConversion.h | 8 ++------
 1 file changed, 2 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e8ee9fecf711d..e4a785eaaf855 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -906,9 +906,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
     // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
     // does not copy the replacements vector.
     auto vals = llvm::map_to_vector(newValues, [](const auto &r) {
-      // Note: Create intermediate ValueRange because SmallVector<Value, 1>
-      // is not constructible from SmallVector<Value>.
-      return SmallVector<Value, 1>(ValueRange(r));
+      return SmallVector<Value, 1>(std::begin(r), std::end(r));
     });
     replaceOpWithMultiple(op, ArrayRef(vals));
   }
@@ -919,9 +917,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
     // Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
     // does not copy the replacements vector.
     auto vals = llvm::map_to_vector(newValues, [](const RangeT &r) {
-      // Note: Create intermediate ValueRange because SmallVector<Value, 1>
-      // is not constructible from SmallVector<Value>.
-      return SmallVector<Value, 1>(ValueRange(r));
+      return SmallVector<Value, 1>(std::begin(r), std::end(r));
     });
     replaceOpWithMultiple(op, ArrayRef(vals));
   }



More information about the Mlir-commits mailing list