[llvm-branch-commits] [mlir] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (PR #142422)
Momchil Velikov via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jun 9 09:42:16 PDT 2025
https://github.com/momchil-velikov updated https://github.com/llvm/llvm-project/pull/142422
>From b950757c234900db941ed950ea3469b520d2e28a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 2 Jun 2025 15:13:13 +0000
Subject: [PATCH 1/6] [MLIR] Fix incorrect slice contiguity inference in
`vector::isContiguousSlice`
Previously, slices were sometimes marked as non-contiguous when
they were actually contiguous. This occurred when the vector type had
leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``.
In such cases, only the trailing n dimensions of the memref need to be
contiguous, not the entire vector rank.
This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern`
flattens `transfer_read` and `transfer_write`` ops. The pattern used
to collapse a number of dimensions equal the vector rank, which
may be is incorrect when leading dimensions are unit-sized.
This patch fixes the issue by collapsing only as many trailing memref
dimensions as are actually contiguous.
---
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 ++++-----
.../Transforms/VectorTransferOpTransforms.cpp | 8 +-
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++--
.../Vector/vector-transfer-flatten.mlir | 108 +++++++++++++-----
4 files changed, 120 insertions(+), 75 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 6609b28d77b6c..ed06d7a029494 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -49,35 +49,37 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
///
-/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
-/// checked (the other dims are not relevant). Note that for `vectorType` to be
-/// a contiguous slice of `memrefType`, the trailing dims of the latter have
-/// to be contiguous - this is checked by looking at the corresponding strides.
+/// The leading unit dimensions of the vector type are ignored as they
+/// are not relevant to the result. Let N be the number of the vector
+/// dimensions after ignoring a leading sequence of unit ones.
///
-/// There might be some restriction on the leading dim of `VectorType`:
+/// For `vectorType` to be a contiguous slice of `memrefType`
+/// a) the N trailing dimensions of the latter must be contiguous, and
+/// b) the trailing N dimensions of `vectorType` and `memrefType`,
+/// except the first of them, must match.
///
-/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
-/// of `memrefType` then the leading dim of `vectorType` can be
-/// arbitrary.
-///
-/// Ex. 1.1 contiguous slice, perfect match
-/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
-/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
-///
-/// Case 2. If an "internal" dim of `vectorType` does not match the
-/// corresponding trailing dim in `memrefType` then the remaining
-/// leading dims of `vectorType` have to be 1 (the first non-matching
-/// dim can be arbitrary).
+/// Examples:
///
-/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
-/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
-/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
-/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
-/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+/// Ex.1 contiguous slice, perfect match
+/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
+/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
+/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
+/// Ex.3 non-contiguous slice, 2 != 3
+/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
+/// 2 != 3 (allowed)
+/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
+/// 2 != 3 (allowed)
+/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
+/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+/// Ex.7 contiguous slice, memref needs to be contiguous only on the last
+/// dimension
+/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
+/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last
+/// two dimensions, and it isn't
+/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
/// Returns an iterator for all positions in the leading dimensions of `vType`
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..709716365f825 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -630,7 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
if (transferReadOp.getMask())
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // Determinine the first memref dimension to collapse
+ int64_t firstDimToCollapse =
+ sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
// 1. Collapse the source memref
Value collapsedSource =
@@ -722,7 +724,9 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // Determinine the first memref dimension to collapse
+ int64_t firstDimToCollapse =
+ sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
// 1. Collapse the source memref
Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 590d244daef40..109be454bba61 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (vectorType.isScalable())
return false;
- ArrayRef<int64_t> vectorShape = vectorType.getShape();
- auto vecRank = vectorType.getRank();
+ // Ignore a leading contiguous sequence of unit dimensions in the vector.
+ ArrayRef<int64_t> vectorShape =
+ vectorType.getShape().drop_while([](auto v) { return v == 1; });
+ auto vecRank = vectorShape.size();
if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;
- // Extract the trailing dims and strides of the input memref
+ // Extract the trailing dims of the input memref
auto memrefShape = memrefType.getShape().take_back(vecRank);
- // Compare the dims of `vectorType` against `memrefType` (in reverse).
- // In the most basic case, all dims will match.
- auto firstNonMatchingDim =
- std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
- memrefShape.rbegin(), memrefShape.rend());
- if (firstNonMatchingDim.first == vectorShape.rend())
- return true;
-
- // One non-matching dim is still fine, however the remaining leading dims of
- // `vectorType` need to be 1.
- SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
- vectorShape.rend());
-
- return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+ // Compare the dims of `vectorType` against `memrefType`.
+ // All of the dimensions, except the first must match.
+ return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
}
std::optional<StaticTileOffsetRange>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 2b012f1a97971..6da26ec277bd8 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -116,10 +116,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>
// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1032xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
@@ -170,16 +171,18 @@ func.func @transfer_read_leading_dynamic_dims(
return %res : vector<8x4xi8>
}
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
// CHECK-LABEL: func @transfer_read_leading_dynamic_dims
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME: [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
// CHECK-SAME: {in_bounds = [true]}
-// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK-SAME: : memref<?x?xi8, {{.+}}>, vector<32xi8>
// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
// CHECK: return %[[RES]] : vector<8x4xi8>
@@ -232,13 +235,12 @@ func.func @transfer_read_dynamic_dim_to_flatten(
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
-// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
-// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32>
// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
@@ -419,11 +421,10 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>,
// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
-// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+// CHECK: vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
@@ -471,16 +472,18 @@ func.func @transfer_write_leading_dynamic_dims(
return
}
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
// CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME: [%[[ARG2]], %[[COLLAPSED_IDX]]]
// CHECK-SAME: {in_bounds = [true]}
-// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+// CHECK-SAME: : vector<32xi8>, memref<?x?xi8, {{.+}}>
// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
// CHECK-128B: memref.collapse_shape
@@ -530,14 +533,13 @@ func.func @transfer_write_dynamic_to_flatten(
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
-// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<?xi32>
// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
@@ -615,8 +617,12 @@ func.func @negative_out_of_bound_transfer_read(
memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
return %res : vector<5x4x3x2xi8>
}
-// CHECK: func.func @negative_out_of_bound_transfer_read
-// CHECK-NOT: memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-NOT: memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
// -----
@@ -629,5 +635,47 @@ func.func @negative_out_of_bound_transfer_write(
vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
return
}
-// CHECK: func.func @negative_out_of_bound_transfer_write
-// CHECK-NOT: memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-NOT: memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_contig_slice(
+ %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+ return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_contig_slice
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
+// CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+
+// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
+// CHECK-128B-NOT: vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_discontig_slice(
+ %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+ vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+ return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_discontig_slice
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
+// CHECK-128B-NOT: vector.shape_cast
+
>From 956d63ee70c2b6ee84d20db54235a5e1cd39a32e Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Tue, 3 Jun 2025 16:59:27 +0000
Subject: [PATCH 2/6] [fixup] Don't try to collapse non-leftmost dynamic
dimension
Even though it's possible in principle, the affected patterns need
strides to be determined statically.
---
.../Transforms/VectorTransferOpTransforms.cpp | 19 +++++++++++++++----
1 file changed, 15 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 709716365f825..2478cefc303ca 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -582,6 +582,15 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
namespace {
+/// Helper functon to return the index of the last dynamic dimension in `shape`.
+int64_t lastDynIndex(ArrayRef<int64_t> shape) {
+ return static_cast<int64_t>(
+ std::distance(
+ std::find(shape.rbegin(), shape.rend(), ShapedType::kDynamic),
+ shape.rend()) -
+ 1);
+}
+
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -631,8 +640,9 @@ class FlattenContiguousRowMajorTransferReadPattern
return failure();
// Determinine the first memref dimension to collapse
- int64_t firstDimToCollapse =
- sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
+ int64_t firstDimToCollapse = std::max(
+ lastDynIndex(sourceType.getShape()),
+ sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
// 1. Collapse the source memref
Value collapsedSource =
@@ -725,8 +735,9 @@ class FlattenContiguousRowMajorTransferWritePattern
return failure();
// Determinine the first memref dimension to collapse
- int64_t firstDimToCollapse =
- sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
+ int64_t firstDimToCollapse = std::max(
+ lastDynIndex(sourceType.getShape()),
+ sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
// 1. Collapse the source memref
Value collapsedSource =
>From 93f49b64ec47b5e265c3553b8e8d8b0b8679b02e Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Wed, 4 Jun 2025 11:06:39 +0000
Subject: [PATCH 3/6] [fixup] Update a member functon name and fix a test
failure
---
.../Transforms/VectorTransferOpTransforms.cpp | 4 +-
.../Vector/vector-transfer-flatten.mlir | 49 ++++++++++---------
2 files changed, 27 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 2478cefc303ca..4d93f97d0748d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -642,7 +642,7 @@ class FlattenContiguousRowMajorTransferReadPattern
// Determinine the first memref dimension to collapse
int64_t firstDimToCollapse = std::max(
lastDynIndex(sourceType.getShape()),
- sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
+ sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
// 1. Collapse the source memref
Value collapsedSource =
@@ -737,7 +737,7 @@ class FlattenContiguousRowMajorTransferWritePattern
// Determinine the first memref dimension to collapse
int64_t firstDimToCollapse = std::max(
lastDynIndex(sourceType.getShape()),
- sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims());
+ sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
// 1. Collapse the source memref
Value collapsedSource =
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 6da26ec277bd8..6ad12a00d19e7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -113,14 +113,14 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
-// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>
-// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
+// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
+// CHECK: %[[C_0:.+]] = arith.constant 0 : i32
+// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1032xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
@@ -228,20 +228,21 @@ func.func @transfer_read_dynamic_dim_to_flatten(
return %res : vector<1x2x6xi32>
}
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
// CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
-// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
-// CHECK-SAME: memref<1x?x4x6xi32> into memref<?xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]],
-// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32>
-// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
+// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
@@ -525,21 +526,21 @@ func.func @transfer_write_dynamic_to_flatten(
return
}
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
// CHECK-LABEL: func.func @transfer_write_dynamic_to_flatten
// CHECK-SAME: %[[IDX_1:arg0]]: index
// CHECK-SAME: %[[IDX_2:arg1]]: index
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
-
-// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
-// CHECK-SAME: : memref<1x?x4x6xi32> into memref<?xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
-// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
+// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
// CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
>From 366aa6a5a5b276931523df51607f20845fef0f12 Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Thu, 5 Jun 2025 16:46:18 +0000
Subject: [PATCH 4/6] [fixup] One more rename
---
.../Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 4d93f97d0748d..8bc4b66797730 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -642,7 +642,7 @@ class FlattenContiguousRowMajorTransferReadPattern
// Determinine the first memref dimension to collapse
int64_t firstDimToCollapse = std::max(
lastDynIndex(sourceType.getShape()),
- sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
+ sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
// 1. Collapse the source memref
Value collapsedSource =
@@ -737,7 +737,7 @@ class FlattenContiguousRowMajorTransferWritePattern
// Determinine the first memref dimension to collapse
int64_t firstDimToCollapse = std::max(
lastDynIndex(sourceType.getShape()),
- sourceType.getRank() - sourceType.getMaxContiguousTrailingDims());
+ sourceType.getRank() - sourceType.getNumContiguousTrailingDims());
// 1. Collapse the source memref
Value collapsedSource =
>From 5e66da17ebccd6c8b5f50fe103907a6e48b5289d Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Fri, 6 Jun 2025 12:58:34 +0000
Subject: [PATCH 5/6] [fixup] Comment fixes
---
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h | 12 ++++++------
.../Vector/Transforms/VectorTransferOpTransforms.cpp | 3 ++-
2 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index ed06d7a029494..508fb017007ed 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -54,9 +54,9 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
/// dimensions after ignoring a leading sequence of unit ones.
///
/// For `vectorType` to be a contiguous slice of `memrefType`
-/// a) the N trailing dimensions of the latter must be contiguous, and
-/// b) the trailing N dimensions of `vectorType` and `memrefType`,
-/// except the first of them, must match.
+/// a) the N trailing dimensions of `memrefType` must be contiguous, and
+/// b) the N-1 trailing dimensions of `vectorType` and `memrefType`
+/// must match.
///
/// Examples:
///
@@ -69,15 +69,15 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
/// 2 != 3 (allowed)
/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
+/// Ex.5. contiguous slice, leading two unit dims of the vector ignored,
/// 2 != 3 (allowed)
/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
-/// Ex.7 contiguous slice, memref needs to be contiguous only on the last
+/// Ex.7 contiguous slice, memref needs to be contiguous only in the last
/// dimension
/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
-/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last
+/// Ex.8 non-contiguous slice, memref needs to be contiguous in the last
/// two dimensions, and it isn't
/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 8bc4b66797730..05cdcd37cfaa8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -582,7 +582,8 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
namespace {
-/// Helper functon to return the index of the last dynamic dimension in `shape`.
+/// Helper function to return the index of the last dynamic dimension
+/// in `shape` or -1 if there are no dynamic dimensions.
int64_t lastDynIndex(ArrayRef<int64_t> shape) {
return static_cast<int64_t>(
std::distance(
>From 3b17c94fcae7ce6784fb9899d50626ce5b5f954a Mon Sep 17 00:00:00 2001
From: Momchil Velikov <momchil.velikov at arm.com>
Date: Mon, 9 Jun 2025 13:31:01 +0000
Subject: [PATCH 6/6] [fixup] Test tweaks for better coverage
---
.../Vector/vector-transfer-flatten.mlir | 112 +++++++++++-------
1 file changed, 67 insertions(+), 45 deletions(-)
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 6ad12a00d19e7..6de9b9c9722f7 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -70,10 +70,11 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
// -----
-// The shape of the memref and the vector don't match, but the vector is a
-// contiguous subset of the memref, so "flattenable".
+// The shape of the memref and the vector don't match, but the vector,
+// ignoring the unit dimensions, is a contiguous subset of the memref,
+// so "flattenable"
-func.func @transfer_read_dims_mismatch_contiguous(
+func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
%c0 = arith.constant 0 : index
@@ -83,7 +84,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
return %res : vector<1x1x2x2xi8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -92,7 +93,37 @@ func.func @transfer_read_dims_mismatch_contiguous(
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
-// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable"
+
+func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
+ %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x3x2xi8>
+ return %res : vector<2x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
+// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
+// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
+// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]]], %[[C0_I8]] {in_bounds = [true]}
+// CHECK-SAME: : memref<120xi8, strided<[1], offset: ?>>, vector<12xi8>
+// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
+// CHECK: return %[[VEC]] : vector<2x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
// CHECK-128B: memref.collapse_shape
// -----
@@ -380,7 +411,7 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
// -----
-func.func @transfer_write_dims_mismatch_contiguous(
+func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
%vec : vector<1x1x2x2xi8>) {
@@ -390,7 +421,7 @@ func.func @transfer_write_dims_mismatch_contiguous(
return
}
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
@@ -398,7 +429,33 @@ func.func @transfer_write_dims_mismatch_contiguous(
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
-// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
+func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
+ %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+ %vec : vector<2x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+ vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
+// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
+// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0, 1, 2, 3]]
+// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
+// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]]] {in_bounds = [true]}
+// CHECK-SAME: : vector<4xi8>, memref<120xi8, {{.+}}>
+
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims(
// CHECK-128B: memref.collapse_shape
// -----
@@ -620,6 +677,7 @@ func.func @negative_out_of_bound_transfer_read(
}
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
// CHECK-128B-NOT: memref.collapse_shape
@@ -638,45 +696,9 @@ func.func @negative_out_of_bound_transfer_write(
}
// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
// CHECK-128B-NOT: memref.collapse_shape
// CHECK-128B-NOT: vector.shape_cast
-// -----
-
-func.func @discontig_mem_contig_slice(
- %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
- return
-}
-
-// CHECK-LABEL: func.func @discontig_mem_contig_slice
-// CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
-// CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32>
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32>
-// CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
-// CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
-
-// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
-// CHECK-128B-NOT: vector.shape_cast
-
-// -----
-
-func.func @discontig_mem_discontig_slice(
- %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) {
- %c0 = arith.constant 0 : index
- vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} :
- vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
- return
-}
-
-// CHECK-LABEL: func.func @discontig_mem_discontig_slice
-// CHECK-NOT: vector.shape_cast
-
-// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
-// CHECK-128B-NOT: vector.shape_cast
-
More information about the llvm-branch-commits
mailing list