[Mlir-commits] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)

Matthias Springer llvmlistbot at llvm.org
Sat Dec 14 06:03:16 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116524

>From f9e875810d6c8af844c90128d7a052a7a40297c0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH 1/2] [mlir][Transforms] Dialect Conversion: Do not build
 target mat. during 1:N replacement

fix test

experiement
---
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 130 ++++++++++++------
 .../Transforms/Utils/DialectConversion.cpp    |  46 ++-----
 mlir/test/Transforms/test-legalizer.mlir      |   8 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  47 +++----
 4 files changed, 128 insertions(+), 103 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Add generic source and target materializations to handle cases where
+  // non-LLVM types persist after an LLVM conversion.
+  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+
   // Helper function that checks if the given value range is a bare pointer.
   auto isBarePointer = [](ValueRange values) {
     return values.size() == 1 &&
            isa<LLVM::LLVMPointerType>(values.front().getType());
   };
 
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type. The dialect conversion framework will then
-  // insert a target materialization from the original block argument type to
-  // a legal type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packUnrankedMemRefDesc =
+      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+          Location loc, LLVMTypeConverter &converter) -> Value {
     // Note: Bare pointers are not supported for unranked memrefs because a
     // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
       return Value();
-    Value desc =
-        UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                          inputs);
+  };
+
+  // MemRef descriptor elements -> UnrankedMemRefType
+  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+                                          UnrankedMemRefType resultType,
+                                          ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type
     // `resultType`, so insert a cast from the memref descriptor type
     // (!llvm.struct) to the original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    Value desc;
-    if (isBarePointer(inputs)) {
-      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
-                                               inputs[0]);
-    } else if (TypeRange(inputs) ==
-               getMemRefDescriptorFields(resultType,
-                                         /*unpackAggregates=*/true)) {
-      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    } else {
-      // The inputs are neither a bare pointer nor an unpacked memref
-      // descriptor. This materialization function cannot be used.
+    Value packed =
+        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-    }
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+        .getResult(0);
+  };
+
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  LLVMTypeConverter &converter) -> Value {
+    assert(resultType && "expected non-null result type");
+    if (isBarePointer(inputs))
+      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                               resultType, inputs[0]);
+    if (TypeRange(inputs) ==
+        converter.getMemRefDescriptorFields(resultType,
+                                            /*unpackAggregates=*/true))
+      return MemRefDescriptor::pack(builder, loc, converter, resultType,
+                                    inputs);
+    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+    // This materialization function cannot be used.
+    return Value();
+  };
+
+  // MemRef descriptor elements -> MemRefType
+  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
     // original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  // Add generic source and target materializations to handle cases where
-  // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+    Value packed =
+        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
         .getResult(0);
-  });
+  };
+
+  // Argument materializations convert from the new block argument types
+  // (multiple SSA values that make up a memref descriptor) back to the
+  // original block argument type.
+  addArgumentMaterialization(unrakedMemRefMaterialization);
+  addArgumentMaterialization(rankedMemRefMaterialization);
+  addSourceMaterialization(unrakedMemRefMaterialization);
+  addSourceMaterialization(rankedMemRefMaterialization);
+
+  // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> Value {
+    // The original MemRef type is required to build a MemRef descriptor
+    // because the sizes/strides of the MemRef cannot be inferred from just the
+    // bare pointer.
+    if (!originalType)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
+    if (resultType != convertType(originalType))
+      return Value();
+    if (auto memrefType = dyn_cast<MemRefType>(originalType))
+      return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+    if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+      return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+                                    *this);
+    return Value();
   });
 
   // Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// function will be deleted when full 1:N support has been added.
   ///
   /// This function inserts an argument materialization back to the original
-  /// type, followed by a target materialization to the legalized type (if
-  /// applicable).
+  /// type.
   void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
                                  ValueRange replacements, Value originalValue,
                                  const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    insertNTo1Materialization(
-        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    if (replArgs.size() == 1) {
+      mapping.map(origArg, replArgs.front());
+    } else {
+      insertNTo1Materialization(
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
   // Insert argument materialization back to the original type.
   Type originalType = originalValue.getType();
   UnrealizedConversionCastOp argCastOp;
-  Value argMat = buildUnresolvedMaterialization(
+  buildUnresolvedMaterialization(
       MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
-      &argCastOp);
+      /*inputs=*/replacements, originalType,
+      /*originalType=*/Type(), converter, &argCastOp);
   if (argCastOp)
     nTo1TempMaterializations.insert(argCastOp);
-
-  // Insert target materialization to the legalized type.
-  Type legalOutputType;
-  if (converter) {
-    legalOutputType = converter->convertType(originalType);
-  } else if (replacements.size() == 1) {
-    // When there is no type converter, assume that the replacement value
-    // types are legal. This is reasonable to assume because they were
-    // specified by the user.
-    // FIXME: This won't work for 1->N conversions because multiple output
-    // types are not supported in parts of the dialect conversion. In such a
-    // case, we currently use the original value type.
-    legalOutputType = replacements[0].getType();
-  }
-  if (legalOutputType && legalOutputType != originalType) {
-    UnrealizedConversionCastOp targetCastOp;
-    buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(argMat), loc,
-        /*valueToMap=*/argMat, /*inputs=*/argMat,
-        /*outputType=*/legalOutputType, /*originalType=*/originalType,
-        converter, &targetCastOp);
-    if (targetCastOp)
-      nTo1TempMaterializations.insert(targetCastOp);
-  }
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
+  assert(this && "expected non-null type converter");
+  assert(t && "expected non-null type");
+
   {
     std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
                                                          std::defer_lock);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
   // expected-remark at +1 {{op 'foo.region' is not legalizable}}
   "foo.region"() ({
-    // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
-    ^bb0(%i0: i64, %unused: i16, %i1: i64):
-      // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
-      "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+    // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+    ^bb0(%i0: f64, %unused: i16, %i1: f64):
+      // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+      "test.invalid"(%i0, %i1) : (f64, f64) -> ()
   }) : () -> ()
   // expected-remark at +1 {{op 'func.return' is not legalizable}}
   return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 8a0bc597c56beb..466ae7ff6f46f1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
 };
 /// This pattern simply updates the operands of the given operation.
 struct TestPassthroughInvalidOp : public ConversionPattern {
-  TestPassthroughInvalidOp(MLIRContext *ctx)
-      : ConversionPattern("test.invalid", 1, ctx) {}
+  TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.invalid", 1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1301,19 +1301,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns.add<
-        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-        TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
-        TestSplitReturnType, TestChangeProducerTypeI32ToF32,
-        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
-        TestUpdateConsumerType, TestNonRootReplacement,
-        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
-        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-        TestUndoPropertiesModification, TestEraseOp,
-        TestRepetitive1ToNConsumer>(&getContext());
-    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
-        &getContext(), converter);
+    patterns
+        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+             TestNonRootReplacement, TestBoundedRecursiveRewrite,
+             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+             TestUndoPropertiesModification, TestEraseOp,
+             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+                 TestPassthroughInvalidOp>(&getContext(), converter);
     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1749,8 +1749,9 @@ struct TestTypeConversionAnotherProducer
 };
 
 struct TestReplaceWithLegalOp : public ConversionPattern {
-  TestReplaceWithLegalOp(MLIRContext *ctx)
-      : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+  TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(converter, "test.replace_with_legal_op",
+                          /*benefit=*/1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1872,12 +1873,12 @@ struct TestTypeConversionDriver
 
     // Initialize the set of rewrite patterns.
     RewritePatternSet patterns(&getContext());
-    patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
-                 TestSignatureConversionUndo,
-                 TestTestSignatureConversionNoConverter>(converter,
-                                                         &getContext());
-    patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
-        &getContext());
+    patterns
+        .add<TestTypeConsumerForward, TestTypeConversionProducer,
+             TestSignatureConversionUndo,
+             TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+            converter, &getContext());
+    patterns.add<TestTypeConversionAnotherProducer>(&getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
 

>From f983e30b097770fa04358d7a827545a0997de04f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 15:02:10 +0100
Subject: [PATCH 2/2] prototype

---
 .../Conversion/LLVMCommon/TypeConverter.cpp   |   2 -
 .../EmitC/Transforms/TypeConversions.cpp      |   1 -
 .../Dialect/Linalg/Transforms/Detensorize.cpp |   1 -
 .../Quant/Transforms/StripFuncQuantTypes.cpp  |   1 -
 .../Utils/SparseTensorDescriptor.cpp          |   3 -
 .../Vector/Transforms/VectorLinearize.cpp     |   1 -
 .../Transforms/Utils/DialectConversion.cpp    | 526 +++++++++++-------
 .../Func/TestDecomposeCallGraphTypes.cpp      |   2 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |   1 -
 .../lib/Transforms/TestDialectConversion.cpp  |   1 -
 10 files changed, 318 insertions(+), 221 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..ef8181e80cee38 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -237,8 +237,6 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization(unrakedMemRefMaterialization);
-  addArgumentMaterialization(rankedMemRefMaterialization);
   addSourceMaterialization(unrakedMemRefMaterialization);
   addSourceMaterialization(rankedMemRefMaterialization);
 
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
 
   converter.addSourceMaterialization(materializeAsUnrealizedCast);
   converter.addTargetMaterialization(materializeAsUnrealizedCast);
-  converter.addArgumentMaterialization(materializeAsUnrealizedCast);
 }
 
 /// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index af38485291182f..61bc5022893741 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
     });
 
     addSourceMaterialization(sourceMaterializationCallback);
-    addArgumentMaterialization(sourceMaterializationCallback);
   }
 };
 
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
     addConversion(convertQuantizedType);
     addConversion(convertTensorType);
 
-    addArgumentMaterialization(materializeConversion);
     addSourceMaterialization(materializeConversion);
     addTargetMaterialization(materializeConversion);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 
   // Required by scf.for 1:N type conversion.
   addSourceMaterialization(materializeTuple);
-
-  // Required as a workaround until we have full 1:N support.
-  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
 
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
-  typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 51686646a0a2fc..81c8c1f422551f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -63,11 +63,45 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
   return OpBuilder::InsertPoint(insertBlock, insertPt);
 }
 
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+  assert(!vals.empty() && "expected at least one value");
+  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+  for (Value v : vals.drop_front()) {
+    OpBuilder::InsertPoint pt2 = computeInsertPoint(v);
+    assert(pt.getBlock() == pt2.getBlock());
+    if (pt.getPoint() == pt.getBlock()->begin()) {
+      pt = pt2;
+      continue;
+    }
+    if (pt2.getPoint() == pt2.getBlock()->begin()) {
+      continue;
+    }
+    if (pt.getPoint()->isBeforeInBlock(&*pt2.getPoint()))
+      pt = pt2;
+  }
+  return pt;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
 namespace {
+struct SmallVectorMapInfo {
+  static SmallVector<Value, 1> getEmptyKey() { return SmallVector<Value, 1>{}; }
+  static SmallVector<Value, 1> getTombstoneKey() {
+    return SmallVector<Value, 1>{};
+  }
+  static ::llvm::hash_code getHashValue(SmallVector<Value, 1> val) {
+    return ::llvm::hash_combine_range(val.begin(), val.end());
+  }
+  static bool isEqual(SmallVector<Value, 1> LHS, SmallVector<Value, 1> RHS) {
+    return LHS == RHS;
+  }
+};
+
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
@@ -75,71 +109,240 @@ struct ConversionValueMapping {
   /// false positives.
   bool isMappedTo(Value value) const { return mappedTo.contains(value); }
 
-  /// Lookup the most recently mapped value with the desired type in the
-  /// mapping.
+  /// Find the most recently mapped values for the given value. If the value is
+  /// not mapped at all, return the given value.
+  SmallVector<Value, 1> lookupOrDefault(Value from) const;
+
+  /// TODO: Find most recently mapped or materialization with matching type. May
+  /// return the given value if the type matches.
+  SmallVector<Value, 1>
+  lookupOrDefault(Value from, SmallVector<Type, 1> desiredTypes) const;
+
+  Value lookupDirectSingleReplacement(Value from) const {
+    auto it = mapping.find(from);
+    if (it == mapping.end())
+      return Value();
+    const SmallVector<Value, 1> &repl = it->second;
+    if (repl.size() != 1)
+      return Value();
+    return repl.front();
+    /*
+        if (!mapping.contains(from)) return Value();
+        auto it = llvm::find(mapping, from);
+        const SmallVector<Value, 1> &repl = it->second;
+        if (repl.size() != 1) return Value();
+        return repl.front();
+        */
+  }
+
+  SmallVector<Value,1> lookupDirectReplacement(Value from) const {
+    auto it = mapping.find(from);
+    if (it == mapping.end())
+      return {};
+    return it->second;
+  }
+
+  /// Find the most recently mapped values for the given value. If the value is
+  /// not mapped at all, return an empty vector.
+  SmallVector<Value, 1> lookupOrNull(Value from) const;
+
+  /// Find the most recently mapped values for the given value. If those values
+  /// have the desired types, return them. Otherwise, try to find a
+  /// materialization to the desired types.
   ///
-  /// Special cases:
-  /// - If the desired type is "null", simply return the most recently mapped
-  ///   value.
-  /// - If there is no mapping to the desired type, also return the most
-  ///   recently mapped value.
-  /// - If there is no mapping for the given value at all, return the given
-  ///   value.
-  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
-
-  /// Lookup a mapped value within the map, or return null if a mapping does not
-  /// exist. If a mapping exists, this follows the same behavior of
-  /// `lookupOrDefault`.
-  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
-
-  /// Map a value to the one provided.
-  void map(Value oldVal, Value newVal) {
-    LLVM_DEBUG({
-      for (Value it = newVal; it; it = mapping.lookupOrNull(it))
-        assert(it != oldVal && "inserting cyclic mapping");
-    });
-    mapping.map(oldVal, newVal);
-    mappedTo.insert(newVal);
+  /// If the given value is not mapped at all or if there are no mapped values/
+  /// materialization results with the desired types, return an empty vector.
+  SmallVector<Value, 1> lookupOrNull(Value from,
+                                     SmallVector<Type, 1> desiredTypes) const;
+
+  Value lookupOrNull(Value from, Type desiredType) {
+    SmallVector<Value, 1> vals =
+        lookupOrNull(from, SmallVector<Type, 1>{desiredType});
+    if (vals.empty())
+      return Value();
+    assert(vals.size() == 1 && "expected single value");
+    return vals.front();
+  }
+
+  void erase(Value from) { mapping.erase(from); }
+
+  void map(Value from, ValueRange to) {
+#ifndef NDEBUG
+    assert(from && "expected non-null value");
+    assert(!to.empty() && "cannot map to zero values");
+    for (Value v : to)
+      assert(v && "expected non-null value");
+#endif
+    // assert(from != to && "cannot map value to itself");
+    //  TODO: Check for cyclic mapping.
+    assert(!mapping.contains(from) && "value is already mapped");
+    mapping[from].assign(to.begin(), to.end());
+    for (Value v : to)
+      mappedTo.insert(v);
+  }
+
+  void map(Value from, ArrayRef<BlockArgument> to) {
+    SmallVector<Value> vals;
+    for (Value v : to)
+      vals.push_back(v);
+    map(from, vals);
+  }
+  /*
+    void map(Value from, ArrayRef<Value> to) {
+  #ifndef NDEBUG
+      assert(from && "expected non-null value");
+      assert(!to.empty() && "cannot map to zero values");
+      for (Value v : to)
+        assert(v && "expected non-null value");
+  #endif
+      // assert(from != to && "cannot map value to itself");
+      //  TODO: Check for cyclic mapping.
+      assert(!mapping.contains(from) && "value is already mapped");
+      mapping[from].assign(to.begin(), to.end());
+    }
+  */
+
+  void mapMaterialization(SmallVector<Value, 1> from,
+                          SmallVector<Value, 1> to) {
+#ifndef NDEBUG
+    assert(!from.empty() && "from cannot be empty");
+    assert(!to.empty() && "to cannot be empty");
+    for (Value v : from) {
+      assert(v && "expected non-null value");
+      assert(!mapping.contains(v) &&
+             "cannot add materialization for mapped value");
+    }
+    for (Value v : to) {
+      assert(v && "expected non-null value");
+    }
+    assert(TypeRange(from) != TypeRange(to) &&
+           "cannot add materialization for identical type");
+    for (const SmallVector<Value, 1> &mat : materializations[from])
+      assert(TypeRange(mat) != TypeRange(to) &&
+             "cannot register duplicate materialization");
+#endif // NDEBUG
+    materializations[from].push_back(to);
+    for (Value v : to)
+      mappedTo.insert(v);
+  }
+
+  void eraseMaterialization(SmallVector<Value, 1> from,
+                            SmallVector<Value, 1> to) {
+    if (!materializations.count(from))
+      return;
+    auto it = llvm::find(materializations[from], to);
+    if (it == materializations[from].end())
+      return;
+    if (materializations[from].size() == 1)
+      materializations.erase(from);
+    else
+      materializations[from].erase(it);
   }
 
-  /// Drop the last mapping for the given value.
-  void erase(Value value) { mapping.erase(value); }
+  /// Returns the inverse raw value mapping (without recursive query support).
+  DenseMap<Value, SmallVector<Value>> getInverse() const {
+    DenseMap<Value, SmallVector<Value>> inverse;
+
+    for (auto &it : mapping)
+      for (Value v : it.second)
+        inverse[v].push_back(it.first);
+
+    for (auto &it : materializations)
+      for (const SmallVector<Value, 1> &mat : it.second)
+        for (Value v : mat)
+          for (Value v2 : it.first)
+            inverse[v].push_back(v2);
+
+    return inverse;
+  }
 
 private:
-  /// Current value mappings.
-  IRMapping mapping;
+  /// Replacement mapping: Value -> ValueRange
+  DenseMap<Value, SmallVector<Value, 1>> mapping;
+
+  /// Materializations: ValueRange -> ValueRange*
+  DenseMap<SmallVector<Value, 1>, SmallVector<SmallVector<Value, 1>>,
+           SmallVectorMapInfo>
+      materializations;
 
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
-Value ConversionValueMapping::lookupOrDefault(Value from,
-                                              Type desiredType) const {
-  // Try to find the deepest value that has the desired type. If there is no
-  // such value, simply return the deepest value.
-  Value desiredValue;
-  do {
-    if (!desiredType || from.getType() == desiredType)
-      desiredValue = from;
-
-    Value mappedValue = mapping.lookupOrNull(from);
-    if (!mappedValue)
-      break;
-    from = mappedValue;
-  } while (true);
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrDefault(Value from) const {
+  SmallVector<Value, 1> to = lookupOrNull(from);
+  return to.empty() ? SmallVector<Value, 1>{from} : to;
+}
 
-  // If the desired value was found use it, otherwise default to the leaf value.
-  return desiredValue ? desiredValue : from;
+SmallVector<Value, 1> ConversionValueMapping::lookupOrDefault(
+    Value from, SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+  assert(desiredTypes.size() > 0 && "expected non-empty types");
+  for (Type t : desiredTypes)
+    assert(t && "expected non-null type");
+#endif // NDEBUG
+
+  SmallVector<Value, 1> vals = lookupOrNull(from);
+  if (vals.empty()) {
+    // Value is not mapped. Return if the type matches.
+    if (TypeRange(from) == desiredTypes)
+      return {from};
+    // Check materializations.
+    auto it = materializations.find({from});
+    if (it == materializations.end())
+      return {};
+    for (const SmallVector<Value, 1> &mat : it->second)
+      if (TypeRange(mat) == desiredTypes)
+        return mat;
+    return {};
+  }
+
+  return lookupOrNull(from, desiredTypes);
 }
 
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
-  Value result = lookupOrDefault(from, desiredType);
-  if (result == from || (desiredType && result.getType() != desiredType))
-    return nullptr;
+SmallVector<Value, 1> ConversionValueMapping::lookupOrNull(Value from) const {
+  auto it = mapping.find(from);
+  if (it == mapping.end())
+    return {};
+  SmallVector<Value, 1> result;
+  for (Value v : it->second) {
+    llvm::append_range(result, lookupOrDefault(v));
+  }
   return result;
 }
 
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrNull(Value from,
+                                     SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+  assert(desiredTypes.size() > 0 && "expected non-empty types");
+  for (Type t : desiredTypes)
+    assert(t && "expected non-null type");
+#endif // NDEBUG
+
+  SmallVector<Value, 1> vals = lookupOrNull(from);
+  if (vals.empty())
+    return {};
+
+  // There is a mapping and the types match.
+  if (TypeRange(vals) == desiredTypes)
+    return vals;
+
+  // There is a mapping, but the types do not match. Try to find a matching
+  // materialization.
+  auto it = materializations.find(vals);
+  if (it == materializations.end())
+    return {};
+  for (const SmallVector<Value, 1> &mat : it->second)
+    if (TypeRange(mat) == desiredTypes)
+      return mat;
+
+  // No materialization found. Return an empty vector.
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter and Translation State
 //===----------------------------------------------------------------------===//
@@ -673,7 +876,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
                                    UnrealizedConversionCastOp op,
                                    const TypeConverter *converter,
                                    MaterializationKind kind, Type originalType,
-                                   Value mappedValue);
+                                   SmallVector<Value, 1> mappedValue);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -710,7 +913,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   /// The value in the conversion value mapping that is being replaced by the
   /// results of this unresolved materialization.
-  Value mappedValue;
+  SmallVector<Value, 1> mappedValue;
 };
 } // namespace
 
@@ -779,7 +982,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   LogicalResult remapValues(StringRef valueDiagTag,
                             std::optional<Location> inputLoc,
                             PatternRewriter &rewriter, ValueRange values,
-                            SmallVector<SmallVector<Value>> &remapped);
+                            SmallVector<SmallVector<Value, 1>> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
@@ -825,7 +1028,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// mapping.
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+      SmallVector<Value,1> valueToMap, ValueRange inputs, TypeRange outputTypes,
       Type originalType, const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr);
   Value buildUnresolvedMaterialization(
@@ -833,27 +1036,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
       const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr) {
-    return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
+    SmallVector<Value,1> valuesToMap;
+    if (valueToMap) valuesToMap.push_back(valueToMap);
+    return buildUnresolvedMaterialization(kind, ip, loc, valuesToMap, inputs,
                                           TypeRange(outputType), originalType,
                                           converter, castOp)
         .front();
   }
 
-  /// Build an N:1 materialization for the given original value that was
-  /// replaced with the given replacement values.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. The conversion mapping can store only 1:1 replacements
-  /// and the conversion patterns only support single Value replacements in the
-  /// adaptor, so N values must be converted back to a single value. This
-  /// function will be deleted when full 1:N support has been added.
-  ///
-  /// This function inserts an argument materialization back to the original
-  /// type.
-  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
-                                 ValueRange replacements, Value originalValue,
-                                 const TypeConverter *converter);
-
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
   /// value. If there is no replacement value with the correct type, find the
@@ -862,16 +1052,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value findOrBuildReplacementValue(Value value,
                                     const TypeConverter *converter);
 
-  /// Unpack an N:1 materialization and return the inputs of the
-  /// materialization. This function unpacks only those materializations that
-  /// were built with `insertNTo1Materialization`.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. It allows us to write 1:N conversion patterns while
-  /// 1:N support is still missing in the conversion value mapping. This
-  /// function will be deleted when full 1:N support has been added.
-  SmallVector<Value> unpackNTo1Materialization(Value value);
-
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -1101,7 +1281,7 @@ void CreateOperationRewrite::rollback() {
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
     const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    Value mappedValue)
+    SmallVector<Value, 1> mappedValue)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
       converterAndKind(converter, kind), originalType(originalType),
       mappedValue(mappedValue) {
@@ -1111,8 +1291,8 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (mappedValue)
-    rewriterImpl.mapping.erase(mappedValue);
+  if (!mappedValue.empty())
+    rewriterImpl.mapping.eraseMaterialization(mappedValue, op->getResults());
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
   rewriterImpl.nTo1TempMaterializations.erase(getOperation());
   op->erase();
@@ -1160,7 +1340,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
 LogicalResult ConversionPatternRewriterImpl::remapValues(
     StringRef valueDiagTag, std::optional<Location> inputLoc,
     PatternRewriter &rewriter, ValueRange values,
-    SmallVector<SmallVector<Value>> &remapped) {
+    SmallVector<SmallVector<Value, 1>> &remapped) {
   remapped.reserve(llvm::size(values));
 
   for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1348,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type origType = operand.getType();
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
-    // Find the most recently mapped value. Unpack all temporary N:1
-    // materializations. Such conversions are a workaround around missing
-    // 1:N support in the ConversionValueMapping. (The conversion patterns
-    // already support 1:N replacements.)
-    Value repl = mapping.lookupOrDefault(operand);
-    SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
     if (!currentTypeConverter) {
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
       // pass through the most recently mapped value.
-      remapped.push_back(std::move(unpacked));
+      SmallVector<Value, 1> repl = mapping.lookupOrDefault(operand);
+      remapped.push_back(repl);
       continue;
     }
 
@@ -1199,44 +1373,23 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       continue;
     }
 
-    if (legalTypes.size() != 1) {
-      // TODO: This is a 1:N conversion. The conversion value mapping does not
-      // store such materializations yet. If the types of the most recently
-      // mapped values do not match, build a target materialization.
-      ValueRange unpackedRange(unpacked);
-      if (TypeRange(unpackedRange) == legalTypes) {
-        remapped.push_back(std::move(unpacked));
-        continue;
-      }
-
-      // Insert a target materialization if the current pattern expects
-      // different legalized types.
-      ValueRange targetMat = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
-          /*valueToMap=*/Value(), /*inputs=*/unpacked,
-          /*outputType=*/legalTypes, /*originalType=*/origType,
-          currentTypeConverter);
-      remapped.push_back(targetMat);
+    SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes);
+    if (!mat.empty()) {
+      // Mapped value has the correct type or there is an existing
+      // materialization. Or the value is not mapped at all and has the
+      // correct type.
+      remapped.push_back(mat);
       continue;
     }
 
-    // Handle 1->1 type conversions.
-    Type desiredType = legalTypes.front();
-    // Try to find a mapped value with the desired type. (Or the operand itself
-    // if the value is not mapped at all.)
-    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
-    if (newOperand.getType() != desiredType) {
-      // If the looked up value's type does not have the desired type, it means
-      // that the value was replaced with a value of different type and no
-      // target materialization was created yet.
-      Value castValue = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
-          /*outputType=*/desiredType, /*originalType=*/origType,
-          currentTypeConverter);
-      newOperand = castValue;
-    }
-    remapped.push_back({newOperand});
+    // Create a materialization for the most recently mapped value.
+    SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
+    ValueRange castValues = buildUnresolvedMaterialization(
+        MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
+        /*valueToMap=*/vals,
+        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+        currentTypeConverter);
+    remapped.push_back(castValues);
   }
   return success();
 }
@@ -1350,11 +1503,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (!inputMap) {
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
-      buildUnresolvedMaterialization(
+      Value sourceMat = buildUnresolvedMaterialization(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
+          /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
+      mapping.map(origArg, sourceMat);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1369,19 +1523,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       continue;
     }
 
-    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
-    // dialect conversion. Therefore, we need an argument materialization to
-    // turn the replacement block arguments into a single SSA value that can be
-    // used as a replacement.
+    // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    if (replArgs.size() == 1) {
-      mapping.map(origArg, replArgs.front());
-    } else {
-      insertNTo1Materialization(
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
-    }
+    mapping.map(origArg, replArgs);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1402,7 +1547,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+    SmallVector<Value,1> valuesToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
     UnrealizedConversionCastOp *castOp) {
   assert((!originalType || kind == MaterializationKind::Target) &&
@@ -1410,10 +1555,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes) {
-    if (valueToMap) {
-      assert(inputs.size() == 1 && "1:N mapping is not supported");
-      mapping.map(valueToMap, inputs.front());
-    }
+    if (!valuesToMap.empty())
+      mapping.mapMaterialization(valuesToMap, inputs);
     return inputs;
   }
 
@@ -1423,36 +1566,23 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
-  if (valueToMap) {
-    assert(outputTypes.size() == 1 && "1:N mapping is not supported");
-    mapping.map(valueToMap, convertOp.getResult(0));
-  }
+  if (!valuesToMap.empty())
+    mapping.mapMaterialization(valuesToMap, {convertOp.getResult(0)});
   if (castOp)
     *castOp = convertOp;
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType, valueToMap);
+                                                  originalType, valuesToMap);
   return convertOp.getResults();
 }
 
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
-    OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
-    Value originalValue, const TypeConverter *converter) {
-  // Insert argument materialization back to the original type.
-  Type originalType = originalValue.getType();
-  UnrealizedConversionCastOp argCastOp;
-  buildUnresolvedMaterialization(
-      MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType,
-      /*originalType=*/Type(), converter, &argCastOp);
-  if (argCastOp)
-    nTo1TempMaterializations.insert(argCastOp);
-}
-
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
+  //if (Value repl = mapping.lookupDirectSingleReplacement(value))
+  //  if (repl.getType() == value.getType())
+  //    return repl;
+
   // Find a replacement value with the same type.
-  Value repl = mapping.lookupOrNull(value, value.getType());
-  if (repl)
+  if (Value repl = mapping.lookupOrNull(value, value.getType()))
     return repl;
 
   // Check if the value is dead. No replacement value is needed in that case.
@@ -1467,8 +1597,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   // No replacement value was found. Get the latest replacement value
   // (regardless of the type) and build a source materialization to the
   // original type.
-  repl = mapping.lookupOrNull(value);
-  if (!repl) {
+  SmallVector<Value, 1> repl = mapping.lookupOrNull(value);
+  if (repl.empty()) {
     // No replacement value is registered in the mapping. This means that the
     // value is dropped and no longer needed. (If the value were still needed,
     // a source materialization producing a replacement value "out of thin air"
@@ -1478,34 +1608,12 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   }
   Value castValue = buildUnresolvedMaterialization(
       MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
-      /*originalType=*/Type(), converter);
-  mapping.map(value, castValue);
+      /*valueToMap=*/repl, /*inputs=*/repl, /*outputType=*/{value.getType()},
+      /*originalType=*/Type(), converter)[0];
+  //mapping.map(value, castValue);
   return castValue;
 }
 
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
-  // Unpack unrealized_conversion_cast ops that were inserted as a N:1
-  // workaround.
-  auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-  if (!castOp)
-    return {value};
-  if (!nTo1TempMaterializations.contains(castOp))
-    return {value};
-  assert(castOp->getNumResults() == 1 && "expected single result");
-
-  SmallVector<Value> result;
-  for (Value v : castOp.getOperands()) {
-    // Keep unpacking if possible. This is needed because during block
-    // signature conversions and 1:N op replacements, the driver may have
-    // inserted two materializations back-to-back: first an argument
-    // materialization, then a target materialization.
-    llvm::append_range(result, unpackNTo1Materialization(v));
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1552,11 +1660,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       }
 
       // Materialize a replacement value "out of thin air".
-      buildUnresolvedMaterialization(
+      Value sourceMat = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
-          result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
+          result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
+      mapping.map(result, sourceMat);
       continue;
     } else {
       // Make sure that the user does not mess with unresolved materializations
@@ -1572,16 +1681,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
     // Remap result to replacement value.
     if (repl.empty())
       continue;
-
-    if (repl.size() == 1) {
-      // Single replacement value: replace directly.
-      mapping.map(result, repl.front());
-    } else {
-      // Multiple replacement values: insert N:1 materialization.
-      insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
-                                /*replacements=*/repl, /*outputValue=*/result,
-                                currentTypeConverter);
-    }
+    mapping.map(result, repl);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1760,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
   SmallVector<ValueRange> newVals;
-  for (size_t i = 0; i < newValues.size(); ++i)
-    newVals.push_back(newValues.slice(i, 1));
+  for (size_t i = 0; i < newValues.size(); ++i) {
+    if (newValues[i]) {
+      newVals.push_back(newValues.slice(i, 1));
+    } else {
+      newVals.push_back(ValueRange());
+    }
+  }
   impl->notifyOpReplaced(op, newVals);
 }
 
@@ -1729,11 +1834,14 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from);
+  assert(mapped.size() == 1 &&
+         "replaceUsesOfBlockArgument is not supported for 1:N replacements");
+  impl->mapping.map(mapped.front(), to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
-  SmallVector<SmallVector<Value>> remappedValues;
+  SmallVector<SmallVector<Value, 1>> remappedValues;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
                                remappedValues)))
     return nullptr;
@@ -1746,7 +1854,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                                              SmallVectorImpl<Value> &results) {
   if (keys.empty())
     return success();
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<SmallVector<Value, 1>> remapped;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
                                remapped)))
     return failure();
@@ -1872,7 +1980,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
                                              getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<SmallVector<Value, 1>> remapped;
   if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
                                       op->getOperands(), remapped))) {
     return failure();
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad50..d0b62e71ab0cf2 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
           tupleType.getFlattenedTypes(types);
           return success();
         });
-    typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+    typeConverter.addSourceMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 466ae7ff6f46f1..749df2cb9ea7cc 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1235,7 +1235,6 @@ struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
   TestTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d788..a03bf0a1023d57 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
 struct PDLLTypeConverter : public TypeConverter {
   PDLLTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 



More information about the Mlir-commits mailing list