[flang-commits] [flang] 3ace685 - [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (#116524)

via flang-commits flang-commits at lists.llvm.org
Fri Jan 3 07:12:03 PST 2025


Author: Matthias Springer
Date: 2025-01-03T16:11:56+01:00
New Revision: 3ace685105d3b50bca68328bf0c945af22d70f23

URL: https://github.com/llvm/llvm-project/commit/3ace685105d3b50bca68328bf0c945af22d70f23
DIFF: https://github.com/llvm/llvm-project/commit/3ace685105d3b50bca68328bf0c945af22d70f23.diff

LOG: [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (#116524)

This commit updates the internal `ConversionValueMapping` data structure
in the dialect conversion driver to support 1:N replacements. This is
the last major commit for adding 1:N support to the dialect conversion
driver.

Since #116470, the infrastructure already supports 1:N replacements. But
the `ConversionValueMapping` still stored 1:1 value mappings. To that
end, the driver inserted temporary argument materializations (converting
N SSA values into 1 value). This is no longer the case. Argument
materializations are now entirely gone. (They will be deleted from the
type converter after some time, when we delete the old 1:N dialect
conversion driver.)

Note for LLVM integration: Replace all occurrences of
`addArgumentMaterialization` (except for 1:N dialect conversion passes)
with `addSourceMaterialization`.

---------

Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>

Added: 
    

Modified: 
    flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
    mlir/docs/DialectConversion.md
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
    mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Transforms/TestDialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 1bb91d252529f0..104ae7408b80c1 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
     addConversion([&](TypeDescType ty) {
       return TypeDescType::get(convertType(ty.getOfTy()));
     });
-    addArgumentMaterialization(materializeProcedure);
     addSourceMaterialization(materializeProcedure);
     addTargetMaterialization(materializeProcedure);
   }

diff  --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 3168f5e13c7515..abacd5a82c61eb 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure
 type safety during the conversion process. There are several types of
 materializations depending on the situation.
 
-*   Argument Materialization
-
-    -   An argument materialization is used when converting the type of a block
-        argument during a [signature conversion](#region-signature-conversion).
-        The new block argument types are specified in a `SignatureConversion`
-        object. An original block argument can be converted into multiple
-        block arguments, which is not supported everywhere in the dialect
-        conversion. (E.g., adaptors support only a single replacement value for
-        each original value.) Therefore, an argument materialization is used to
-        convert potentially multiple new block arguments back into a single SSA
-        value. An argument materialization is also used when replacing an op
-        result with multiple values.
-
 *   Source Materialization
 
     -   A source materialization is used when a value was replaced with a value
@@ -343,17 +330,6 @@ class TypeConverter {
   /// Materialization functions must be provided when a type conversion may
   /// persist after the conversion has finished.
 
-  /// This method registers a materialization that will be called when
-  /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value
-  /// with the old argument type.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
-  void addArgumentMaterialization(FnT &&callback) {
-    argumentMaterializations.emplace_back(
-        wrapMaterialization<T>(std::forward<FnT>(callback)));
-  }
-
   /// This method registers a materialization that will be called when
   /// converting a replacement value back to its original source type.
   /// This is used when some uses of the original value persist beyond the main
@@ -406,12 +382,11 @@ done explicitly via a conversion pattern.
 To convert the types of block arguments within a Region, a custom hook on the
 `ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
 uses a provided type converter to apply type conversions to all blocks of a
-given region. As noted above, the conversions performed by this method use the
-argument materialization hook on the `TypeConverter`. This hook also takes an
-optional `TypeConverter::SignatureConversion` parameter that applies a custom
-conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to the operation, e.g.,
-`func::FuncOp`, `AffineForOp`, etc.
+given region. This hook also takes an optional
+`TypeConverter::SignatureConversion` parameter that applies a custom conversion
+to the entry block of the region. The types of the entry block arguments are
+often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`,
+etc.
 
 To convert the signature of just one given block, the
 `applySignatureConversion` hook can be used.

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913e3..9a6975dcf8dfae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,6 +181,10 @@ class TypeConverter {
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: Argument materializations are used only with the 1:N dialect
+  /// conversion driver. The 1:N dialect conversion driver will be removed soon
+  /// and so will be argument materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
@@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   void replaceOp(Operation *op, Operation *newOp) override;
 
   /// Replace the given operation with the new value ranges. The number of op
-  /// results and value ranges must match. If an original SSA value is replaced
-  /// by multiple SSA values (i.e., a value range has more than 1 element), the
-  /// conversion driver will insert an argument materialization to convert the
-  /// N SSA values back into 1 SSA value of the original type. The given
-  /// operation is erased.
-  ///
-  /// Note: The argument materialization is a workaround until we have full 1:N
-  /// support in the dialect conversion. (It is going to disappear from both
-  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+  /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1285,8 +1281,8 @@ struct ConversionConfig {
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
 
-  /// If set to "true", the dialect conversion attempts to build source/target/
-  /// argument materializations through the type converter API in lieu of
+  /// If set to "true", the dialect conversion attempts to build source/target
+  /// materializations through the type converter API in lieu of
   /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
   /// at least one materialization could not be built.
   ///

diff  --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 49e2d943286645..72799e42cf3fd1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -85,7 +85,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder,
                                            UnrankedMemRefType resultType,
                                            ValueRange inputs, Location loc,
                                            const LLVMTypeConverter &converter) {
-  // An argument materialization must return a value of type
+  // A source materialization must return a value of type
   // `resultType`, so insert a cast from the memref descriptor type
   // (!llvm.struct) to the original memref type.
   Value packed =
@@ -101,7 +101,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder,
                                          MemRefType resultType,
                                          ValueRange inputs, Location loc,
                                          const LLVMTypeConverter &converter) {
-  // An argument materialization must return a value of type `resultType`,
+  // A source materialization must return a value of type `resultType`,
   // so insert a cast from the memref descriptor type (!llvm.struct) to the
   // original memref type.
   Value packed =
@@ -234,19 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   });
 
-  // Argument materializations convert from the new block argument types
+  // Source materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    return unrankedMemRefMaterialization(builder, resultType, inputs, loc,
-                                         *this);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this);
-  });
   addSourceMaterialization([&](OpBuilder &builder,
                                UnrankedMemRefType resultType, ValueRange inputs,
                                Location loc) {

diff  --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
 
   converter.addSourceMaterialization(materializeAsUnrealizedCast);
   converter.addTargetMaterialization(materializeAsUnrealizedCast);
-  converter.addArgumentMaterialization(materializeAsUnrealizedCast);
 }
 
 /// Get an unsigned integer or size data type corresponding to \p ty.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 0e651f4cee4c36..fc6671ef811759 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
     });
 
     addSourceMaterialization(sourceMaterializationCallback);
-    addArgumentMaterialization(sourceMaterializationCallback);
   }
 };
 

diff  --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
     addConversion(convertQuantizedType);
     addConversion(convertTensorType);
 
-    addArgumentMaterialization(materializeConversion);
     addSourceMaterialization(materializeConversion);
     addTargetMaterialization(materializeConversion);
   }

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 
   // Required by scf.for 1:N type conversion.
   addSourceMaterialization(materializeTuple);
-
-  // Required as a workaround until we have full 1:N support.
-  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
 
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
-  typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 2b006430d38175..0c5520988eff38 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Iterators.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,6 +54,55 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
+/// Given two insertion points in the same block, choose the later one.
+static OpBuilder::InsertPoint
+chooseLaterInsertPointInBlock(OpBuilder::InsertPoint a,
+                              OpBuilder::InsertPoint b) {
+  assert(a.getBlock() == b.getBlock() && "expected same block");
+  Block *block = a.getBlock();
+  if (a.getPoint() == block->begin())
+    return b;
+  if (b.getPoint() == block->begin())
+    return a;
+  if (a.getPoint()->isBeforeInBlock(&*b.getPoint()))
+    return b;
+  return a;
+}
+
+/// Helper function that chooses the insertion point among the two given ones
+/// that is later.
+// TODO: Extend DominanceInfo API to work with block iterators.
+static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a,
+                                                     OpBuilder::InsertPoint b) {
+  // Case 1: Fast path: Same block. This is the most common case.
+  if (LLVM_LIKELY(a.getBlock() == b.getBlock()))
+    return chooseLaterInsertPointInBlock(a, b);
+
+  // Case 2: Different block, but same region.
+  if (a.getBlock()->getParent() == b.getBlock()->getParent()) {
+    DominanceInfo domInfo;
+    if (domInfo.properlyDominates(a.getBlock(), b.getBlock()))
+      return b;
+    if (domInfo.properlyDominates(b.getBlock(), a.getBlock()))
+      return a;
+    // Neither of the two blocks dominante each other.
+    llvm_unreachable("unable to find valid insertion point");
+  }
+
+  // Case 3: b's region contains a: choose a.
+  if (b.getBlock()->getParent()->findAncestorOpInRegion(
+          *a.getPoint()->getParentOp()))
+    return a;
+
+  // Case 4: a's region contains b: choose b.
+  if (a.getBlock()->getParent()->findAncestorOpInRegion(
+          *b.getPoint()->getParentOp()))
+    return b;
+
+  // Neither of the two operations contain each other.
+  llvm_unreachable("unable to find valid insertion point");
+}
+
 /// Helper function that computes an insertion point where the given value is
 /// defined and can be used without a dominance violation.
 static OpBuilder::InsertPoint computeInsertPoint(Value value) {
@@ -63,11 +113,38 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
   return OpBuilder::InsertPoint(insertBlock, insertPt);
 }
 
+/// Helper function that computes an insertion point where the given values are
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+  assert(!vals.empty() && "expected at least one value");
+  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+  for (Value v : vals.drop_front())
+    pt = chooseLaterInsertPoint(pt, computeInsertPoint(v));
+  return pt;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
+/// A vector of SSA values, optimized for the most common case of a single
+/// value.
+using ValueVector = SmallVector<Value, 1>;
+
 namespace {
+
+/// Helper class to make it possible to use `ValueVector` as a key in DenseMap.
+struct ValueVectorMapInfo {
+  static ValueVector getEmptyKey() { return ValueVector{}; }
+  static ValueVector getTombstoneKey() { return ValueVector{}; }
+  static ::llvm::hash_code getHashValue(const ValueVector &val) {
+    return ::llvm::hash_combine_range(val.begin(), val.end());
+  }
+  static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) {
+    return LHS == RHS;
+  }
+};
+
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
@@ -75,68 +152,129 @@ struct ConversionValueMapping {
   /// false positives.
   bool isMappedTo(Value value) const { return mappedTo.contains(value); }
 
-  /// Lookup the most recently mapped value with the desired type in the
+  /// Lookup the most recently mapped values with the desired types in the
   /// mapping.
   ///
   /// Special cases:
-  /// - If the desired type is "null", simply return the most recently mapped
+  /// - If the desired type range is empty, simply return the most recently
+  ///   mapped values.
+  /// - If there is no mapping to the desired types, also return the most
+  ///   recently mapped values.
+  /// - If there is no mapping for the given values at all, return the given
   ///   value.
-  /// - If there is no mapping to the desired type, also return the most
-  ///   recently mapped value.
-  /// - If there is no mapping for the given value at all, return the given
-  ///   value.
-  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
+  ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const;
+
+  /// Lookup the given value within the map, or return an empty vector if the
+  /// value is not mapped. If it is mapped, this follows the same behavior
+  /// as `lookupOrDefault`.
+  ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const;
 
-  /// Lookup a mapped value within the map, or return null if a mapping does not
-  /// exist. If a mapping exists, this follows the same behavior of
-  /// `lookupOrDefault`.
-  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
+  template <typename T>
+  struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
 
-  /// Map a value to the one provided.
-  void map(Value oldVal, Value newVal) {
+  /// Map a value vector to the one provided.
+  template <typename OldVal, typename NewVal>
+  std::enable_if_t<IsValueVector<OldVal>::value && IsValueVector<NewVal>::value>
+  map(OldVal &&oldVal, NewVal &&newVal) {
     LLVM_DEBUG({
-      for (Value it = newVal; it; it = mapping.lookupOrNull(it))
-        assert(it != oldVal && "inserting cyclic mapping");
+      ValueVector next(newVal);
+      while (true) {
+        assert(next != oldVal && "inserting cyclic mapping");
+        auto it = mapping.find(next);
+        if (it == mapping.end())
+          break;
+        next = it->second;
+      }
     });
-    mapping.map(oldVal, newVal);
-    mappedTo.insert(newVal);
+    for (Value v : newVal)
+      mappedTo.insert(v);
+
+    mapping[std::forward<OldVal>(oldVal)] = std::forward<NewVal>(newVal);
+  }
+
+  /// Map a value vector or single value to the one provided.
+  template <typename OldVal, typename NewVal>
+  std::enable_if_t<!IsValueVector<OldVal>::value ||
+                   !IsValueVector<NewVal>::value>
+  map(OldVal &&oldVal, NewVal &&newVal) {
+    if constexpr (IsValueVector<OldVal>{}) {
+      map(std::forward<OldVal>(oldVal), ValueVector{newVal});
+    } else if constexpr (IsValueVector<NewVal>{}) {
+      map(ValueVector{oldVal}, std::forward<NewVal>(newVal));
+    } else {
+      map(ValueVector{oldVal}, ValueVector{newVal});
+    }
   }
 
-  /// Drop the last mapping for the given value.
-  void erase(Value value) { mapping.erase(value); }
+  /// Drop the last mapping for the given values.
+  void erase(const ValueVector &value) { mapping.erase(value); }
 
 private:
   /// Current value mappings.
-  IRMapping mapping;
+  DenseMap<ValueVector, ValueVector, ValueVectorMapInfo> mapping;
 
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
-Value ConversionValueMapping::lookupOrDefault(Value from,
-                                              Type desiredType) const {
-  // Try to find the deepest value that has the desired type. If there is no
-  // such value, simply return the deepest value.
-  Value desiredValue;
+ValueVector
+ConversionValueMapping::lookupOrDefault(Value from,
+                                        TypeRange desiredTypes) const {
+  // Try to find the deepest values that have the desired types. If there is no
+  // such mapping, simply return the deepest values.
+  ValueVector desiredValue;
+  ValueVector current{from};
   do {
-    if (!desiredType || from.getType() == desiredType)
-      desiredValue = from;
+    // Store the current value if the types match.
+    if (TypeRange(current) == desiredTypes)
+      desiredValue = current;
+
+    // If possible, Replace each value with (one or multiple) mapped values.
+    ValueVector next;
+    for (Value v : current) {
+      auto it = mapping.find({v});
+      if (it != mapping.end()) {
+        llvm::append_range(next, it->second);
+      } else {
+        next.push_back(v);
+      }
+    }
+    if (next != current) {
+      // If at least one value was replaced, continue the lookup from there.
+      current = std::move(next);
+      continue;
+    }
 
-    Value mappedValue = mapping.lookupOrNull(from);
-    if (!mappedValue)
+    // Otherwise: Check if there is a mapping for the entire vector. Such
+    // mappings are materializations. (N:M mapping are not supported for value
+    // replacements.)
+    //
+    // Note: From a correctness point of view, materializations do not have to
+    // be stored (and looked up) in the mapping. But for performance reasons,
+    // we choose to reuse existing IR (when possible) instead of creating it
+    // multiple times.
+    auto it = mapping.find(current);
+    if (it == mapping.end()) {
+      // No mapping found: The lookup stops here.
       break;
-    from = mappedValue;
+    }
+    current = it->second;
   } while (true);
 
-  // If the desired value was found use it, otherwise default to the leaf value.
-  return desiredValue ? desiredValue : from;
+  // If the desired values were found use them, otherwise default to the leaf
+  // values.
+  // Note: If `desiredTypes` is empty, this function always returns `current`.
+  return !desiredValue.empty() ? std::move(desiredValue) : std::move(current);
 }
 
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
-  Value result = lookupOrDefault(from, desiredType);
-  if (result == from || (desiredType && result.getType() != desiredType))
-    return nullptr;
+ValueVector ConversionValueMapping::lookupOrNull(Value from,
+                                                 TypeRange desiredTypes) const {
+  ValueVector result = lookupOrDefault(from, desiredTypes);
+  TypeRange resultTypes(result);
+  if (result == ValueVector{from} ||
+      (!desiredTypes.empty() && resultTypes != desiredTypes))
+    return {};
   return result;
 }
 
@@ -651,10 +789,6 @@ class CreateOperationRewrite : public OperationRewrite {
 
 /// The type of materialization.
 enum MaterializationKind {
-  /// This materialization materializes a conversion for an illegal block
-  /// argument type, to the original one.
-  Argument,
-
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
   Target,
@@ -673,7 +807,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
                                    UnrealizedConversionCastOp op,
                                    const TypeConverter *converter,
                                    MaterializationKind kind, Type originalType,
-                                   Value mappedValue);
+                                   ValueVector mappedValues);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,9 +842,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
   /// materializations.
   Type originalType;
 
-  /// The value in the conversion value mapping that is being replaced by the
+  /// The values in the conversion value mapping that are being replaced by the
   /// results of this unresolved materialization.
-  Value mappedValue;
+  ValueVector mappedValues;
 };
 } // namespace
 
@@ -779,7 +913,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   LogicalResult remapValues(StringRef valueDiagTag,
                             std::optional<Location> inputLoc,
                             PatternRewriter &rewriter, ValueRange values,
-                            SmallVector<SmallVector<Value>> &remapped);
+                            SmallVector<ValueVector> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
@@ -820,39 +954,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// If a cast op was built, it can optionally be returned with the `castOp`
   /// output argument.
   ///
-  /// If `valueToMap` is set to a non-null Value, then that value is mapped to
+  /// If `valuesToMap` is set to a non-null Value, then that value is mapped to
   /// the results of the unresolved materialization in the conversion value
   /// mapping.
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+      ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
       Type originalType, const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr);
-  Value buildUnresolvedMaterialization(
-      MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
-      const TypeConverter *converter,
-      UnrealizedConversionCastOp *castOp = nullptr) {
-    return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
-                                          TypeRange(outputType), originalType,
-                                          converter, castOp)
-        .front();
-  }
-
-  /// Build an N:1 materialization for the given original value that was
-  /// replaced with the given replacement values.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. The conversion mapping can store only 1:1 replacements
-  /// and the conversion patterns only support single Value replacements in the
-  /// adaptor, so N values must be converted back to a single value. This
-  /// function will be deleted when full 1:N support has been added.
-  ///
-  /// This function inserts an argument materialization back to the original
-  /// type.
-  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
-                                 ValueRange replacements, Value originalValue,
-                                 const TypeConverter *converter);
 
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
@@ -862,16 +971,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value findOrBuildReplacementValue(Value value,
                                     const TypeConverter *converter);
 
-  /// Unpack an N:1 materialization and return the inputs of the
-  /// materialization. This function unpacks only those materializations that
-  /// were built with `insertNTo1Materialization`.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. It allows us to write 1:N conversion patterns while
-  /// 1:N support is still missing in the conversion value mapping. This
-  /// function will be deleted when full 1:N support has been added.
-  SmallVector<Value> unpackNTo1Materialization(Value value);
-
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -1041,7 +1140,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   });
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
@@ -1082,7 +1181,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
 
 void ReplaceOperationRewrite::rollback() {
   for (auto result : op->getResults())
-    rewriterImpl.mapping.erase(result);
+    rewriterImpl.mapping.erase({result});
 }
 
 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
@@ -1101,18 +1200,18 @@ void CreateOperationRewrite::rollback() {
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
     const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    Value mappedValue)
+    ValueVector mappedValues)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
       converterAndKind(converter, kind), originalType(originalType),
-      mappedValue(mappedValue) {
+      mappedValues(std::move(mappedValues)) {
   assert((!originalType || kind == MaterializationKind::Target) &&
          "original type is valid only for target materializations");
   rewriterImpl.unresolvedMaterializations[op] = this;
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (mappedValue)
-    rewriterImpl.mapping.erase(mappedValue);
+  if (!mappedValues.empty())
+    rewriterImpl.mapping.erase(mappedValues);
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
   rewriterImpl.nTo1TempMaterializations.erase(getOperation());
   op->erase();
@@ -1160,7 +1259,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
 LogicalResult ConversionPatternRewriterImpl::remapValues(
     StringRef valueDiagTag, std::optional<Location> inputLoc,
     PatternRewriter &rewriter, ValueRange values,
-    SmallVector<SmallVector<Value>> &remapped) {
+    SmallVector<ValueVector> &remapped) {
   remapped.reserve(llvm::size(values));
 
   for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1267,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type origType = operand.getType();
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
-    // Find the most recently mapped value. Unpack all temporary N:1
-    // materializations. Such conversions are a workaround around missing
-    // 1:N support in the ConversionValueMapping. (The conversion patterns
-    // already support 1:N replacements.)
-    Value repl = mapping.lookupOrDefault(operand);
-    SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
     if (!currentTypeConverter) {
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
-      // pass through the most recently mapped value.
-      remapped.push_back(std::move(unpacked));
+      // pass through the most recently mapped values.
+      remapped.push_back(mapping.lookupOrDefault(operand));
       continue;
     }
 
@@ -1192,51 +1284,28 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       });
       return failure();
     }
-
     // If a type is converted to 0 types, there is nothing to do.
     if (legalTypes.empty()) {
       remapped.push_back({});
       continue;
     }
 
-    if (legalTypes.size() != 1) {
-      // TODO: This is a 1:N conversion. The conversion value mapping does not
-      // store such materializations yet. If the types of the most recently
-      // mapped values do not match, build a target materialization.
-      ValueRange unpackedRange(unpacked);
-      if (TypeRange(unpackedRange) == legalTypes) {
-        remapped.push_back(std::move(unpacked));
-        continue;
-      }
-
-      // Insert a target materialization if the current pattern expects
-      // 
diff erent legalized types.
-      ValueRange targetMat = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
-          /*valueToMap=*/Value(), /*inputs=*/unpacked,
-          /*outputTypes=*/legalTypes, /*originalType=*/origType,
-          currentTypeConverter);
-      remapped.push_back(targetMat);
+    ValueVector repl = mapping.lookupOrDefault(operand, legalTypes);
+    if (!repl.empty() && TypeRange(repl) == legalTypes) {
+      // Mapped values have the correct type or there is an existing
+      // materialization. Or the operand is not mapped at all and has the
+      // correct type.
+      remapped.push_back(std::move(repl));
       continue;
     }
 
-    // Handle 1->1 type conversions.
-    Type desiredType = legalTypes.front();
-    // Try to find a mapped value with the desired type. (Or the operand itself
-    // if the value is not mapped at all.)
-    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
-    if (newOperand.getType() != desiredType) {
-      // If the looked up value's type does not have the desired type, it means
-      // that the value was replaced with a value of 
diff erent type and no
-      // target materialization was created yet.
-      Value castValue = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
-          /*outputType=*/desiredType, /*originalType=*/origType,
-          currentTypeConverter);
-      newOperand = castValue;
-    }
-    remapped.push_back({newOperand});
+    // Create a materialization for the most recently mapped values.
+    repl = mapping.lookupOrDefault(operand);
+    ValueRange castValues = buildUnresolvedMaterialization(
+        MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
+        /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes,
+        /*originalType=*/origType, currentTypeConverter);
+    remapped.push_back(castValues);
   }
   return success();
 }
@@ -1353,7 +1422,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       buildUnresolvedMaterialization(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
+          /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
@@ -1369,19 +1438,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       continue;
     }
 
-    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
-    // dialect conversion. Therefore, we need an argument materialization to
-    // turn the replacement block arguments into a single SSA value that can be
-    // used as a replacement.
+    // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    if (replArgs.size() == 1) {
-      mapping.map(origArg, replArgs.front());
-    } else {
-      insertNTo1Materialization(
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
-    }
+    ValueVector replArgVals = llvm::to_vector_of<Value, 1>(replArgs);
+    mapping.map(origArg, std::move(replArgVals));
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1402,7 +1463,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+    ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
     UnrealizedConversionCastOp *castOp) {
   assert((!originalType || kind == MaterializationKind::Target) &&
@@ -1410,10 +1471,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes) {
-    if (valueToMap) {
-      assert(inputs.size() == 1 && "1:N mapping is not supported");
-      mapping.map(valueToMap, inputs.front());
-    }
+    if (!valuesToMap.empty())
+      mapping.map(std::move(valuesToMap), inputs);
     return inputs;
   }
 
@@ -1423,37 +1482,21 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
-  if (valueToMap) {
-    assert(outputTypes.size() == 1 && "1:N mapping is not supported");
-    mapping.map(valueToMap, convertOp.getResult(0));
-  }
+  if (!valuesToMap.empty())
+    mapping.map(valuesToMap, convertOp.getResults());
   if (castOp)
     *castOp = convertOp;
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType, valueToMap);
+  appendRewrite<UnresolvedMaterializationRewrite>(
+      convertOp, converter, kind, originalType, std::move(valuesToMap));
   return convertOp.getResults();
 }
 
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
-    OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
-    Value originalValue, const TypeConverter *converter) {
-  // Insert argument materialization back to the original type.
-  Type originalType = originalValue.getType();
-  UnrealizedConversionCastOp argCastOp;
-  buildUnresolvedMaterialization(
-      MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType,
-      /*originalType=*/Type(), converter, &argCastOp);
-  if (argCastOp)
-    nTo1TempMaterializations.insert(argCastOp);
-}
-
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
   // Find a replacement value with the same type.
-  Value repl = mapping.lookupOrNull(value, value.getType());
-  if (repl)
-    return repl;
+  ValueVector repl = mapping.lookupOrNull(value, value.getType());
+  if (!repl.empty())
+    return repl.front();
 
   // Check if the value is dead. No replacement value is needed in that case.
   // This is an approximate check that may have false negatives but does not
@@ -1468,7 +1511,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   // (regardless of the type) and build a source materialization to the
   // original type.
   repl = mapping.lookupOrNull(value);
-  if (!repl) {
+  if (repl.empty()) {
     // No replacement value is registered in the mapping. This means that the
     // value is dropped and no longer needed. (If the value were still needed,
     // a source materialization producing a replacement value "out of thin air"
@@ -1476,36 +1519,29 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     // `applySignatureConversion`.)
     return Value();
   }
-  Value castValue = buildUnresolvedMaterialization(
-      MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
-      /*originalType=*/Type(), converter);
+
+  // Note: `computeInsertPoint` computes the "earliest" insertion point at
+  // which all values in `repl` are defined. It is important to emit the
+  // materialization at that location because the same materialization may be
+  // reused in a 
diff erent context. (That's because materializations are cached
+  // in the conversion value mapping.) The insertion point of the
+  // materialization must be valid for all future users that may be created
+  // later in the conversion process.
+  //
+  // Note: Instead of creating new IR, `buildUnresolvedMaterialization` may
+  // return an already existing, cached materialization from the conversion
+  // value mapping.
+  Value castValue =
+      buildUnresolvedMaterialization(MaterializationKind::Source,
+                                     computeInsertPoint(repl), value.getLoc(),
+                                     /*valuesToMap=*/{value}, /*inputs=*/repl,
+                                     /*outputType=*/value.getType(),
+                                     /*originalType=*/Type(), converter)
+          .front();
   mapping.map(value, castValue);
   return castValue;
 }
 
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
-  // Unpack unrealized_conversion_cast ops that were inserted as a N:1
-  // workaround.
-  auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-  if (!castOp)
-    return {value};
-  if (!nTo1TempMaterializations.contains(castOp))
-    return {value};
-  assert(castOp->getNumResults() == 1 && "expected single result");
-
-  SmallVector<Value> result;
-  for (Value v : castOp.getOperands()) {
-    // Keep unpacking if possible. This is needed because during block
-    // signature conversions and 1:N op replacements, the driver may have
-    // inserted two materializations back-to-back: first an argument
-    // materialization, then a target materialization.
-    llvm::append_range(result, unpackNTo1Materialization(v));
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1554,7 +1590,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       // Materialize a replacement value "out of thin air".
       buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
-          result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
+          result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
       continue;
@@ -1572,16 +1608,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
     // Remap result to replacement value.
     if (repl.empty())
       continue;
-
-    if (repl.size() == 1) {
-      // Single replacement value: replace directly.
-      mapping.map(result, repl.front());
-    } else {
-      // Multiple replacement values: insert N:1 materialization.
-      insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
-                                /*replacements=*/repl, /*outputValue=*/result,
-                                currentTypeConverter);
-    }
+    mapping.map(result, repl);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1687,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
   SmallVector<ValueRange> newVals;
-  for (size_t i = 0; i < newValues.size(); ++i)
-    newVals.push_back(newValues.slice(i, 1));
+  for (size_t i = 0; i < newValues.size(); ++i) {
+    if (newValues[i]) {
+      newVals.push_back(newValues.slice(i, 1));
+    } else {
+      newVals.push_back(ValueRange());
+    }
+  }
   impl->notifyOpReplaced(op, newVals);
 }
 
@@ -1733,7 +1765,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
-  SmallVector<SmallVector<Value>> remappedValues;
+  SmallVector<ValueVector> remappedValues;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
                                remappedValues)))
     return nullptr;
@@ -1746,7 +1778,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                                              SmallVectorImpl<Value> &results) {
   if (keys.empty())
     return success();
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<ValueVector> remapped;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
                                remapped)))
     return failure();
@@ -1872,7 +1904,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
                                              getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<ValueVector> remapped;
   if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
                                       op->getOperands(), remapped))) {
     return failure();
@@ -2625,19 +2657,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
     rewriter.setInsertionPoint(op);
     SmallVector<Value> newMaterialization;
     switch (rewrite->getMaterializationKind()) {
-    case MaterializationKind::Argument: {
-      // Try to materialize an argument conversion.
-      assert(op->getNumResults() == 1 && "expected single result");
-      Value argMat = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
-      if (argMat) {
-        newMaterialization.push_back(argMat);
-        break;
-      }
-    }
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), op.getResultTypes(), inputOperands,

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 297eb5acef21b7..ae7d344b7167f9 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) {
 // Contents of the old block are moved to the new block.
 // CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
 
-// The new block arguments are used in "test.return".
-// CHECK-NEXT: notifyOperationModified: test.return
-
 // The old block is erased.
 // CHECK-NEXT: notifyBlockErased
 
@@ -390,8 +387,8 @@ func.func @caller() {
   // CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16)
   %0:2 = func.call @callee() : () -> (f32, i24)
 
-  // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
-  // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
+  // CHECK-DAG: %[[cast1:.*]] = "test.cast"() : () -> i24
+  // CHECK-DAG: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
   // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> ()
   // expected-remark @below{{'test.some_user' is not legalizable}}
   "test.some_user"(%0#0, %0#1) : (f32, i24) -> ()
@@ -494,13 +491,8 @@ func.func @test_1_to_n_block_signature_conversion() {
 
 // CHECK-LABEL: func @test_multiple_1_to_n_replacement()
 //       CHECK:   %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16)
-// TODO: There should be a single cast (i.e., a single target materialization).
-// This is currently not possible due to 1:N limitations of the conversion
-// mapping. Instead, we have 3 argument materializations.
-//       CHECK:   %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16
-//       CHECK:   %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16
-//       CHECK:   %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16
-//       CHECK:   "test.valid"(%[[cast3]]) : (f16) -> ()
+//       CHECK:   %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16
+//       CHECK:   "test.valid"(%[[cast]]) : (f16) -> ()
 func.func @test_multiple_1_to_n_replacement() {
   %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
   "test.invalid"(%0) : (f16) -> ()

diff  --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad50..d0b62e71ab0cf2 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
           tupleType.getFlattenedTypes(types);
           return success();
         });
-    typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+    typeConverter.addSourceMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 826c222990be4f..5b7c36c9b97bf4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1264,14 +1264,6 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
     // Replace test.multiple_1_to_n_replacement with test.step_1.
     Operation *repl1 = replaceWithDoubleResults(op, "test.step_1");
     // Now replace test.step_1 with test.legal_op.
-    // TODO: Ideally, it should not be necessary to reset the insertion point
-    // here. Based on the API calls, it looks like test.step_1 is entirely
-    // erased. But that's not the case: an argument materialization will
-    // survive. And that argument materialization will be used by the users of
-    // `op`. If we don't reset the insertion point here, we get dominance
-    // errors. This will be fixed when we have 1:N support in the conversion
-    // value mapping.
-    rewriter.setInsertionPoint(repl1);
     replaceWithDoubleResults(repl1, "test.legal_op");
     return success();
   }
@@ -1284,7 +1276,6 @@ struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
   TestTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 

diff  --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d788..a03bf0a1023d57 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
 struct PDLLTypeConverter : public TypeConverter {
   PDLLTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 


        


More information about the flang-commits mailing list