[Mlir-commits] [mlir] [mlir][Transforms] Add 1:N `matchAndRewrite` overload (PR #116470)

Matthias Springer llvmlistbot at llvm.org
Fri Nov 22 20:02:09 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116470

>From 8a7a0421e5d346d428a665d8ca2f31a211af308d Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 12 Nov 2024 05:14:43 +0100
Subject: [PATCH] replace with multiple
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Apply suggestions from code review

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

address comments

[WIP] 1:N conversion pattern

update test cases

Update mlir/lib/Transforms/Utils/DialectConversion.cpp

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

Update mlir/lib/Transforms/Utils/DialectConversion.cpp

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

address comments

rollback unresolved materializations properly
---
 .../mlir/Conversion/LLVMCommon/Pattern.h      |  35 +-
 .../mlir/Transforms/DialectConversion.h       |  69 ++++
 .../Transforms/DecomposeCallGraphTypes.cpp    |  56 +---
 .../Func/Transforms/FuncConversions.cpp       |   5 +-
 .../Transforms/StructuralTypeConversions.cpp  | 106 +++---
 .../Transforms/SparseTensorCodegen.cpp        | 114 ++++---
 .../Transforms/Utils/SparseTensorDescriptor.h |  16 +-
 .../Transforms/Utils/DialectConversion.cpp    | 305 ++++++++++++------
 .../decompose-call-graph-types.mlir           |  38 +--
 mlir/test/Transforms/test-legalizer.mlir      |  11 +
 mlir/test/lib/Dialect/Test/TestOps.td         |   5 +
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  84 ++++-
 12 files changed, 509 insertions(+), 335 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f3bf5b66398e09..86ea87b55af1cd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -143,6 +143,8 @@ template <typename SourceOp>
 class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
 
   explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
                                   PatternBenefit benefit = 1)
@@ -153,8 +155,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   /// Wrappers around the RewritePattern methods that pass the derived op type.
   void rewrite(Operation *op, ArrayRef<Value> operands,
                ConversionPatternRewriter &rewriter) const final {
-    rewrite(cast<SourceOp>(op), OpAdaptor(operands, cast<SourceOp>(op)),
-            rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
   }
   LogicalResult match(Operation *op) const final {
     return match(cast<SourceOp>(op));
@@ -162,8 +169,15 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    return matchAndRewrite(cast<SourceOp>(op),
-                           OpAdaptor(operands, cast<SourceOp>(op)), rewriter);
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
+  }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
   }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
@@ -175,6 +189,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override rewrite or matchAndRewrite");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -183,6 +203,13 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConvertToLLVMPattern::match;
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index de47765006f81e..9abb6eedb9d912 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -537,8 +537,15 @@ class ConversionPattern : public RewritePattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("unimplemented rewrite");
   }
+  virtual void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
   /// Hook for derived classes to implement combined matching and rewriting.
+  /// This overload supports only 1:1 replacements. The 1:N overload is called
+  /// by the driver. By default, it calls this 1:1 overload or reports a fatal
+  /// error if 1:N replacements were found.
   virtual LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
@@ -548,6 +555,14 @@ class ConversionPattern : public RewritePattern {
     return success();
   }
 
+  /// Hook for derived classes to implement combined matching and rewriting.
+  /// This overload supports 1:N replacements.
+  virtual LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
+
   /// Attempt to match and rewrite the IR root at the specified operation.
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const final;
@@ -574,6 +589,15 @@ class ConversionPattern : public RewritePattern {
       : RewritePattern(std::forward<Args>(args)...),
         typeConverter(&typeConverter) {}
 
+  /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
+  /// try to extract the single value of each range to construct a the inputs
+  /// for a 1:1 adaptor.
+  ///
+  /// This function produces a fatal error if at least one range has 0 or
+  /// more than 1 value: "pattern 'name' does not support 1:N conversion"
+  SmallVector<Value>
+  getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
+
 protected:
   /// An optional type converter for use by this pattern.
   const TypeConverter *typeConverter = nullptr;
@@ -589,6 +613,8 @@ template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
   using OpAdaptor = typename SourceOp::Adaptor;
+  using OneToNOpAdaptor =
+      typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
 
   OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
@@ -607,12 +633,24 @@ class OpConversionPattern : public ConversionPattern {
     auto sourceOp = cast<SourceOp>(op);
     rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     auto sourceOp = cast<SourceOp>(op);
     return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    auto sourceOp = cast<SourceOp>(op);
+    return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
+                           rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -623,6 +661,12 @@ class OpConversionPattern : public ConversionPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                       ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const {
@@ -631,6 +675,13 @@ class OpConversionPattern : public ConversionPattern {
     rewrite(op, adaptor, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const {
+    SmallVector<Value> oneToOneOperands =
+        getOneToOneAdaptorOperands(adaptor.getOperands());
+    return matchAndRewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
@@ -656,11 +707,20 @@ class OpInterfaceConversionPattern : public ConversionPattern {
                ConversionPatternRewriter &rewriter) const final {
     rewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  void rewrite(Operation *op, ArrayRef<ValueRange> operands,
+               ConversionPatternRewriter &rewriter) const final {
+    rewrite(cast<SourceOp>(op), operands, rewriter);
+  }
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
     return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
   }
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+  }
 
   /// Rewrite and Match methods that operate on the SourceOp type. These must be
   /// overridden by the derived pattern class.
@@ -668,6 +728,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
                        ConversionPatternRewriter &rewriter) const {
     llvm_unreachable("must override matchAndRewrite or a rewrite method");
   }
+  virtual void rewrite(SourceOp op, ArrayRef<ValueRange> operands,
+                       ConversionPatternRewriter &rewriter) const {
+    rewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
   virtual LogicalResult
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const {
@@ -676,6 +740,11 @@ class OpInterfaceConversionPattern : public ConversionPattern {
     rewrite(op, operands, rewriter);
     return success();
   }
+  virtual LogicalResult
+  matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const {
+    return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+  }
 
 private:
   using ConversionPattern::matchAndRewrite;
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index a08764326a80b6..03be00328bda33 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -13,40 +13,6 @@
 using namespace mlir;
 using namespace mlir::func;
 
-//===----------------------------------------------------------------------===//
-// Helper functions
-//===----------------------------------------------------------------------===//
-
-/// If the given value can be decomposed with the type converter, decompose it.
-/// Otherwise, return the given value.
-// TODO: Value decomposition should happen automatically through a 1:N adaptor.
-// This function will disappear when the 1:1 and 1:N drivers are merged.
-static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
-                                         Value value,
-                                         const TypeConverter *converter) {
-  // Try to convert the given value's type. If that fails, just return the
-  // given value.
-  SmallVector<Type> convertedTypes;
-  if (failed(converter->convertType(value.getType(), convertedTypes)))
-    return {value};
-  if (convertedTypes.empty())
-    return {};
-
-  // If the given value's type is already legal, just return the given value.
-  TypeRange convertedTypeRange(convertedTypes);
-  if (convertedTypeRange == TypeRange(value.getType()))
-    return {value};
-
-  // Try to materialize a target conversion. If the materialization did not
-  // produce values of the requested type, the materialization failed. Just
-  // return the given value in that case.
-  SmallVector<Value> result = converter->materializeTargetConversion(
-      builder, loc, convertedTypeRange, value);
-  if (result.empty())
-    return {value};
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // DecomposeCallGraphTypesForFuncArgs
 //===----------------------------------------------------------------------===//
@@ -102,16 +68,11 @@ struct DecomposeCallGraphTypesForReturnOp
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
     return success();
   }
@@ -128,18 +89,13 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(CallOp op, OpAdaptor adaptor,
+  matchAndRewrite(CallOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
 
     // Create the operands list of the new `CallOp`.
     SmallVector<Value, 2> newOperands;
-    for (Value operand : adaptor.getOperands()) {
-      // TODO: We can directly take the values from the adaptor once this is a
-      // 1:N conversion pattern.
-      llvm::append_range(newOperands,
-                         decomposeValue(rewriter, operand.getLoc(), operand,
-                                        getTypeConverter()));
-    }
+    for (ValueRange operand : adaptor.getOperands())
+      llvm::append_range(newOperands, operand);
 
     // Create the new result types for the new `CallOp` and track the number of
     // replacement types for each original op result.
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..d81f822f7d4b51 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -21,7 +21,7 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
 
   /// Hook for derived classes to implement combined matching and rewriting.
   LogicalResult
-  matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
+  matchAndRewrite(CallOp callOp, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Convert the original function results.
     SmallVector<Type, 1> convertedResults;
@@ -37,7 +37,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
     // Substitute with the new result types from the corresponding FuncType
     // conversion.
     rewriter.replaceOpWithNewOp<CallOp>(
-        callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+        callOp, callOp.getCallee(), convertedResults,
+        getOneToOneAdaptorOperands(adaptor.getOperands()));
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 93a78056db1944..c0589044c26ecb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -16,20 +16,18 @@ using namespace mlir::scf;
 
 namespace {
 
-// Unpacks the single unrealized_conversion_cast using the list of inputs
-// e.g., return [%b, %c, %d] for %a = unrealized_conversion_cast(%b, %c, %d)
-static void unpackUnrealizedConversionCast(Value v,
-                                           SmallVectorImpl<Value> &unpacked) {
-  if (auto cast =
-          dyn_cast_or_null<UnrealizedConversionCastOp>(v.getDefiningOp())) {
-    if (cast.getInputs().size() != 1) {
-      // 1 : N type conversion.
-      unpacked.append(cast.getInputs().begin(), cast.getInputs().end());
-      return;
-    }
-  }
-  // 1 : 1 type conversion.
-  unpacked.push_back(v);
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
 }
 
 // CRTP
@@ -40,19 +38,21 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
 public:
   using OpConversionPattern<SourceOp>::typeConverter;
   using OpConversionPattern<SourceOp>::OpConversionPattern;
-  using OpAdaptor = typename OpConversionPattern<SourceOp>::OpAdaptor;
+  using OneToNOpAdaptor =
+      typename OpConversionPattern<SourceOp>::OneToNOpAdaptor;
 
   //
   // Derived classes should provide the following method which performs the
   // actual conversion. It should return std::nullopt upon conversion failure
   // and return the converted operation upon success.
   //
-  // std::optional<SourceOp> convertSourceOp(SourceOp op, OpAdaptor adaptor,
-  //                                    ConversionPatternRewriter &rewriter,
-  //                                    TypeRange dstTypes) const;
+  // std::optional<SourceOp> convertSourceOp(
+  //     SourceOp op, OneToNOpAdaptor adaptor,
+  //     ConversionPatternRewriter &rewriter,
+  //     TypeRange dstTypes) const;
 
   LogicalResult
-  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
+  matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SmallVector<Type> dstTypes;
     SmallVector<unsigned> offsets;
@@ -73,28 +73,15 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
       return rewriter.notifyMatchFailure(op, "could not convert operation");
 
     // Packs the return value.
-    SmallVector<Value> packedRets;
+    SmallVector<ValueRange> packedRets;
     for (unsigned i = 1, e = offsets.size(); i < e; i++) {
       unsigned start = offsets[i - 1], end = offsets[i];
       unsigned len = end - start;
       ValueRange mappedValue = newOp->getResults().slice(start, len);
-      if (len != 1) {
-        // 1 : N type conversion.
-        Type origType = op.getResultTypes()[i - 1];
-        Value mat = typeConverter->materializeSourceConversion(
-            rewriter, op.getLoc(), origType, mappedValue);
-        if (!mat) {
-          return rewriter.notifyMatchFailure(
-              op, "Failed to materialize 1:N type conversion");
-        }
-        packedRets.push_back(mat);
-      } else {
-        // 1 : 1 type conversion.
-        packedRets.push_back(mappedValue.front());
-      }
+      packedRets.push_back(mappedValue);
     }
 
-    rewriter.replaceOp(op, packedRets);
+    rewriter.replaceOpWithMultiple(op, packedRets);
     return success();
   }
 };
@@ -105,7 +92,7 @@ class ConvertForOpTypes
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
   // The callback required by CRTP.
-  std::optional<ForOp> convertSourceOp(ForOp op, OpAdaptor adaptor,
+  std::optional<ForOp> convertSourceOp(ForOp op, OneToNOpAdaptor adaptor,
                                        ConversionPatternRewriter &rewriter,
                                        TypeRange dstTypes) const {
     // Create a empty new op and inline the regions from the old op.
@@ -129,16 +116,13 @@ class ConvertForOpTypes
     if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
       return std::nullopt;
 
-    // Unpacked the iteration arguments.
-    SmallVector<Value> flatArgs;
-    for (Value arg : adaptor.getInitArgs())
-      unpackUnrealizedConversionCast(arg, flatArgs);
-
     // We can not do clone as the number of result types after conversion
     // might be different.
-    ForOp newOp = rewriter.create<ForOp>(op.getLoc(), adaptor.getLowerBound(),
-                                         adaptor.getUpperBound(),
-                                         adaptor.getStep(), flatArgs);
+    ForOp newOp = rewriter.create<ForOp>(
+        op.getLoc(), getSingleValue(adaptor.getLowerBound()),
+        getSingleValue(adaptor.getUpperBound()),
+        getSingleValue(adaptor.getStep()),
+        flattenValues(adaptor.getInitArgs()));
 
     // Reserve whatever attributes in the original op.
     newOp->setAttrs(op->getAttrs());
@@ -160,12 +144,12 @@ class ConvertIfOpTypes
 public:
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
-  std::optional<IfOp> convertSourceOp(IfOp op, OpAdaptor adaptor,
+  std::optional<IfOp> convertSourceOp(IfOp op, OneToNOpAdaptor adaptor,
                                       ConversionPatternRewriter &rewriter,
                                       TypeRange dstTypes) const {
 
-    IfOp newOp = rewriter.create<IfOp>(op.getLoc(), dstTypes,
-                                       adaptor.getCondition(), true);
+    IfOp newOp = rewriter.create<IfOp>(
+        op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
     newOp->setAttrs(op->getAttrs());
 
     // We do not need the empty blocks created by rewriter.
@@ -189,15 +173,11 @@ class ConvertWhileOpTypes
 public:
   using Structural1ToNConversionPattern::Structural1ToNConversionPattern;
 
-  std::optional<WhileOp> convertSourceOp(WhileOp op, OpAdaptor adaptor,
+  std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor,
                                          ConversionPatternRewriter &rewriter,
                                          TypeRange dstTypes) const {
-    // Unpacked the iteration arguments.
-    SmallVector<Value> flatArgs;
-    for (Value arg : adaptor.getOperands())
-      unpackUnrealizedConversionCast(arg, flatArgs);
-
-    auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, flatArgs);
+    auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes,
+                                          flattenValues(adaptor.getOperands()));
 
     for (auto i : {0u, 1u}) {
       if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter)))
@@ -218,13 +198,10 @@ class ConvertYieldOpTypes : public OpConversionPattern<scf::YieldOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
+  matchAndRewrite(scf::YieldOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> unpackedYield;
-    for (Value operand : adaptor.getOperands())
-      unpackUnrealizedConversionCast(operand, unpackedYield);
-
-    rewriter.replaceOpWithNewOp<scf::YieldOp>(op, unpackedYield);
+    rewriter.replaceOpWithNewOp<scf::YieldOp>(
+        op, flattenValues(adaptor.getOperands()));
     return success();
   }
 };
@@ -235,13 +212,10 @@ class ConvertConditionOpTypes : public OpConversionPattern<ConditionOp> {
 public:
   using OpConversionPattern<ConditionOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ConditionOp op, OpAdaptor adaptor,
+  matchAndRewrite(ConditionOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> unpackedYield;
-    for (Value operand : adaptor.getOperands())
-      unpackUnrealizedConversionCast(operand, unpackedYield);
-
-    rewriter.modifyOpInPlace(op, [&]() { op->setOperands(unpackedYield); });
+    rewriter.modifyOpInPlace(
+        op, [&]() { op->setOperands(flattenValues(adaptor.getOperands())); });
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 25fca49cb0154a..9184224e7aef4b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -39,25 +39,18 @@ using namespace mlir::sparse_tensor;
 // Helper methods.
 //===----------------------------------------------------------------------===//
 
-/// Flattens a list of operands that may contain sparse tensors.
-static void flattenOperands(ValueRange operands,
-                            SmallVectorImpl<Value> &flattened) {
-  // In case of
-  // sparse_tensor, c, sparse_tensor
-  // ==>
-  // memref ..., c, memref ...
-  for (auto operand : operands) {
-    if (getSparseTensorEncoding(operand.getType())) {
-      auto tuple = getTuple(operand);
-      // An unrealized_conversion_cast will be inserted by type converter to
-      // inter-mix the gap between 1:N conversion between sparse tensors and
-      // fields. In this case, take the operands in the cast and replace the
-      // sparse tensor output with the flattened type array.
-      flattened.append(tuple.getOperands().begin(), tuple.getOperands().end());
-    } else {
-      flattened.push_back(operand);
-    }
-  }
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+  SmallVector<Value> result;
+  for (const auto &vals : values)
+    llvm::append_range(result, vals);
+  return result;
+}
+
+/// Assert that the given value range contains a single value and return it.
+static Value getSingleValue(ValueRange values) {
+  assert(values.size() == 1 && "expected single value");
+  return values.front();
 }
 
 /// Generates a load with proper `index` typing.
@@ -567,12 +560,11 @@ class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+  matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> flattened;
-    flattenOperands(adaptor.getOperands(), flattened);
     // Create a return with the flattened value extracted from sparse tensors.
-    rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
+    rewriter.replaceOpWithNewOp<func::ReturnOp>(
+        op, flattenValues(adaptor.getOperands()));
     return success();
   }
 };
@@ -583,7 +575,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
   // The default CallOp converter can not handle 1:N type conversion.
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+  matchAndRewrite(func::CallOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     // In case of:
@@ -596,10 +588,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
       return failure();
 
     // (1) Generates new call with flattened return value.
-    SmallVector<Value> flattened;
-    flattenOperands(adaptor.getOperands(), flattened);
-    auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
-                                                 finalRetTy, flattened);
+    auto newCall = rewriter.create<func::CallOp>(
+        loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
     // (2) Gather sparse tensor returns.
     SmallVector<SmallVector<Value>> packedResultVals;
     // Tracks the offset of current return value (of the original call)
@@ -643,7 +633,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(LvlOp op, OpAdaptor adaptor,
+  matchAndRewrite(LvlOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     std::optional<int64_t> lvl = op.getConstantLvlIndex();
     RankedTensorType srcType = op.getSource().getType();
@@ -662,7 +652,7 @@ class SparseLvlOpConverter : public OpConversionPattern<LvlOp> {
 struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ReorderCOOOp op, ReorderCOOOpAdaptor adaptor,
+  matchAndRewrite(ReorderCOOOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
@@ -693,7 +683,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
 
     // Since we do in-place sorting, the destinate tensor will have the same set
     // of memrefs as the source tensor.
-    rewriter.replaceOp(op, adaptor.getInputCoo());
+    rewriter.replaceOpWithMultiple(op, {adaptor.getInputCoo()});
     return success();
   }
 };
@@ -703,7 +693,8 @@ class SparseSliceGetterOpConverter : public OpConversionPattern<Op> {
 public:
   using OpConversionPattern<Op>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+  matchAndRewrite(Op op,
+                  typename OpConversionPattern<Op>::OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Simply lowers to specifer.get <field> operation.
     auto desc = getDescriptorFromTensorTuple(adaptor.getSlice(),
@@ -721,14 +712,14 @@ class SparseCastConverter : public OpConversionPattern<tensor::CastOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::CastOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Only rewrite identically annotated source/dest.
     auto encDst = getSparseTensorEncoding(op.getType());
     auto encSrc = getSparseTensorEncoding(op.getSource().getType());
     if (!encDst || encDst != encSrc)
       return failure();
-    rewriter.replaceOp(op, adaptor.getOperands());
+    rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
     return success();
   }
 };
@@ -737,10 +728,10 @@ class SparseReMapConverter : public OpConversionPattern<ReinterpretMapOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ReinterpretMapOp op, OpAdaptor adaptor,
+  matchAndRewrite(ReinterpretMapOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Simply fold the operation.
-    rewriter.replaceOp(op, adaptor.getSource());
+    rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
     return success();
   }
 };
@@ -756,7 +747,7 @@ class SparseTensorAllocConverter
         enableBufferInitialization(enableInit) {}
 
   LogicalResult
-  matchAndRewrite(bufferization::AllocTensorOp op, OpAdaptor adaptor,
+  matchAndRewrite(bufferization::AllocTensorOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     const auto resType = getSparseTensorType(op);
     if (!resType.hasEncoding())
@@ -791,7 +782,8 @@ class SparseTensorAllocConverter
     }
     // Level size equals to dimension size since lvl2dim map is an identity map.
     SmallVector<Value> lvlSizesValues;
-    createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
+    createDimSizes(rewriter, loc, resType,
+                   flattenValues(adaptor.getDynamicSizes()),
                    /*dimSizesValues=*/lvlSizesValues);
 
     // Construct allocation for each field.
@@ -861,7 +853,7 @@ class SparseTensorDeallocConverter
         createDeallocs(createDeallocs) {}
 
   LogicalResult
-  matchAndRewrite(bufferization::DeallocTensorOp op, OpAdaptor adaptor,
+  matchAndRewrite(bufferization::DeallocTensorOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto enc = getSparseTensorEncoding(op.getTensor().getType());
     if (!enc)
@@ -892,7 +884,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(LoadOp op, OpAdaptor adaptor,
+  matchAndRewrite(LoadOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Prepare descriptor.
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
@@ -911,7 +903,7 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
+  matchAndRewrite(ExpandOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     if (!getSparseTensorEncoding(op.getTensor().getType()))
       return failure();
@@ -963,16 +955,16 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(CompressOp op, OpAdaptor adaptor,
+  matchAndRewrite(CompressOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op->getLoc();
     SmallVector<Value> fields;
     auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
                                                 op.getTensor().getType());
-    Value values = adaptor.getValues();
-    Value filled = adaptor.getFilled();
-    Value added = adaptor.getAdded();
-    Value count = adaptor.getCount();
+    Value values = getSingleValue(adaptor.getValues());
+    Value filled = getSingleValue(adaptor.getFilled());
+    Value added = getSingleValue(adaptor.getAdded());
+    Value count = getSingleValue(adaptor.getCount());
     const SparseTensorType dstType(desc.getRankedTensorType());
     Type eltType = dstType.getElementType();
 
@@ -1005,7 +997,8 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
     SmallVector<Type> flatSpTensorTps = llvm::to_vector(
         llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
-    params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+    SmallVector<Value> flatLvlCoords = flattenValues(adaptor.getLvlCoords());
+    params.append(flatLvlCoords.begin(), flatLvlCoords.end());
     params.push_back(crd);
     params.push_back(value);
     SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
@@ -1033,9 +1026,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::InsertOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto stt = getSparseTensorType(adaptor.getDest());
+    auto stt = getSparseTensorType(op.getDest());
     if (!stt.hasEncoding())
       return failure();
     assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
@@ -1045,8 +1038,9 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
         getDescriptorFromTensorTuple(adaptor.getDest(), op.getDest().getType());
     TypeRange flatSpTensorTps = desc.getFields().getTypes();
     SmallVector<Value> params = llvm::to_vector(desc.getFields());
-    params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
-    params.push_back(adaptor.getScalar());
+    SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
+    params.append(flatIndices.begin(), flatIndices.end());
+    params.push_back(getSingleValue(adaptor.getScalar()));
     SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
                                     params, /*genCall=*/true);
     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
@@ -1062,7 +1056,7 @@ class SparseToPositionsConverter : public OpConversionPattern<ToPositionsOp> {
   using OpAdaptor = typename ToPositionsOp::Adaptor;
   using OpConversionPattern<ToPositionsOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToPositionsOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToPositionsOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested position access with corresponding field.
     // The view is restricted to the actual size to ensure clients
@@ -1085,7 +1079,7 @@ class SparseToCoordinatesConverter
   using OpAdaptor = typename ToCoordinatesOp::Adaptor;
   using OpConversionPattern<ToCoordinatesOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToCoordinatesOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested coordinates access with corresponding field.
     // The view is restricted to the actual size to ensure clients
@@ -1111,7 +1105,7 @@ class SparseToCoordinatesBufferConverter
   using OpAdaptor = typename ToCoordinatesBufferOp::Adaptor;
   using OpConversionPattern<ToCoordinatesBufferOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToCoordinatesBufferOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested coordinates access with corresponding field.
     // The view is restricted to the actual size to ensure clients
@@ -1133,7 +1127,7 @@ class SparseToValuesConverter : public OpConversionPattern<ToValuesOp> {
   using OpAdaptor = typename ToValuesOp::Adaptor;
   using OpConversionPattern<ToValuesOp>::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ToValuesOp op, OpAdaptor adaptor,
+  matchAndRewrite(ToValuesOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Replace the requested values access with corresponding field.
     // The view is restricted to the actual size to ensure clients
@@ -1153,7 +1147,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
+  matchAndRewrite(ConvertOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType());
     SparseTensorEncodingAttr encSrc =
@@ -1173,7 +1167,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
     Type srcElemTp = op.getSource().getType().getElementType();
     // Fold the trivial cases.
     if (retElemTp == srcElemTp && encDst == encSrc) {
-      rewriter.replaceOp(op, adaptor.getSource());
+      rewriter.replaceOpWithMultiple(op, {adaptor.getSource()});
       return success();
     }
     //
@@ -1239,7 +1233,7 @@ class SparseExtractSliceConverter
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
+  matchAndRewrite(tensor::ExtractSliceOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
@@ -1296,7 +1290,7 @@ class SparseNumberOfEntriesConverter
 public:
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
-  matchAndRewrite(NumberOfEntriesOp op, OpAdaptor adaptor,
+  matchAndRewrite(NumberOfEntriesOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // Query memSizes for the actually stored values.
     // FIXME: the nse value computed in this way might be wrong when there is
@@ -1430,7 +1424,7 @@ struct SparseDisassembleOpConverter
       : OpConversionPattern(typeConverter, context) {}
 
   LogicalResult
-  matchAndRewrite(DisassembleOp op, OpAdaptor adaptor,
+  matchAndRewrite(DisassembleOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor(),
                                              op.getTensor().getType());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
index 89858546e37e1b..869c7864d75354 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.h
@@ -228,11 +228,6 @@ class MutSparseTensorDescriptor
   }
 };
 
-/// Returns the "tuple" value of the adapted tensor.
-inline UnrealizedConversionCastOp getTuple(Value tensor) {
-  return llvm::cast<UnrealizedConversionCastOp>(tensor.getDefiningOp());
-}
-
 /// Packs the given values as a "tuple" value.
 inline Value genTuple(OpBuilder &builder, Location loc, Type tp,
                       ValueRange values) {
@@ -246,16 +241,15 @@ inline Value genTuple(OpBuilder &builder, Location loc,
 }
 
 inline SparseTensorDescriptor
-getDescriptorFromTensorTuple(Value tensor, RankedTensorType type) {
-  auto tuple = getTuple(tensor);
-  return SparseTensorDescriptor(SparseTensorType(type), tuple.getInputs());
+getDescriptorFromTensorTuple(ValueRange adaptorValues, RankedTensorType type) {
+  return SparseTensorDescriptor(SparseTensorType(type), adaptorValues);
 }
 
 inline MutSparseTensorDescriptor
-getMutDescriptorFromTensorTuple(Value tensor, SmallVectorImpl<Value> &fields,
+getMutDescriptorFromTensorTuple(ValueRange adaptorValues,
+                                SmallVectorImpl<Value> &fields,
                                 RankedTensorType type) {
-  auto tuple = getTuple(tensor);
-  fields.assign(tuple.getInputs().begin(), tuple.getInputs().end());
+  fields.assign(adaptorValues.begin(), adaptorValues.end());
   return MutSparseTensorDescriptor(SparseTensorType(type), fields);
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5acd095da8e386..71a82cae86f5f2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -67,10 +67,6 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
-/// A list of replacement SSA values. Optimized for the common case of a single
-/// SSA value.
-using ReplacementValues = SmallVector<Value, 1>;
-
 namespace {
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -674,7 +670,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
                                    UnrealizedConversionCastOp op,
                                    const TypeConverter *converter,
-                                   MaterializationKind kind, Type originalType);
+                                   MaterializationKind kind, Type originalType,
+                                   Value valueToMap);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,6 +705,10 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// The original type of the SSA value. Only used for target
   /// materializations.
   Type originalType;
+
+  /// The value in the conversion value mapping that is being replaced by the
+  /// results of this unresolved materialization.
+  Value valueToMap;
 };
 } // namespace
 
@@ -764,7 +765,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   LogicalResult remapValues(StringRef valueDiagTag,
                             std::optional<Location> inputLoc,
                             PatternRewriter &rewriter, ValueRange values,
-                            SmallVectorImpl<Value> &remapped);
+                            SmallVector<SmallVector<Value>> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
@@ -798,13 +799,31 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // Materializations
   //===--------------------------------------------------------------------===//
 
-  /// Build an unresolved materialization operation given an output type and set
-  /// of input operands.
-  Value buildUnresolvedMaterialization(MaterializationKind kind,
-                                       OpBuilder::InsertPoint ip, Location loc,
-                                       ValueRange inputs, Type outputType,
-                                       Type originalType,
-                                       const TypeConverter *converter);
+  /// Build an unresolved materialization operation given a range of output
+  /// types and a list of input operands. Returns the inputs if they their
+  /// types match the output types.
+  ///
+  /// If a cast op was built, it can optionally be returned with the `castOp`
+  /// output argument.
+  ///
+  /// If `valueToMap` is set to a non-null Value, then that value is mapped to
+  /// the results of the unresolved materialization in the conversion value
+  /// mapping.
+  ValueRange buildUnresolvedMaterialization(
+      MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+      Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+      Type originalType, const TypeConverter *converter,
+      UnrealizedConversionCastOp *castOp = nullptr);
+  Value buildUnresolvedMaterialization(
+      MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+      Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
+      const TypeConverter *converter,
+      UnrealizedConversionCastOp *castOp = nullptr) {
+    return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
+                                          TypeRange(outputType), originalType,
+                                          converter, castOp)
+        .front();
+  }
 
   /// Build an N:1 materialization for the given original value that was
   /// replaced with the given replacement values.
@@ -830,6 +849,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value findOrBuildReplacementValue(Value value,
                                     const TypeConverter *converter);
 
+  /// Unpack an N:1 materialization and return the inputs of the
+  /// materialization. This function unpacks only those materializations that
+  /// were built with `insertNTo1Materialization`.
+  ///
+  /// This is a workaround around incomplete 1:N support in the dialect
+  /// conversion driver. It allows us to write 1:N conversion patterns while
+  /// 1:N support is still missing in the conversion value mapping. This
+  /// function will be deleted when full 1:N support has been added.
+  SmallVector<Value> unpackNTo1Materialization(Value value);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -839,7 +868,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                OpBuilder::InsertPoint previous) override;
 
   /// Notifies that an op is about to be replaced with the given values.
-  void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
+  void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
@@ -932,6 +961,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
       unresolvedMaterializations;
 
+  /// A set of all N:1 materializations that were added to work around
+  /// incomplete 1:N support in the dialect conversion driver.
+  DenseSet<UnrealizedConversionCastOp> nTo1TempMaterializations;
+
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
@@ -1054,20 +1087,21 @@ void CreateOperationRewrite::rollback() {
 
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
-    const TypeConverter *converter, MaterializationKind kind, Type originalType)
+    const TypeConverter *converter, MaterializationKind kind, Type originalType,
+    Value valueToMap)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
-      converterAndKind(converter, kind), originalType(originalType) {
+      converterAndKind(converter, kind), originalType(originalType),
+      valueToMap(valueToMap) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   rewriterImpl.unresolvedMaterializations[op] = this;
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (getMaterializationKind() == MaterializationKind::Target) {
-    for (Value input : op->getOperands())
-      rewriterImpl.mapping.erase(input);
-  }
+  if (valueToMap)
+    rewriterImpl.mapping.erase(valueToMap);
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
+  rewriterImpl.nTo1TempMaterializations.erase(getOperation());
   op->erase();
 }
 
@@ -1113,7 +1147,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
 LogicalResult ConversionPatternRewriterImpl::remapValues(
     StringRef valueDiagTag, std::optional<Location> inputLoc,
     PatternRewriter &rewriter, ValueRange values,
-    SmallVectorImpl<Value> &remapped) {
+    SmallVector<SmallVector<Value>> &remapped) {
   remapped.reserve(llvm::size(values));
 
   for (const auto &it : llvm::enumerate(values)) {
@@ -1121,11 +1155,18 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type origType = operand.getType();
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
+    // Find the most recently mapped value. Unpack all temporary N:1
+    // materializations. Such conversions are a workaround around missing
+    // 1:N support in the ConversionValueMapping. (The conversion patterns
+    // already support 1:N replacements.)
+    Value repl = mapping.lookupOrDefault(operand);
+    SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
+
     if (!currentTypeConverter) {
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
       // pass through the most recently mapped value.
-      remapped.push_back(mapping.lookupOrDefault(operand));
+      remapped.push_back(std::move(unpacked));
       continue;
     }
 
@@ -1139,15 +1180,29 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       return failure();
     }
 
+    // If a type is converted to 0 types, there is nothing to do.
+    if (legalTypes.empty()) {
+      remapped.push_back({});
+      continue;
+    }
+
     if (legalTypes.size() != 1) {
-      // TODO: Parts of the dialect conversion infrastructure do not support
-      // 1->N type conversions yet. Therefore, if a type is converted to 0 or
-      // multiple types, the only thing that we can do for now is passing
-      // through the most recently mapped value. Fixing this requires
-      // improvements to the `ConversionValueMapping` (to be able to store 1:N
-      // mappings) and to the `ConversionPattern` adaptor handling (to be able
-      // to pass multiple remapped values for a single operand to the adaptor).
-      remapped.push_back(mapping.lookupOrDefault(operand));
+      // TODO: This is a 1:N conversion. The conversion value mapping does not
+      // store such materializations yet. If the types of the most recently
+      // mapped values do not match, build a target materialization.
+      if (TypeRange(unpacked) == legalTypes) {
+        remapped.push_back(std::move(unpacked));
+        continue;
+      }
+
+      // Insert a target materialization if the current pattern expects
+      // different legalized types.
+      ValueRange targetMat = buildUnresolvedMaterialization(
+          MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+          /*valueToMap=*/Value(),
+          /*inputs=*/unpacked, /*outputType=*/legalTypes,
+          /*originalType=*/origType, currentTypeConverter);
+      remapped.push_back(targetMat);
       continue;
     }
 
@@ -1159,16 +1214,15 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     if (newOperand.getType() != desiredType) {
       // If the looked up value's type does not have the desired type, it means
       // that the value was replaced with a value of different type and no
-      // source materialization was created yet.
+      // target materialization was created yet.
       Value castValue = buildUnresolvedMaterialization(
           MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc,
-          /*inputs=*/newOperand, /*outputType=*/desiredType,
+          operandLoc, /*valueToMap=*/newOperand,
+          /*inputs=*/unpacked, /*outputType=*/desiredType,
           /*originalType=*/origType, currentTypeConverter);
-      mapping.map(newOperand, castValue);
       newOperand = castValue;
     }
-    remapped.push_back(newOperand);
+    remapped.push_back({newOperand});
   }
   return success();
 }
@@ -1276,12 +1330,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (!inputMap) {
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
-      Value repl = buildUnresolvedMaterialization(
+      buildUnresolvedMaterialization(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*inputs=*/ValueRange(),
+          /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
-      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1323,26 +1376,38 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
 /// Build an unresolved materialization operation given an output type and set
 /// of input operands.
-Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, Type outputType, Type originalType,
-    const TypeConverter *converter) {
+    Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+    Type originalType, const TypeConverter *converter,
+    UnrealizedConversionCastOp *castOp) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
 
   // Avoid materializing an unnecessary cast.
-  if (inputs.size() == 1 && inputs.front().getType() == outputType)
-    return inputs.front();
+  if (TypeRange(inputs) == outputTypes) {
+    if (valueToMap) {
+      assert(inputs.size() == 1 && "1:N mapping is not supported");
+      mapping.map(valueToMap, inputs.front());
+    }
+    return inputs;
+  }
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
-  OpBuilder builder(outputType.getContext());
+  OpBuilder builder(outputTypes.front().getContext());
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
-      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+      builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
+  if (valueToMap) {
+    assert(outputTypes.size() == 1 && "1:N mapping is not supported");
+    mapping.map(valueToMap, convertOp.getResult(0));
+  }
+  if (castOp)
+    *castOp = convertOp;
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType);
-  return convertOp.getResult(0);
+                                                  originalType, valueToMap);
+  return convertOp.getResults();
 }
 
 void ConversionPatternRewriterImpl::insertNTo1Materialization(
@@ -1350,11 +1415,13 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
     Value originalValue, const TypeConverter *converter) {
   // Insert argument materialization back to the original type.
   Type originalType = originalValue.getType();
-  Value argMat =
-      buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
-                                     /*inputs=*/replacements, originalType,
-                                     /*originalType=*/Type(), converter);
-  mapping.map(originalValue, argMat);
+  UnrealizedConversionCastOp argCastOp;
+  Value argMat = buildUnresolvedMaterialization(
+      MaterializationKind::Argument, ip, loc,
+      /*valueToMap=*/originalValue, /*inputs=*/replacements, originalType,
+      /*originalType=*/Type(), converter, &argCastOp);
+  if (argCastOp)
+    nTo1TempMaterializations.insert(argCastOp);
 
   // Insert target materialization to the legalized type.
   Type legalOutputType;
@@ -1370,11 +1437,14 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
     legalOutputType = replacements[0].getType();
   }
   if (legalOutputType && legalOutputType != originalType) {
-    Value targetMat = buildUnresolvedMaterialization(
+    UnrealizedConversionCastOp targetCastOp;
+    buildUnresolvedMaterialization(
         MaterializationKind::Target, computeInsertPoint(argMat), loc,
+        /*valueToMap=*/argMat,
         /*inputs=*/argMat, /*outputType=*/legalOutputType,
-        /*originalType=*/originalType, converter);
-    mapping.map(argMat, targetMat);
+        /*originalType=*/originalType, converter, &targetCastOp);
+    if (targetCastOp)
+      nTo1TempMaterializations.insert(targetCastOp);
   }
 }
 
@@ -1408,12 +1478,34 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   }
   Value castValue = buildUnresolvedMaterialization(
       MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*inputs=*/repl, /*outputType=*/value.getType(),
+      /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
       /*originalType=*/Type(), converter);
   mapping.map(value, castValue);
   return castValue;
 }
 
+SmallVector<Value>
+ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
+  // Unpack unrealized_conversion_cast ops that were inserted as a N:1
+  // workaround.
+  auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
+  if (!castOp)
+    return {value};
+  if (!nTo1TempMaterializations.contains(castOp))
+    return {value};
+  assert(castOp->getNumResults() == 1 && "expected single result");
+
+  SmallVector<Value> result;
+  for (Value v : castOp.getOperands()) {
+    // Keep unpacking if possible. This is needed because during block
+    // signature conversions and 1:N op replacements, the driver may have
+    // inserted two materializations back-to-back: first an argument
+    // materialization, then a target materialization.
+    llvm::append_range(result, unpackNTo1Materialization(v));
+  }
+  return result;
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1438,7 +1530,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
 }
 
 void ConversionPatternRewriterImpl::notifyOpReplaced(
-    Operation *op, ArrayRef<ReplacementValues> newValues) {
+    Operation *op, ArrayRef<ValueRange> newValues) {
   assert(newValues.size() == op->getNumResults());
   assert(!ignoredOps.contains(op) && "operation was already replaced");
 
@@ -1450,8 +1542,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       isUnresolvedMaterialization = true;
 
   // Create mappings for each of the new result values.
-  for (auto [n, result] : llvm::zip_equal(newValues, op->getResults())) {
-    ReplacementValues repl = n;
+  for (auto [repl, result] : llvm::zip_equal(newValues, op->getResults())) {
     if (repl.empty()) {
       // This result was dropped and no replacement value was provided.
       if (isUnresolvedMaterialization) {
@@ -1461,12 +1552,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       }
 
       // Materialize a replacement value "out of thin air".
-      Value sourceMat = buildUnresolvedMaterialization(
+      buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
-          result.getLoc(), /*inputs=*/ValueRange(),
+          result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
-      repl.push_back(sourceMat);
+      continue;
     } else {
       // Make sure that the user does not mess with unresolved materializations
       // that were inserted by the conversion driver. We keep track of these
@@ -1568,10 +1659,9 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<ReplacementValues> newVals(newValues.size());
-  for (auto [index, val] : llvm::enumerate(newValues))
-    if (val)
-      newVals[index].push_back(val);
+  SmallVector<ValueRange> newVals;
+  for (int i = 0; i < newValues.size(); ++i)
+    newVals.push_back(newValues.slice(i, 1));
   impl->notifyOpReplaced(op, newVals);
 }
 
@@ -1583,10 +1673,7 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<ReplacementValues> newVals(newValues.size(), {});
-  for (auto [index, val] : llvm::enumerate(newValues))
-    llvm::append_range(newVals[index], val);
-  impl->notifyOpReplaced(op, newVals);
+  impl->notifyOpReplaced(op, newValues);
 }
 
 void ConversionPatternRewriter::eraseOp(Operation *op) {
@@ -1594,7 +1681,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
     impl->logger.startLine()
         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
   });
-  SmallVector<ReplacementValues> nullRepls(op->getNumResults(), {});
+  SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
   impl->notifyOpReplaced(op, nullRepls);
 }
 
@@ -1646,11 +1733,12 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
-  SmallVector<Value> remappedValues;
+  SmallVector<SmallVector<Value>> remappedValues;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
                                remappedValues)))
     return nullptr;
-  return remappedValues.front();
+  assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
+  return remappedValues.front().front();
 }
 
 LogicalResult
@@ -1658,8 +1746,15 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                                              SmallVectorImpl<Value> &results) {
   if (keys.empty())
     return success();
-  return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
-                           results);
+  SmallVector<SmallVector<Value>> remapped;
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+                               remapped)))
+    return failure();
+  for (const auto &values : remapped) {
+    assert(values.size() == 1 && "1:N conversion not supported");
+    results.push_back(values.front());
+  }
+  return success();
 }
 
 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1753,6 +1848,19 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
 // ConversionPattern
 //===----------------------------------------------------------------------===//
 
+SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+    ArrayRef<ValueRange> operands) const {
+  SmallVector<Value> oneToOneOperands;
+  oneToOneOperands.reserve(operands.size());
+  for (ValueRange operand : operands) {
+    if (operand.size() != 1)
+      llvm::report_fatal_error("pattern '" + getDebugName() +
+                               "' does not support 1:N conversion");
+    oneToOneOperands.push_back(operand.front());
+  }
+  return oneToOneOperands;
+}
+
 LogicalResult
 ConversionPattern::matchAndRewrite(Operation *op,
                                    PatternRewriter &rewriter) const {
@@ -1764,12 +1872,14 @@ ConversionPattern::matchAndRewrite(Operation *op,
                                              getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<Value, 4> operands;
+  SmallVector<SmallVector<Value>> remapped;
   if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
-                                      op->getOperands(), operands))) {
+                                      op->getOperands(), remapped))) {
     return failure();
   }
-  return matchAndRewrite(op, operands, dialectRewriter);
+  SmallVector<ValueRange> remappedAsRange =
+      llvm::to_vector_of<ValueRange>(remapped);
+  return matchAndRewrite(op, remappedAsRange, dialectRewriter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -2509,45 +2619,52 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
   assert(!op.use_empty() &&
          "expected that dead materializations have already been DCE'd");
   Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = op.getResultTypes()[0];
 
   // Try to materialize the conversion.
   if (const TypeConverter *converter = rewrite->getConverter()) {
     rewriter.setInsertionPoint(op);
-    Value newMaterialization;
+    SmallVector<Value> newMaterialization;
     switch (rewrite->getMaterializationKind()) {
-    case MaterializationKind::Argument:
+    case MaterializationKind::Argument: {
       // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
+      assert(op->getNumResults() == 1 && "expected single result");
+      Value argMat = converter->materializeArgumentConversion(
+          rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
+      if (argMat) {
+        newMaterialization.push_back(argMat);
         break;
+      }
+    }
       // If an argument materialization failed, fallback to trying a target
       // materialization.
       [[fallthrough]];
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands,
+          rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
           rewrite->getOriginalType());
       break;
     case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
+      assert(op->getNumResults() == 1 && "expected single result");
+      Value sourceMat = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
+      if (sourceMat)
+        newMaterialization.push_back(sourceMat);
       break;
     }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
+    if (!newMaterialization.empty()) {
+      assert(TypeRange(newMaterialization) == op.getResultTypes() &&
              "materialization callback produced value of incorrect type");
       rewriter.replaceOp(op, newMaterialization);
       return success();
     }
   }
 
-  InFlightDiagnostic diag =
-      op->emitError() << "failed to legalize unresolved materialization "
-                         "from ("
-                      << inputOperands.getTypes() << ") to (" << outputType
-                      << ") that remained live after conversion";
+  InFlightDiagnostic diag = op->emitError()
+                            << "failed to legalize unresolved materialization "
+                               "from ("
+                            << inputOperands.getTypes() << ") to ("
+                            << op.getResultTypes()
+                            << ") that remained live after conversion";
   diag.attachNote(op->getUsers().begin()->getLoc())
       << "see existing live user here: " << *op->getUsers().begin();
   return failure();
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index b8fad63eb4de67..4e641317ac2f3d 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -9,10 +9,7 @@
 // CHECK-LABEL:   func @identity(
 // CHECK-SAME:                   %[[ARG0:.*]]: i1,
 // CHECK-SAME:                   %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK:           %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK:           return %[[ARG0]], %[[ARG1]] : i1, i32
 // CHECK-12N-LABEL:   func @identity(
 // CHECK-12N-SAME:                   %[[ARG0:.*]]: i1,
 // CHECK-12N-SAME:                   %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -56,18 +53,7 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
 // CHECK-LABEL:   func @mixed_recursive_decomposition(
 // CHECK-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-SAME:                 %[[ARG1:.*]]: i2) -> (i1, i2) {
-// CHECK:           %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
-// CHECK:           %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
-// CHECK:           %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
-// CHECK:           %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
-// CHECK:           %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
-// CHECK:           %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 0 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
-// CHECK:           %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 1 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
-// CHECK:           %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) <{index = 0 : i32}> : (tuple<i1>) -> i1
-// CHECK:           %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) <{index = 2 : i32}> : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
-// CHECK:           %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) <{index = 0 : i32}> : (tuple<tuple<i2>>) -> tuple<i2>
-// CHECK:           %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) <{index = 0 : i32}> : (tuple<i2>) -> i2
-// CHECK:           return %[[V7]], %[[V10]] : i1, i2
+// CHECK:           return %[[ARG0]], %[[ARG1]] : i1, i2
 // CHECK-12N-LABEL:   func @mixed_recursive_decomposition(
 // CHECK-12N-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-12N-SAME:                 %[[ARG1:.*]]: i2) -> (i1, i2) {
@@ -87,14 +73,8 @@ func.func private @callee(tuple<i1, i32>) -> tuple<i1, i32>
 // CHECK-LABEL:   func @caller(
 // CHECK-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
-// CHECK:           %[[ARG_MATERIALIZED:.*]] = "test.make_tuple"(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> tuple<i1, i32>
-// CHECK:           %[[CALL_ARG0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK:           %[[CALL_ARG1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK:           %[[DECOMPOSED:.*]]:2 = call @callee(%[[CALL_ARG0]], %[[CALL_ARG1]]) : (i1, i32) -> (i1, i32)
-// CHECK:           %[[CALL_RESULT_RECOMPOSED:.*]] = "test.make_tuple"(%[[DECOMPOSED]]#0, %[[DECOMPOSED]]#1) : (i1, i32) -> tuple<i1, i32>
-// CHECK:           %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 0 : i32}> : (tuple<i1, i32>) -> i1
-// CHECK:           %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) <{index = 1 : i32}> : (tuple<i1, i32>) -> i32
-// CHECK:           return %[[RET0]], %[[RET1]] : i1, i32
+// CHECK:           %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32)
+// CHECK:           return %[[V0]]#0, %[[V0]]#1 : i1, i32
 // CHECK-12N-LABEL:   func @caller(
 // CHECK-12N-SAME:                 %[[ARG0:.*]]: i1,
 // CHECK-12N-SAME:                 %[[ARG1:.*]]: i32) -> (i1, i32) {
@@ -190,14 +170,8 @@ func.func private @callee(tuple<>, i1, tuple<i2>, i3, tuple<i4, i5>, i6) -> (tup
 // CHECK-SAME:                 %[[I4:.*]]: i4,
 // CHECK-SAME:                 %[[I5:.*]]: i5,
 // CHECK-SAME:                 %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) {
-// CHECK:           %[[ARG_TUPLE:.*]] = "test.make_tuple"(%[[I4]], %[[I5]]) : (i4, i5) -> tuple<i4, i5>
-// CHECK:           %[[ARG_TUPLE_0:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK:           %[[ARG_TUPLE_1:.*]] = "test.get_tuple_element"(%[[ARG_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK:           %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[ARG_TUPLE_0]], %[[ARG_TUPLE_1]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
-// CHECK:           %[[RET_TUPLE:.*]] = "test.make_tuple"(%[[CALL]]#3, %[[CALL]]#4) : (i4, i5) -> tuple<i4, i5>
-// CHECK:           %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 0 : i32}> : (tuple<i4, i5>) -> i4
-// CHECK:           %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) <{index = 1 : i32}> : (tuple<i4, i5>) -> i5
-// CHECK:           return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
+// CHECK:           %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6)
+// CHECK:           return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6
 // CHECK-12N-LABEL:   func @caller(
 // CHECK-12N-SAME:                 %[[I1:.*]]: i1,
 // CHECK-12N-SAME:                 %[[I2:.*]]: i2,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e5503ee8920424..3ebee795a251df 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -463,3 +463,14 @@ func.func @circular_mapping() {
   %0 = "test.erase_op"() : () -> (i64)
   "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
 }
+
+// -----
+
+func.func @test_1_to_n_block_signature_conversion() {
+  "test.duplicate_block_args"() ({
+  ^bb0(%arg0: i64):
+    "test.repetitive_1_to_n_consumer"(%arg0) : (i64) -> ()
+  }) {} : () -> ()
+  "test.return"() : () -> ()
+}
+
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index cfe19a2fd5c08b..d59caacc18ae44 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1886,6 +1886,11 @@ def LegalOpC : TEST_Op<"legal_op_c">,
   Arguments<(ins I32)>, Results<(outs I32)>;
 def LegalOpD : TEST_Op<"legal_op_d">, Arguments<(ins AnyType)>;
 
+def DuplicateBlockArgsOp : TEST_Op<"duplicate_block_args", [SingleBlock]> {
+  let arguments = (ins UnitAttr:$is_legal);
+  let regions = (region SizedRegion<1>:$body);
+}
+
 // Check that the conversion infrastructure can properly undo the creation of
 // operations where an operation was created before its parent, in this case,
 // in the parent's builder.
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3df6cff3c0a60b..0b5239168efc43 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -982,9 +982,25 @@ struct TestPassthroughInvalidOp : public ConversionPattern {
   TestPassthroughInvalidOp(MLIRContext *ctx)
       : ConversionPattern("test.invalid", 1, ctx) {}
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, operands,
+    SmallVector<Value> flattened;
+    for (auto it : llvm::enumerate(operands)) {
+      ValueRange range = it.value();
+      if (range.size() == 1) {
+        flattened.push_back(range.front());
+        continue;
+      }
+
+      // This is a 1:N replacement. Insert a test.cast op. (That's what the
+      // argument materialization used to do.)
+      flattened.push_back(
+          rewriter
+              .create<TestCastOp>(op->getLoc(),
+                                  op->getOperand(it.index()).getType(), range)
+              .getResult());
+    }
+    rewriter.replaceOpWithNewOp<TestValidOp>(op, std::nullopt, flattened,
                                              std::nullopt);
     return success();
   }
@@ -1010,23 +1026,13 @@ struct TestSplitReturnType : public ConversionPattern {
   TestSplitReturnType(MLIRContext *ctx)
       : ConversionPattern("test.return", 1, ctx) {}
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
     // Check for a return of F32.
     if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32())
       return failure();
-
-    // Check if the first operation is a cast operation, if it is we use the
-    // results directly.
-    auto *defOp = operands[0].getDefiningOp();
-    if (auto packerOp =
-            llvm::dyn_cast_or_null<UnrealizedConversionCastOp>(defOp)) {
-      rewriter.replaceOpWithNewOp<TestReturnOp>(op, packerOp.getOperands());
-      return success();
-    }
-
-    // Otherwise, fail to match.
-    return failure();
+    rewriter.replaceOpWithNewOp<TestReturnOp>(op, operands[0]);
+    return success();
   }
 };
 
@@ -1181,6 +1187,47 @@ class TestEraseOp : public ConversionPattern {
   }
 };
 
+/// This pattern matches a test.duplicate_block_args op and duplicates all
+/// block arguments.
+class TestDuplicateBlockArgs
+    : public OpConversionPattern<DuplicateBlockArgsOp> {
+  using OpConversionPattern<DuplicateBlockArgsOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(DuplicateBlockArgsOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (op.getIsLegal())
+      return failure();
+    rewriter.startOpModification(op);
+    Block *body = &op.getBody().front();
+    TypeConverter::SignatureConversion result(body->getNumArguments());
+    for (auto it : llvm::enumerate(body->getArgumentTypes()))
+      result.addInputs(it.index(), {it.value(), it.value()});
+    rewriter.applySignatureConversion(body, result, getTypeConverter());
+    op.setIsLegal(true);
+    rewriter.finalizeOpModification(op);
+    return success();
+  }
+};
+
+/// This pattern replaces test.repetitive_1_to_n_consumer ops with a test.valid
+/// op. The pattern supports 1:N replacements and forwards the replacement
+/// values of the single operand as test.valid operands.
+class TestRepetitive1ToNConsumer : public ConversionPattern {
+public:
+  TestRepetitive1ToNConsumer(MLIRContext *ctx)
+      : ConversionPattern("test.repetitive_1_to_n_consumer", 1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    // A single operand is expected.
+    if (op->getNumOperands() != 1)
+      return failure();
+    rewriter.replaceOpWithNewOp<TestValidOp>(op, operands.front());
+    return success();
+  }
+};
+
 } // namespace
 
 namespace {
@@ -1258,9 +1305,11 @@ struct TestLegalizePatternDriver
         TestUpdateConsumerType, TestNonRootReplacement,
         TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
         TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-        TestUndoPropertiesModification, TestEraseOp>(&getContext());
+        TestUndoPropertiesModification, TestEraseOp,
+        TestRepetitive1ToNConsumer>(&getContext());
     patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
         &getContext(), converter);
+    patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
     mlir::populateCallOpTypeConversionPattern(patterns, converter);
@@ -1312,6 +1361,9 @@ struct TestLegalizePatternDriver
     target.addDynamicallyLegalOp<TestOpInPlaceSelfFold>(
         [](TestOpInPlaceSelfFold op) { return op.getFolded(); });
 
+    target.addDynamicallyLegalOp<DuplicateBlockArgsOp>(
+        [](DuplicateBlockArgsOp op) { return op.getIsLegal(); });
+
     // Handle a partial conversion.
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;



More information about the Mlir-commits mailing list