[llvm-branch-commits] [mlir] [mlir][Transforms] Support 1:N mappings in `ConversionValueMapping` (PR #116524)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Dec 15 04:50:21 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116524

>From bf0d13553b2bc2124a266e398976ba80a1114580 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 16:34:47 +0100
Subject: [PATCH 1/4] [mlir][Vector] Move mask materialization patterns to
 greedy rewrite

The mask materialization patterns during `VectorToLLVM` are rewrite patterns. They should run as part of the greedy pattern rewrite and not the dialect conversion. (Rewrite patterns and conversion patterns are not generally compatible.)

The current combination of rewrite patterns and conversion patterns triggered an edge case when merging the 1:1 and 1:N dialect conversions.
---
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  7 +-
 .../VectorToLLVM/vector-mask-to-llvm.mlir     |  4 +-
 .../VectorToLLVM/vector-to-llvm.mlir          |  4 +-
 .../VectorToLLVM/vector-xfer-to-llvm.mlir     | 80 +++++++++----------
 4 files changed, 44 insertions(+), 51 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 4623b9667998cc..64a9ad8e9bade0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -61,8 +61,8 @@ struct ConvertVectorToLLVMPass
 } // namespace
 
 void ConvertVectorToLLVMPass::runOnOperation() {
-  // Perform progressive lowering of operations on slices and
-  // all contraction operations. Also applies folding and DCE.
+  // Perform progressive lowering of operations on slices and all contraction
+  // operations. Also materializes masks, applies folding and DCE.
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -76,6 +76,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
                                             VectorTransformsOptions());
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
+    populateVectorMaskMaterializationPatterns(patterns,
+                                              force32BitVectorIndices);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
@@ -83,7 +85,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
   LowerToLLVMOptions options(&getContext());
   LLVMTypeConverter converter(&getContext(), options);
   RewritePatternSet patterns(&getContext());
-  populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
   populateVectorTransferLoweringPatterns(patterns);
   populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
   populateVectorToLLVMConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
index 82351eb7c98a43..91e5358622b69d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-mask-to-llvm.mlir
@@ -7,7 +7,7 @@
 // CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
 // CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32>
 // CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi32>
-// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32>
+// CMP32: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi32>
 // CMP32: return %[[T4]] : vector<11xi1>
 
 // CMP64-LABEL: @genbool_var_1d(
@@ -16,7 +16,7 @@
 // CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
 // CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
 // CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
-// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64>
+// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
 // CMP64: return %[[T4]] : vector<11xi1>
 
 func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 2473fe933ffcb2..ea88fece9e662d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -3097,7 +3097,7 @@ func.func @create_mask_0d(%num_elems : index) -> vector<i1> {
 // CHECK:  %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
 // CHECK:  %[[BOUNDS:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
 // CHECK:  %[[BOUNDS_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOUNDS]] : vector<1xi32> to vector<i32>
-// CHECK:  %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS_CAST]] : vector<i32>
+// CHECK:  %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS_CAST]], %[[INDICES]] : vector<i32>
 // CHECK:  return %[[RESULT]] : vector<i1>
 
 // -----
@@ -3113,7 +3113,7 @@ func.func @create_mask_1d(%num_elems : index) -> vector<4xi1> {
 // CHECK:  %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
 // CHECK:  %[[BOUNDS_INSERT:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
 // CHECK:  %[[BOUNDS:.*]] = llvm.shufflevector %[[BOUNDS_INSERT]]
-// CHECK:  %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS]] : vector<4xi32>
+// CHECK:  %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS]], %[[INDICES]] : vector<4xi32>
 // CHECK:  return %[[RESULT]] : vector<4xi1>
 
 // -----
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
index 8f01cc2b8d44c3..d3f6d7eca90b41 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-xfer-to-llvm.mlir
@@ -14,30 +14,28 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
 // CHECK-LABEL: func @transfer_read_write_1d
 //  CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
 //  CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
-//       CHECK: %[[C7:.*]] = arith.constant 7.0
-//
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
-//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
+// 1. Create pass-through vector.
+//   CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<17xf32>
 //
 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = arith.constant dense
+//   CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense
 //  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
 //
-// 3. Create bound vector to compute in-bound mask:
+// 3. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
+//       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
+//
+// 4. Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] :
 //  CMP32-SAME: index to i32
 //  CMP64-SAME: index to i64
 //       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
 //       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]>
+//       CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
 //  CMP64-SAME: : vector<17xi64>
 //
-// 4. Create pass-through vector.
-//       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
-//
 // 5. Bitcast to vector form.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
 //  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -48,28 +46,23 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
 //  CHECK-SAME: -> vector<17xf32>
 //
 // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
 //       CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
 //
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex_b:.*]] = arith.constant dense
-//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
-//
-// 3. Create bound vector to compute in-bound mask:
+// 2. Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]]
 //  CMP32-SAME: index to i32
 //       CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
 //       CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
-//       CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
-//  CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]>
+//       CHECK: %[[mask_b:.*]] = arith.cmpi sgt, %[[boundVect_b]],
+//  CHECK-SAME: %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
 //
-// 4. Bitcast to vector form.
+// 3. Bitcast to vector form.
 //       CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
 //  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
 //
-// 5. Rewrite as a masked write.
+// 4. Rewrite as a masked write.
 //       CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
 //  CHECK-SAME: {alignment = 4 : i32} :
 //  CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
@@ -87,17 +80,18 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
 // CHECK-LABEL: func @transfer_read_write_1d_scalable
 //  CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
 //  CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
-//       CHECK: %[[C7:.*]] = arith.constant 7.0
+// 1. Create pass-through vector.
+//   CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<[17]xf32>
 //
-// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0:.*]] = arith.constant 0 : index
+// 2. Let dim be the memref dimension, compute the in-bound index (dim - offset)
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
 //       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]],  %[[BASE]] : index
 //
-// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
+// 3. Create a vector with linear indices [ 0 .. vector_length - 1 ].
 //       CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
 //
-// 3. Create bound vector to compute in-bound mask:
+// 4. Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
 //       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
@@ -105,9 +99,6 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
 //       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
 //  CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]>
 //
-// 4. Create pass-through vector.
-//       CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
-//
 // 5. Bitcast to vector form.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
 //  CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@@ -118,8 +109,7 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
 //  CHECK-SAME: -> vector<[17]xf32>
 //
 // 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
-//       CHECK: %[[C0_b:.*]] = arith.constant 0 : index
-//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
+//       CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
 //       CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
 //
 // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
@@ -197,23 +187,23 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
 }
 // CHECK-LABEL: func @transfer_read_2d_to_1d
 //  CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
-//       CHECK: %[[c1:.*]] = arith.constant 1 : index
+//
+// Create a vector with linear indices [ 0 .. vector_length - 1 ].
+//   CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
+//  CHECK-SAME: vector<17x[[$IDX_TYPE]]>
+//
+//   CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
 //
 // Compute the in-bound index (dim - offset)
 //       CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
 //
-// Create a vector with linear indices [ 0 .. vector_length - 1 ].
-//       CHECK: %[[linearIndex:.*]] = arith.constant dense
-//  CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
-//  CHECK-SAME: vector<17x[[$IDX_TYPE]]>
-//
 // Create bound vector to compute in-bound mask:
 //    [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
 //       CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
 //       CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
 //       CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
-//       CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
+//       CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]]
 
 func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
   %f7 = arith.constant 7.0: f32
@@ -255,12 +245,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %bas
 // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
 //  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
 //
+//       CHECK: %[[c0:.*]] = arith.constant 0 : index
+//
 // 1. Check address space for GEP is correct.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
 //  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
 //
 // 2. Check address space of the memref is correct.
-//       CHECK: %[[c0:.*]] = arith.constant 0 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
 //
 // 3. Check address space for GEP is correct.
@@ -280,12 +271,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32,
 // CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
 //  CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
 //
+//       CHECK: %[[c0:.*]] = arith.constant 0 : index
+//
 // 1. Check address space for GEP is correct.
 //       CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
 //  CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
 //
 // 2. Check address space of the memref is correct.
-//       CHECK: %[[c0:.*]] = arith.constant 0 : index
 //       CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
 //
 // 3. Check address space for GEP is correct.
@@ -330,10 +322,10 @@ func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index)
 
 // CHECK-LABEL: func @transfer_read_write_1d_mask
 // CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
-// CHECK: %[[cmpi:.*]] = arith.cmpi slt
+// CHECK: %[[cmpi:.*]] = arith.cmpi sgt
 // CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
 // CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
-// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
+// CHECK: %[[cmpi_1:.*]] = arith.cmpi sgt
 // CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
 // CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
 // CHECK: return %[[r]]

>From e5926b63835eec731c513d91ce9c451c429ca572 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 17:19:18 +0100
Subject: [PATCH 2/4] [mlir][Vector] Clean up
 `populateVectorToLLVMConversionPatterns`

---
 .../Vector/Transforms/LoweringPatterns.h      |  4 +++
 .../GPUCommon/GPUToLLVMConversion.cpp         | 12 +++++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 27 ++++++++++---------
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  6 ++++-
 .../Conversion/GPUCommon/lower-vector.mlir    |  4 +--
 .../VectorToLLVM/vector-to-llvm.mlir          |  5 ----
 6 files changed, 37 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 3d643c96b45008..c507b23c6d4de6 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -292,6 +292,10 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
                                            int64_t targetRank = 1,
                                            PatternBenefit benefit = 1);
 
+/// Populates a pattern that rank-reduces n-D FMAs into (n-1)-D FMAs where
+/// n > 1.
+void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
+
 } // namespace vector
 } // namespace mlir
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 1497d662dcdbdd..2fe3b1302e5e5b 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -32,10 +32,12 @@
 #include "mlir/Dialect/GPU/Transforms/Passes.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Error.h"
@@ -522,6 +524,16 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 void GpuToLLVMConversionPass::runOnOperation() {
   MLIRContext *context = &getContext();
+
+  // Perform progressive lowering of vector transfer operations.
+  {
+    RewritePatternSet patterns(&getContext());
+    // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
+    vector::populateVectorTransferLoweringPatterns(patterns,
+                                                   /*maxTransferRank=*/1);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+
   LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
   RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index a9a07c323c7358..577b74bb7e0c26 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1475,16 +1475,16 @@ class VectorTypeCastOpConversion
 
 /// Conversion pattern for a `vector.create_mask` (1-D scalable vectors only).
 /// Non-scalable versions of this operation are handled in Vector Transforms.
-class VectorCreateMaskOpRewritePattern
-    : public OpRewritePattern<vector::CreateMaskOp> {
+class VectorCreateMaskOpConversion
+    : public OpConversionPattern<vector::CreateMaskOp> {
 public:
-  explicit VectorCreateMaskOpRewritePattern(MLIRContext *context,
+  explicit VectorCreateMaskOpConversion(MLIRContext *context,
                                             bool enableIndexOpt)
-      : OpRewritePattern<vector::CreateMaskOp>(context),
+      : OpConversionPattern<vector::CreateMaskOp>(context),
         force32BitVectorIndices(enableIndexOpt) {}
 
-  LogicalResult matchAndRewrite(vector::CreateMaskOp op,
-                                PatternRewriter &rewriter) const override {
+  LogicalResult matchAndRewrite(vector::CreateMaskOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const override {
     auto dstType = op.getType();
     if (dstType.getRank() != 1 || !cast<VectorType>(dstType).isScalable())
       return failure();
@@ -1495,7 +1495,7 @@ class VectorCreateMaskOpRewritePattern
         loc, LLVM::getVectorType(idxType, dstType.getShape()[0],
                                  /*isScalable=*/true));
     auto bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType,
-                                                 op.getOperand(0));
+                                                 adaptor.getOperands()[0]);
     Value bounds = rewriter.create<SplatOp>(loc, indices.getType(), bound);
     Value comp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
                                                 indices, bounds);
@@ -1896,16 +1896,19 @@ struct VectorScalableStepOpLowering
 
 } // namespace
 
+void mlir::vector::populateVectorRankReducingFMAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<VectorFMAOpNDRewritePattern>(patterns.getContext());
+}
+
 /// Populate the given list with patterns that convert from Vector to LLVM.
 void mlir::populateVectorToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
     bool reassociateFPReductions, bool force32BitVectorIndices) {
+  // This function populates only ConversionPatterns, not RewritePatterns.
   MLIRContext *ctx = converter.getDialect()->getContext();
-  patterns.add<VectorFMAOpNDRewritePattern>(ctx);
-  populateVectorInsertExtractStridedSliceTransforms(patterns);
-  populateVectorStepLoweringPatterns(patterns);
   patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
-  patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
+  patterns.add<VectorCreateMaskOpConversion>(ctx, force32BitVectorIndices);
   patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
                VectorExtractElementOpConversion, VectorExtractOpConversion,
                VectorFMAOp1DConversion, VectorInsertElementOpConversion,
@@ -1922,8 +1925,6 @@ void mlir::populateVectorToLLVMConversionPatterns(
                MaskedReductionOpConversion, VectorInterleaveOpLowering,
                VectorDeinterleaveOpLowering, VectorFromElementsLowering,
                VectorScalableStepOpLowering>(converter);
-  // Transfer ops with rank > 1 are handled by VectorToSCF.
-  populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
 }
 
 void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 64a9ad8e9bade0..2d94c2f2e85a08 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -62,7 +62,8 @@ struct ConvertVectorToLLVMPass
 
 void ConvertVectorToLLVMPass::runOnOperation() {
   // Perform progressive lowering of operations on slices and all contraction
-  // operations. Also materializes masks, applies folding and DCE.
+  // operations. Also materializes masks, lowers vector.step, rank-reduces FMA,
+  // applies folding and DCE.
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
@@ -78,6 +79,9 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
     populateVectorMaskMaterializationPatterns(patterns,
                                               force32BitVectorIndices);
+    populateVectorInsertExtractStridedSliceTransforms(patterns);
+    populateVectorStepLoweringPatterns(patterns);
+    populateVectorRankReducingFMAPattern(patterns);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 44deb45cd752b4..532a2383cea9ef 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,11 +1,11 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
-  func.func @func(%arg: vector<11xf32>) {
+  func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
     // CHECK: vector.mask
     // CHECK-SAME: vector.yield %arg0
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
-    return
+    return %127 : vector<11xf32>
   }
 }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index ea88fece9e662d..f95e943250bd44 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2046,7 +2046,6 @@ func.func @extract_strided_slice_f32_2d_from_2d_scalable(%arg0: vector<4x[8]xf32
 // CHECK-LABEL: @extract_strided_slice_f32_2d_from_2d_scalable(
 //  CHECK-SAME:     %[[ARG:.*]]: vector<4x[8]xf32>)
 // CHECK:           %[[T1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x[8]xf32> to !llvm.array<4 x vector<[8]xf32>>
-// CHECK:           %[[T2:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x[8]xf32>
 // CHECK:           %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x[8]xf32> to !llvm.array<2 x vector<[8]xf32>>
 // CHECK:           %[[T5:.*]] = llvm.extractvalue %[[T1]][2] : !llvm.array<4 x vector<[8]xf32>>
@@ -2067,7 +2066,6 @@ func.func @insert_strided_slice_f32_2d_into_3d(%b: vector<4x4xf32>, %c: vector<4
   return %0 : vector<4x4x4xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xf32>>>
 
 // -----
@@ -2077,7 +2075,6 @@ func.func @insert_strided_slice_f32_2d_into_3d_scalable(%b: vector<4x[4]xf32>, %
   return %0 : vector<4x4x[4]xf32>
 }
 // CHECK-LABEL: @insert_strided_slice_f32_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xf32>>>
 
 // -----
@@ -2087,7 +2084,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d(%b: vector<4x4xindex>, %c
   return %0 : vector<4x4x4xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>>
 
 // -----
@@ -2097,7 +2093,6 @@ func.func @insert_strided_index_slice_index_2d_into_3d_scalable(%b: vector<4x[4]
   return %0 : vector<4x4x[4]xindex>
 }
 // CHECK-LABEL: @insert_strided_index_slice_index_2d_into_3d_scalable
-//       CHECK:    llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 //       CHECK:    llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<[4]xi64>>>
 
 // -----

>From ddc92940e9a1490f5a26d77576f4b663253e3cdd Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 25 Nov 2024 04:13:24 +0100
Subject: [PATCH 3/4] [mlir][Transforms] Dialect Conversion: Do not build
 target mat. during 1:N replacement

fix test

experiement
---
 .../Conversion/LLVMCommon/TypeConverter.cpp   | 130 ++++++++++++------
 .../Transforms/Utils/DialectConversion.cpp    |  46 ++-----
 mlir/test/Transforms/test-legalizer.mlir      |   8 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  47 +++----
 4 files changed, 128 insertions(+), 103 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 59b0f5c9b09bcd..e2ab0ed6f66cc5 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
+  // Add generic source and target materializations to handle cases where
+  // non-LLVM types persist after an LLVM conversion.
+  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+  addTargetMaterialization([&](OpBuilder &builder, Type resultType,
+                               ValueRange inputs, Location loc) {
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+        .getResult(0);
+  });
+
   // Helper function that checks if the given value range is a bare pointer.
   auto isBarePointer = [](ValueRange values) {
     return values.size() == 1 &&
            isa<LLVM::LLVMPointerType>(values.front().getType());
   };
 
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type. The dialect conversion framework will then
-  // insert a target materialization from the original block argument type to
-  // a legal type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packUnrankedMemRefDesc =
+      [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
+          Location loc, LLVMTypeConverter &converter) -> Value {
     // Note: Bare pointers are not supported for unranked memrefs because a
     // memref descriptor cannot be built just from a bare pointer.
-    if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
+    if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
       return Value();
-    Value desc =
-        UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
+    return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
+                                          inputs);
+  };
+
+  // MemRef descriptor elements -> UnrankedMemRefType
+  auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
+                                          UnrankedMemRefType resultType,
+                                          ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type
     // `resultType`, so insert a cast from the memref descriptor type
     // (!llvm.struct) to the original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    Value desc;
-    if (isBarePointer(inputs)) {
-      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
-                                               inputs[0]);
-    } else if (TypeRange(inputs) ==
-               getMemRefDescriptorFields(resultType,
-                                         /*unpackAggregates=*/true)) {
-      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    } else {
-      // The inputs are neither a bare pointer nor an unpacked memref
-      // descriptor. This materialization function cannot be used.
+    Value packed =
+        packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-    }
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
+        .getResult(0);
+  };
+
+  // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
+  // must be passed explicitly.
+  auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
+                                  ValueRange inputs, Location loc,
+                                  LLVMTypeConverter &converter) -> Value {
+    assert(resultType && "expected non-null result type");
+    if (isBarePointer(inputs))
+      return MemRefDescriptor::fromStaticShape(builder, loc, converter,
+                                               resultType, inputs[0]);
+    if (TypeRange(inputs) ==
+        converter.getMemRefDescriptorFields(resultType,
+                                            /*unpackAggregates=*/true))
+      return MemRefDescriptor::pack(builder, loc, converter, resultType,
+                                    inputs);
+    // The inputs are neither a bare pointer nor an unpacked memref descriptor.
+    // This materialization function cannot be used.
+    return Value();
+  };
+
+  // MemRef descriptor elements -> MemRefType
+  auto rankedMemRefMaterialization = [&](OpBuilder &builder,
+                                         MemRefType resultType,
+                                         ValueRange inputs, Location loc) {
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
     // original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  // Add generic source and target materializations to handle cases where
-  // non-LLVM types persist after an LLVM conversion.
-  addSourceMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+    Value packed =
+        packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
+    if (!packed)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
+    return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
         .getResult(0);
-  });
+  };
+
+  // Argument materializations convert from the new block argument types
+  // (multiple SSA values that make up a memref descriptor) back to the
+  // original block argument type.
+  addArgumentMaterialization(unrakedMemRefMaterialization);
+  addArgumentMaterialization(rankedMemRefMaterialization);
+  addSourceMaterialization(unrakedMemRefMaterialization);
+  addSourceMaterialization(rankedMemRefMaterialization);
+
+  // Bare pointer -> Packed MemRef descriptor
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs, Location loc) {
-    if (inputs.size() != 1)
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> Value {
+    // The original MemRef type is required to build a MemRef descriptor
+    // because the sizes/strides of the MemRef cannot be inferred from just the
+    // bare pointer.
+    if (!originalType)
       return Value();
-
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
-        .getResult(0);
+    if (resultType != convertType(originalType))
+      return Value();
+    if (auto memrefType = dyn_cast<MemRefType>(originalType))
+      return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
+    if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
+      return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
+                                    *this);
+    return Value();
   });
 
   // Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1607740a1ee076..51686646a0a2fc 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// function will be deleted when full 1:N support has been added.
   ///
   /// This function inserts an argument materialization back to the original
-  /// type, followed by a target materialization to the legalized type (if
-  /// applicable).
+  /// type.
   void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
                                  ValueRange replacements, Value originalValue,
                                  const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     // used as a replacement.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    insertNTo1Materialization(
-        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    if (replArgs.size() == 1) {
+      mapping.map(origArg, replArgs.front());
+    } else {
+      insertNTo1Materialization(
+          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
+    }
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
   // Insert argument materialization back to the original type.
   Type originalType = originalValue.getType();
   UnrealizedConversionCastOp argCastOp;
-  Value argMat = buildUnresolvedMaterialization(
+  buildUnresolvedMaterialization(
       MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
-      &argCastOp);
+      /*inputs=*/replacements, originalType,
+      /*originalType=*/Type(), converter, &argCastOp);
   if (argCastOp)
     nTo1TempMaterializations.insert(argCastOp);
-
-  // Insert target materialization to the legalized type.
-  Type legalOutputType;
-  if (converter) {
-    legalOutputType = converter->convertType(originalType);
-  } else if (replacements.size() == 1) {
-    // When there is no type converter, assume that the replacement value
-    // types are legal. This is reasonable to assume because they were
-    // specified by the user.
-    // FIXME: This won't work for 1->N conversions because multiple output
-    // types are not supported in parts of the dialect conversion. In such a
-    // case, we currently use the original value type.
-    legalOutputType = replacements[0].getType();
-  }
-  if (legalOutputType && legalOutputType != originalType) {
-    UnrealizedConversionCastOp targetCastOp;
-    buildUnresolvedMaterialization(
-        MaterializationKind::Target, computeInsertPoint(argMat), loc,
-        /*valueToMap=*/argMat, /*inputs=*/argMat,
-        /*outputType=*/legalOutputType, /*originalType=*/originalType,
-        converter, &targetCastOp);
-    if (targetCastOp)
-      nTo1TempMaterializations.insert(targetCastOp);
-  }
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
 
 LogicalResult TypeConverter::convertType(Type t,
                                          SmallVectorImpl<Type> &results) const {
+  assert(this && "expected non-null type converter");
+  assert(t && "expected non-null type");
+
   {
     std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
                                                          std::defer_lock);
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d98a6a036e6b1f..2ca5f49637523f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
   // expected-remark at +1 {{op 'foo.region' is not legalizable}}
   "foo.region"() ({
-    // CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
-    ^bb0(%i0: i64, %unused: i16, %i1: i64):
-      // CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
-      "test.invalid"(%i0, %i1) : (i64, i64) -> ()
+    // CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
+    ^bb0(%i0: f64, %unused: i16, %i1: f64):
+      // CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
+      "test.invalid"(%i0, %i1) : (f64, f64) -> ()
   }) : () -> ()
   // expected-remark at +1 {{op 'func.return' is not legalizable}}
   return
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 8a0bc597c56beb..466ae7ff6f46f1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
 };
 /// This pattern simply updates the operands of the given operation.
 struct TestPassthroughInvalidOp : public ConversionPattern {
-  TestPassthroughInvalidOp(MLIRContext *ctx)
-      : ConversionPattern("test.invalid", 1, ctx) {}
+  TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
+      : ConversionPattern(converter, "test.invalid", 1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1301,19 +1301,19 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::RewritePatternSet patterns(&getContext());
     populateWithGenerated(patterns);
-    patterns.add<
-        TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
-        TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
-        TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
-        TestSplitReturnType, TestChangeProducerTypeI32ToF32,
-        TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
-        TestUpdateConsumerType, TestNonRootReplacement,
-        TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
-        TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
-        TestUndoPropertiesModification, TestEraseOp,
-        TestRepetitive1ToNConsumer>(&getContext());
-    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
-        &getContext(), converter);
+    patterns
+        .add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
+             TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
+             TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
+             TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+             TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+             TestNonRootReplacement, TestBoundedRecursiveRewrite,
+             TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore,
+             TestUndoPropertiesModification, TestEraseOp,
+             TestRepetitive1ToNConsumer>(&getContext());
+    patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
+                 TestPassthroughInvalidOp>(&getContext(), converter);
     patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1749,8 +1749,9 @@ struct TestTypeConversionAnotherProducer
 };
 
 struct TestReplaceWithLegalOp : public ConversionPattern {
-  TestReplaceWithLegalOp(MLIRContext *ctx)
-      : ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
+  TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
+      : ConversionPattern(converter, "test.replace_with_legal_op",
+                          /*benefit=*/1, ctx) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
@@ -1872,12 +1873,12 @@ struct TestTypeConversionDriver
 
     // Initialize the set of rewrite patterns.
     RewritePatternSet patterns(&getContext());
-    patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
-                 TestSignatureConversionUndo,
-                 TestTestSignatureConversionNoConverter>(converter,
-                                                         &getContext());
-    patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
-        &getContext());
+    patterns
+        .add<TestTypeConsumerForward, TestTypeConversionProducer,
+             TestSignatureConversionUndo,
+             TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
+            converter, &getContext());
+    patterns.add<TestTypeConversionAnotherProducer>(&getContext());
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
 

>From b11cc17e948370509b6c5abf106325d0e2d87f3b Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 14 Dec 2024 15:02:10 +0100
Subject: [PATCH 4/4] prototype

---
 .../Conversion/LLVMCommon/TypeConverter.cpp   |   2 -
 .../EmitC/Transforms/TypeConversions.cpp      |   1 -
 .../Dialect/Linalg/Transforms/Detensorize.cpp |   1 -
 .../Quant/Transforms/StripFuncQuantTypes.cpp  |   1 -
 .../Utils/SparseTensorDescriptor.cpp          |   3 -
 .../Vector/Transforms/VectorLinearize.cpp     |   1 -
 .../Transforms/Utils/DialectConversion.cpp    | 526 +++++++++++-------
 .../Func/TestDecomposeCallGraphTypes.cpp      |   2 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |   1 -
 .../lib/Transforms/TestDialectConversion.cpp  |   1 -
 10 files changed, 318 insertions(+), 221 deletions(-)

diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index e2ab0ed6f66cc5..ef8181e80cee38 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -237,8 +237,6 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   // Argument materializations convert from the new block argument types
   // (multiple SSA values that make up a memref descriptor) back to the
   // original block argument type.
-  addArgumentMaterialization(unrakedMemRefMaterialization);
-  addArgumentMaterialization(rankedMemRefMaterialization);
   addSourceMaterialization(unrakedMemRefMaterialization);
   addSourceMaterialization(rankedMemRefMaterialization);
 
diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
index 0b3a494794f3f5..72c8fd0f324850 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
 
   converter.addSourceMaterialization(materializeAsUnrealizedCast);
   converter.addTargetMaterialization(materializeAsUnrealizedCast);
-  converter.addArgumentMaterialization(materializeAsUnrealizedCast);
 }
 
 /// Get an unsigned integer or size data type corresponding to \p ty.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index af38485291182f..61bc5022893741 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
     });
 
     addSourceMaterialization(sourceMaterializationCallback);
-    addArgumentMaterialization(sourceMaterializationCallback);
   }
 };
 
diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
index 61912722662830..71b88d1be1b05b 100644
--- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
     addConversion(convertQuantizedType);
     addConversion(convertTensorType);
 
-    addArgumentMaterialization(materializeConversion);
     addSourceMaterialization(materializeConversion);
     addTargetMaterialization(materializeConversion);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index 834e3634cc130d..8bbb2cac5efdf3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
 
   // Required by scf.for 1:N type conversion.
   addSourceMaterialization(materializeTuple);
-
-  // Required as a workaround until we have full 1:N support.
-  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 757631944f224f..68535ae5a7a5c6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
 
     return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
   };
-  typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
   target.markUnknownOpDynamicallyLegal(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 51686646a0a2fc..81c8c1f422551f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -63,11 +63,45 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) {
   return OpBuilder::InsertPoint(insertBlock, insertPt);
 }
 
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
+  assert(!vals.empty() && "expected at least one value");
+  OpBuilder::InsertPoint pt = computeInsertPoint(vals.front());
+  for (Value v : vals.drop_front()) {
+    OpBuilder::InsertPoint pt2 = computeInsertPoint(v);
+    assert(pt.getBlock() == pt2.getBlock());
+    if (pt.getPoint() == pt.getBlock()->begin()) {
+      pt = pt2;
+      continue;
+    }
+    if (pt2.getPoint() == pt2.getBlock()->begin()) {
+      continue;
+    }
+    if (pt.getPoint()->isBeforeInBlock(&*pt2.getPoint()))
+      pt = pt2;
+  }
+  return pt;
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
 namespace {
+struct SmallVectorMapInfo {
+  static SmallVector<Value, 1> getEmptyKey() { return SmallVector<Value, 1>{}; }
+  static SmallVector<Value, 1> getTombstoneKey() {
+    return SmallVector<Value, 1>{};
+  }
+  static ::llvm::hash_code getHashValue(SmallVector<Value, 1> val) {
+    return ::llvm::hash_combine_range(val.begin(), val.end());
+  }
+  static bool isEqual(SmallVector<Value, 1> LHS, SmallVector<Value, 1> RHS) {
+    return LHS == RHS;
+  }
+};
+
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
@@ -75,71 +109,240 @@ struct ConversionValueMapping {
   /// false positives.
   bool isMappedTo(Value value) const { return mappedTo.contains(value); }
 
-  /// Lookup the most recently mapped value with the desired type in the
-  /// mapping.
+  /// Find the most recently mapped values for the given value. If the value is
+  /// not mapped at all, return the given value.
+  SmallVector<Value, 1> lookupOrDefault(Value from) const;
+
+  /// TODO: Find most recently mapped or materialization with matching type. May
+  /// return the given value if the type matches.
+  SmallVector<Value, 1>
+  lookupOrDefault(Value from, SmallVector<Type, 1> desiredTypes) const;
+
+  Value lookupDirectSingleReplacement(Value from) const {
+    auto it = mapping.find(from);
+    if (it == mapping.end())
+      return Value();
+    const SmallVector<Value, 1> &repl = it->second;
+    if (repl.size() != 1)
+      return Value();
+    return repl.front();
+    /*
+        if (!mapping.contains(from)) return Value();
+        auto it = llvm::find(mapping, from);
+        const SmallVector<Value, 1> &repl = it->second;
+        if (repl.size() != 1) return Value();
+        return repl.front();
+        */
+  }
+
+  SmallVector<Value,1> lookupDirectReplacement(Value from) const {
+    auto it = mapping.find(from);
+    if (it == mapping.end())
+      return {};
+    return it->second;
+  }
+
+  /// Find the most recently mapped values for the given value. If the value is
+  /// not mapped at all, return an empty vector.
+  SmallVector<Value, 1> lookupOrNull(Value from) const;
+
+  /// Find the most recently mapped values for the given value. If those values
+  /// have the desired types, return them. Otherwise, try to find a
+  /// materialization to the desired types.
   ///
-  /// Special cases:
-  /// - If the desired type is "null", simply return the most recently mapped
-  ///   value.
-  /// - If there is no mapping to the desired type, also return the most
-  ///   recently mapped value.
-  /// - If there is no mapping for the given value at all, return the given
-  ///   value.
-  Value lookupOrDefault(Value from, Type desiredType = nullptr) const;
-
-  /// Lookup a mapped value within the map, or return null if a mapping does not
-  /// exist. If a mapping exists, this follows the same behavior of
-  /// `lookupOrDefault`.
-  Value lookupOrNull(Value from, Type desiredType = nullptr) const;
-
-  /// Map a value to the one provided.
-  void map(Value oldVal, Value newVal) {
-    LLVM_DEBUG({
-      for (Value it = newVal; it; it = mapping.lookupOrNull(it))
-        assert(it != oldVal && "inserting cyclic mapping");
-    });
-    mapping.map(oldVal, newVal);
-    mappedTo.insert(newVal);
+  /// If the given value is not mapped at all or if there are no mapped values/
+  /// materialization results with the desired types, return an empty vector.
+  SmallVector<Value, 1> lookupOrNull(Value from,
+                                     SmallVector<Type, 1> desiredTypes) const;
+
+  Value lookupOrNull(Value from, Type desiredType) {
+    SmallVector<Value, 1> vals =
+        lookupOrNull(from, SmallVector<Type, 1>{desiredType});
+    if (vals.empty())
+      return Value();
+    assert(vals.size() == 1 && "expected single value");
+    return vals.front();
+  }
+
+  void erase(Value from) { mapping.erase(from); }
+
+  void map(Value from, ValueRange to) {
+#ifndef NDEBUG
+    assert(from && "expected non-null value");
+    assert(!to.empty() && "cannot map to zero values");
+    for (Value v : to)
+      assert(v && "expected non-null value");
+#endif
+    // assert(from != to && "cannot map value to itself");
+    //  TODO: Check for cyclic mapping.
+    assert(!mapping.contains(from) && "value is already mapped");
+    mapping[from].assign(to.begin(), to.end());
+    for (Value v : to)
+      mappedTo.insert(v);
+  }
+
+  void map(Value from, ArrayRef<BlockArgument> to) {
+    SmallVector<Value> vals;
+    for (Value v : to)
+      vals.push_back(v);
+    map(from, vals);
+  }
+  /*
+    void map(Value from, ArrayRef<Value> to) {
+  #ifndef NDEBUG
+      assert(from && "expected non-null value");
+      assert(!to.empty() && "cannot map to zero values");
+      for (Value v : to)
+        assert(v && "expected non-null value");
+  #endif
+      // assert(from != to && "cannot map value to itself");
+      //  TODO: Check for cyclic mapping.
+      assert(!mapping.contains(from) && "value is already mapped");
+      mapping[from].assign(to.begin(), to.end());
+    }
+  */
+
+  void mapMaterialization(SmallVector<Value, 1> from,
+                          SmallVector<Value, 1> to) {
+#ifndef NDEBUG
+    assert(!from.empty() && "from cannot be empty");
+    assert(!to.empty() && "to cannot be empty");
+    for (Value v : from) {
+      assert(v && "expected non-null value");
+      assert(!mapping.contains(v) &&
+             "cannot add materialization for mapped value");
+    }
+    for (Value v : to) {
+      assert(v && "expected non-null value");
+    }
+    assert(TypeRange(from) != TypeRange(to) &&
+           "cannot add materialization for identical type");
+    for (const SmallVector<Value, 1> &mat : materializations[from])
+      assert(TypeRange(mat) != TypeRange(to) &&
+             "cannot register duplicate materialization");
+#endif // NDEBUG
+    materializations[from].push_back(to);
+    for (Value v : to)
+      mappedTo.insert(v);
+  }
+
+  void eraseMaterialization(SmallVector<Value, 1> from,
+                            SmallVector<Value, 1> to) {
+    if (!materializations.count(from))
+      return;
+    auto it = llvm::find(materializations[from], to);
+    if (it == materializations[from].end())
+      return;
+    if (materializations[from].size() == 1)
+      materializations.erase(from);
+    else
+      materializations[from].erase(it);
   }
 
-  /// Drop the last mapping for the given value.
-  void erase(Value value) { mapping.erase(value); }
+  /// Returns the inverse raw value mapping (without recursive query support).
+  DenseMap<Value, SmallVector<Value>> getInverse() const {
+    DenseMap<Value, SmallVector<Value>> inverse;
+
+    for (auto &it : mapping)
+      for (Value v : it.second)
+        inverse[v].push_back(it.first);
+
+    for (auto &it : materializations)
+      for (const SmallVector<Value, 1> &mat : it.second)
+        for (Value v : mat)
+          for (Value v2 : it.first)
+            inverse[v].push_back(v2);
+
+    return inverse;
+  }
 
 private:
-  /// Current value mappings.
-  IRMapping mapping;
+  /// Replacement mapping: Value -> ValueRange
+  DenseMap<Value, SmallVector<Value, 1>> mapping;
+
+  /// Materializations: ValueRange -> ValueRange*
+  DenseMap<SmallVector<Value, 1>, SmallVector<SmallVector<Value, 1>>,
+           SmallVectorMapInfo>
+      materializations;
 
   /// All SSA values that are mapped to. May contain false positives.
   DenseSet<Value> mappedTo;
 };
 } // namespace
 
-Value ConversionValueMapping::lookupOrDefault(Value from,
-                                              Type desiredType) const {
-  // Try to find the deepest value that has the desired type. If there is no
-  // such value, simply return the deepest value.
-  Value desiredValue;
-  do {
-    if (!desiredType || from.getType() == desiredType)
-      desiredValue = from;
-
-    Value mappedValue = mapping.lookupOrNull(from);
-    if (!mappedValue)
-      break;
-    from = mappedValue;
-  } while (true);
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrDefault(Value from) const {
+  SmallVector<Value, 1> to = lookupOrNull(from);
+  return to.empty() ? SmallVector<Value, 1>{from} : to;
+}
 
-  // If the desired value was found use it, otherwise default to the leaf value.
-  return desiredValue ? desiredValue : from;
+SmallVector<Value, 1> ConversionValueMapping::lookupOrDefault(
+    Value from, SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+  assert(desiredTypes.size() > 0 && "expected non-empty types");
+  for (Type t : desiredTypes)
+    assert(t && "expected non-null type");
+#endif // NDEBUG
+
+  SmallVector<Value, 1> vals = lookupOrNull(from);
+  if (vals.empty()) {
+    // Value is not mapped. Return if the type matches.
+    if (TypeRange(from) == desiredTypes)
+      return {from};
+    // Check materializations.
+    auto it = materializations.find({from});
+    if (it == materializations.end())
+      return {};
+    for (const SmallVector<Value, 1> &mat : it->second)
+      if (TypeRange(mat) == desiredTypes)
+        return mat;
+    return {};
+  }
+
+  return lookupOrNull(from, desiredTypes);
 }
 
-Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const {
-  Value result = lookupOrDefault(from, desiredType);
-  if (result == from || (desiredType && result.getType() != desiredType))
-    return nullptr;
+SmallVector<Value, 1> ConversionValueMapping::lookupOrNull(Value from) const {
+  auto it = mapping.find(from);
+  if (it == mapping.end())
+    return {};
+  SmallVector<Value, 1> result;
+  for (Value v : it->second) {
+    llvm::append_range(result, lookupOrDefault(v));
+  }
   return result;
 }
 
+SmallVector<Value, 1>
+ConversionValueMapping::lookupOrNull(Value from,
+                                     SmallVector<Type, 1> desiredTypes) const {
+#ifndef NDEBUG
+  assert(desiredTypes.size() > 0 && "expected non-empty types");
+  for (Type t : desiredTypes)
+    assert(t && "expected non-null type");
+#endif // NDEBUG
+
+  SmallVector<Value, 1> vals = lookupOrNull(from);
+  if (vals.empty())
+    return {};
+
+  // There is a mapping and the types match.
+  if (TypeRange(vals) == desiredTypes)
+    return vals;
+
+  // There is a mapping, but the types do not match. Try to find a matching
+  // materialization.
+  auto it = materializations.find(vals);
+  if (it == materializations.end())
+    return {};
+  for (const SmallVector<Value, 1> &mat : it->second)
+    if (TypeRange(mat) == desiredTypes)
+      return mat;
+
+  // No materialization found. Return an empty vector.
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter and Translation State
 //===----------------------------------------------------------------------===//
@@ -673,7 +876,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
                                    UnrealizedConversionCastOp op,
                                    const TypeConverter *converter,
                                    MaterializationKind kind, Type originalType,
-                                   Value mappedValue);
+                                   SmallVector<Value, 1> mappedValue);
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -710,7 +913,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   /// The value in the conversion value mapping that is being replaced by the
   /// results of this unresolved materialization.
-  Value mappedValue;
+  SmallVector<Value, 1> mappedValue;
 };
 } // namespace
 
@@ -779,7 +982,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   LogicalResult remapValues(StringRef valueDiagTag,
                             std::optional<Location> inputLoc,
                             PatternRewriter &rewriter, ValueRange values,
-                            SmallVector<SmallVector<Value>> &remapped);
+                            SmallVector<SmallVector<Value, 1>> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
@@ -825,7 +1028,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// mapping.
   ValueRange buildUnresolvedMaterialization(
       MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-      Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+      SmallVector<Value,1> valueToMap, ValueRange inputs, TypeRange outputTypes,
       Type originalType, const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr);
   Value buildUnresolvedMaterialization(
@@ -833,27 +1036,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
       const TypeConverter *converter,
       UnrealizedConversionCastOp *castOp = nullptr) {
-    return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs,
+    SmallVector<Value,1> valuesToMap;
+    if (valueToMap) valuesToMap.push_back(valueToMap);
+    return buildUnresolvedMaterialization(kind, ip, loc, valuesToMap, inputs,
                                           TypeRange(outputType), originalType,
                                           converter, castOp)
         .front();
   }
 
-  /// Build an N:1 materialization for the given original value that was
-  /// replaced with the given replacement values.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. The conversion mapping can store only 1:1 replacements
-  /// and the conversion patterns only support single Value replacements in the
-  /// adaptor, so N values must be converted back to a single value. This
-  /// function will be deleted when full 1:N support has been added.
-  ///
-  /// This function inserts an argument materialization back to the original
-  /// type.
-  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
-                                 ValueRange replacements, Value originalValue,
-                                 const TypeConverter *converter);
-
   /// Find a replacement value for the given SSA value in the conversion value
   /// mapping. The replacement value must have the same type as the given SSA
   /// value. If there is no replacement value with the correct type, find the
@@ -862,16 +1052,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   Value findOrBuildReplacementValue(Value value,
                                     const TypeConverter *converter);
 
-  /// Unpack an N:1 materialization and return the inputs of the
-  /// materialization. This function unpacks only those materializations that
-  /// were built with `insertNTo1Materialization`.
-  ///
-  /// This is a workaround around incomplete 1:N support in the dialect
-  /// conversion driver. It allows us to write 1:N conversion patterns while
-  /// 1:N support is still missing in the conversion value mapping. This
-  /// function will be deleted when full 1:N support has been added.
-  SmallVector<Value> unpackNTo1Materialization(Value value);
-
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -1101,7 +1281,7 @@ void CreateOperationRewrite::rollback() {
 UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
     ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
     const TypeConverter *converter, MaterializationKind kind, Type originalType,
-    Value mappedValue)
+    SmallVector<Value, 1> mappedValue)
     : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
       converterAndKind(converter, kind), originalType(originalType),
       mappedValue(mappedValue) {
@@ -1111,8 +1291,8 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
 }
 
 void UnresolvedMaterializationRewrite::rollback() {
-  if (mappedValue)
-    rewriterImpl.mapping.erase(mappedValue);
+  if (!mappedValue.empty())
+    rewriterImpl.mapping.eraseMaterialization(mappedValue, op->getResults());
   rewriterImpl.unresolvedMaterializations.erase(getOperation());
   rewriterImpl.nTo1TempMaterializations.erase(getOperation());
   op->erase();
@@ -1160,7 +1340,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
 LogicalResult ConversionPatternRewriterImpl::remapValues(
     StringRef valueDiagTag, std::optional<Location> inputLoc,
     PatternRewriter &rewriter, ValueRange values,
-    SmallVector<SmallVector<Value>> &remapped) {
+    SmallVector<SmallVector<Value, 1>> &remapped) {
   remapped.reserve(llvm::size(values));
 
   for (const auto &it : llvm::enumerate(values)) {
@@ -1168,18 +1348,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type origType = operand.getType();
     Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
 
-    // Find the most recently mapped value. Unpack all temporary N:1
-    // materializations. Such conversions are a workaround around missing
-    // 1:N support in the ConversionValueMapping. (The conversion patterns
-    // already support 1:N replacements.)
-    Value repl = mapping.lookupOrDefault(operand);
-    SmallVector<Value> unpacked = unpackNTo1Materialization(repl);
-
     if (!currentTypeConverter) {
       // The current pattern does not have a type converter. I.e., it does not
       // distinguish between legal and illegal types. For each operand, simply
       // pass through the most recently mapped value.
-      remapped.push_back(std::move(unpacked));
+      SmallVector<Value, 1> repl = mapping.lookupOrDefault(operand);
+      remapped.push_back(repl);
       continue;
     }
 
@@ -1199,44 +1373,23 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       continue;
     }
 
-    if (legalTypes.size() != 1) {
-      // TODO: This is a 1:N conversion. The conversion value mapping does not
-      // store such materializations yet. If the types of the most recently
-      // mapped values do not match, build a target materialization.
-      ValueRange unpackedRange(unpacked);
-      if (TypeRange(unpackedRange) == legalTypes) {
-        remapped.push_back(std::move(unpacked));
-        continue;
-      }
-
-      // Insert a target materialization if the current pattern expects
-      // different legalized types.
-      ValueRange targetMat = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(repl), operandLoc,
-          /*valueToMap=*/Value(), /*inputs=*/unpacked,
-          /*outputType=*/legalTypes, /*originalType=*/origType,
-          currentTypeConverter);
-      remapped.push_back(targetMat);
+    SmallVector<Value, 1> mat = mapping.lookupOrDefault(operand, legalTypes);
+    if (!mat.empty()) {
+      // Mapped value has the correct type or there is an existing
+      // materialization. Or the value is not mapped at all and has the
+      // correct type.
+      remapped.push_back(mat);
       continue;
     }
 
-    // Handle 1->1 type conversions.
-    Type desiredType = legalTypes.front();
-    // Try to find a mapped value with the desired type. (Or the operand itself
-    // if the value is not mapped at all.)
-    Value newOperand = mapping.lookupOrDefault(operand, desiredType);
-    if (newOperand.getType() != desiredType) {
-      // If the looked up value's type does not have the desired type, it means
-      // that the value was replaced with a value of different type and no
-      // target materialization was created yet.
-      Value castValue = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked,
-          /*outputType=*/desiredType, /*originalType=*/origType,
-          currentTypeConverter);
-      newOperand = castValue;
-    }
-    remapped.push_back({newOperand});
+    // Create a materialization for the most recently mapped value.
+    SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
+    ValueRange castValues = buildUnresolvedMaterialization(
+        MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
+        /*valueToMap=*/vals,
+        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+        currentTypeConverter);
+    remapped.push_back(castValues);
   }
   return success();
 }
@@ -1350,11 +1503,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     if (!inputMap) {
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
-      buildUnresolvedMaterialization(
+      Value sourceMat = buildUnresolvedMaterialization(
           MaterializationKind::Source,
           OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
+          /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
+      mapping.map(origArg, sourceMat);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
       continue;
     }
@@ -1369,19 +1523,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       continue;
     }
 
-    // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
-    // dialect conversion. Therefore, we need an argument materialization to
-    // turn the replacement block arguments into a single SSA value that can be
-    // used as a replacement.
+    // This is a 1->1+ mapping.
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
-    if (replArgs.size() == 1) {
-      mapping.map(origArg, replArgs.front());
-    } else {
-      insertNTo1Materialization(
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
-    }
+    mapping.map(origArg, replArgs);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
   }
 
@@ -1402,7 +1547,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    Value valueToMap, ValueRange inputs, TypeRange outputTypes,
+    SmallVector<Value,1> valuesToMap, ValueRange inputs, TypeRange outputTypes,
     Type originalType, const TypeConverter *converter,
     UnrealizedConversionCastOp *castOp) {
   assert((!originalType || kind == MaterializationKind::Target) &&
@@ -1410,10 +1555,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes) {
-    if (valueToMap) {
-      assert(inputs.size() == 1 && "1:N mapping is not supported");
-      mapping.map(valueToMap, inputs.front());
-    }
+    if (!valuesToMap.empty())
+      mapping.mapMaterialization(valuesToMap, inputs);
     return inputs;
   }
 
@@ -1423,36 +1566,23 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
-  if (valueToMap) {
-    assert(outputTypes.size() == 1 && "1:N mapping is not supported");
-    mapping.map(valueToMap, convertOp.getResult(0));
-  }
+  if (!valuesToMap.empty())
+    mapping.mapMaterialization(valuesToMap, {convertOp.getResult(0)});
   if (castOp)
     *castOp = convertOp;
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
-                                                  originalType, valueToMap);
+                                                  originalType, valuesToMap);
   return convertOp.getResults();
 }
 
-void ConversionPatternRewriterImpl::insertNTo1Materialization(
-    OpBuilder::InsertPoint ip, Location loc, ValueRange replacements,
-    Value originalValue, const TypeConverter *converter) {
-  // Insert argument materialization back to the original type.
-  Type originalType = originalValue.getType();
-  UnrealizedConversionCastOp argCastOp;
-  buildUnresolvedMaterialization(
-      MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
-      /*inputs=*/replacements, originalType,
-      /*originalType=*/Type(), converter, &argCastOp);
-  if (argCastOp)
-    nTo1TempMaterializations.insert(argCastOp);
-}
-
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
+  //if (Value repl = mapping.lookupDirectSingleReplacement(value))
+  //  if (repl.getType() == value.getType())
+  //    return repl;
+
   // Find a replacement value with the same type.
-  Value repl = mapping.lookupOrNull(value, value.getType());
-  if (repl)
+  if (Value repl = mapping.lookupOrNull(value, value.getType()))
     return repl;
 
   // Check if the value is dead. No replacement value is needed in that case.
@@ -1467,8 +1597,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   // No replacement value was found. Get the latest replacement value
   // (regardless of the type) and build a source materialization to the
   // original type.
-  repl = mapping.lookupOrNull(value);
-  if (!repl) {
+  SmallVector<Value, 1> repl = mapping.lookupOrNull(value);
+  if (repl.empty()) {
     // No replacement value is registered in the mapping. This means that the
     // value is dropped and no longer needed. (If the value were still needed,
     // a source materialization producing a replacement value "out of thin air"
@@ -1478,34 +1608,12 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
   }
   Value castValue = buildUnresolvedMaterialization(
       MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
-      /*originalType=*/Type(), converter);
-  mapping.map(value, castValue);
+      /*valueToMap=*/repl, /*inputs=*/repl, /*outputType=*/{value.getType()},
+      /*originalType=*/Type(), converter)[0];
+  //mapping.map(value, castValue);
   return castValue;
 }
 
-SmallVector<Value>
-ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) {
-  // Unpack unrealized_conversion_cast ops that were inserted as a N:1
-  // workaround.
-  auto castOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-  if (!castOp)
-    return {value};
-  if (!nTo1TempMaterializations.contains(castOp))
-    return {value};
-  assert(castOp->getNumResults() == 1 && "expected single result");
-
-  SmallVector<Value> result;
-  for (Value v : castOp.getOperands()) {
-    // Keep unpacking if possible. This is needed because during block
-    // signature conversions and 1:N op replacements, the driver may have
-    // inserted two materializations back-to-back: first an argument
-    // materialization, then a target materialization.
-    llvm::append_range(result, unpackNTo1Materialization(v));
-  }
-  return result;
-}
-
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1552,11 +1660,12 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
       }
 
       // Materialize a replacement value "out of thin air".
-      buildUnresolvedMaterialization(
+      Value sourceMat = buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(result),
-          result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(),
+          result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
           /*outputType=*/result.getType(), /*originalType=*/Type(),
           currentTypeConverter);
+      mapping.map(result, sourceMat);
       continue;
     } else {
       // Make sure that the user does not mess with unresolved materializations
@@ -1572,16 +1681,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
     // Remap result to replacement value.
     if (repl.empty())
       continue;
-
-    if (repl.size() == 1) {
-      // Single replacement value: replace directly.
-      mapping.map(result, repl.front());
-    } else {
-      // Multiple replacement values: insert N:1 materialization.
-      insertNTo1Materialization(computeInsertPoint(result), result.getLoc(),
-                                /*replacements=*/repl, /*outputValue=*/result,
-                                currentTypeConverter);
-    }
+    mapping.map(result, repl);
   }
 
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);
@@ -1660,8 +1760,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
   SmallVector<ValueRange> newVals;
-  for (size_t i = 0; i < newValues.size(); ++i)
-    newVals.push_back(newValues.slice(i, 1));
+  for (size_t i = 0; i < newValues.size(); ++i) {
+    if (newValues[i]) {
+      newVals.push_back(newValues.slice(i, 1));
+    } else {
+      newVals.push_back(ValueRange());
+    }
+  }
   impl->notifyOpReplaced(op, newVals);
 }
 
@@ -1729,11 +1834,14 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
                                               impl->currentTypeConverter);
-  impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
+  SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from);
+  assert(mapped.size() == 1 &&
+         "replaceUsesOfBlockArgument is not supported for 1:N replacements");
+  impl->mapping.map(mapped.front(), to);
 }
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
-  SmallVector<SmallVector<Value>> remappedValues;
+  SmallVector<SmallVector<Value, 1>> remappedValues;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
                                remappedValues)))
     return nullptr;
@@ -1746,7 +1854,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                                              SmallVectorImpl<Value> &results) {
   if (keys.empty())
     return success();
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<SmallVector<Value, 1>> remapped;
   if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
                                remapped)))
     return failure();
@@ -1872,7 +1980,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
                                              getTypeConverter());
 
   // Remap the operands of the operation.
-  SmallVector<SmallVector<Value>> remapped;
+  SmallVector<SmallVector<Value, 1>> remapped;
   if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
                                       op->getOperands(), remapped))) {
     return failure();
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 09c5b4b2a0ad50..d0b62e71ab0cf2 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes
           tupleType.getFlattenedTypes(types);
           return success();
         });
-    typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+    typeConverter.addSourceMaterialization(buildMakeTupleOp);
     typeConverter.addTargetMaterialization(buildDecomposeTuple);
 
     populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 466ae7ff6f46f1..749df2cb9ea7cc 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1235,7 +1235,6 @@ struct TestTypeConverter : public TypeConverter {
   using TypeConverter::TypeConverter;
   TestTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 
diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp
index 2cc1fb5d39d788..a03bf0a1023d57 100644
--- a/mlir/test/lib/Transforms/TestDialectConversion.cpp
+++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp
@@ -28,7 +28,6 @@ namespace {
 struct PDLLTypeConverter : public TypeConverter {
   PDLLTypeConverter() {
     addConversion(convertType);
-    addArgumentMaterialization(materializeCast);
     addSourceMaterialization(materializeCast);
   }
 



More information about the llvm-branch-commits mailing list