[llvm-branch-commits] [flang] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Dec 15 09:53:13 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116524

>From bf0d13553b2bc2124a266e398976ba80a1114580 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 16:34:47 +0100
Subject: [PATCH 1/4] [mlir][Vector] Move mask materialization patterns to
 greedy rewrite

The mask materialization patterns during `VectorToLLVM` are rewrite patterns. They should run as part of the greedy pattern rewrite and not the dialect conversion. (Rewrite patterns and conversion patterns are not generally compatible.)

The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions.
---
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  7 +-
 .../VectorToLLVM/vector-mask-to-llvm.mlir     |  4 +-
 .../VectorToLLVM/vector-to-llvm.mlir          |  4 +-
 .../VectorToLLVM/vector-xfer-to-llvm.mlir     | 80 +++++++++----------
 4 files changed, 44 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4623b9667998cc..64a9ad8e9bade0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -61,8 +61,8 @@ struct ConvertVectorToLLVMPass
 } // namespace
 
 void ConvertVectorToLLVMPass::runOnOperation() {
-  // Perform progressive lowering of operations on slices and
-  // all contraction operations. Also applies folding and DCE.
+  // Perform progressive lowering of operations on slices and all contraction
+  // operations. Also materializes masks, applies folding and DCE.
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -76,6 +76,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
                                             VectorTransformsOptions());
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
+    populateVectorMaskMaterializationPatterns(patterns,
+                                              force32BitVectorIndices);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
@@ -83,7 +85,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   LowerToLLVMOptions options(&getContext());
   LLVMTypeConverter converter(&getContext(), options);
   RewritePatternSet patterns(&getContext());
-  populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 82351eb7c98a43..91e5358622b69d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -7,7 +7,7 @@
 // CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
 // CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32>
 // CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi32>
-// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32>
+// CMP32: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi32>
 // CMP32: return %[[T4]] : vector<11xi1>
 
 // CMP64-LABEL: @genbool_var_1d(
@@ -16,7 +16,7 @@
 // CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
 // CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
 // CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
-// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64>
+// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
 // CMP64: return %[[T4]] : vector<11xi1>
 
 func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2473fe933ffcb2..ea88fece9e662d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -3097,7 +3097,7 @@ func.func @create_mask_0d(%num_elems : index) -> vector<i1> {
 // CHECK:  %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
 // CHECK:  %[[BOUNDS:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
 // CHECK:  %[[BOUNDS_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOUNDS]] : vector<1xi32> to vector<i32>
-// CHECK:  %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS_CAST]] : vector<i32>
+// CHECK:  %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS_CAST]], %[[INDICES]] : vector<i32>
 // CHECK:  return %[[RESULT]] : vector<i1>
 
 // -----
@@ -3113,7 +3113,7 @@ func.func @create_mask_1d(%num_elems : index) -> vector<4xi1> {
 // CHECK:  %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
 // CHECK:  %[[BOUNDS_INSERT:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
 // CHECK:  %[[BOUNDS:.*]] = llvm.shufflevector %[[BOUNDS_INSERT]]
-// CHECK:  %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS]] : vector<4xi32>
+// CHECK:  %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS]], %[[INDICES]] : vector<4xi32>
 // CHECK:  return %[[RESULT]] : vector<4xi1>
 
 // -----
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
index 8f01cc2b8d44c3..d3f6d7eca90b41 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
@@ -14,30 +14,28 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
 // CHECK-LABEL: func @transfer_read_write_1d
 //  CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
 //  CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
-//       CHECK: %[[C7:.*]] = arith.constant 7.0
-//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
-//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
+// 1. Create pass-through vector.
+//   CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<17xf32>
 //
 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = arith.constant dense
+//   CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense
 //  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
 //
-// 3. Create bound vector to compute in-bound mask:
+// 3. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
+//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
+//
+// 4. Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] :
 //  CMP32-SAME: index to i32
 //  CMP64-SAME: index to i64
 //       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
 //       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]>
+//       CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
 //  CMP64-SAME: : vector<17xi64>
 //
-// 4. Create pass-through vector.
-//       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
-//
 // 5. Bitcast to vector form.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
 //  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -48,28 +46,23 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
 //  CHECK-SAME: -> vector<17xf32>
 //
 // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
 //       CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
 //
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex_b:.*]] = arith.constant dense
-//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
-//
-// 3. Create bound vector to compute in-bound mask:
+// 2. Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]]
 //  CMP32-SAME: index to i32
 //       CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
 //       CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
-//       CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
-//  CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]>
+//       CHECK: %[[mask_b:.*]] = arith.cmpi sgt, %[[boundVect_b]],
+//  CHECK-SAME: %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
 //
-// 4. Bitcast to vector form.
+// 3. Bitcast to vector form.
 //       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
 //  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
 //
-// 5. Rewrite as a masked write.
+// 4. Rewrite as a masked write.
 //       CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
 //  CHECK-SAME: {alignment = 4 : i32} :
 //  CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
@@ -87,17 +80,18 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
 // CHECK-LABEL: func @transfer_read_write_1d_scalable
 //  CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
 //  CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
-//       CHECK: %[[C7:.*]] = arith.constant 7.0
+// 1. Create pass-through vector.
+//   CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<[17]xf32>
 //
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+// 2. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
 //       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
 //
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 3. Create a vector with linear indices [ 0 .. vector_length - 1 ].
 //       CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
 //
-// 3. Create bound vector to compute in-bound mask:
+// 4. Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
 //       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
@@ -105,9 +99,6 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
 //       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
 //  CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]>
 //
-// 4. Create pass-through vector.
-//       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
-//
 // 5. Bitcast to vector form.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
 //  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -118,8 +109,7 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
 //  CHECK-SAME: -> vector<[17]xf32>
 //
 // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
 //       CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
 //
 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
@@ -197,23 +187,23 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
 }
 // CHECK-LABEL: func @transfer_read_2d_to_1d
 //  CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
-//       CHECK: %[[c1:.*]] = arith.constant 1 : index
+//
+// Create a vector with linear indices [ 0 .. vector_length - 1 ].
+//   CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+//  CHECK-SAME: vector<17x[[$IDX_TYPE]]>
+//
+//   CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
 //
 // Compute the in-bound index (dim - offset)
 //       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
 //
-// Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = arith.constant dense
-//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
-//  CHECK-SAME: vector<17x[[$IDX_TYPE]]>
-//
 // Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
 //       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
 //       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+//       CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]]
 
 func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
   %f7 = arith.constant 7.0: f32
@@ -255,12 +245,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %bas
 // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
 //  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
 //
+//       CHECK: %[[c0:.*]] = arith.constant 0 : index
+//
 // 1. Check address space for GEP is correct.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
 //  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
 //
 // 2. Check address space of the memref is correct.
-//       CHECK: %[[c0:.*]] = arith.constant 0 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
 //
 // 3. Check address space for GEP is correct.
@@ -280,12 +271,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32,
 // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
 //  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
 //
+//       CHECK: %[[c0:.*]] = arith.constant 0 : index
+//
 // 1. Check address space for GEP is correct.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
 //  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
 //
 // 2. Check address space of the memref is correct.
-//       CHECK: %[[c0:.*]] = arith.constant 0 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
 //
 // 3. Check address space for GEP is correct.
@@ -330,10 +322,10 @@ func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index)
 
 // CHECK-LABEL: func @transfer_read_write_1d_mask
 // CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
-// CHECK: %[[cmpi:.*]] = arith.cmpi slt
+// CHECK: %[[cmpi:.*]] = arith.cmpi sgt
 // CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
 // CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
-// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
+// CHECK: %[[cmpi_1:.*]] = arith.cmpi sgt
 // CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
 // CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
 // CHECK: return %[[r]]

>From e5926b63835eec731c513d91ce9c451c429ca572 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 17:19:18 +0100
Subject: [PATCH 2/4] [mlir][Vector] Clean up
 `populateVectorToLLVMConversionPatterns`

---
 .../Vector/Transforms/LoweringPatterns.h      |  4 +++
 .../GPUCommon/GPUToLLVMConversion.cpp         | 12 +++++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 27 ++++++++++---------
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  6 ++++-
 .../Conversion/GPUCommon/lower-vector.mlir    |  4 +--
 .../VectorToLLVM/vector-to-llvm.mlir          |  5 ----
 6 files changed, 37 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 3d643c96b45008..c507b23c6d4de6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
                                            int64_t targetRank = 1,
                                            PatternBenefit benefit = 1);
 
+/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
+/// n > 1.
+void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1497d662dcdbdd..2fe3b1302e5e5b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -32,10 +32,12 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Error.h"
@@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 void GpuToLLVMConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
+
+  // Perform progressive lowering of vector transfer operations.
+  {
+    RewritePatternSet patterns(&getContext());
+    // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+    vector::populateVectorTransferLoweringPatterns(patterns,
+                                                   /*maxTransferRank=*/1);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+
   LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
   RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9a07c323c7358..577b74bb7e0c26 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion
 
 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
 /// Non-scalable versions of this operation are handled in Vector Transforms.
-class VectorCreateMaskOpRewritePattern
-    : public OpRewritePattern<vector::CreateMaskOp> {
+class VectorCreateMaskOpConversion
+    : public OpConversionPattern<vector::CreateMaskOp> {
 public:
-  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+  explicit VectorCreateMaskOpConversion(MLIRContext *context,
                                             bool enableIndexOpt)
-      : OpRewritePattern<vector::CreateMaskOp>(context),
+      : OpConversionPattern<vector::CreateMaskOp>(context),
         force32BitVectorIndices(enableIndexOpt) {}
 
-  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
-                                PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const override {
     auto dstType = op.getType();
     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
       return failure();
@@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern
         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
-                                                 op.getOperand(0));
+                                                 adaptor.getOperands()[0]);
     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
@@ -1896,16 +1896,19 @@ struct VectorScalableStepOpLowering
 
 } // namespace
 
+void mlir::vector::populateVectorRankReducingFMAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions, bool force32BitVectorIndices) {
+  // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
-  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
-  populateVectorInsertExtractStridedSliceTransforms(patterns);
-  populateVectorStepLoweringPatterns(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
-  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
+  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1925,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
                VectorScalableStepOpLowering>(converter);
-  // Transfer ops with rank > 1 are handled by VectorToSCF.
-  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 64a9ad8e9bade0..2d94c2f2e85a08 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
 
 void ConvertVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and all contraction
-  // operations. Also materializes masks, applies folding and DCE.
+  // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
+  // applies folding and DCE.
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
                                               force32BitVectorIndices);
+    populateVectorInsertExtractStridedSliceTransforms(patterns);
+    populateVectorStepLoweringPatterns(patterns);
+    populateVectorRankReducingFMAPattern(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 44deb45cd752b4..532a2383cea9ef 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
-  func.func @func(%arg: vector<11xf32>) {
+  func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
     // CHECK: vector.mask
     // CHECK-SAME: vector.yield %arg0
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
-    return
+    return %127 : vector<11xf32>
   }
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ea88fece9e662d..f95e943250bd44 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
 //  CHECK-SAME:     %[[ARG:.*]]: vector<4x[8]xf32>)
 // CHECK:           %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK:           %[[T2:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
 // CHECK:           %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
 // CHECK:           %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
   return %0 : vector<4x4x4xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 
 // -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
   return %0 : vector<4x4x[4]xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 
 // -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
   return %0 : vector<4x4x4xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 
 // -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
   return %0 : vector<4x4x[4]xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 
 // -----

>From ddc92940e9a1490f5a26d77576f4b663253e3cdd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH 3/4] [mlir][Transforms] Dialect Conversion: Do not build
 target mat. during 1:N replacement

fix test

experiement
---
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 130 ++++++++++++------
 .../Transforms/Utils/DialectConversion.cpp    |  46 ++-----
 mlir/test/Transforms/test-legalizer.mlir      |   8 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  47 +++----
 4 files changed, 128 insertions(+), 103 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Add generic source and target materializations to handle cases where
+  // non-LLVM types persist after an LLVM conversion.
+  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+
   // Helper function that checks if the given value range is a bare pointer.
   auto isBarePointer = [](ValueRange values) {
     return values.size() == 1 &&
            isa<LLVM::LLVMPointerType>(values.front().getType());
   };
 
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type. The dialect conversion framework will then
-  // insert a target materialization from the original block argument type to
-  // a legal type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packUnrankedMemRefDesc =
+      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+          Location loc, LLVMTypeConverter &converter) -> Value {
     // Note: Bare pointers are not supported for unranked memrefs because a
     // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
       return Value();
-    Value desc =
-        UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                          inputs);
+  };
+
+  // MemRef descriptor elements -> UnrankedMemRefType
+  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+                                          UnrankedMemRefType resultType,
+                                          ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type
     // `resultType`, so insert a cast from the memref descriptor type
     // (!llvm.struct) to the original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    Value desc;
-    if (isBarePointer(inputs)) {
-      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
-                                               inputs[0]);
-    } else if (TypeRange(inputs) ==
-               getMemRefDescriptorFields(resultType,
-                                         /*unpackAggregates=*/true)) {
-      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    } else {
-      // The inputs are neither a bare pointer nor an unpacked memref
-      // descriptor. This materialization function cannot be used.
+    Value packed =
+        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-    }
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+        .getResult(0);
+  };
+
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  LLVMTypeConverter &converter) -> Value {
+    assert(resultType && "expected non-null result type");
+    if (isBarePointer(inputs))
+      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                               resultType, inputs[0]);
+    if (TypeRange(inputs) ==
+        converter.getMemRefDescriptorFields(resultType,
+                                            /*unpackAggregates=*/true))
+      return MemRefDescriptor::pack(builder, loc, converter, resultType,
+                                    inputs);
+    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+    // This materialization function cannot be used.
+    return Value();
+  };
+
+  // MemRef descriptor elements -> MemRefType
+  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
     // original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  // Add generic source and target materializations to handle cases where
-  // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+    Value packed =
+        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
         .getResult(0);
-  });
+  };
+
+  // Argument materializations convert from the new block argument types
+  // (multiple SSA values that make up a memref descriptor) back to the
+  // original block argument type.
+  addArgumentMaterialization(unrakedMemRefMaterialization);
+  addArgumentMaterialization(rankedMemRefMaterialization);
+  addSourceMaterialization(unrakedMemRefMaterialization);
+  addSourceMaterialization(rankedMemRefMaterialization);
+
+  // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> Value {
+    // The original MemRef type is required to build a MemRef descriptor
+    // because the sizes/strides of the MemRef cannot be inferred from just the
+    // bare pointer.
+    if (!originalType)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
+    if (resultType != convertType(originalType))
+      return Value();
+    if (auto memrefType = dyn_cast<MemRefType>(originalType))
+      return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+    if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+      return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+                                    *this);
+    return Value();
   });
 
   // Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// function will be deleted when full 1:N support has been added.
   ///
   /// This function inserts an argument materialization back to the original
-  /// type, followed by a target materialization to the legalized type (if
-  /// applicable).
+  /// type.
   void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
                                  ValueRange replacements, Value originalValue,
                                  const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    insertNTo1Materialization(
-        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    if (replArgs.size() == 1) {
+      mapping.map(origArg, replArgs.front());
+    } else {
+      insertNTo1Materialization(
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
   // Insert argument materialization back to the original type.
   Type originalType = originalValue.getType();
   UnrealizedConversionCastOp argCastOp;
-  Value argMat = buildUnresolvedMaterialization(
+  buildUnresolvedMaterialization(
       MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
-      &argCastOp);
+      /*inputs=*/replacements, originalType,
+      /*originalType=*/Type(), converter, &argCastOp);
   if (argCastOp)
     nTo1TempMaterializations.insert(argCastOp);
-
-  // Insert target materialization to the legalized type.
-  Type legalOutputType;
-  if (converter) {
-    legalOutputType = converter->convertType(originalType);
-  } else if (replacements.size() == 1) {
-    // When there is no type converter, assume that the replacement value
-    // types are legal. This is reasonable to assume because they were
-    // specified by the user.
-    // FIXME: This won't work for 1->N conversions because multiple output
-    // types are not supported in parts of the dialect conversion. In such a
-    // case, we currently use the original value type.
-    legalOutputType = replacements[0].getType();
-  }
-  if (legalOutputType && legalOutputType != originalType) {
-    UnrealizedConversionCastOp targetCastOp;
-    buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(argMat), loc,
-        /*valueToMap=*/argMat, /*inputs=*/argMat,
-        /*outputType=*/legalOutputType, /*originalType=*/originalType,
-        converter, &targetCastOp);
-    if (targetCastOp)
-      nTo1TempMaterializations.insert(targetCastOp);
-  }
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
+  assert(this && "expected non-null type converter");
+  assert(t && "expected non-null type");
+
   {
     std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
                                                          std::defer_lock);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
   // expected-remark at +1 {{op 'foo.region' is not legalizable}}
   "foo.region"() ({
-    // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
-    ^bb0(%i0: i64, %unused: i16, %i1: i64):
-      // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
-      "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+    // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+    ^bb0(%i0: f64, %unused: i16, %i1: f64):
+      // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+      "test.invalid"(%i0, %i1) : (f64, f64) -> ()
   }) : () -> ()
   // expected-remark at +1 {{op 'func.return' is not legalizable}}
   return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 8a0bc597c56beb..466ae7ff6f46f1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
 };
 /// This pattern simply updates the operands of the given operation.
 struct TestPassthroughInvalidOp : public ConversionPattern {
-  TestPassthroughInvalidOp(MLIRContext *ctx)
-      : ConversionPattern("test.invalid", 1, ctx) {}
+  TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.invalid", 1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1301,19 +1301,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns.add<
-        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-        TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
-        TestSplitReturnType, TestChangeProducerTypeI32ToF32,
-        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
-        TestUpdateConsumerType, TestNonRootReplacement,
-        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
-        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-        TestUndoPropertiesModification, TestEraseOp,
-        TestRepetitive1ToNConsumer>(&getContext());
-    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
-        &getContext(), converter);
+    patterns
+        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+             TestNonRootReplacement, TestBoundedRecursiveRewrite,
+             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+             TestUndoPropertiesModification, TestEraseOp,
+             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+                 TestPassthroughInvalidOp>(&getContext(), converter);
     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1749,8 +1749,9 @@ struct TestTypeConversionAnotherProducer
 };
 
 struct TestReplaceWithLegalOp : public ConversionPattern {
-  TestReplaceWithLegalOp(MLIRContext *ctx)
-      : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+  TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(converter, "test.replace_with_legal_op",
+                          /*benefit=*/1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1872,12 +1873,12 @@ struct TestTypeConversionDriver
 
     // Initialize the set of rewrite patterns.
     RewritePatternSet patterns(&getContext());
-    patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
-                 TestSignatureConversionUndo,
-                 TestTestSignatureConversionNoConverter>(converter,
-                                                         &getContext());
-    patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
-        &getContext());
+    patterns
+        .add<TestTypeConsumerForward, TestTypeConversionProducer,
+             TestSignatureConversionUndo,
+             TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+            converter, &getContext());
+    patterns.add<TestTypeConversionAnotherProducer>(&getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
 

>From 98214ce9901c3e7624fc292f1ec8f12cfe146c23 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 15 Dec 2024 17:36:49 +0100
Subject: [PATCH 4/4] ex

---
 .../lib/Optimizer/CodeGen/BoxedProcedure.cpp  |   1 -
 mlir/docs/DialectConversion.md                |  35 +-
 .../mlir/Transforms/DialectConversion.h       |  18 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   |  14 +-
 .../EmitC/Transforms/TypeConversions.cpp      |   1 -
 .../Dialect/Linalg/Transforms/Detensorize.cpp |   1 -
 .../Quant/Transforms/StripFuncQuantTypes.cpp  |   1 -
 .../Utils/SparseTensorDescriptor.cpp          |   3 -
 .../Vector/Transforms/VectorLinearize.cpp     |   1 -
 .../Transforms/Utils/DialectConversion.cpp    | 345 ++++++++----------
 mlir/test/Transforms/test-legalizer.mlir      |   7 +-
 .../Func/TestDecomposeCallGraphTypes.cpp      |   2 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |   1 -
 .../lib/Transforms/TestDialectConversion.cpp  |   1 -
 14 files changed, 173 insertions(+), 258 deletions(-)

diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
index 1bb91d252529f0..104ae7408b80c1 100644
--- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
+++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp
@@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter {
     addConversion([&](TypeDescType ty) {
       return TypeDescType::get(convertType(ty.getOfTy()));
     });
-    addArgumentMaterialization(materializeProcedure);
     addSourceMaterialization(materializeProcedure);
     addTargetMaterialization(materializeProcedure);
   }
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 3168f5e13c7515..abacd5a82c61eb 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure
 type safety during the conversion process. There are several types of
 materializations depending on the situation.
 
-*   Argument Materialization
-
-    -   An argument materialization is used when converting the type of a block
-        argument during a [signature conversion](#region-signature-conversion).
-        The new block argument types are specified in a `SignatureConversion`
-        object. An original block argument can be converted into multiple
-        block arguments, which is not supported everywhere in the dialect
-        conversion. (E.g., adaptors support only a single replacement value for
-        each original value.) Therefore, an argument materialization is used to
-        convert potentially multiple new block arguments back into a single SSA
-        value. An argument materialization is also used when replacing an op
-        result with multiple values.
-
 *   Source Materialization
 
     -   A source materialization is used when a value was replaced with a value
@@ -343,17 +330,6 @@ class TypeConverter {
   /// Materialization functions must be provided when a type conversion may
   /// persist after the conversion has finished.
 
-  /// This method registers a materialization that will be called when
-  /// converting (potentially multiple) block arguments that were the result of
-  /// a signature conversion of a single block argument, to a single SSA value
-  /// with the old argument type.
-  template <typename FnT,
-            typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
-  void addArgumentMaterialization(FnT &&callback) {
-    argumentMaterializations.emplace_back(
-        wrapMaterialization<T>(std::forward<FnT>(callback)));
-  }
-
   /// This method registers a materialization that will be called when
   /// converting a replacement value back to its original source type.
   /// This is used when some uses of the original value persist beyond the main
@@ -406,12 +382,11 @@ done explicitly via a conversion pattern.
 To convert the types of block arguments within a Region, a custom hook on the
 `ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook
 uses a provided type converter to apply type conversions to all blocks of a
-given region. As noted above, the conversions performed by this method use the
-argument materialization hook on the `TypeConverter`. This hook also takes an
-optional `TypeConverter::SignatureConversion` parameter that applies a custom
-conversion to the entry block of the region. The types of the entry block
-arguments are often tied semantically to the operation, e.g.,
-`func::FuncOp`, `AffineForOp`, etc.
+given region. This hook also takes an optional
+`TypeConverter::SignatureConversion` parameter that applies a custom conversion
+to the entry block of the region. The types of the entry block arguments are
+often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`,
+etc.
 
 To convert the signature of just one given block, the
 `applySignatureConversion` hook can be used.
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 28150e886913e3..9a6975dcf8dfae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,6 +181,10 @@ class TypeConverter {
   /// converting (potentially multiple) block arguments that were the result of
   /// a signature conversion of a single block argument, to a single SSA value
   /// with the old block argument type.
+  ///
+  /// Note: Argument materializations are used only with the 1:N dialect
+  /// conversion driver. The 1:N dialect conversion driver will be removed soon
+  /// and so will be argument materializations.
   template <typename FnT, typename T = typename llvm::function_traits<
                               std::decay_t<FnT>>::template arg_t<1>>
   void addArgumentMaterialization(FnT &&callback) {
@@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
   void replaceOp(Operation *op, Operation *newOp) override;
 
   /// Replace the given operation with the new value ranges. The number of op
-  /// results and value ranges must match. If an original SSA value is replaced
-  /// by multiple SSA values (i.e., a value range has more than 1 element), the
-  /// conversion driver will insert an argument materialization to convert the
-  /// N SSA values back into 1 SSA value of the original type. The given
-  /// operation is erased.
-  ///
-  /// Note: The argument materialization is a workaround until we have full 1:N
-  /// support in the dialect conversion. (It is going to disappear from both
-  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+  /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1285,8 +1281,8 @@ struct ConversionConfig {
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
 
-  /// If set to "true", the dialect conversion attempts to build source/target/
-  /// argument materializations through the type converter API in lieu of
+  /// If set to "true", the dialect conversion attempts to build source/target
+  /// materializations through the type converter API in lieu of
   /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
   /// at least one materialization could not be built.
   ///
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..d27b557736c924 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -189,9 +189,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
                                           UnrankedMemRefType resultType,
                                           ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type
-    // `resultType`, so insert a cast from the memref descriptor type
-    // (!llvm.struct) to the original memref type.
+    // A source materialization must return a value of type `resultType`, so
+    // insert a cast from the memref descriptor type (!llvm.struct) to the
+    // original memref type.
     Value packed =
         packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
     if (!packed)
@@ -223,7 +223,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   auto rankedMemRefMaterialization = [&](OpBuilder &builder,
                                          MemRefType resultType,
                                          ValueRange inputs, Location loc) {
-    // An argument materialization must return a value of type `resultType`,
+    // A source materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
     // original memref type.
     Value packed =
@@ -234,11 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   };
 
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
+  // Source materializations convert from the new block argument types
+  // (e.g., multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization(unrakedMemRefMaterialization);
-  addArgumentMaterialization(rankedMemRefMaterialization);
   addSourceMaterialization(unrakedMemRefMaterialization);
   addSourceMaterialization(rankedMemRefMaterialization);
 
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
 
   converter.addSourceMaterialization(materializeAsUnrealizedCast);
   converter.addTargetMaterialization(materializeAsUnrealizedCast);
-  converter.addArgumentMaterialization(materializeAsUnrealizedCast);
 }
 
 /// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index af38485291182f..61bc5022893741 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
     });
 
     addSourceMaterialization(sourceMaterializationCallback);
-    addArgumentMaterialization(sourceMaterializationCallback);
   }
 };
 
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
     addConversion(convertQuantizedType);
     addConversion(convertTensorType);
 
-    addArgumentMaterialization(materializeConversion);
     addSourceMaterialization(materializeConversion);
     addTargetMaterialization(materializeConversion);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 
   // Required by scf.for 1:N type conversion.
   addSourceMaterialization(materializeTuple);
-
-  // Required as a workaround until we have full 1:N support.
-  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
 
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
-  typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 51686646a0a2fc..d521df819895dd 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -63,11 +63,64 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
   return OpBuilder::InsertPoint(insertBlock, insertPt);
 }
 
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+  assert(!vals.empty() && "expected at least one value");
+  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+  for (Value v : vals.drop_front()) {
+    OpBuilder::InsertPoint pt2 = computeInsertPoint(v);
+    assert(pt.getBlock() == pt2.getBlock());
+    if (pt.getPoint() == pt.getBlock()->begin()) {
+      pt = pt2;
+      continue;
+    }
+    if (pt2.getPoint() == pt2.getBlock()->begin()) {
+      continue;
+    }
+    if (pt.getPoint()->isBeforeInBlock(&*pt2.getPoint()))
+      pt = pt2;
+  }
+  return pt;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
 namespace {
+
+struct SmallVectorMapInfo {
+  static SmallVector<Value, 1> getEmptyKey() { return SmallVector<Value, 1>{}; }
+  static SmallVector<Value, 1> getTombstoneKey() {
+    return SmallVector<Value, 1>{};
+  }
+  static ::llvm::hash_code getHashValue(SmallVector<Value, 1> val) {
+    return ::llvm::hash_combine_range(val.begin(), val.end());
+  }
+  static bool isEqual(SmallVector<Value, 1> LHS, SmallVector<Value, 1> RHS) {
+    return LHS == RHS;
+  }
+};
+
+struct MultiIRMapping {
+  SmallVector<Value, 1> lookupOrNull(SmallVector<Value, 1> from) const {
+    auto it = mapping.find(from);
+    if (it == mapping.end())
+      return {};
+    return it->second;
+  }
+
+  void map(SmallVector<Value, 1> from, SmallVector<Value, 1> to) {
+    mapping[from] = to;
+  }
+
+  void erase(SmallVector<Value, 1> from) { mapping.erase(from); }
+
+  DenseMap<SmallVector<Value, 1>, SmallVector<Value, 1>, SmallVectorMapInfo>
+      mapping;
+};
+
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
@@ -85,58 +138,80 @@ struct ConversionValueMapping {
   ///   recently mapped value.
   /// - If there is no mapping for the given value at all, return the given
   ///   value.
-  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
+  SmallVector<Value, 1> lookupOrDefault(SmallVector<Value, 1> from,
+                                        TypeRange desiredType = {}) const;
 
   /// Lookup a mapped value within the map, or return null if a mapping does not
   /// exist. If a mapping exists, this follows the same behavior of
   /// `lookupOrDefault`.
-  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
+  SmallVector<Value, 1> lookupOrNull(SmallVector<Value, 1> from,
+                                     TypeRange desiredType = {}) const;
 
   /// Map a value to the one provided.
-  void map(Value oldVal, Value newVal) {
+  void map(SmallVector<Value, 1> oldVal, SmallVector<Value, 1> newVal) {
     LLVM_DEBUG({
-      for (Value it = newVal; it; it = mapping.lookupOrNull(it))
+      for (SmallVector<Value, 1> it = newVal; !it.empty();
+           it = mapping.lookupOrNull(it))
         assert(it != oldVal && "inserting cyclic mapping");
     });
     mapping.map(oldVal, newVal);
-    mappedTo.insert(newVal);
+    for (Value v : newVal)
+      mappedTo.insert(v);
   }
 
   /// Drop the last mapping for the given value.
-  void erase(Value value) { mapping.erase(value); }
+  void erase(SmallVector<Value, 1> value) { mapping.erase(value); }
 
 private:
   /// Current value mappings.
-  IRMapping mapping;
+  MultiIRMapping mapping;
 
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
-Value ConversionValueMapping::lookupOrDefault(Value from,
-                                              Type desiredType) const {
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrDefault(SmallVector<Value, 1> from,
+                                        TypeRange desiredType) const {
   // Try to find the deepest value that has the desired type. If there is no
   // such value, simply return the deepest value.
-  Value desiredValue;
+  SmallVector<Value, 1> desiredValue;
   do {
-    if (!desiredType || from.getType() == desiredType)
+    if (desiredType.empty() || TypeRange(from) == desiredType)
       desiredValue = from;
 
-    Value mappedValue = mapping.lookupOrNull(from);
-    if (!mappedValue)
+    SmallVector<Value, 1> next;
+    for (Value v : from) {
+      SmallVector<Value, 1> mappedValue = mapping.lookupOrNull({v});
+      if (!mappedValue.empty()) {
+        llvm::append_range(next, mappedValue);
+      } else {
+        next.push_back(v);
+      }
+    }
+    if (next != from) {
+      from = next;
+      continue;
+    }
+    next.clear();
+    next = mapping.lookupOrNull(from);
+    if (next.empty())
       break;
-    from = mappedValue;
+    from = next;
   } while (true);
 
   // If the desired value was found use it, otherwise default to the leaf value.
-  return desiredValue ? desiredValue : from;
+  return !desiredValue.empty() ? desiredValue : from;
 }
 
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
-  Value result = lookupOrDefault(from, desiredType);
-  if (result == from || (desiredType && result.getType() != desiredType))
-    return nullptr;
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrNull(SmallVector<Value, 1> from,
+                                     TypeRange desiredType) const {
+  SmallVector<Value, 1> result = lookupOrDefault(from, desiredType);
+  if (result == from ||
+      (!desiredType.empty() && TypeRange(result) != desiredType))
+    return {};
   return result;
 }
 
@@ -651,10 +726,6 @@ class CreateOperationRewrite : public OperationRewrite {
 
 /// The type of materialization.
 enum MaterializationKind {
-  /// This materialization materializes a conversion for an illegal block
-  /// argument type, to the original one.
-  Argument,
-
   /// This materialization materializes a conversion from an illegal type to a
   /// legal one.
   Target,
@@ -673,7 +744,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
                                    UnrealizedConversionCastOp op,
                                    const TypeConverter *converter,
                                    MaterializationKind kind, Type originalType,
-                                   Value mappedValue);
+                                   SmallVector<Value, 1> mappedValue);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -710,7 +781,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   /// The value in the conversion value mapping that is being replaced by the
   /// results of this unresolved materialization.
-  Value mappedValue;
+  SmallVector<Value, 1> mappedValue;
 };
 } // namespace
 
@@ -779,7 +850,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   LogicalResult remapValues(StringRef valueDiagTag,
                             std::optional<Location> inputLoc,
                             PatternRewriter &rewriter, ValueRange values,
-                            SmallVector<SmallVector<Value>> &remapped);
+                            SmallVector<SmallVector<Value, 1>> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
@@ -825,35 +896,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// mapping.
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      Value valueToMap, ValueRange inputs, TypeRange outputTypes,
-      Type originalType, const TypeConverter *converter,
+      SmallVector<Value, 1> valueToMap, ValueRange inputs,
+      TypeRange outputTypes, Type originalType, const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr);
   Value buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
       Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
       const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr) {
-    return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
-                                          TypeRange(outputType), originalType,
-                                          converter, castOp)
+    return buildUnresolvedMaterialization(
+               kind, ip, loc,
+               valueToMap ? SmallVector<Value, 1>{valueToMap}
+                          : SmallVector<Value, 1>(),
+               inputs, TypeRange(outputType), originalType, converter, castOp)
         .front();
   }
 
-  /// Build an N:1 materialization for the given original value that was
-  /// replaced with the given replacement values.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. The conversion mapping can store only 1:1 replacements
-  /// and the conversion patterns only support single Value replacements in the
-  /// adaptor, so N values must be converted back to a single value. This
-  /// function will be deleted when full 1:N support has been added.
-  ///
-  /// This function inserts an argument materialization back to the original
-  /// type.
-  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
-                                 ValueRange replacements, Value originalValue,
-                                 const TypeConverter *converter);
-
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
   /// value. If there is no replacement value with the correct type, find the
@@ -862,16 +920,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value findOrBuildReplacementValue(Value value,
                                     const TypeConverter *converter);
 
-  /// Unpack an N:1 materialization and return the inputs of the
-  /// materialization. This function unpacks only those materializations that
-  /// were built with `insertNTo1Materialization`.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. It allows us to write 1:N conversion patterns while
-  /// 1:N support is still missing in the conversion value mapping. This
-  /// function will be deleted when full 1:N support has been added.
-  SmallVector<Value> unpackNTo1Materialization(Value value);
-
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -1041,7 +1089,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   });
 }
 
-void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
+void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   auto *listener =
@@ -1082,7 +1130,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
 
 void ReplaceOperationRewrite::rollback() {
   for (auto result : op->getResults())
-    rewriterImpl.mapping.erase(result);
+    rewriterImpl.mapping.erase({result});
 }
 
 void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
@@ -1101,7 +1149,7 @@ void CreateOperationRewrite::rollback() {
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
     const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    Value mappedValue)
+    SmallVector<Value, 1> mappedValue)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
       converterAndKind(converter, kind), originalType(originalType),
       mappedValue(mappedValue) {
@@ -1111,7 +1159,7 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (mappedValue)
+  if (!mappedValue.empty())
     rewriterImpl.mapping.erase(mappedValue);
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
   rewriterImpl.nTo1TempMaterializations.erase(getOperation());
@@ -1160,7 +1208,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
 LogicalResult ConversionPatternRewriterImpl::remapValues(
     StringRef valueDiagTag, std::optional<Location> inputLoc,
     PatternRewriter &rewriter, ValueRange values,
-    SmallVector<SmallVector<Value>> &remapped) {
+    SmallVector<SmallVector<Value, 1>> &remapped) {
   remapped.reserve(llvm::size(values));
 
   for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1216,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type origType = operand.getType();
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
-    // Find the most recently mapped value. Unpack all temporary N:1
-    // materializations. Such conversions are a workaround around missing
-    // 1:N support in the ConversionValueMapping. (The conversion patterns
-    // already support 1:N replacements.)
-    Value repl = mapping.lookupOrDefault(operand);
-    SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
     if (!currentTypeConverter) {
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
       // pass through the most recently mapped value.
-      remapped.push_back(std::move(unpacked));
+      SmallVector<Value, 1> repl = mapping.lookupOrDefault({operand});
+      remapped.push_back(repl);
       continue;
     }
 
@@ -1192,51 +1234,29 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       });
       return failure();
     }
-
     // If a type is converted to 0 types, there is nothing to do.
     if (legalTypes.empty()) {
       remapped.push_back({});
       continue;
     }
 
-    if (legalTypes.size() != 1) {
-      // TODO: This is a 1:N conversion. The conversion value mapping does not
-      // store such materializations yet. If the types of the most recently
-      // mapped values do not match, build a target materialization.
-      ValueRange unpackedRange(unpacked);
-      if (TypeRange(unpackedRange) == legalTypes) {
-        remapped.push_back(std::move(unpacked));
-        continue;
-      }
-
-      // Insert a target materialization if the current pattern expects
-      // different legalized types.
-      ValueRange targetMat = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
-          /*valueToMap=*/Value(), /*inputs=*/unpacked,
-          /*outputType=*/legalTypes, /*originalType=*/origType,
-          currentTypeConverter);
-      remapped.push_back(targetMat);
+    SmallVector<Value, 1> mat = mapping.lookupOrDefault({operand}, legalTypes);
+    if (!mat.empty() && TypeRange(mat) == legalTypes) {
+      // Mapped value has the correct type or there is an existing
+      // materialization. Or the value is not mapped at all and has the
+      // correct type.
+      remapped.push_back(mat);
       continue;
     }
 
-    // Handle 1->1 type conversions.
-    Type desiredType = legalTypes.front();
-    // Try to find a mapped value with the desired type. (Or the operand itself
-    // if the value is not mapped at all.)
-    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
-    if (newOperand.getType() != desiredType) {
-      // If the looked up value's type does not have the desired type, it means
-      // that the value was replaced with a value of different type and no
-      // target materialization was created yet.
-      Value castValue = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
-          /*outputType=*/desiredType, /*originalType=*/origType,
-          currentTypeConverter);
-      newOperand = castValue;
-    }
-    remapped.push_back({newOperand});
+    // Create a materialization for the most recently mapped value.
+    SmallVector<Value, 1> vals = mapping.lookupOrDefault({operand});
+    ValueRange castValues = buildUnresolvedMaterialization(
+        MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
+        /*valueToMap=*/vals,
+        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+        currentTypeConverter);
+    remapped.push_back(castValues);
   }
   return success();
 }
@@ -1364,7 +1384,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       assert(inputMap->size == 0 &&
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
-      mapping.map(origArg, repl);
+      mapping.map({origArg}, {repl});
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1375,13 +1395,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    if (replArgs.size() == 1) {
-      mapping.map(origArg, replArgs.front());
-    } else {
-      insertNTo1Materialization(
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
-    }
+    SmallVector<Value, 1> replArgVals = llvm::map_to_vector<1>(
+        replArgs, [](BlockArgument arg) -> Value { return arg; });
+    mapping.map({origArg}, replArgVals);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1402,7 +1418,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+    SmallVector<Value, 1> valueToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
     UnrealizedConversionCastOp *castOp) {
   assert((!originalType || kind == MaterializationKind::Target) &&
@@ -1410,10 +1426,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes) {
-    if (valueToMap) {
-      assert(inputs.size() == 1 && "1:N mapping is not supported");
-      mapping.map(valueToMap, inputs.front());
-    }
+    if (!valueToMap.empty())
+      mapping.map(valueToMap, inputs);
     return inputs;
   }
 
@@ -1423,10 +1437,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
-  if (valueToMap) {
-    assert(outputTypes.size() == 1 && "1:N mapping is not supported");
-    mapping.map(valueToMap, convertOp.getResult(0));
-  }
+  if (!valueToMap.empty())
+    mapping.map(valueToMap, convertOp.getResults());
   if (castOp)
     *castOp = convertOp;
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
@@ -1434,26 +1446,12 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   return convertOp.getResults();
 }
 
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
-    OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
-    Value originalValue, const TypeConverter *converter) {
-  // Insert argument materialization back to the original type.
-  Type originalType = originalValue.getType();
-  UnrealizedConversionCastOp argCastOp;
-  buildUnresolvedMaterialization(
-      MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType,
-      /*originalType=*/Type(), converter, &argCastOp);
-  if (argCastOp)
-    nTo1TempMaterializations.insert(argCastOp);
-}
-
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
   // Find a replacement value with the same type.
-  Value repl = mapping.lookupOrNull(value, value.getType());
-  if (repl)
-    return repl;
+  SmallVector<Value, 1> repl = mapping.lookupOrNull({value}, value.getType());
+  if (!repl.empty())
+    return repl.front();
 
   // Check if the value is dead. No replacement value is needed in that case.
   // This is an approximate check that may have false negatives but does not
@@ -1467,8 +1465,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   // No replacement value was found. Get the latest replacement value
   // (regardless of the type) and build a source materialization to the
   // original type.
-  repl = mapping.lookupOrNull(value);
-  if (!repl) {
+  repl = mapping.lookupOrNull({value});
+  if (repl.empty()) {
     // No replacement value is registered in the mapping. This means that the
     // value is dropped and no longer needed. (If the value were still needed,
     // a source materialization producing a replacement value "out of thin air"
@@ -1480,32 +1478,10 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
       MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
       /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
       /*originalType=*/Type(), converter);
-  mapping.map(value, castValue);
+  mapping.map({value}, {castValue});
   return castValue;
 }
 
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
-  // Unpack unrealized_conversion_cast ops that were inserted as a N:1
-  // workaround.
-  auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-  if (!castOp)
-    return {value};
-  if (!nTo1TempMaterializations.contains(castOp))
-    return {value};
-  assert(castOp->getNumResults() == 1 && "expected single result");
-
-  SmallVector<Value> result;
-  for (Value v : castOp.getOperands()) {
-    // Keep unpacking if possible. This is needed because during block
-    // signature conversions and 1:N op replacements, the driver may have
-    // inserted two materializations back-to-back: first an argument
-    // materialization, then a target materialization.
-    llvm::append_range(result, unpackNTo1Materialization(v));
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1572,16 +1548,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
     // Remap result to replacement value.
     if (repl.empty())
       continue;
-
-    if (repl.size() == 1) {
-      // Single replacement value: replace directly.
-      mapping.map(result, repl.front());
-    } else {
-      // Multiple replacement values: insert N:1 materialization.
-      insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
-                                /*replacements=*/repl, /*outputValue=*/result,
-                                currentTypeConverter);
-    }
+    mapping.map({result}, repl);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1627,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
   SmallVector<ValueRange> newVals;
-  for (size_t i = 0; i < newValues.size(); ++i)
-    newVals.push_back(newValues.slice(i, 1));
+  for (size_t i = 0; i < newValues.size(); ++i) {
+    if (newValues[i]) {
+      newVals.push_back(newValues.slice(i, 1));
+    } else {
+      newVals.push_back(ValueRange());
+    }
+  }
   impl->notifyOpReplaced(op, newVals);
 }
 
@@ -1729,11 +1701,11 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to});
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
-  SmallVector<SmallVector<Value>> remappedValues;
+  SmallVector<SmallVector<Value, 1>> remappedValues;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
                                remappedValues)))
     return nullptr;
@@ -1746,7 +1718,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                                              SmallVectorImpl<Value> &results) {
   if (keys.empty())
     return success();
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<SmallVector<Value, 1>> remapped;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
                                remapped)))
     return failure();
@@ -1872,7 +1844,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
                                              getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<SmallVector<Value, 1>> remapped;
   if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
                                       op->getOperands(), remapped))) {
     return failure();
@@ -2625,19 +2597,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
     rewriter.setInsertionPoint(op);
     SmallVector<Value> newMaterialization;
     switch (rewrite->getMaterializationKind()) {
-    case MaterializationKind::Argument: {
-      // Try to materialize an argument conversion.
-      assert(op->getNumResults() == 1 && "expected single result");
-      Value argMat = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
-      if (argMat) {
-        newMaterialization.push_back(argMat);
-        break;
-      }
-    }
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
     case MaterializationKind::Target:
       newMaterialization = converter->materializeTargetConversion(
           rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 2ca5f49637523f..51dce9b251628e 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) {
 // Contents of the old block are moved to the new block.
 // CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
 
-// The new block arguments are used in "test.return".
-// CHECK-NEXT: notifyOperationModified: test.return
-
 // The old block is erased.
 // CHECK-NEXT: notifyBlockErased
 
@@ -390,8 +387,8 @@ func.func @caller() {
   // CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16)
   %0:2 = func.call @callee() : () -> (f32, i24)
 
-  // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
-  // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
+  // CHECK-DAG: %[[cast1:.*]] = "test.cast"() : () -> i24
+  // CHECK-DAG: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32
   // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> ()
   // expected-remark @below{{'test.some_user' is not legalizable}}
   "test.some_user"(%0#0, %0#1) : (f32, i24) -> ()
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad50..d0b62e71ab0cf2 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
           tupleType.getFlattenedTypes(types);
           return success();
         });
-    typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+    typeConverter.addSourceMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 466ae7ff6f46f1..749df2cb9ea7cc 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1235,7 +1235,6 @@ struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
   TestTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d788..a03bf0a1023d57 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
 struct PDLLTypeConverter : public TypeConverter {
   PDLLTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 



More information about the llvm-branch-commits mailing list