[Mlir-commits] [mlir] andrzej/refactor xfer flatten 2 (PR #95744)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 17 00:06:01 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
- **[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (1/n)**
- **[mlir][vector] Refactor vector-transfer-flatten.mlir (nfc) (2/n)**
---
Patch is 30.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95744.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+23-5)
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+247-142)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c131fde517f80..4c93d3841bf87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -568,6 +568,7 @@ namespace {
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
+///
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
@@ -617,7 +618,7 @@ class FlattenContiguousRowMajorTransferReadPattern
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
- dyn_cast<MemRefType>(collapsedSource.getType());
+ cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
@@ -658,6 +659,10 @@ class FlattenContiguousRowMajorTransferReadPattern
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_write has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
+///
+/// If `targetVectorBitwidth` is provided, the flattening will only happen if
+/// the trailing dimension of the vector read is smaller than the provided
+/// bitwidth.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
public:
@@ -674,9 +679,12 @@ class FlattenContiguousRowMajorTransferWritePattern
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+ // 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
+ // If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
@@ -688,7 +696,6 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
@@ -697,10 +704,9 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
- SmallVector<Value> collapsedIndices =
- getCollapsedIndices(rewriter, loc, sourceType.getShape(),
- transferWriteOp.getIndices(), firstDimToCollapse);
+ int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
@@ -708,11 +714,20 @@ class FlattenContiguousRowMajorTransferWritePattern
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
+ // 2. Generate input args for a new vector.transfer_read that will read
+ // from the collapsed memref.
+ // 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+ // 2.2 New indices
+ SmallVector<Value> collapsedIndices =
+ getCollapsedIndices(rewriter, loc, sourceType.getShape(),
+ transferWriteOp.getIndices(), firstDimToCollapse);
+
+ // 3. Create new vector.transfer_write that writes to the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
@@ -721,6 +736,9 @@ class FlattenContiguousRowMajorTransferWritePattern
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+ // 4. Replace the old transfer_write with the new one writing the
+ // collapsed shape
rewriter.eraseOp(transferWriteOp);
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index d7365d25d21b4..e96c4b785b406 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,17 +1,23 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B
+///----------------------------------------------------------------------------------------
+/// vector.transfer_read
+/// [Pattern: FlattenContiguousRowMajorTransferReadPattern]
+///----------------------------------------------------------------------------------------
+
func.func @transfer_read_dims_match_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
// CHECK-LABEL: func @transfer_read_dims_match_contiguous
-// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8>
@@ -24,11 +30,12 @@ func.func @transfer_read_dims_match_contiguous(
func.func @transfer_read_dims_match_contiguous_empty_stride(
%arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
- return %v : vector<5x4x3x2xi8>
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride(
@@ -47,16 +54,17 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
// contiguous subset of the memref, so "flattenable".
func.func @transfer_read_dims_mismatch_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
- return %v : vector<1x1x2x2xi8>
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+ return %v : vector<1x1x2x2xi8>
}
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
@@ -70,51 +78,53 @@ func.func @transfer_read_dims_mismatch_contiguous(
// -----
func.func @transfer_read_dims_mismatch_non_zero_indices(
- %idx_1: index,
- %idx_2: index,
- %m_in: memref<1x43x4x6xi32>,
- %m_out: memref<1x2x6xi32>) {
+ %idx_1: index,
+ %idx_2: index,
+ %arg: memref<1x43x4x6xi32>) -> vector<1x2x6xi32>{
+
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %v = vector.transfer_read %arg[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x43x4x6xi32>, vector<1x2x6xi32>
- vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x6xi32>, memref<1x2x6xi32>
- return
+ return %v : vector<1x2x6xi32>
}
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>,
-// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
+// CHECK-SAME: %[[M_IN:.*]]: memref<1x43x4x6xi32>
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[M_IN]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
-// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
-// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
// -----
+// Overall, the source memref is non-contiguous. However, the slice from which
+// the output vector is to be read _is_ contiguous. Hence the flattening works fine.
+
func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
- %subview : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
- %idx0 : index, %idx1 : index) -> vector<2x2xf32> {
+ %arg : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>,
+ %idx0 : index,
+ %idx1 : index) -> vector<2x2xf32> {
+
%c0 = arith.constant 0 : index
%cst_1 = arith.constant 0.000000e+00 : f32
- %8 = vector.transfer_read %subview[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
+ %8 = vector.transfer_read %arg[%c0, %idx0, %idx1, %c0], %cst_1 {in_bounds = [true, true]} :
+ memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>>, vector<2x2xf32>
return %8 : vector<2x2xf32>
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2)>
+
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
-// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
-// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : memref<1x3x3x2xf32, strided<[40, 10, 2, 1], offset: ?>> into memref<1x3x6xf32, strided<[40, 10, 1], offset: ?>>
+// CHECK: %[[APPLY:.*]] = affine.apply #[[$MAP]]()
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// CHECK-128B: memref.collapse_shape
@@ -125,80 +135,106 @@ func.func @transfer_read_dims_mismatch_non_contiguous_non_zero_indices(
// TODO: This case could be supported via memref.dim
func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
- %idx_1: index,
- %idx_2: index,
- %m_in: memref<1x?x4x6xi32>,
- %m_out: memref<1x2x6xi32>) {
+ %idx_1: index,
+ %idx_2: index,
+ %m_in: memref<1x?x4x6xi32>) -> vector<1x2x6xi32> {
+
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
- %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
+ %v = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} :
memref<1x?x4x6xi32>, vector<1x2x6xi32>
- vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x6xi32>, memref<1x2x6xi32>
- return
+ return %v : vector<1x2x6xi32>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
-// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[M_IN:.*]]: memref<1x?x4x6xi32>,
-// CHECK-SAME: %[[M_OUT:.*]]: memref<1x2x6xi32>) {
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[M_IN]]{{.*}} : memref<1x?x4x6xi32>, vector<1x2x6xi32>
-// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_OUT]]{{.*}} : memref<1x2x6xi32> into memref<12xi32>
-// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32>
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes(
// CHECK-128B-NOT: memref.collapse_shape
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
+// The vector to be read represents a _non-contiguous_ slice of the input
+// memref.
+
+func.func @transfer_read_dims_mismatch_non_contiguous_slice(
+ %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_slice(
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_slice(
// CHECK-128B-NOT: memref.collapse_shape
// -----
-func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
- %arg : memref<5x4x3x2xi8>) -> vector<2x1x2x2xi8> {
- %c0 = arith.constant 0 : index
- %cst = arith.constant 0 : i8
- %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8>, vector<2x1x2x2xi8>
- return %v : vector<2x1x2x2xi8>
+func.func @transfer_read_0d(
+ %arg : memref<i8>) -> vector<i8> {
+
+ %cst = arith.constant 0 : i8
+ %0 = vector.transfer_read %arg[], %cst : memref<i8>, vector<i8>
+ return %0 : vector<i8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func @transfer_read_0d(
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+// Strides make the input memref non-contiguous, hence non-flattenable.
+
+func.func @transfer_read_non_contiguous_src(
+ %arg : memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 8, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
+ return %v : vector<5x4x3x2xi8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_src
// CHECK-NOT: memref.collapse_shape
// CHECK-NOT: vector.shape_cast
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride(
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_src
// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
// -----
+///----------------------------------------------------------------------------------------
+/// vector.transfer_write
+/// [Pattern: FlattenContiguousRowMajorTransferWritePattern]
+///----------------------------------------------------------------------------------------
+
func.func @transfer_write_dims_match_contiguous(
- %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
- vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
- return
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
}
// CHECK-LABEL: func @transfer_write_dims_match_contiguous(
-// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
-// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
@@ -208,68 +244,161 @@ func.func @transfer_write_dims_match_contiguous(
// -----
+func.func @transfer_write_dims_match_contiguous_empty_stride(
+ %arg : memref<5x4x3x2xi8>,
+ %vec : vector<5x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<5x4x3x2xi8>, memref<5x4x3x2xi8>
+ return
+}
+
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous_empty_stride(
+// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
+// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
+// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8>
+// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8>
+// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
+
+// CHECK-128B-LABEL: func @transfer_write_dims_match_cont...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/95744
More information about the Mlir-commits
mailing list