[Mlir-commits] [mlir] andrzej/refactor xfer flatten 3 (PR #95745)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jun 17 00:06:57 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/95745

- **[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n)**
- **[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n)**
- **[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (3/n)**


>From 9a5be401b9fb38cce6de5a83ee00506c2a47264e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 15 Jun 2024 20:27:07 +0100
Subject: [PATCH 1/3] [mlir][vector] Refactor vector-transfer-flatten.mlir
 (nfc) (1/n)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
  * split tests that covered xfer_read + xfer_write into separate tests
    (majority of the existing tests check _one xfer Op_ at a time),
  * organise tests for xfer_read and xfer_write into separate groups.

Note, all tests are preserved and some new tests are added. Deletions
that you will see in `git diff` correspond to xfer_write and xfer_read
Ops being extracted to separate functions (so that there's one xfer Op
per function). In particular, the number of test functions has grown
from 26 to 30.

In addition, this PR unifies the tests so that:
  * input variable names are consistent (e.g. make sure that the input
    memref is always `arg`)
  * CHECK lines use similar indentations
  * 2 x tabs are always used for function arguments, 1 x tab for
    function body

Finally, changes in "VectorTransferOpTransforms.cpp" are merely meant to
unify comments and logic between
  * `FlattenContiguousRowMajorTransferWritePattern` and
  * `FlattenContiguousRowMajorTransferReadPattern`.
---
 .../Transforms/VectorTransferOpTransforms.cpp |  28 +-
 .../Vector/vector-transfer-flatten.mlir       | 349 ++++++++++++------
 2 files changed, 264 insertions(+), 113 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
 /// the trailing dimension of the vector read is smaller than the provided
 /// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
-        dyn_cast<MemRefType>(collapsedSource.getType());
+        cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_write has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
 public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
       return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    SmallVector<Value> collapsedIndices =
-        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
-                            transferWriteOp.getIndices(), firstDimToCollapse);
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
 
+    // 1. Collapse the source memref
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
 
+    // 2.2 New indices
+    SmallVector<Value> collapsedIndices =
+        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+                            transferWriteOp.getIndices(), firstDimToCollapse);
+
+    // 3. Create new vector.transfer_write that writes to the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
         rewriter.create<vector::TransferWriteOp>(
             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_write with the new one writing the
+    // collapsed shape
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d25d21b4..0de5a807affe0 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_read_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
 
 func.func @transfer_read_dims_match_contiguous_empty_stride(
     %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 // contiguous subset of the memref, so "flattenable".
 
 func.func @transfer_read_dims_mismatch_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
-    return %v : vector<1x1x2x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+  return %v : vector<1x1x2x2xi8>
 }
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,135 +78,160 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x43x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
 // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
 func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-    %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+    %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+    %idx0 : index,
+    %idx1 : index) -> vector<2x2xf32> {
+
   %c0 = arith.constant 0 : index
   %cst_1 = arith.constant 0.000000e+00 : f32
-  %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+  %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+    memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
   return %8 : vector<2x2xf32>
 }
 
-//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
 // CHECK-LABEL:  func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK:         %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK:         %[[APPLY:.*]] = affine.apply #[[$MAP]]()
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
+func.func @transfer_read_dims_mismatch_non_contiguous(
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+  return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
 func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x?x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+  return %v : vector<2x1x2x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
-    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+    %arg : memref<i8>) -> vector<i8> {
+
+  %cst = arith.constant 0 : i8
+  %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+  return %0 : vector<i8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_0d
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_0d(
 //   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_write_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-      vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-    return
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+  return
 }
 
 // CHECK-LABEL: func @transfer_write_dims_match_contiguous(
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
 // CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
@@ -208,42 +241,101 @@ func.func @transfer_write_dims_match_contiguous(
 
 // -----
 
+func.func @transfer_write_dims_match_contiguous_empty_stride(
+    %arg : memref<5x4x3x2xi8>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<5x4x3x2xi8>, memref<5x4x3x2xi8>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8>
+// CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+
+// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 func.func @transfer_write_dims_mismatch_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-      vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-    return
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<1x1x2x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+  return
 }
 
 // CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous
-// CHECK-SAME:                                            %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
-// CHECK-SAME:                                            %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
 // CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
 // CHECK:           vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
-// CHECK:           return
-// CHECK:         }
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
 //       CHECK-128B:   memref.collapse_shape
 
 // -----
 
+func.func @transfer_write_dims_mismatch_non_zero_indices(
+    %idx_1: index,
+    %idx_2: index,
+    %arg: memref<1x43x4x6xi32>,
+    %vec: vector<1x2x6xi32>) {
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x43x4x6xi32>
+  return
+}
+
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+
+// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_non_zero_indices(
+// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
+// CHECK-SAME:      %[[ARG:.*]]: memref<1x43x4x6xi32>,
+// CHECK-SAME:      %[[VEC:.*]]: vector<1x2x6xi32>) {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
+// CHECK:           %[[CS:.*]] = memref.collapse_shape %[[ARG]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK:           %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK:           vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+// Overall, the destination memref is non-contiguous. However, the slice to
+// which the input vector is to be written _is_ contiguous. Hence the
+// flattening works fine.
+
 func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
     %value : vector<2x2xf32>,
     %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index, %idx1 : index) {
+    %idx0 : index,
+    %idx1 : index) {
+
   %c0 = arith.constant 0 : index
   vector.transfer_write %value, %subview[%c0, %idx0, %idx1, %c0] {in_bounds = [true, true]} : vector<2x2xf32>, memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>
   return
 }
 
-//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
 // CHECK-LABEL:  func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
-//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
-//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK:         %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK:         %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
 //       CHECK-128B:   memref.collapse_shape
@@ -251,11 +343,13 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
 // -----
 
 func.func @transfer_write_dims_mismatch_non_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-      vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-    return
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<2x1x2x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+  return
 }
 
 // CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
@@ -267,37 +361,76 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
 
 // -----
 
-func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
-      vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
-      return
+// The input memref has a dynamic trailing shape and hence is not flattened.
+// TODO: This case could be supported via memref.dim
+
+func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+    %idx_1: index,
+    %idx_2: index,
+    %vec : vector<1x2x6xi32>,
+    %arg: memref<1x?x4x6xi32>) {
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  vector.transfer_write %vec, %arg[%c0, %idx_1, %idx_2, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x?x4x6xi32>
+  return
 }
 
-// CHECK-LABEL: func.func @transfer_write_0d
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_write_0d(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
 //   CHECK-128B-NOT:   memref.collapse_shape
-//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
-func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
-      %cst = arith.constant 0 : i8
-      %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
-      return %0 : vector<i8>
+// The vector to be written represents a _non-contiguous_ slice of the output
+// memref.
+
+func.func @transfer_write_dims_mismatch_non_contiguous_slice(
+    %arg : memref<5x4x3x2xi8>,
+    %vec : vector<2x1x2x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
+    vector<2x1x2x2xi8>, memref<5x4x3x2xi8>
+  return
 }
 
-// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous_slice(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_0d(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous_slice(
+//   CHECK-128B-NOT:   memref.collapse_shape
+
+// -----
+
+func.func @transfer_write_0d(
+    %arg : memref<i8>,
+    %vec : vector<i8>) {
+
+  vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
+  return
+}
+
+// CHECK-LABEL: func.func @transfer_write_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_write_0d(
 //   CHECK-128B-NOT:   memref.collapse_shape
 //   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// TODO: Categorize + re-format
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
     %c0_i8 = arith.constant 0 : i8
     %c0 = arith.constant 0 : index

>From 2d23d6c3bd32f3861135db4d9c29b36f0c391161 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 15 Jun 2024 20:27:07 +0100
Subject: [PATCH 2/3] [mlir][vector] Refactor vector-transfer-flatten.mlir
 (nfc) (2/n)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:
1. `@transfer_{read|write}_dims_mismatch_non_contiguous` and
   `@transfer_read_flattenable_negative` duplicated
   `@transfer_{read|write}_dims_mismatch_non_contiguous_slice`. Both
   tests are deleted
   (`@transfer_{read|write}_dims_mismatch_non_contiguous_slice` is
   preserved).

2. `@transfer_read_flattenable_negative2` is replaced with
   `@transfer_read_non_contiguous_src` and
   `@transfer_write_non_contiguous_src` (i.e. a dedicated test for
   xfer_read and xfer_read with more descriptive func names)
---
 .../Vector/vector-transfer-flatten.mlir       | 116 +++++++-----------
 1 file changed, 44 insertions(+), 72 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 0de5a807affe0..e96c4b785b406 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -131,25 +131,6 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-
-  %c0 = arith.constant 0 : index
-  %cst = arith.constant 0 : i8
-  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
-  return %v : vector<2x1x2x2xi8>
-}
-
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
@@ -214,6 +195,28 @@ func.func @transfer_read_0d(
 
 // -----
 
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+    %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.transfer_write
 /// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
@@ -342,25 +345,6 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
-func.func @transfer_write_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
-    %vec : vector<2x1x2x2xi8>) {
-
-  %c0 = arith.constant 0 : index
-  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-    vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-  return
-}
-
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
-// CHECK-NOT: memref.collapse_shape
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
@@ -427,6 +411,28 @@ func.func @transfer_write_0d(
 
 // -----
 
+// The strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_write_non_contiguous_src(
+    %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] :
+   vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: func.func @transfer_write_non_contiguous_src
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_write_non_contiguous_src
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// TODO: Categorize + re-format
 ///----------------------------------------------------------------------------------------
@@ -478,40 +484,6 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto
 
 // -----
 
-func.func @transfer_read_flattenable_negative(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x2x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x2x2x2xi8>
-    return %v : vector<2x2x2x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative
-//       CHECK:   vector.transfer_read {{.*}} vector<2x2x2x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
-func.func @transfer_read_flattenable_negative2(
-      %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_negative2
-//       CHECK:   vector.transfer_read {{.*}} vector<5x4x3x2xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2(
-//   CHECK-128B-NOT:   memref.collapse_shape
-
-// -----
-
 func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
    %add = arith.addi %arg0, %arg0 : vector<1x8xi32>
    return %add : vector<1x8xi32>

>From b9ee24df98be5aa4e22929fcbc6b8176fca37acd Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 16 Jun 2024 13:44:12 +0100
Subject: [PATCH 3/3] [mlir][vector] Refactor vector-transfer-flatten.mlir
 (nfc) (3/n)

The main goal of this and subsequent PRs is to unify and categorize
tests in:
  * vector-transfer-flatten.mlir
This should make it easier to identify the edge cases being tested (and
how they differ), remove duplicates and to add tests for scalable
vectors.

The main contributions of this PR:

1. Refactor `@transfer_read_flattenable_with_dynamic_dims_and_indices`,
   i.e. move it near other tests for xfer_read, unify variable names to
   match other xfer_read tests, highlight what makes this a positive
   test to better contrast it with
   `@transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim`

2. Similar changes for
   `@transfer_write_flattenable_with_dynamic_dims_and_indices`.

Depends on https://github.com/llvm/llvm-project/pull/95743 and https://github.com/llvm/llvm-project/pull/95744

**Only review the top top commit**
---
 .../Vector/vector-transfer-flatten.mlir       | 124 ++++++++++--------
 1 file changed, 70 insertions(+), 54 deletions(-)

diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index e96c4b785b406..d38cb4a56d722 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -131,10 +131,42 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
+/// The leading dynamic shapes don't affect whether this example is flattenable
+/// or not as those dynamic shapes are not candidates for flattening anyway.
+
+func.func @transfer_read_leading_dynamic_dims(
+    %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+    %idx_1 : index,
+    %idx_2 : index) -> vector<8x4xi8> {
+
+  %c0_i8 = arith.constant 0 : i8
+  %c0 = arith.constant 0 : index
+  %result = vector.transfer_read %arg[%idx_1, %idx_2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
+  return %result : vector<8x4xi8>
+}
+
+// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
+// CHECK-SAME:    %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
+// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
+// CHECK:       return %[[VEC2D]] : vector<8x4xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_leading_dynamic_dims
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
-func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
     %idx_1: index,
     %idx_2: index,
     %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
@@ -146,11 +178,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
   return %v : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_trailing_dynamic_dim
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
@@ -345,10 +377,40 @@ func.func @transfer_write_dims_mismatch_non_contiguous_non_zero_indices(
 
 // -----
 
+// The leading dynamic shapes don't affect whether this example is flattenable
+// or not as those dynamic shapes are not candidates for flattening anyway.
+
+func.func @transfer_write_leading_dynamic_dims(
+    %vec : vector<8x4xi8>,
+    %arg : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>,
+    %idx_1 : index,
+    %idx_2 : index) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg[%idx_1, %idx_2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
+// CHECK-SAME:    %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:       %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
+// CHECK:       vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+// CHECK-SAME:    [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME:    {in_bounds = [true]}
+// CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+
+// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
+//       CHECK-128B:   memref.collapse_shape
+
+// -----
+
 // The input memref has a dynamic trailing shape and hence is not flattened.
 // TODO: This case could be supported via memref.dim
 
-func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
     %idx_1: index,
     %idx_2: index,
     %vec : vector<1x2x6xi32>,
@@ -361,11 +423,11 @@ func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
   return
 }
 
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices_trailing_dynamic_dim(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
@@ -434,56 +496,10 @@ func.func @transfer_write_non_contiguous_src(
 // -----
 
 ///----------------------------------------------------------------------------------------
-/// TODO: Categorize + re-format
+/// [Pattern: DropUnitDimFromElementwiseOps]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
 ///----------------------------------------------------------------------------------------
 
-func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> {
-    %c0_i8 = arith.constant 0 : i8
-    %c0 = arith.constant 0 : index
-    %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, vector<8x4xi8>
-    return %result : vector<8x4xi8>
-}
-
-// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices
-// CHECK-SAME:    %[[ARG0:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index
-// CHECK:       %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK:       %[[C0:.+]] = arith.constant 0 : index
-// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
-// CHECK:       %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME:    [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]]
-// CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:    : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
-// CHECK:       %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
-// CHECK:       return %[[VEC2D]] : vector<8x4xi8>
-
-// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices(
-//       CHECK-128B:   memref.collapse_shape
-
-// -----
-
-func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>
-    return
-}
-
-// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices
-// CHECK-SAME:    %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK:       %[[C0:.+]] = arith.constant 0 : index
-// CHECK:       %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:    : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
-// CHECK:       %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8>
-// CHECK:       vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME:    [%[[ARG2]], %[[ARG3]], %[[C0]]]
-// CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:    : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
-
-// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices(
-//       CHECK-128B:   memref.collapse_shape
-
-// -----
-
 func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> {
    %add = arith.addi %arg0, %arg0 : vector<1x8xi32>
    return %add : vector<1x8xi32>



More information about the Mlir-commits mailing list