[flang-commits] [flang] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Tue Dec 31 04:33:10 PST 2024


Markus =?utf-8?q?Böck?= <markus.boeck02 at gmail.com>,Matthias Springer
 <mspringer at nvidia.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/116524 at github.com>


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116524

>From 40612e0473a55a10588ec03105a63c503cf2a112 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 15 Dec 2024 17:36:49 +0100
Subject: [PATCH 1/5] ex

---
 .../lib/Optimizer/CodeGen/BoxedProcedure.cpp  |   1 -
 mlir/docs/DialectConversion.md                |  35 +-
 .../mlir/Transforms/DialectConversion.h       |  18 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   |  16 +-
 .../EmitC/Transforms/TypeConversions.cpp      |   1 -
 .../Dialect/Linalg/Transforms/Detensorize.cpp |   1 -
 .../Quant/Transforms/StripFuncQuantTypes.cpp  |   1 -
 .../Utils/SparseTensorDescriptor.cpp          |   3 -
 .../Vector/Transforms/VectorLinearize.cpp     |   1 -
 .../Transforms/Utils/DialectConversion.cpp    | 432 +++++++++---------
 mlir/test/Transforms/test-legalizer.mlir      |   7 +-
 .../Func/TestDecomposeCallGraphTypes.cpp      |   2 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |   1 -
 .../lib/Transforms/TestDialectConversion.cpp  |   1 -
 14 files changed, 224 insertions(+), 296 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 1bb91d252529f0..104ae7408b80c1 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
     addConversion([&](TypeDescType ty) {
       return TypeDescType::get(convertType(ty.getOfTy()));
     });
-    addArgumentMaterialization(materializeProcedure);
     addSourceMaterialization(materializeProcedure);
     addTargetMaterialization(materializeProcedure);
   }
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 3168f5e13c7515..abacd5a82c61eb 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure
 type safety during the conversion process. There are several types of
 materializations depending on the situation.
 
-*   Argument Materialization
-
-    -   An argument materialization is used when converting the type of a block
-        argument during a [signature conversion](#region-signature-conversion).
-        The new block argument types are specified in a `SignatureConversion`
-        object. An original block argument can be converted into multiple
-        block arguments, which is not supported everywhere in the dialect
-        conversion. (E.g., adaptors support only a single replacement value for
-        each original value.) Therefore, an argument materialization is used to
-        convert potentially multiple new block arguments back into a single SSA
-        value. An argument materialization is also used when replacing an op
-        result with multiple values.
-
 *   Source Materialization
 
     -   A source materialization is used when a value was replaced with a value
@@ -343,17 +330,6 @@ class TypeConverter {
   /// Materialization functions must be provided when a type conversion may
   /// persist after the conversion has finished.
 
-  /// This method registers a materialization that will be called when
-  /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value
-  /// with the old argument type.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
-  void addArgumentMaterialization(FnT &&callback) {
-    argumentMaterializations.emplace_back(
-        wrapMaterialization<T>(std::forward<FnT>(callback)));
-  }
-
   /// This method registers a materialization that will be called when
   /// converting a replacement value back to its original source type.
   /// This is used when some uses of the original value persist beyond the main
@@ -406,12 +382,11 @@ done explicitly via a conversion pattern.
 To convert the types of block arguments within a Region, a custom hook on the
 `ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
 uses a provided type converter to apply type conversions to all blocks of a
-given region. As noted above, the conversions performed by this method use the
-argument materialization hook on the `TypeConverter`. This hook also takes an
-optional `TypeConverter::SignatureConversion` parameter that applies a custom
-conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to the operation, e.g.,
-`func::FuncOp`, `AffineForOp`, etc.
+given region. This hook also takes an optional
+`TypeConverter::SignatureConversion` parameter that applies a custom conversion
+to the entry block of the region. The types of the entry block arguments are
+often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`,
+etc.
 
 To convert the signature of just one given block, the
 `applySignatureConversion` hook can be used.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913e3..9a6975dcf8dfae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,6 +181,10 @@ class TypeConverter {
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: Argument materializations are used only with the 1:N dialect
+  /// conversion driver. The 1:N dialect conversion driver will be removed soon
+  /// and so will be argument materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
@@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   void replaceOp(Operation *op, Operation *newOp) override;
 
   /// Replace the given operation with the new value ranges. The number of op
-  /// results and value ranges must match. If an original SSA value is replaced
-  /// by multiple SSA values (i.e., a value range has more than 1 element), the
-  /// conversion driver will insert an argument materialization to convert the
-  /// N SSA values back into 1 SSA value of the original type. The given
-  /// operation is erased.
-  ///
-  /// Note: The argument materialization is a workaround until we have full 1:N
-  /// support in the dialect conversion. (It is going to disappear from both
-  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+  /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1285,8 +1281,8 @@ struct ConversionConfig {
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
 
-  /// If set to "true", the dialect conversion attempts to build source/target/
-  /// argument materializations through the type converter API in lieu of
+  /// If set to "true", the dialect conversion attempts to build source/target
+  /// materializations through the type converter API in lieu of
   /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
   /// at least one materialization could not be built.
   ///
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 49e2d943286645..72799e42cf3fd1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -85,7 +85,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder,
                                            UnrankedMemRefType resultType,
                                            ValueRange inputs, Location loc,
                                            const LLVMTypeConverter &converter) {
-  // An argument materialization must return a value of type
+  // A source materialization must return a value of type
   // `resultType`, so insert a cast from the memref descriptor type
   // (!llvm.struct) to the original memref type.
   Value packed =
@@ -101,7 +101,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder,
                                          MemRefType resultType,
                                          ValueRange inputs, Location loc,
                                          const LLVMTypeConverter &converter) {
-  // An argument materialization must return a value of type `resultType`,
+  // A source materialization must return a value of type `resultType`,
   // so insert a cast from the memref descriptor type (!llvm.struct) to the
   // original memref type.
   Value packed =
@@ -234,19 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   });
 
-  // Argument materializations convert from the new block argument types
+  // Source materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
-                                         *this);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
-  });
   addSourceMaterialization([&](OpBuilder &builder,
                                UnrankedMemRefType resultType, ValueRange inputs,
                                Location loc) {
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
 
   converter.addSourceMaterialization(materializeAsUnrealizedCast);
   converter.addTargetMaterialization(materializeAsUnrealizedCast);
-  converter.addArgumentMaterialization(materializeAsUnrealizedCast);
 }
 
 /// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 0e651f4cee4c36..fc6671ef811759 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
     });
 
     addSourceMaterialization(sourceMaterializationCallback);
-    addArgumentMaterialization(sourceMaterializationCallback);
   }
 };
 
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
     addConversion(convertQuantizedType);
     addConversion(convertTensorType);
 
-    addArgumentMaterialization(materializeConversion);
     addSourceMaterialization(materializeConversion);
     addTargetMaterialization(materializeConversion);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 
   // Required by scf.for 1:N type conversion.
   addSourceMaterialization(materializeTuple);
-
-  // Required as a workaround until we have full 1:N support.
-  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
 
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
-  typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 255b0ba2559ee6..96cbe07f0f12f9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Iterators.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,6 +54,55 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
+/// Given two insertion points in the same block, choose the later one.
+static OpBuilder::InsertPoint
+chooseLaterInsertPointInBlock(OpBuilder::InsertPoint a,
+                              OpBuilder::InsertPoint b) {
+  assert(a.getBlock() == b.getBlock() && "expected same block");
+  Block *block = a.getBlock();
+  if (a.getPoint() == block->begin())
+    return b;
+  if (b.getPoint() == block->begin())
+    return a;
+  if (a.getPoint()->isBeforeInBlock(&*b.getPoint()))
+    return b;
+  return a;
+}
+
+/// Helper function that chooses the insertion point among the two given ones
+/// that is later.
+// TODO: Extend DominanceInfo API to work with block iterators.
+static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a,
+                                                     OpBuilder::InsertPoint b) {
+  // Case 1: Same block.
+  if (a.getBlock() == b.getBlock())
+    return chooseLaterInsertPointInBlock(a, b);
+
+  // Case 2: Different block, but same region.
+  if (a.getBlock()->getParent() == b.getBlock()->getParent()) {
+    DominanceInfo domInfo;
+    if (domInfo.properlyDominates(a.getBlock(), b.getBlock()))
+      return b;
+    if (domInfo.properlyDominates(b.getBlock(), a.getBlock()))
+      return a;
+    // Neither of the two blocks dominante each other.
+    llvm_unreachable("unable to find valid insertion point");
+  }
+
+  // Case 3: b's region contains a: choose a.
+  if (Operation *aParent = b.getBlock()->getParent()->findAncestorOpInRegion(
+          *a.getPoint()->getParentOp()))
+    return a;
+
+  // Case 4: a's region contains b: choose b.
+  if (Operation *bParent = a.getBlock()->getParent()->findAncestorOpInRegion(
+          *b.getPoint()->getParentOp()))
+    return b;
+
+  // Neither of the two operations contain each other.
+  llvm_unreachable("unable to find valid insertion point");
+}
+
 /// Helper function that computes an insertion point where the given value is
 /// defined and can be used without a dominance violation.
 static OpBuilder::InsertPoint computeInsertPoint(Value value) {
@@ -63,11 +113,36 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
   return OpBuilder::InsertPoint(insertBlock, insertPt);
 }
 
+/// Helper function that computes an insertion point where the given values are
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+  assert(!vals.empty() && "expected at least one value");
+  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+  for (Value v : vals.drop_front())
+    pt = chooseLaterInsertPoint(pt, computeInsertPoint(v));
+  return pt;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
+/// A vector of SSA values, optimized for the most common case of a single
+/// value.
+using ValueVector = SmallVector<Value, 1>;
+
 namespace {
+
+/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
+struct ValueVectorMapInfo {
+  static ValueVector getEmptyKey() { return ValueVector{}; }
+  static ValueVector getTombstoneKey() { return ValueVector{}; }
+  static ::llvm::hash_code getHashValue(ValueVector val) {
+    return ::llvm::hash_combine_range(val.begin(), val.end());
+  }
+  static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; }
+};
+
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
@@ -75,68 +150,103 @@ struct ConversionValueMapping {
   /// false positives.
   bool isMappedTo(Value value) const { return mappedTo.contains(value); }
 
-  /// Lookup the most recently mapped value with the desired type in the
+  /// Lookup the most recently mapped values with the desired types in the
   /// mapping.
   ///
   /// Special cases:
-  /// - If the desired type is "null", simply return the most recently mapped
-  ///   value.
-  /// - If there is no mapping to the desired type, also return the most
-  ///   recently mapped value.
-  /// - If there is no mapping for the given value at all, return the given
-  ///   value.
-  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
-
-  /// Lookup a mapped value within the map, or return null if a mapping does not
-  /// exist. If a mapping exists, this follows the same behavior of
-  /// `lookupOrDefault`.
-  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
+  /// - If the desired type range is empty, simply return the most recently
+  ///   mapped values.
+  /// - If there is no mapping to the desired types, also return the most
+  ///   recently mapped values.
+  /// - If there is no mapping for the given values at all, return the given
+  ///   values.
+  ValueVector lookupOrDefault(ValueVector from,
+                              TypeRange desiredTypes = {}) const;
+
+  /// Lookup the given values within the map, or return an empty vector if the
+  /// values are not mapped. If they are mapped, this follows the same behavior
+  /// as `lookupOrDefault`.
+  ValueVector lookupOrNull(const ValueVector &from,
+                           TypeRange desiredTypes = {}) const;
 
   /// Map a value to the one provided.
-  void map(Value oldVal, Value newVal) {
+  void map(const ValueVector &oldVal, const ValueVector &newVal) {
     LLVM_DEBUG({
-      for (Value it = newVal; it; it = mapping.lookupOrNull(it))
-        assert(it != oldVal && "inserting cyclic mapping");
+      ValueVector next = newVal;
+      while (true) {
+        assert(next != oldVal && "inserting cyclic mapping");
+        auto it = mapping.find(next);
+        if (it == mapping.end())
+          break;
+        next = it->second;
+      }
     });
-    mapping.map(oldVal, newVal);
-    mappedTo.insert(newVal);
+    mapping[oldVal] = newVal;
+    for (Value v : newVal)
+      mappedTo.insert(v);
   }
 
-  /// Drop the last mapping for the given value.
-  void erase(Value value) { mapping.erase(value); }
+  /// Drop the last mapping for the given values.
+  void erase(ValueVector value) { mapping.erase(value); }
 
 private:
   /// Current value mappings.
-  IRMapping mapping;
+  DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
 
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
-Value ConversionValueMapping::lookupOrDefault(Value from,
-                                              Type desiredType) const {
-  // Try to find the deepest value that has the desired type. If there is no
-  // such value, simply return the deepest value.
-  Value desiredValue;
+ValueVector
+ConversionValueMapping::lookupOrDefault(ValueVector from,
+                                        TypeRange desiredTypes) const {
+  // Try to find the deepest values that have the desired types. If there is no
+  // such mapping, simply return the deepest values.
+  ValueVector desiredValue;
   do {
-    if (!desiredType || from.getType() == desiredType)
+    // Store the current value if the types match.
+    if (desiredTypes.empty() || TypeRange(from) == desiredTypes)
       desiredValue = from;
 
-    Value mappedValue = mapping.lookupOrNull(from);
-    if (!mappedValue)
+    // If possible, Replace each value with (one or multiple) mapped values.
+    ValueVector next;
+    for (Value v : from) {
+      auto it = mapping.find({v});
+      if (it != mapping.end()) {
+        llvm::append_range(next, it->second);
+      } else {
+        next.push_back(v);
+      }
+    }
+    if (next != from) {
+      // If at least one value was replaced, continue the lookup from there.
+      from = next;
+      continue;
+    }
+
+    // Otherwise: Check if there is a mapping for the entire vector. Such
+    // mappings are materializations. (N:M mapping are not supported for value
+    // replacements.)
+    auto it = mapping.find(from);
+    if (it == mapping.end()) {
+      // No mapping found: The lookup stops here.
       break;
-    from = mappedValue;
+    }
+    from = it->second;
   } while (true);
 
-  // If the desired value was found use it, otherwise default to the leaf value.
-  return desiredValue ? desiredValue : from;
+  // If the desired values were found use them, otherwise default to the leaf
+  // values.
+  return !desiredValue.empty() ? desiredValue : from;
 }
 
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
-  Value result = lookupOrDefault(from, desiredType);
-  if (result == from || (desiredType && result.getType() != desiredType))
-    return nullptr;
+ValueVector ConversionValueMapping::lookupOrNull(const ValueVector &from,
+                                                 TypeRange desiredTypes) const {
+  ValueVector result = lookupOrDefault(from, desiredTypes);
+  TypeRange resultTypes(result);
+  if (result == from || (!desiredTypes.empty() && resultTypes != desiredTypes))
+    return {};
   return result;
 }
 
@@ -651,10 +761,6 @@ class CreateOperationRewrite : public OperationRewrite {
 
 /// The type of materialization.
 enum MaterializationKind {
-  /// This materialization materializes a conversion for an illegal block
-  /// argument type, to the original one.
-  Argument,
-
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
   Target,
@@ -673,7 +779,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
                                    UnrealizedConversionCastOp op,
                                    const TypeConverter *converter,
                                    MaterializationKind kind, Type originalType,
-                                   Value mappedValue);
+                                   ValueVector mappedValues);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,9 +814,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// materializations.
   Type originalType;
 
-  /// The value in the conversion value mapping that is being replaced by the
+  /// The values in the conversion value mapping that are being replaced by the
   /// results of this unresolved materialization.
-  Value mappedValue;
+  ValueVector mappedValues;
 };
 } // namespace
 
@@ -779,7 +885,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   LogicalResult remapValues(StringRef valueDiagTag,
                             std::optional<Location> inputLoc,
                             PatternRewriter &rewriter, ValueRange values,
-                            SmallVector<SmallVector<Value>> &remapped);
+                            SmallVector<ValueVector> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
@@ -820,39 +926,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// If a cast op was built, it can optionally be returned with the `castOp`
   /// output argument.
   ///
-  /// If `valueToMap` is set to a non-null Value, then that value is mapped to
+  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
   /// the results of the unresolved materialization in the conversion value
   /// mapping.
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+      ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
       Type originalType, const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr);
-  Value buildUnresolvedMaterialization(
-      MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
-      const TypeConverter *converter,
-      UnrealizedConversionCastOp *castOp = nullptr) {
-    return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
-                                          TypeRange(outputType), originalType,
-                                          converter, castOp)
-        .front();
-  }
-
-  /// Build an N:1 materialization for the given original value that was
-  /// replaced with the given replacement values.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. The conversion mapping can store only 1:1 replacements
-  /// and the conversion patterns only support single Value replacements in the
-  /// adaptor, so N values must be converted back to a single value. This
-  /// function will be deleted when full 1:N support has been added.
-  ///
-  /// This function inserts an argument materialization back to the original
-  /// type.
-  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
-                                 ValueRange replacements, Value originalValue,
-                                 const TypeConverter *converter);
 
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
@@ -862,16 +943,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value findOrBuildReplacementValue(Value value,
                                     const TypeConverter *converter);
 
-  /// Unpack an N:1 materialization and return the inputs of the
-  /// materialization. This function unpacks only those materializations that
-  /// were built with `insertNTo1Materialization`.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. It allows us to write 1:N conversion patterns while
-  /// 1:N support is still missing in the conversion value mapping. This
-  /// function will be deleted when full 1:N support has been added.
-  SmallVector<Value> unpackNTo1Materialization(Value value);
-
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -1041,7 +1112,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   });
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
@@ -1082,7 +1153,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
 
 void ReplaceOperationRewrite::rollback() {
   for (auto result : op->getResults())
-    rewriterImpl.mapping.erase(result);
+    rewriterImpl.mapping.erase({result});
 }
 
 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
@@ -1101,18 +1172,18 @@ void CreateOperationRewrite::rollback() {
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
     const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    Value mappedValue)
+    ValueVector mappedValues)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
       converterAndKind(converter, kind), originalType(originalType),
-      mappedValue(mappedValue) {
+      mappedValues(mappedValues) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   rewriterImpl.unresolvedMaterializations[op] = this;
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (mappedValue)
-    rewriterImpl.mapping.erase(mappedValue);
+  if (!mappedValues.empty())
+    rewriterImpl.mapping.erase(mappedValues);
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
   rewriterImpl.nTo1TempMaterializations.erase(getOperation());
   op->erase();
@@ -1160,7 +1231,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
 LogicalResult ConversionPatternRewriterImpl::remapValues(
     StringRef valueDiagTag, std::optional<Location> inputLoc,
     PatternRewriter &rewriter, ValueRange values,
-    SmallVector<SmallVector<Value>> &remapped) {
+    SmallVector<ValueVector> &remapped) {
   remapped.reserve(llvm::size(values));
 
   for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1239,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type origType = operand.getType();
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
-    // Find the most recently mapped value. Unpack all temporary N:1
-    // materializations. Such conversions are a workaround around missing
-    // 1:N support in the ConversionValueMapping. (The conversion patterns
-    // already support 1:N replacements.)
-    Value repl = mapping.lookupOrDefault(operand);
-    SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
     if (!currentTypeConverter) {
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
-      // pass through the most recently mapped value.
-      remapped.push_back(std::move(unpacked));
+      // pass through the most recently mapped values.
+      remapped.push_back(mapping.lookupOrDefault({operand}));
       continue;
     }
 
@@ -1192,51 +1256,28 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       });
       return failure();
     }
-
     // If a type is converted to 0 types, there is nothing to do.
     if (legalTypes.empty()) {
       remapped.push_back({});
       continue;
     }
 
-    if (legalTypes.size() != 1) {
-      // TODO: This is a 1:N conversion. The conversion value mapping does not
-      // store such materializations yet. If the types of the most recently
-      // mapped values do not match, build a target materialization.
-      ValueRange unpackedRange(unpacked);
-      if (TypeRange(unpackedRange) == legalTypes) {
-        remapped.push_back(std::move(unpacked));
-        continue;
-      }
-
-      // Insert a target materialization if the current pattern expects
-      // different legalized types.
-      ValueRange targetMat = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
-          /*valueToMap=*/Value(), /*inputs=*/unpacked,
-          /*outputType=*/legalTypes, /*originalType=*/origType,
-          currentTypeConverter);
-      remapped.push_back(targetMat);
+    ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes);
+    if (!repl.empty() && TypeRange(repl) == legalTypes) {
+      // Mapped values have the correct type or there is an existing
+      // materialization. Or the opreand is not mapped at all and has the
+      // correct type.
+      remapped.push_back(repl);
       continue;
     }
 
-    // Handle 1->1 type conversions.
-    Type desiredType = legalTypes.front();
-    // Try to find a mapped value with the desired type. (Or the operand itself
-    // if the value is not mapped at all.)
-    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
-    if (newOperand.getType() != desiredType) {
-      // If the looked up value's type does not have the desired type, it means
-      // that the value was replaced with a value of different type and no
-      // target materialization was created yet.
-      Value castValue = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
-          /*outputType=*/desiredType, /*originalType=*/origType,
-          currentTypeConverter);
-      newOperand = castValue;
-    }
-    remapped.push_back({newOperand});
+    // Create a materialization for the most recently mapped values.
+    repl = mapping.lookupOrDefault({operand});
+    ValueRange castValues = buildUnresolvedMaterialization(
+        MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+        /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
+        /*originalType=*/origType, currentTypeConverter);
+    remapped.push_back(castValues);
   }
   return success();
 }
@@ -1353,7 +1394,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       buildUnresolvedMaterialization(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
+          /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
@@ -1364,7 +1405,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, repl);
+      mapping.map({origArg}, {repl});
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1375,13 +1416,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    if (replArgs.size() == 1) {
-      mapping.map(origArg, replArgs.front());
-    } else {
-      insertNTo1Materialization(
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
-    }
+    ValueVector replArgVals = llvm::map_to_vector<1>(
+        replArgs, [](BlockArgument arg) -> Value { return arg; });
+    mapping.map({origArg}, replArgVals);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1402,7 +1439,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+    ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
     UnrealizedConversionCastOp *castOp) {
   assert((!originalType || kind == MaterializationKind::Target) &&
@@ -1410,10 +1447,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes) {
-    if (valueToMap) {
-      assert(inputs.size() == 1 && "1:N mapping is not supported");
-      mapping.map(valueToMap, inputs.front());
-    }
+    if (!valuesToMap.empty())
+      mapping.map(valuesToMap, inputs);
     return inputs;
   }
 
@@ -1423,37 +1458,21 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
-  if (valueToMap) {
-    assert(outputTypes.size() == 1 && "1:N mapping is not supported");
-    mapping.map(valueToMap, convertOp.getResult(0));
-  }
+  if (!valuesToMap.empty())
+    mapping.map(valuesToMap, convertOp.getResults());
   if (castOp)
     *castOp = convertOp;
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType, valueToMap);
+                                                  originalType, valuesToMap);
   return convertOp.getResults();
 }
 
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
-    OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
-    Value originalValue, const TypeConverter *converter) {
-  // Insert argument materialization back to the original type.
-  Type originalType = originalValue.getType();
-  UnrealizedConversionCastOp argCastOp;
-  buildUnresolvedMaterialization(
-      MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType,
-      /*originalType=*/Type(), converter, &argCastOp);
-  if (argCastOp)
-    nTo1TempMaterializations.insert(argCastOp);
-}
-
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
   // Find a replacement value with the same type.
-  Value repl = mapping.lookupOrNull(value, value.getType());
-  if (repl)
-    return repl;
+  ValueVector repl = mapping.lookupOrNull({value}, value.getType());
+  if (!repl.empty())
+    return repl.front();
 
   // Check if the value is dead. No replacement value is needed in that case.
   // This is an approximate check that may have false negatives but does not
@@ -1467,8 +1486,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   // No replacement value was found. Get the latest replacement value
   // (regardless of the type) and build a source materialization to the
   // original type.
-  repl = mapping.lookupOrNull(value);
-  if (!repl) {
+  repl = mapping.lookupOrNull({value});
+  if (repl.empty()) {
     // No replacement value is registered in the mapping. This means that the
     // value is dropped and no longer needed. (If the value were still needed,
     // a source materialization producing a replacement value "out of thin air"
@@ -1478,34 +1497,12 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   }
   Value castValue = buildUnresolvedMaterialization(
       MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
-      /*originalType=*/Type(), converter);
-  mapping.map(value, castValue);
+      /*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(),
+      /*originalType=*/Type(), converter)[0];
+  mapping.map({value}, {castValue});
   return castValue;
 }
 
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
-  // Unpack unrealized_conversion_cast ops that were inserted as a N:1
-  // workaround.
-  auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-  if (!castOp)
-    return {value};
-  if (!nTo1TempMaterializations.contains(castOp))
-    return {value};
-  assert(castOp->getNumResults() == 1 && "expected single result");
-
-  SmallVector<Value> result;
-  for (Value v : castOp.getOperands()) {
-    // Keep unpacking if possible. This is needed because during block
-    // signature conversions and 1:N op replacements, the driver may have
-    // inserted two materializations back-to-back: first an argument
-    // materialization, then a target materialization.
-    llvm::append_range(result, unpackNTo1Materialization(v));
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1554,7 +1551,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       // Materialize a replacement value "out of thin air".
       buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
-          result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
+          result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
       continue;
@@ -1572,16 +1569,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
     // Remap result to replacement value.
     if (repl.empty())
       continue;
-
-    if (repl.size() == 1) {
-      // Single replacement value: replace directly.
-      mapping.map(result, repl.front());
-    } else {
-      // Multiple replacement values: insert N:1 materialization.
-      insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
-                                /*replacements=*/repl, /*outputValue=*/result,
-                                currentTypeConverter);
-    }
+    mapping.map({result}, repl);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1648,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
   SmallVector<ValueRange> newVals;
-  for (size_t i = 0; i < newValues.size(); ++i)
-    newVals.push_back(newValues.slice(i, 1));
+  for (size_t i = 0; i < newValues.size(); ++i) {
+    if (newValues[i]) {
+      newVals.push_back(newValues.slice(i, 1));
+    } else {
+      newVals.push_back(ValueRange());
+    }
+  }
   impl->notifyOpReplaced(op, newVals);
 }
 
@@ -1729,11 +1722,11 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to});
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
-  SmallVector<SmallVector<Value>> remappedValues;
+  SmallVector<ValueVector> remappedValues;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
                                remappedValues)))
     return nullptr;
@@ -1746,7 +1739,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                                              SmallVectorImpl<Value> &results) {
   if (keys.empty())
     return success();
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<ValueVector> remapped;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
                                remapped)))
     return failure();
@@ -1872,7 +1865,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
                                              getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<ValueVector> remapped;
   if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
                                       op->getOperands(), remapped))) {
     return failure();
@@ -2625,19 +2618,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
     rewriter.setInsertionPoint(op);
     SmallVector<Value> newMaterialization;
     switch (rewrite->getMaterializationKind()) {
-    case MaterializationKind::Argument: {
-      // Try to materialize an argument conversion.
-      assert(op->getNumResults() == 1 && "expected single result");
-      Value argMat = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
-      if (argMat) {
-        newMaterialization.push_back(argMat);
-        break;
-      }
-    }
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 297eb5acef21b7..4cd196c5b44b31 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) {
 // Contents of the old block are moved to the new block.
 // CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
 
-// The new block arguments are used in "test.return".
-// CHECK-NEXT: notifyOperationModified: test.return
-
 // The old block is erased.
 // CHECK-NEXT: notifyBlockErased
 
@@ -390,8 +387,8 @@ func.func @caller() {
   // CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16)
   %0:2 = func.call @callee() : () -> (f32, i24)
 
-  // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
-  // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
+  // CHECK-DAG: %[[cast1:.*]] = "test.cast"() : () -> i24
+  // CHECK-DAG: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
   // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> ()
   // expected-remark @below{{'test.some_user' is not legalizable}}
   "test.some_user"(%0#0, %0#1) : (f32, i24) -> ()
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad50..d0b62e71ab0cf2 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
           tupleType.getFlattenedTypes(types);
           return success();
         });
-    typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+    typeConverter.addSourceMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 826c222990be4f..eae9b887e9d49a 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1284,7 +1284,6 @@ struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
   TestTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d788..a03bf0a1023d57 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
 struct PDLLTypeConverter : public TypeConverter {
   PDLLTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 

>From d6bf2cd3636c115cd7fe26a5213917832ee5eb38 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 23 Dec 2024 14:03:02 +0100
Subject: [PATCH 2/5] Apply suggestions from code review
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
 .../Transforms/Utils/DialectConversion.cpp    | 28 ++++++++++---------
 1 file changed, 15 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 96cbe07f0f12f9..8d6291f0f4f0d7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -187,7 +187,7 @@ struct ConversionValueMapping {
   }
 
   /// Drop the last mapping for the given values.
-  void erase(ValueVector value) { mapping.erase(value); }
+  void erase(const ValueVector &value) { mapping.erase(value); }
 
 private:
   /// Current value mappings.
@@ -221,7 +221,7 @@ ConversionValueMapping::lookupOrDefault(ValueVector from,
     }
     if (next != from) {
       // If at least one value was replaced, continue the lookup from there.
-      from = next;
+      from = std::move(next);
       continue;
     }
 
@@ -1175,7 +1175,7 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ValueVector mappedValues)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
       converterAndKind(converter, kind), originalType(originalType),
-      mappedValues(mappedValues) {
+      mappedValues(std::move(mappedValues)) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   rewriterImpl.unresolvedMaterializations[op] = this;
@@ -1265,9 +1265,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes);
     if (!repl.empty() && TypeRange(repl) == legalTypes) {
       // Mapped values have the correct type or there is an existing
-      // materialization. Or the opreand is not mapped at all and has the
+      // materialization. Or the operand is not mapped at all and has the
       // correct type.
-      remapped.push_back(repl);
+      remapped.push_back(std::move(repl));
       continue;
     }
 
@@ -1416,8 +1416,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    ValueVector replArgVals = llvm::map_to_vector<1>(
-        replArgs, [](BlockArgument arg) -> Value { return arg; });
+    ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
     mapping.map({origArg}, replArgVals);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
@@ -1462,8 +1461,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     mapping.map(valuesToMap, convertOp.getResults());
   if (castOp)
     *castOp = convertOp;
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType, valuesToMap);
+  appendRewrite<UnresolvedMaterializationRewrite>(
+      convertOp, converter, kind, originalType, std::move(valuesToMap));
   return convertOp.getResults();
 }
 
@@ -1495,10 +1494,13 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     // `applySignatureConversion`.)
     return Value();
   }
-  Value castValue = buildUnresolvedMaterialization(
-      MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(),
-      /*originalType=*/Type(), converter)[0];
+  Value castValue =
+      buildUnresolvedMaterialization(MaterializationKind::Source,
+                                     computeInsertPoint(repl), value.getLoc(),
+                                     /*valuesToMap=*/{value}, /*inputs=*/repl,
+                                     /*outputType=*/value.getType(),
+                                     /*originalType=*/Type(), converter)
+          .front();
   mapping.map({value}, {castValue});
   return castValue;
 }

>From d8d77d66eaadede0e20647c74c4833eca1ba5e60 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 29 Dec 2024 13:25:20 +0100
Subject: [PATCH 3/5] rebase fixes

---
 mlir/test/Transforms/test-legalizer.mlir    | 9 ++-------
 mlir/test/lib/Dialect/Test/TestPatterns.cpp | 8 --------
 2 files changed, 2 insertions(+), 15 deletions(-)

diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 4cd196c5b44b31..ae7d344b7167f9 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -491,13 +491,8 @@ func.func @test_1_to_n_block_signature_conversion() {
 
 // CHECK-LABEL: func @test_multiple_1_to_n_replacement()
 //       CHECK:   %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
-// TODO: There should be a single cast (i.e., a single target materialization).
-// This is currently not possible due to 1:N limitations of the conversion
-// mapping. Instead, we have 3 argument materializations.
-//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16
-//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16
-//       CHECK:   %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16
-//       CHECK:   "test.valid"(%[[cast3]]) : (f16) -> ()
+//       CHECK:   %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16
+//       CHECK:   "test.valid"(%[[cast]]) : (f16) -> ()
 func.func @test_multiple_1_to_n_replacement() {
   %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
   "test.invalid"(%0) : (f16) -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index eae9b887e9d49a..5b7c36c9b97bf4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1264,14 +1264,6 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
     // Replace test.multiple_1_to_n_replacement with test.step_1.
     Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
     // Now replace test.step_1 with test.legal_op.
-    // TODO: Ideally, it should not be necessary to reset the insertion point
-    // here. Based on the API calls, it looks like test.step_1 is entirely
-    // erased. But that's not the case: an argument materialization will
-    // survive. And that argument materialization will be used by the users of
-    // `op`. If we don't reset the insertion point here, we get dominance
-    // errors. This will be fixed when we have 1:N support in the conversion
-    // value mapping.
-    rewriter.setInsertionPoint(repl1);
     replaceWithDoubleResults(repl1, "test.legal_op");
     return success();
   }

>From 2efb6d71e535323f05f79b39ef429d91ec7d436f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Sat, 21 Dec 2024 19:01:19 +0100
Subject: [PATCH 4/5] use universal references for `map`

---
 .../Transforms/Utils/DialectConversion.cpp    | 36 ++++++++++++++-----
 1 file changed, 27 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8d6291f0f4f0d7..2a5c11c3d32ec3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -169,10 +169,15 @@ struct ConversionValueMapping {
   ValueVector lookupOrNull(const ValueVector &from,
                            TypeRange desiredTypes = {}) const;
 
+  template <typename T>
+  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
+
   /// Map a value to the one provided.
-  void map(const ValueVector &oldVal, const ValueVector &newVal) {
+  template <typename OldVal, typename NewVal>
+  std::enable_if_t<IsValueVector<OldVal>{} && IsValueVector<NewVal>{}>
+  map(OldVal &&oldVal, NewVal &&newVal) {
     LLVM_DEBUG({
-      ValueVector next = newVal;
+      ValueVector next(newVal);
       while (true) {
         assert(next != oldVal && "inserting cyclic mapping");
         auto it = mapping.find(next);
@@ -181,9 +186,22 @@ struct ConversionValueMapping {
         next = it->second;
       }
     });
-    mapping[oldVal] = newVal;
     for (Value v : newVal)
       mappedTo.insert(v);
+
+    mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+  }
+
+  template <typename OldVal, typename NewVal>
+  std::enable_if_t<!IsValueVector<OldVal>{} || !IsValueVector<NewVal>{}>
+  map(OldVal &&oldVal, NewVal &&newVal) {
+    if constexpr (IsValueVector<OldVal>{}) {
+      map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+    } else if constexpr (IsValueVector<NewVal>{}) {
+      map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+    } else {
+      map(ValueVector{oldVal}, ValueVector{newVal});
+    }
   }
 
   /// Drop the last mapping for the given values.
@@ -1405,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map({origArg}, {repl});
+      mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1417,7 +1435,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
     ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
-    mapping.map({origArg}, replArgVals);
+    mapping.map(origArg, std::move(replArgVals));
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1447,7 +1465,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes) {
     if (!valuesToMap.empty())
-      mapping.map(valuesToMap, inputs);
+      mapping.map(std::move(valuesToMap), inputs);
     return inputs;
   }
 
@@ -1501,7 +1519,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
                                      /*outputType=*/value.getType(),
                                      /*originalType=*/Type(), converter)
           .front();
-  mapping.map({value}, {castValue});
+  mapping.map(value, castValue);
   return castValue;
 }
 
@@ -1571,7 +1589,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
     // Remap result to replacement value.
     if (repl.empty())
       continue;
-    mapping.map({result}, repl);
+    mapping.map(result, repl);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1724,7 +1742,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
-  impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to});
+  impl->mapping.map(impl->mapping.lookupOrDefault({from}), to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {

>From 8e9b48a6daff04a6d15cbacc2d60d422a6f53722 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 30 Dec 2024 19:35:01 +0100
Subject: [PATCH 5/5] address comments

---
 .../Transforms/Utils/DialectConversion.cpp    | 63 ++++++++++---------
 1 file changed, 32 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2a5c11c3d32ec3..3571e017158be9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -137,10 +137,12 @@ namespace {
 struct ValueVectorMapInfo {
   static ValueVector getEmptyKey() { return ValueVector{}; }
   static ValueVector getTombstoneKey() { return ValueVector{}; }
-  static ::llvm::hash_code getHashValue(ValueVector val) {
+  static ::llvm::hash_code getHashValue(const ValueVector &val) {
     return ::llvm::hash_combine_range(val.begin(), val.end());
   }
-  static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; }
+  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
+    return LHS == RHS;
+  }
 };
 
 /// This class wraps a IRMapping to provide recursive lookup
@@ -159,20 +161,18 @@ struct ConversionValueMapping {
   /// - If there is no mapping to the desired types, also return the most
   ///   recently mapped values.
   /// - If there is no mapping for the given values at all, return the given
-  ///   values.
-  ValueVector lookupOrDefault(ValueVector from,
-                              TypeRange desiredTypes = {}) const;
+  ///   value.
+  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
 
-  /// Lookup the given values within the map, or return an empty vector if the
-  /// values are not mapped. If they are mapped, this follows the same behavior
+  /// Lookup the given value within the map, or return an empty vector if the
+  /// value is not mapped. If it is mapped, this follows the same behavior
   /// as `lookupOrDefault`.
-  ValueVector lookupOrNull(const ValueVector &from,
-                           TypeRange desiredTypes = {}) const;
+  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
 
   template <typename T>
   struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
 
-  /// Map a value to the one provided.
+  /// Map a value vector to the one provided.
   template <typename OldVal, typename NewVal>
   std::enable_if_t<IsValueVector<OldVal>{} && IsValueVector<NewVal>{}>
   map(OldVal &&oldVal, NewVal &&newVal) {
@@ -192,6 +192,7 @@ struct ConversionValueMapping {
     mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
   }
 
+  /// Map a value vector or single value to the one provided.
   template <typename OldVal, typename NewVal>
   std::enable_if_t<!IsValueVector<OldVal>{} || !IsValueVector<NewVal>{}>
   map(OldVal &&oldVal, NewVal &&newVal) {
@@ -217,19 +218,20 @@ struct ConversionValueMapping {
 } // namespace
 
 ValueVector
-ConversionValueMapping::lookupOrDefault(ValueVector from,
+ConversionValueMapping::lookupOrDefault(Value from,
                                         TypeRange desiredTypes) const {
   // Try to find the deepest values that have the desired types. If there is no
   // such mapping, simply return the deepest values.
   ValueVector desiredValue;
+  ValueVector current{from};
   do {
     // Store the current value if the types match.
-    if (desiredTypes.empty() || TypeRange(from) == desiredTypes)
-      desiredValue = from;
+    if (TypeRange(current) == desiredTypes)
+      desiredValue = current;
 
     // If possible, Replace each value with (one or multiple) mapped values.
     ValueVector next;
-    for (Value v : from) {
+    for (Value v : current) {
       auto it = mapping.find({v});
       if (it != mapping.end()) {
         llvm::append_range(next, it->second);
@@ -237,33 +239,35 @@ ConversionValueMapping::lookupOrDefault(ValueVector from,
         next.push_back(v);
       }
     }
-    if (next != from) {
+    if (next != current) {
       // If at least one value was replaced, continue the lookup from there.
-      from = std::move(next);
+      current = std::move(next);
       continue;
     }
 
     // Otherwise: Check if there is a mapping for the entire vector. Such
     // mappings are materializations. (N:M mapping are not supported for value
     // replacements.)
-    auto it = mapping.find(from);
+    auto it = mapping.find(current);
     if (it == mapping.end()) {
       // No mapping found: The lookup stops here.
       break;
     }
-    from = it->second;
+    current = it->second;
   } while (true);
 
   // If the desired values were found use them, otherwise default to the leaf
   // values.
-  return !desiredValue.empty() ? desiredValue : from;
+  // Note: If `desiredTypes` is empty, this function always returns `current`.
+  return !desiredValue.empty() ? desiredValue : current;
 }
 
-ValueVector ConversionValueMapping::lookupOrNull(const ValueVector &from,
+ValueVector ConversionValueMapping::lookupOrNull(Value from,
                                                  TypeRange desiredTypes) const {
   ValueVector result = lookupOrDefault(from, desiredTypes);
   TypeRange resultTypes(result);
-  if (result == from || (!desiredTypes.empty() && resultTypes != desiredTypes))
+  if (result == ValueVector{from} ||
+      (!desiredTypes.empty() && resultTypes != desiredTypes))
     return {};
   return result;
 }
@@ -1261,7 +1265,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
       // pass through the most recently mapped values.
-      remapped.push_back(mapping.lookupOrDefault({operand}));
+      remapped.push_back(mapping.lookupOrDefault(operand));
       continue;
     }
 
@@ -1280,7 +1284,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       continue;
     }
 
-    ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes);
+    ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
     if (!repl.empty() && TypeRange(repl) == legalTypes) {
       // Mapped values have the correct type or there is an existing
       // materialization. Or the operand is not mapped at all and has the
@@ -1290,7 +1294,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     }
 
     // Create a materialization for the most recently mapped values.
-    repl = mapping.lookupOrDefault({operand});
+    repl = mapping.lookupOrDefault(operand);
     ValueRange castValues = buildUnresolvedMaterialization(
         MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
         /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
@@ -1428,10 +1432,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       continue;
     }
 
-    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
-    // dialect conversion. Therefore, we need an argument materialization to
-    // turn the replacement block arguments into a single SSA value that can be
-    // used as a replacement.
+    // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
     ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
@@ -1487,7 +1488,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
   // Find a replacement value with the same type.
-  ValueVector repl = mapping.lookupOrNull({value}, value.getType());
+  ValueVector repl = mapping.lookupOrNull(value, value.getType());
   if (!repl.empty())
     return repl.front();
 
@@ -1503,7 +1504,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   // No replacement value was found. Get the latest replacement value
   // (regardless of the type) and build a source materialization to the
   // original type.
-  repl = mapping.lookupOrNull({value});
+  repl = mapping.lookupOrNull(value);
   if (repl.empty()) {
     // No replacement value is registered in the mapping. This means that the
     // value is dropped and no longer needed. (If the value were still needed,
@@ -1742,7 +1743,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
-  impl->mapping.map(impl->mapping.lookupOrDefault({from}), to);
+  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {



More information about the flang-commits mailing list