[Mlir-commits] [mlir] andrzej/refactor xfer flatten 2 (PR #95744)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jun 17 00:06:01 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

- **[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n)**
- **[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n)**


---

Patch is 30.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95744.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+23-5) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+247-142) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_read has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
 /// If `targetVectorBitwidth` is provided, the flattening will only happen if
 /// the trailing dimension of the vector read is smaller than the provided
 /// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
-        dyn_cast<MemRefType>(collapsedSource.getType());
+        cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
 /// memref.collapse_shape on the source so that the resulting
 /// vector.transfer_write has a 1D source. Requires the source shape to be
 /// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
 class FlattenContiguousRowMajorTransferWritePattern
     : public OpRewritePattern<vector::TransferWriteOp> {
 public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
       return failure();
     if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
       return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    SmallVector<Value> collapsedIndices =
-        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
-                            transferWriteOp.getIndices(), firstDimToCollapse);
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
 
+    // 1. Collapse the source memref
     Value collapsedSource =
         collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
     int64_t collapsedRank = collapsedSourceType.getRank();
     assert(collapsedRank == firstDimToCollapse + 1);
 
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
         getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
 
+    // 2.2 New indices
+    SmallVector<Value> collapsedIndices =
+        getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+                            transferWriteOp.getIndices(), firstDimToCollapse);
+
+    // 3. Create new vector.transfer_write that writes to the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
         rewriter.create<vector::TransferWriteOp>(
             loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
     flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_write with the new one writing the
+    // collapsed shape
     rewriter.eraseOp(transferWriteOp);
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d25d21b4..e96c4b785b406 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_read_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
 // CHECK:         %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
 
 func.func @transfer_read_dims_match_contiguous_empty_stride(
     %arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
-    return %v : vector<5x4x3x2xi8>
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
 // CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
 // contiguous subset of the memref, so "flattenable".
 
 func.func @transfer_read_dims_mismatch_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
-    return %v : vector<1x1x2x2xi8>
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+  return %v : vector<1x1x2x2xi8>
 }
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME:      %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
 // CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
 // CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
 // CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,51 +78,53 @@ func.func @transfer_read_dims_mismatch_contiguous(
 // -----
 
 func.func @transfer_read_dims_mismatch_non_zero_indices(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x43x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x43x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
 // CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
 
 // CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME:      %[[M_IN:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
 // CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
 // CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK:           %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
 func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-    %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
-    %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+    %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+    %idx0 : index,
+    %idx1 : index) -> vector<2x2xf32> {
+
   %c0 = arith.constant 0 : index
   %cst_1 = arith.constant 0.000000e+00 : f32
-  %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+  %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+    memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
   return %8 : vector<2x2xf32>
 }
 
-//       CHECK:  #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
 // CHECK-LABEL:  func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-//       CHECK:    %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-//       CHECK:    %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK:         %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK:         %[[APPLY:.*]] = affine.apply #[[$MAP]]()
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 //       CHECK-128B:   memref.collapse_shape
@@ -125,80 +135,106 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
 // TODO: This case could be supported via memref.dim
 
 func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-                     %idx_1: index,
-                     %idx_2: index,
-                     %m_in: memref<1x?x4x6xi32>,
-                     %m_out: memref<1x2x6xi32>) {
+    %idx_1: index,
+    %idx_2: index,
+    %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
   %c0 = arith.constant 0 : index
   %c0_i32 = arith.constant 0 : i32
-  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+  %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
     memref<1x?x4x6xi32>, vector<1x2x6xi32>
-  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
-    vector<1x2x6xi32>, memref<1x2x6xi32>
-  return
+  return %v : vector<1x2x6xi32>
 }
 
-// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME:      %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME:      %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK:           %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK:           %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous(
-    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+  return %v : vector<2x1x2x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
 //   CHECK-128B-NOT:   memref.collapse_shape
 
 // -----
 
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
-    %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
-    %c0 = arith.constant 0 : index
-    %cst = arith.constant 0 : i8
-    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
-      memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
-    return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+    %arg : memref<i8>) -> vector<i8> {
+
+  %cst = arith.constant 0 : i8
+  %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+  return %0 : vector<i8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_0d(
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+    %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0 : i8
+  %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+    memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+  return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
 //   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
 func.func @transfer_write_dims_match_contiguous(
-      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
-    %c0 = arith.constant 0 : index
-    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
-      vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
-    return
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+  return
 }
 
 // CHECK-LABEL: func @transfer_write_dims_match_contiguous(
-// CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
 // CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
@@ -208,68 +244,161 @@ func.func @transfer_write_dims_match_contiguous(
 
 // -----
 
+func.func @transfer_write_dims_match_contiguous_empty_stride(
+    %arg : memref<5x4x3x2xi8>,
+    %vec : vector<5x4x3x2xi8>) {
+
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+    vector<5x4x3x2xi8>, memref<5x4x3x2xi8>
+  return
+}
+
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+// CHECK-SAME:    %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME:    %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8>
+// CHECK-DAG:     %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+
+// CHECK-128B-LABEL: func @transfer_write_dims_match_cont...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/95744


More information about the Mlir-commits mailing list