[Mlir-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)

Momchil Velikov llvmlistbot at llvm.org
Mon Jun 23 02:11:05 PDT 2025


https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422

>From 2798815d1659bc8bc55bdd4dd1e10333c9c59202 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 2 Jun 2025 15:13:13 +0000
Subject: [PATCH 1/9] [MLIR] Fix incorrect slice contiguity inference in
 `vector::isContiguousSlice`

Previously, slices were sometimes marked as non-contiguous when
they were actually contiguous. This occurred when the vector type had
leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``.
In such cases, only the trailing n dimensions of the memref need to be
contiguous, not the entire vector rank.

This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern`
flattens `transfer_read` and `transfer_write`` ops. The pattern used
to collapse a number of dimensions equal the vector rank, which
may be is incorrect when leading dimensions are unit-sized.

This patch fixes the issue by collapsing only as many trailing memref
dimensions as are actually contiguous.
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  54 ++++-----
 .../Transforms/VectorTransferOpTransforms.cpp |   8 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  25 ++--
 .../Vector/vector-transfer-flatten.mlir       | 108 +++++++++++++-----
 4 files changed, 120 insertions(+), 75 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 6609b28d77b6c..ed06d7a029494 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -49,35 +49,37 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
 
 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
 ///
-/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
-/// checked (the other dims are not relevant). Note that for `vectorType` to be
-/// a contiguous slice of `memrefType`, the trailing dims of the latter have
-/// to be contiguous - this is checked by looking at the corresponding strides.
+/// The leading unit dimensions of the vector type are ignored as they
+/// are not relevant to the result. Let N be the number of the vector
+/// dimensions after ignoring a leading sequence of unit ones.
 ///
-/// There might be some restriction on the leading dim of `VectorType`:
+/// For `vectorType` to be a contiguous slice of `memrefType`
+///   a) the N trailing dimensions of the latter must be contiguous, and
+///   b) the trailing N dimensions of `vectorType` and `memrefType`,
+///      except the first of them, must match.
 ///
-/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
-///         of `memrefType` then the leading dim of `vectorType` can be
-///         arbitrary.
-///
-///        Ex. 1.1 contiguous slice, perfect match
-///          vector<4x3x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
-///          vector<2x3x2xi32> from memref<5x4x3x2xi32>
-///
-/// Case 2. If an "internal" dim of `vectorType` does not match the
-///         corresponding trailing dim in `memrefType` then the remaining
-///         leading dims of `vectorType` have to be 1 (the first non-matching
-///         dim can be arbitrary).
+/// Examples:
 ///
-///        Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
-///          vector<2x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.2  contiguous slice, 2 != 3 and the leading dim == <1>
-///          vector<1x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
-///          vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
-///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///   Ex.1 contiguous slice, perfect match
+///     vector<4x3x2xi32> from memref<5x4x3x2xi32>
+///   Ex.2 contiguous slice, the leading dim does not match (2 != 4)
+///     vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///   Ex.3 non-contiguous slice, 2 != 3
+///     vector<2x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.4 contiguous slice, leading unit dimension of the vector ignored,
+///        2 != 3 (allowed)
+///     vector<1x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
+///         2 != 3 (allowed)
+///     vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
+///     vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///   Ex.7 contiguous slice, memref needs to be contiguous only on the last
+///        dimension
+///     vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
+///   Ex.8 non-contiguous slice, memref needs to be contiguous one the last
+///        two dimensions, and it isn't
+///     vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 
 /// Returns an iterator for all positions in the leading dimensions of `vType`
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 384717aeca665..80d0f196190e2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -630,7 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (transferReadOp.getMask())
       return failure();
 
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+    // Determinine the first memref dimension to collapse
+    int64_t firstDimToCollapse =
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
 
     // 1. Collapse the source memref
     Value collapsedSource =
@@ -722,7 +724,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+    // Determinine the first memref dimension to collapse
+    int64_t firstDimToCollapse =
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
 
     // 1. Collapse the source memref
     Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 590d244daef40..109be454bba61 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   if (vectorType.isScalable())
     return false;
 
-  ArrayRef<int64_t> vectorShape = vectorType.getShape();
-  auto vecRank = vectorType.getRank();
+  // Ignore a leading contiguous sequence of unit dimensions in the vector.
+  ArrayRef<int64_t> vectorShape =
+      vectorType.getShape().drop_while([](auto v) { return v == 1; });
+  auto vecRank = vectorShape.size();
 
   if (!memrefType.areTrailingDimsContiguous(vecRank))
     return false;
 
-  // Extract the trailing dims and strides of the input memref
+  // Extract the trailing dims of the input memref
   auto memrefShape = memrefType.getShape().take_back(vecRank);
 
-  // Compare the dims of `vectorType` against `memrefType` (in reverse).
-  // In the most basic case, all dims will match.
-  auto firstNonMatchingDim =
-      std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
-                    memrefShape.rbegin(), memrefShape.rend());
-  if (firstNonMatchingDim.first == vectorShape.rend())
-    return true;
-
-  // One non-matching dim is still fine, however the remaining leading dims of
-  // `vectorType` need to be 1.
-  SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
-                                   vectorShape.rend());
-
-  return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+  // Compare the dims of `vectorType` against `memrefType`.
+  // All of the dimensions, except the first must match.
+  return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
 }
 
 std::optional<StaticTileOffsetRange>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 45873aa93153d..70b9c53e49cea 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -116,10 +116,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
 // CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
-// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME:         : memref<1x43x4x6xi32> into memref<1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -170,16 +171,18 @@ func.func @transfer_read_leading_dynamic_dims(
   return %res : vector<8x4xi8>
 }
 
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
 // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
 // CHECK-SAME:    %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
 // CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
 // CHECK:         %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME:    [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
 // CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:      : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK-SAME:      : memref<?x?xi8, {{.+}}>, vector<32xi8>
 // CHECK:         %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
 // CHECK:         return %[[RES]] : vector<8x4xi8>
 
@@ -236,13 +239,12 @@ func.func @transfer_read_dynamic_dim_to_flatten(
 // CHECK-SAME:    %[[IDX_2:arg1]]
 // CHECK-SAME:    %[[MEM:arg2]]
 // CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
-// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME{LITERAL}:  [[0, 1, 2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
-// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32>
 // CHECK:              %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
 // CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
 
@@ -423,11 +425,10 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
 // CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<1x2x6xi32>) {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
-// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+// CHECK:           vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -475,16 +476,18 @@ func.func @transfer_write_leading_dynamic_dims(
   return
 }
 
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
 // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
 // CHECK-SAME:    %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
 // CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME:      [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME:      [%[[ARG2]], %[[COLLAPSED_IDX]]]
 // CHECK-SAME:      {in_bounds = [true]}
-// CHECK-SAME:      : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+// CHECK-SAME:      : vector<32xi8>, memref<?x?xi8, {{.+}}>
 
 // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
 //       CHECK-128B:   memref.collapse_shape
@@ -536,14 +539,13 @@ func.func @transfer_write_dynamic_dim_to_flatten(
 // CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
 // CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
 
-// CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
-// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME{LITERAL}:  [[0, 1, 2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -621,8 +623,12 @@ func.func @negative_out_of_bound_transfer_read(
     memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
   return %res : vector<5x4x3x2xi8>
 }
-// CHECK:     func.func @negative_out_of_bound_transfer_read
-// CHECK-NOT:   memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-NOT:     memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
@@ -633,5 +639,47 @@ func.func @negative_out_of_bound_transfer_write(
     vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
   return
 }
-// CHECK:     func.func @negative_out_of_bound_transfer_write
-// CHECK-NOT:   memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-NOT:     memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_contig_slice(
+    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+  return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_contig_slice
+// CHECK-SAME:    %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
+// CHECK-SAME:    %[[VEC:.+]]: vector<1x1x8xi32>
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
+// CHECK:       vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME:    : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+
+// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_discontig_slice(
+    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+  return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_discontig_slice
+// CHECK-NOT:    vector.shape_cast
+
+// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
+//   CHECK-128B-NOT:   vector.shape_cast
+

>From 36a692612475ea6bf21f8e4ad4eaed3b1fea4e96 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 3 Jun 2025 16:59:27 +0000
Subject: [PATCH 2/9] [fixup] Don't try to collapse non-leftmost dynamic
 dimension

Even though it's possible in principle, the affected patterns need
strides to be determined statically.
---
 .../Transforms/VectorTransferOpTransforms.cpp | 19 +++++++++++++++----
 1 file changed, 15 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 80d0f196190e2..9edcfa124d126 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -582,6 +582,15 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
 
 namespace {
 
+/// Helper functon to return the index of the last dynamic dimension in `shape`.
+int64_t lastDynIndex(ArrayRef<int64_t> shape) {
+  return static_cast<int64_t>(
+      std::distance(
+          std::find(shape.rbegin(), shape.rend(), ShapedType::kDynamic),
+          shape.rend()) -
+      1);
+}
+
 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -631,8 +640,9 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
 
     // Determinine the first memref dimension to collapse
-    int64_t firstDimToCollapse =
-        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
+    int64_t firstDimToCollapse = std::max(
+        lastDynIndex(sourceType.getShape()),
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
 
     // 1. Collapse the source memref
     Value collapsedSource =
@@ -725,8 +735,9 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
 
     // Determinine the first memref dimension to collapse
-    int64_t firstDimToCollapse =
-        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
+    int64_t firstDimToCollapse = std::max(
+        lastDynIndex(sourceType.getShape()),
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
 
     // 1. Collapse the source memref
     Value collapsedSource =

>From 0720f87e2646124298a04887c1eb464b8f102c25 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 4 Jun 2025 11:06:39 +0000
Subject: [PATCH 3/9] [fixup] Update a member functon name and fix a test
 failure

---
 .../Transforms/VectorTransferOpTransforms.cpp |  4 +-
 .../Vector/vector-transfer-flatten.mlir       | 49 ++++++++++---------
 2 files changed, 27 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 9edcfa124d126..e6ffb469af79a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -642,7 +642,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     // Determinine the first memref dimension to collapse
     int64_t firstDimToCollapse = std::max(
         lastDynIndex(sourceType.getShape()),
-        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
+        sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
 
     // 1. Collapse the source memref
     Value collapsedSource =
@@ -737,7 +737,7 @@ class FlattenContiguousRowMajorTransferWritePattern
     // Determinine the first memref dimension to collapse
     int64_t firstDimToCollapse = std::max(
         lastDynIndex(sourceType.getShape()),
-        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
+        sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
 
     // 1. Collapse the source memref
     Value collapsedSource =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 70b9c53e49cea..d3e0d0fcac2ab 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -113,14 +113,14 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
-// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>
-// CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME:      %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
+// CHECK-SAME:      %[[MEM:.+]]: memref<1x43x4x6xi32>
+// CHECK:           %[[C_0:.+]] = arith.constant 0 : i32
+// CHECK:           %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
 // CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
 // CHECK-SAME:         : memref<1x43x4x6xi32> into memref<1032xi32>
-// CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
+// CHECK:           %[[COLLAPSED_IDX:.+]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[READ:.+]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -232,20 +232,21 @@ func.func @transfer_read_dynamic_dim_to_flatten(
   return %res : vector<1x2x6xi32>
 }
 
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
 // CHECK-SAME:    %[[IDX_1:arg0]]
 // CHECK-SAME:    %[[IDX_2:arg1]]
 // CHECK-SAME:    %[[MEM:arg2]]
-// CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}:  [[0, 1, 2, 3]]
-// CHECK-SAME:           memref<1x?x4x6xi32> into memref<?xi32>
-// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]],
-// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32>
-// CHECK:              %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK:              %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK:              %[[C0:.+]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK:              %[[RESULT:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
 // CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
 
 
@@ -531,21 +532,21 @@ func.func @transfer_write_dynamic_dim_to_flatten(
   return
 }
 
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
 // CHECK-SAME:    %[[IDX_1:arg0]]: index
 // CHECK-SAME:    %[[IDX_2:arg1]]: index
 // CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
 // CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
-
-// CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}:  [[0, 1, 2, 3]]
-// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<?xi32>
-// CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
-// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
+// CHECK:              %[[C0:.+]] = arith.constant 0 : index
+// CHECK:              %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK:              %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape

>From 97d2a5faf93fe6a13d4e5a0663f034be6dd7bbcd Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 5 Jun 2025 16:46:18 +0000
Subject: [PATCH 4/9] [fixup] One more rename

---
 .../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp  | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e6ffb469af79a..e3e131473751f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -642,7 +642,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     // Determinine the first memref dimension to collapse
     int64_t firstDimToCollapse = std::max(
         lastDynIndex(sourceType.getShape()),
-        sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
+        sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
 
     // 1. Collapse the source memref
     Value collapsedSource =
@@ -737,7 +737,7 @@ class FlattenContiguousRowMajorTransferWritePattern
     // Determinine the first memref dimension to collapse
     int64_t firstDimToCollapse = std::max(
         lastDynIndex(sourceType.getShape()),
-        sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
+        sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
 
     // 1. Collapse the source memref
     Value collapsedSource =

>From dc05e13cf05374c5523640e7ddd2cf38bf656c75 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 6 Jun 2025 12:58:34 +0000
Subject: [PATCH 5/9] [fixup] Comment fixes

---
 mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 12 ++++++------
 .../Vector/Transforms/VectorTransferOpTransforms.cpp |  3 ++-
 2 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ed06d7a029494..508fb017007ed 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -54,9 +54,9 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
 /// dimensions after ignoring a leading sequence of unit ones.
 ///
 /// For `vectorType` to be a contiguous slice of `memrefType`
-///   a) the N trailing dimensions of the latter must be contiguous, and
-///   b) the trailing N dimensions of `vectorType` and `memrefType`,
-///      except the first of them, must match.
+///   a) the N trailing dimensions of `memrefType` must be contiguous, and
+///   b) the N-1 trailing dimensions of `vectorType` and `memrefType`
+///      must match.
 ///
 /// Examples:
 ///
@@ -69,15 +69,15 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
 ///   Ex.4 contiguous slice, leading unit dimension of the vector ignored,
 ///        2 != 3 (allowed)
 ///     vector<1x2x2xi32> from memref<5x4x3x2xi32>
-///   Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
+///   Ex.5. contiguous slice, leading two unit dims of the vector ignored,
 ///         2 != 3 (allowed)
 ///     vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
 ///   Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
 ///     vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
-///   Ex.7 contiguous slice, memref needs to be contiguous only on the last
+///   Ex.7 contiguous slice, memref needs to be contiguous only in the last
 ///        dimension
 ///     vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
-///   Ex.8 non-contiguous slice, memref needs to be contiguous one the last
+///   Ex.8 non-contiguous slice, memref needs to be contiguous in the last
 ///        two dimensions, and it isn't
 ///     vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e3e131473751f..471c0ce82a99a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -582,7 +582,8 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
 
 namespace {
 
-/// Helper functon to return the index of the last dynamic dimension in `shape`.
+/// Helper function to return the index of the last dynamic dimension
+/// in `shape` or -1 if there are no dynamic dimensions.
 int64_t lastDynIndex(ArrayRef<int64_t> shape) {
   return static_cast<int64_t>(
       std::distance(

>From 25f7bdd0ac47d6cee75c067eff991af9e4b5e4a5 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 9 Jun 2025 13:31:01 +0000
Subject: [PATCH 6/9] [fixup] Test tweaks for better coverage

---
 .../Vector/vector-transfer-flatten.mlir       | 112 +++++++++++-------
 1 file changed, 67 insertions(+), 45 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d3e0d0fcac2ab..7f793d2dda0c2 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 
 // -----
 
-// The shape of the memref and the vector don't match, but the vector is a
-// contiguous subset of the memref, so "flattenable".
+// The shape of the memref and the vector don't match, but the vector,
+// ignoring the unit dimensions, is a contiguous subset of the memref,
+// so "flattenable"
 
-func.func @transfer_read_dims_mismatch_contiguous(
+func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
     %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 
   %c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
   return %res : vector<1x1x2x2xi8>
 }
 
-// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
 // CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
 // CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable"
+
+func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
+    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x3x2xi8>
+  return %res : vector<2x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
+// CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
+// CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME:          : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
+// CHECK:         %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
+// CHECK-SAME:      : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
+// CHECK:         %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
+// CHECK:         return %[[VEC]] : vector<2x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
@@ -384,7 +415,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
 
 // -----
 
-func.func @transfer_write_dims_mismatch_contiguous(
+func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
     %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
     %vec : vector<1x1x2x2xi8>) {
 
@@ -394,7 +425,7 @@ func.func @transfer_write_dims_mismatch_contiguous(
   return
 }
 
-// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous_unit_dims
 // CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<1x1x2x2xi8>) {
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
@@ -402,7 +433,33 @@ func.func @transfer_write_dims_mismatch_contiguous(
 // CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
 // CHECK:           vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
 
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
+func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
+    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<2x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+    vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
+// CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
+// CHECK-SAME:    %[[VEC:.+]]: vector<2x2xi8>
+// CHECK:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK:    %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME:          : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
+// CHECK:    %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
+// CHECK:    vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
+// CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
+
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
@@ -626,6 +683,7 @@ func.func @negative_out_of_bound_transfer_read(
 }
 // CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
 // CHECK-NOT:     memref.collapse_shape
+// CHECK-NOT:     vector.shape_cast
 
 // CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -642,45 +700,9 @@ func.func @negative_out_of_bound_transfer_write(
 }
 // CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
 // CHECK-NOT:     memref.collapse_shape
+// CHECK-NOT:     vector.shape_cast
 
 // CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
 //   CHECK-128B-NOT:   memref.collapse_shape
 //   CHECK-128B-NOT:   vector.shape_cast
 
-// -----
-
-func.func @discontig_mem_contig_slice(
-    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
-  return
-}
-
-// CHECK-LABEL: func.func @discontig_mem_contig_slice
-// CHECK-SAME:    %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
-// CHECK-SAME:    %[[VEC:.+]]: vector<1x1x8xi32>
-// CHECK:       %[[C0:.+]] = arith.constant 0 : index
-// CHECK:       %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
-// CHECK:       vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
-// CHECK-SAME:    : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
-
-// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
-//   CHECK-128B-NOT:   vector.shape_cast
-
-// -----
-
-func.func @discontig_mem_discontig_slice(
-    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
-  return
-}
-
-// CHECK-LABEL: func.func @discontig_mem_discontig_slice
-// CHECK-NOT:    vector.shape_cast
-
-// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
-//   CHECK-128B-NOT:   vector.shape_cast
-

>From 9fdb00c627b17d38060954b0f645cad40df22073 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 12 Jun 2025 10:18:42 +0000
Subject: [PATCH 7/9] [fixup] Do minimal collapsing of the memref

---
 .../Transforms/VectorTransferOpTransforms.cpp |  29 ++---
 .../Vector/vector-transfer-flatten.mlir       | 113 +++++++++---------
 2 files changed, 68 insertions(+), 74 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 471c0ce82a99a..785a8aaf3f0a9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -581,17 +581,6 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
 }
 
 namespace {
-
-/// Helper function to return the index of the last dynamic dimension
-/// in `shape` or -1 if there are no dynamic dimensions.
-int64_t lastDynIndex(ArrayRef<int64_t> shape) {
-  return static_cast<int64_t>(
-      std::distance(
-          std::find(shape.rbegin(), shape.rend(), ShapedType::kDynamic),
-          shape.rend()) -
-      1);
-}
-
 /// Rewrites contiguous row-major vector.transfer_read ops by inserting
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -640,10 +629,11 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (transferReadOp.getMask())
       return failure();
 
-    // Determinine the first memref dimension to collapse
-    int64_t firstDimToCollapse = std::max(
-        lastDynIndex(sourceType.getShape()),
-        sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
+    // Determine the first memref dimension to collapse - just enough so we can
+    // read a flattened vector.
+    int64_t firstDimToCollapse =
+        sourceType.getRank() -
+        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
 
     // 1. Collapse the source memref
     Value collapsedSource =
@@ -735,10 +725,11 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    // Determinine the first memref dimension to collapse
-    int64_t firstDimToCollapse = std::max(
-        lastDynIndex(sourceType.getShape()),
-        sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
+    // Determine the first memref dimension to collapse - just enough so we can
+    // read a flattened vector.
+    int64_t firstDimToCollapse =
+        sourceType.getRank() -
+        vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
 
     // 1. Collapse the source memref
     Value collapsedSource =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 7f793d2dda0c2..ae960969e18e7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -88,8 +88,10 @@ func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
 // CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
-// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>, vector<4xi8>
 // CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
 // CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
 
@@ -116,10 +118,10 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
 // CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
 // CHECK:         %[[C0:.+]] = arith.constant 0 : index
 // CHECK:         %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
-// CHECK-SAME:          : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
-// CHECK:         %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
-// CHECK-SAME:      : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME:          : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
+// CHECK:         %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]], %[[C0]]], %[[C0_I8]] {in_bounds = [true]}
+// CHECK-SAME:      : memref<5x24xi8, strided<[24, 1], offset: ?>>, vector<12xi8>
 // CHECK:         %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
 // CHECK:         return %[[VEC]] : vector<2x3x2xi8>
 
@@ -141,17 +143,18 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
   return %res : vector<1x2x6xi32>
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
 // CHECK-SAME:      %[[MEM:.+]]: memref<1x43x4x6xi32>
-// CHECK:           %[[C_0:.+]] = arith.constant 0 : i32
+// CHECK:           %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK:           %[[C_0:.+]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
-// CHECK-SAME:         : memref<1x43x4x6xi32> into memref<1032xi32>
-// CHECK:           %[[COLLAPSED_IDX:.+]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:           %[[READ:.+]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:         : memref<1x43x4x6xi32> into memref<1x43x24xi32>
+// CHECK:           %[[COLLAPSED_IDX:.+]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]]]
+// CHECK:           %[[READ:.+]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0]], %[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I32]] {in_bounds = [true]} : memref<1x43x24xi32>, vector<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -202,18 +205,16 @@ func.func @transfer_read_leading_dynamic_dims(
   return %res : vector<8x4xi8>
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
-
 // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
 // CHECK-SAME:    %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
 // CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
-// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
 // CHECK:         %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME:    [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
-// CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:      : memref<?x?xi8, {{.+}}>, vector<32xi8>
+// CHECK-SAME:      [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:      {in_bounds = [true]} : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
 // CHECK:         %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
 // CHECK:         return %[[RES]] : vector<8x4xi8>
 
@@ -263,7 +264,7 @@ func.func @transfer_read_dynamic_dim_to_flatten(
   return %res : vector<1x2x6xi32>
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 6)>
 
 // CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
 // CHECK-SAME:    %[[IDX_1:arg0]]
@@ -272,11 +273,11 @@ func.func @transfer_read_dynamic_dim_to_flatten(
 // CHECK:              %[[C0_I32:.+]] = arith.constant 0 : i32
 // CHECK:              %[[C0:.+]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
-// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
-// CHECK:              %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:              %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
-// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK-SAME{LITERAL}:  [[0], [1], [2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?x24xi32>
+// CHECK:              %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
+// CHECK:              %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[IDX_1]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?x24xi32>, vector<12xi32>
 // CHECK:              %[[RESULT:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
 // CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
 
@@ -428,10 +429,12 @@ func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
 // CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous_unit_dims
 // CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<1x1x2x2xi8>) {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
-// CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
-// CHECK:           vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
+// CHECK:           %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK:           vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
@@ -451,13 +454,13 @@ func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
 // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
 // CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
 // CHECK-SAME:    %[[VEC:.+]]: vector<2x2xi8>
-// CHECK:    %[[C0:.+]] = arith.constant 0 : index
-// CHECK:    %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
-// CHECK-SAME:          : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
-// CHECK:    %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
-// CHECK:    vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
-// CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:      : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
+// CHECK:         %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
+// CHECK:         vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME:      : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
@@ -477,16 +480,18 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
   return
 }
 
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_write_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
 // CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<1x2x6xi32>) {
-// CHECK-DAG:       %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
-// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]]]
+// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-DAG-SAME{LITERAL}: [[0], [1], [2, 3]] : memref<1x43x4x6xi32> into memref<1x43x24xi32>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
+// CHECK:           vector.transfer_write %[[SC]], %[[CS]][%[[C0]], %[[IDX_1]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x43x24xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -534,18 +539,16 @@ func.func @transfer_write_leading_dynamic_dims(
   return
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
-
 // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
 // CHECK-SAME:    %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
-// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
 // CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME:      [%[[ARG2]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME:      {in_bounds = [true]}
-// CHECK-SAME:      : vector<32xi8>, memref<?x?xi8, {{.+}}>
+// CHECK-SAME:      [%[[ARG2]], %[[ARG3]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME:      : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
 
 // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
 //       CHECK-128B:   memref.collapse_shape
@@ -589,7 +592,7 @@ func.func @transfer_write_dynamic_dim_to_flatten(
   return
 }
 
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 6)>
 
 // CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
 // CHECK-SAME:    %[[IDX_1:arg0]]: index
@@ -598,12 +601,12 @@ func.func @transfer_write_dynamic_dim_to_flatten(
 // CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
 // CHECK:              %[[C0:.+]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}:  [[0], [1, 2, 3]]
-// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
-// CHECK:              %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK-SAME{LITERAL}:  [[0], [1], [2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?x24xi32>
+// CHECK:              %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
 // CHECK:              %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[IDX_1]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?x24xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape

>From 4a1b0efad7fd22f84eff85eebd787502192d14df Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 13 Jun 2025 14:58:18 +0000
Subject: [PATCH 8/9] [fixup] Misc comments

---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   |  4 ++-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  2 +-
 .../Vector/vector-transfer-flatten.mlir       | 36 ++++++++++++-------
 3 files changed, 27 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 508fb017007ed..cc8421b23a074 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -47,7 +47,9 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
 /// on a 2D slice. Otherwise, returns a failure.
 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
 
-/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+/// Return true if `vectorType` is a contiguous slice of `memrefType`,
+/// in the sense that it can be read/written from/to a contiguous area
+/// of the memref.
 ///
 /// The leading unit dimensions of the vector type are ignored as they
 /// are not relevant to the result. Let N be the number of the vector
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 109be454bba61..0c7cfc13f87da 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -258,7 +258,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   if (vectorType.isScalable())
     return false;
 
-  // Ignore a leading contiguous sequence of unit dimensions in the vector.
+  // Ignore a leading sequence of adjacent unit dimensions in the vector.
   ArrayRef<int64_t> vectorShape =
       vectorType.getShape().drop_while([](auto v) { return v == 1; });
   auto vecRank = vectorShape.size();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae960969e18e7..4712871367071 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -70,28 +70,29 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 
 // -----
 
-// The shape of the memref and the vector don't match, but the vector,
-// ignoring the unit dimensions, is a contiguous subset of the memref,
-// so "flattenable"
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable". The leading unit dimensions
+// of the vector have no effect on the memref area read even if they
+// span a non-contiguous part of the memref.
 
 func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
-    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+    %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 
   %c0 = arith.constant 0 : index
   %cst = arith.constant 0 : i8
   %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
-    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+    memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
   return %res : vector<1x1x2x2xi8>
 }
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
-// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
 // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
-// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
-// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>, vector<4xi8>
+// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
 // CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
 // CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
 
@@ -416,31 +417,40 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
 
 // -----
 
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable". The leading unit dimensions
+// of the vector have no effect on the memref area written even if they
+// span a non-contiguous part of the memref.
+
 func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
-    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
     %vec : vector<1x1x2x2xi8>) {
 
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
-    vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
   return
 }
 
 // CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous_unit_dims
-// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<1x1x2x2xi8>) {
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
 // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
-// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
+// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
 // CHECK:           %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
-// CHECK:           vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
+// CHECK:           vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME:        {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
 func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
     %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
     %vec : vector<2x2xi8>) {

>From 1740d45de6d3d4af83891d3804c9b3eba0a26415 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 16 Jun 2025 12:57:33 +0000
Subject: [PATCH 9/9] [fixup] Add/change a few tests

---
 .../Vector/vector-transfer-flatten.mlir       | 197 ++++++++++++------
 1 file changed, 130 insertions(+), 67 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 4712871367071..0f04d3b79b535 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -70,41 +70,10 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 
 // -----
 
-// The shape of the memref and the vector don't match, but the vector is a
-// contiguous subset of the memref, so "flattenable". The leading unit dimensions
-// of the vector have no effect on the memref area read even if they
-// span a non-contiguous part of the memref.
-
-func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
-    %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0 : i8
-  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
-    memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
-  return %res : vector<1x1x2x2xi8>
-}
-
-// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
-// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
-// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
-// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
-// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
-// CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
-// CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
-//       CHECK-128B:   memref.collapse_shape
-
-// -----
-
 // The shape of the memref and the vector don't match, but the vector is a
 // contiguous subset of the memref, so "flattenable"
 
-func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
+func.func @transfer_read_dims_mismatch_contiguous(
     %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
 
   %c0 = arith.constant 0 : index
@@ -114,7 +83,7 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
   return %res : vector<2x3x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
 // CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
 // CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
 // CHECK:         %[[C0:.+]] = arith.constant 0 : index
@@ -126,9 +95,73 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
 // CHECK:         %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
 // CHECK:         return %[[VEC]] : vector<2x3x2xi8>
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
+// The shape of the memref and the vector don't match, but the mismatch is only
+// at the leading unit dimensions of the vector.
+
+func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
+    %mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>) -> vector<1x1x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0, %c0], %cst :
+    memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>, vector<1x1x4x3x2xi8>
+  return %res : vector<1x1x4x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
+// CHECK-SAME:    %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
+// CHECK-SAME:    -> vector<1x1x4x3x2xi8>
+// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
+// CHECK-SAME:    : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+// CHECK-SAME:      into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
+// CHECK:       %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
+// CHECK:       %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
+// CHECK:       return %[[VEC]] : vector<1x1x4x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
+// The memref is non-contiguous, but the vector is a contiguous subset of the
+// memref, so "flattenable". The leading unit dimensions of the vector have no
+// effect on the memref area read even if they span the non-contiguous part of
+// the memref.
+
+func.func @transfer_read_non_contiguous_unit_dims(
+    %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x3x2xi8>
+  return %res : vector<1x1x3x2xi8>
+}
+
+// CHECK-LABEL:   func.func @transfer_read_non_contiguous_unit_dims(
+// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
+// CHECK:           return %[[VAL_5]] : vector<1x1x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
 
+
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
@@ -418,61 +451,92 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
 // -----
 
 // The shape of the memref and the vector don't match, but the vector is a
-// contiguous subset of the memref, so "flattenable". The leading unit dimensions
-// of the vector have no effect on the memref area written even if they
-// span a non-contiguous part of the memref.
+// contiguous subset of the memref, so "flattenable".
 
-func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
-    %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
-    %vec : vector<1x1x2x2xi8>) {
+func.func @transfer_write_dims_mismatch_contiguous(
+    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<2x2xi8>) {
 
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
-    vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
+    vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
   return
 }
 
-// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous_unit_dims
-// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
-// CHECK-SAME:      %[[VEC:.*]]: vector<1x1x2x2xi8>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
+// CHECK-SAME:    %[[VEC:.+]]: vector<2x2xi8>
+// CHECK:         %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
 // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
-// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
-// CHECK:           %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
-// CHECK:           vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
-// CHECK-SAME:        {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
+// CHECK-SAME:      : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
+// CHECK:         %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
+// CHECK:         vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME:      : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
+
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
+// The shape of the memref and the vector don't match, but the mismatch is only
+// at the leading unit dimensions of the vector.
+
+func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
+    %mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>,
+    %vec : vector<1x1x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0, %c0] :
+    vector<1x1x4x3x2xi8>, memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+
+  return
+}
+
+// CHECK-LABEL:  func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
+// CHECK-SAME:   %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+// CHECK-SAME:   %[[VEC:.+]]: vector<1x1x4x3x2xi8>
+// CHECK:          %[[C0:.+]] = arith.constant 0 : index
+// CHECK:          %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
+// CHECK-SAME:       : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+// CHECK-SAME:         into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
+// CHECK:          %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
+// CHECK:          vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME:       {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
-// The shape of the memref and the vector don't match, but the vector is a
-// contiguous subset of the memref, so "flattenable".
+// The memref is non-contiguous, but the vector is a contiguous subset of the
+// memref, so "flattenable". The leading unit dimensions of the vector have no
+// effect on the memref area read even if they span the non-contiguous part of
+// the memref.
 
-func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
-    %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
-    %vec : vector<2x2xi8>) {
+func.func @transfer_write_non_contiguous_unit_dims(
+    %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
+    %vec : vector<1x1x3x2xi8>) {
 
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
-    vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    vector<1x1x3x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
   return
 }
 
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
-// CHECK-SAME:    %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
-// CHECK-SAME:    %[[VEC:.+]]: vector<2x2xi8>
-// CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-LABEL:   func.func @transfer_write_non_contiguous_unit_dims
+// CHECK-SAME:      %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
+// CHECK-SAME:      %[[VEC:.*]]: vector<1x1x3x2xi8>) {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
 // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
-// CHECK-SAME:      : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
-// CHECK:         %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
-// CHECK:         vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
-// CHECK-SAME:      : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
+// CHECK-SAME:        : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
+// CHECK:           %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8> to vector<6xi8>
+// CHECK:           vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME:        {in_bounds = [true]} : vector<6xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
 
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
+// CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
@@ -718,4 +782,3 @@ func.func @negative_out_of_bound_transfer_write(
 // CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
 //   CHECK-128B-NOT:   memref.collapse_shape
 //   CHECK-128B-NOT:   vector.shape_cast
-



More information about the Mlir-commits mailing list