[Mlir-commits] [mlir] c7d9b6e - [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (#142422)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jun 23 06:13:00 PDT 2025
Author: Momchil Velikov
Date: 2025-06-23T14:12:56+01:00
New Revision: c7d9b6ed5d6d438ec3bcb0cad5f4d63f8e3e5eed
URL: https://github.com/llvm/llvm-project/commit/c7d9b6ed5d6d438ec3bcb0cad5f4d63f8e3e5eed
DIFF: https://github.com/llvm/llvm-project/commit/c7d9b6ed5d6d438ec3bcb0cad5f4d63f8e3e5eed.diff
LOG: [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` (#142422)
Previously, slices were sometimes marked as non-contiguous when they
were actually contiguous. This occurred when the vector type had leading
unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such
cases, only the trailing `n` dimensions of the memref need to be
contiguous, not the entire vector rank.
This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern`
flattens `transfer_read` and `transfer_write` ops.
The patterns used to collapse a number of dimensions equal to the vector
rank which missed some opportunities when the leading unit dimensions of
the vector span non-contiguous dimensions of the memref.
Now that the contiguity of the slice is determined correctly, there is a
choice how many dimensions of the
memref to collapse, ranging from
a) the number of vector dimensions after ignoring the leading unit
dimensions, up to
b) the maximum number of contiguous memref dimensions
This patch makes a choice to do minimal memref collapsing. The rationale
behind this decision is that
this way the least amount of information is discarded.
(It follows that in some cases where the patterns used to trigger and
collapse some memref dimensions, after this patch the patterns may
collapse less dimensions).
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 6609b28d77b6c..cc8421b23a074 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -47,37 +47,41 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
/// on a 2D slice. Otherwise, returns a failure.
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
-/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+/// Return true if `vectorType` is a contiguous slice of `memrefType`,
+/// in the sense that it can be read/written from/to a contiguous area
+/// of the memref.
///
-/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
-/// checked (the other dims are not relevant). Note that for `vectorType` to be
-/// a contiguous slice of `memrefType`, the trailing dims of the latter have
-/// to be contiguous - this is checked by looking at the corresponding strides.
+/// The leading unit dimensions of the vector type are ignored as they
+/// are not relevant to the result. Let N be the number of the vector
+/// dimensions after ignoring a leading sequence of unit ones.
///
-/// There might be some restriction on the leading dim of `VectorType`:
+/// For `vectorType` to be a contiguous slice of `memrefType`
+/// a) the N trailing dimensions of `memrefType` must be contiguous, and
+/// b) the N-1 trailing dimensions of `vectorType` and `memrefType`
+/// must match.
///
-/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
-/// of `memrefType` then the leading dim of `vectorType` can be
-/// arbitrary.
-///
-/// Ex. 1.1 contiguous slice, perfect match
-/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
-/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
-///
-/// Case 2. If an "internal" dim of `vectorType` does not match the
-/// corresponding trailing dim in `memrefType` then the remaining
-/// leading dims of `vectorType` have to be 1 (the first non-matching
-/// dim can be arbitrary).
+/// Examples:
///
-/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
-/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
-/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
-/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
-/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
-/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+/// Ex.1 contiguous slice, perfect match
+/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
+/// Ex.2 contiguous slice, the leading dim does not match (2 != 4)
+/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
+/// Ex.3 non-contiguous slice, 2 != 3
+/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.4 contiguous slice, leading unit dimension of the vector ignored,
+/// 2 != 3 (allowed)
+/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.5. contiguous slice, leading two unit dims of the vector ignored,
+/// 2 != 3 (allowed)
+/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
+/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+/// Ex.7 contiguous slice, memref needs to be contiguous only in the last
+/// dimension
+/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
+/// Ex.8 non-contiguous slice, memref needs to be contiguous in the last
+/// two dimensions, and it isn't
+/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
/// Returns an iterator for all positions in the leading dimensions of `vType`
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 384717aeca665..785a8aaf3f0a9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -581,7 +581,6 @@ static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
}
namespace {
-
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
@@ -630,7 +629,11 @@ class FlattenContiguousRowMajorTransferReadPattern
if (transferReadOp.getMask())
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // Determine the first memref dimension to collapse - just enough so we can
+ // read a flattened vector.
+ int64_t firstDimToCollapse =
+ sourceType.getRank() -
+ vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
// 1. Collapse the source memref
Value collapsedSource =
@@ -722,7 +725,11 @@ class FlattenContiguousRowMajorTransferWritePattern
if (transferWriteOp.getMask())
return failure();
- int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+ // Determine the first memref dimension to collapse - just enough so we can
+ // read a flattened vector.
+ int64_t firstDimToCollapse =
+ sourceType.getRank() -
+ vectorType.getShape().drop_while([](auto v) { return v == 1; }).size();
// 1. Collapse the source memref
Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 590d244daef40..0c7cfc13f87da 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (vectorType.isScalable())
return false;
- ArrayRef<int64_t> vectorShape = vectorType.getShape();
- auto vecRank = vectorType.getRank();
+ // Ignore a leading sequence of adjacent unit dimensions in the vector.
+ ArrayRef<int64_t> vectorShape =
+ vectorType.getShape().drop_while([](auto v) { return v == 1; });
+ auto vecRank = vectorShape.size();
if (!memrefType.areTrailingDimsContiguous(vecRank))
return false;
- // Extract the trailing dims and strides of the input memref
+ // Extract the trailing dims of the input memref
auto memrefShape = memrefType.getShape().take_back(vecRank);
- // Compare the dims of `vectorType` against `memrefType` (in reverse).
- // In the most basic case, all dims will match.
- auto firstNonMatchingDim =
- std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
- memrefShape.rbegin(), memrefShape.rend());
- if (firstNonMatchingDim.first == vectorShape.rend())
- return true;
-
- // One non-matching dim is still fine, however the remaining leading dims of
- // `vectorType` need to be 1.
- SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
- vectorShape.rend());
-
- return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+ // Compare the dims of `vectorType` against `memrefType`.
+ // All of the dimensions, except the first must match.
+ return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
}
std::optional<StaticTileOffsetRange>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 45873aa93153d..0f04d3b79b535 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -71,30 +71,97 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
// -----
// The shape of the memref and the vector don't match, but the vector is a
-// contiguous subset of the memref, so "flattenable".
+// contiguous subset of the memref, so "flattenable"
func.func @transfer_read_dims_mismatch_contiguous(
- %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+ %mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
%res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
- memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
- return %res : vector<1x1x2x2xi8>
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x3x2xi8>
+ return %res : vector<2x3x2xi8>
}
-// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
-// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
-// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
-// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
-// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
-// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
-// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
+// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x24xi8, {{.+}}>
+// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED_MEM]][%[[C0]], %[[C0]]], %[[C0_I8]] {in_bounds = [true]}
+// CHECK-SAME: : memref<5x24xi8, strided<[24, 1], offset: ?>>, vector<12xi8>
+// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
+// CHECK: return %[[VEC]] : vector<2x3x2xi8>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
// CHECK-128B: memref.collapse_shape
+// -----
+
+// The shape of the memref and the vector don't match, but the mismatch is only
+// at the leading unit dimensions of the vector.
+
+func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
+ %mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>) -> vector<1x1x4x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0, %c0], %cst :
+ memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>, vector<1x1x4x3x2xi8>
+ return %res : vector<1x1x4x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
+// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
+// CHECK-SAME: -> vector<1x1x4x3x2xi8>
+// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
+// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+// CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
+// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME: {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
+// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
+// CHECK: return %[[VEC]] : vector<1x1x4x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
+// The memref is non-contiguous, but the vector is a contiguous subset of the
+// memref, so "flattenable". The leading unit dimensions of the vector have no
+// effect on the memref area read even if they span the non-contiguous part of
+// the memref.
+
+func.func @transfer_read_non_contiguous_unit_dims(
+ %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
+
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %res = vector.transfer_read %mem[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>, vector<1x1x3x2xi8>
+ return %res : vector<1x1x3x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
+// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
+// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
+// CHECK: return %[[VAL_5]] : vector<1x1x3x2xi8>
+
+// CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
+// CHECK-128B: memref.collapse_shape
+
+
// -----
func.func @transfer_read_dims_mismatch_non_zero_indices(
@@ -110,16 +177,18 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
return %res : vector<1x2x6xi32>
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 * 6)>
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
-// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
-// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>
-// CHECK: %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK-SAME: %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index,
+// CHECK-SAME: %[[MEM:.+]]: memref<1x43x4x6xi32>
+// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK: %[[C_0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_IN:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1x43x24xi32>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]]]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0]], %[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I32]] {in_bounds = [true]} : memref<1x43x24xi32>, vector<12xi32>
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
@@ -174,12 +243,12 @@ func.func @transfer_read_leading_dynamic_dims(
// CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index
// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
-// CHECK-SAME: {in_bounds = [true]}
-// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME: {in_bounds = [true]} : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
// CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8>
// CHECK: return %[[RES]] : vector<8x4xi8>
@@ -229,21 +298,21 @@ func.func @transfer_read_dynamic_dim_to_flatten(
return %res : vector<1x2x6xi32>
}
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 6)>
// CHECK-LABEL: func.func @transfer_read_dynamic_dim_to_flatten
// CHECK-SAME: %[[IDX_1:arg0]]
// CHECK-SAME: %[[IDX_2:arg1]]
// CHECK-SAME: %[[MEM:arg2]]
-// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
-// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
-// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32>
-// CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
+// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?x24xi32>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
+// CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[IDX_1]], %[[COLLAPSED_IDX]]],
+// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?x24xi32>, vector<12xi32>
+// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32>
// CHECK: return %[[RESULT]] : vector<1x2x6xi32>
@@ -381,29 +450,97 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
// -----
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
func.func @transfer_write_dims_mismatch_contiguous(
%mem : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
- %vec : vector<1x1x2x2xi8>) {
+ %vec : vector<2x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
- vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ vector<2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
return
}
-// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
-// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
-// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
-// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
-// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
-// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>,
+// CHECK-SAME: %[[VEC:.+]]: vector<2x2xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
+// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
+// CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
// CHECK-128B: memref.collapse_shape
// -----
+// The shape of the memref and the vector don't match, but the mismatch is only
+// at the leading unit dimensions of the vector.
+
+func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
+ %mem : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>,
+ %vec : vector<1x1x4x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0, %c0] :
+ vector<1x1x4x3x2xi8>, memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
+// CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+// CHECK-SAME: %[[VEC:.+]]: vector<1x1x4x3x2xi8>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
+// CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
+// CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
+// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
+
+// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
+// The memref is non-contiguous, but the vector is a contiguous subset of the
+// memref, so "flattenable". The leading unit dimensions of the vector have no
+// effect on the memref area read even if they span the non-contiguous part of
+// the memref.
+
+func.func @transfer_write_non_contiguous_unit_dims(
+ %mem : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
+ %vec : vector<1x1x3x2xi8>) {
+
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %mem [%c0, %c0, %c0, %c0] :
+ vector<1x1x3x2xi8>, memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_non_contiguous_unit_dims
+// CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>,
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x3x2xi8>) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
+// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8> to vector<6xi8>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<6xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
+
+// CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims(
+// CHECK-128B: memref.collapse_shape
+
+// -----
+
func.func @transfer_write_dims_mismatch_non_zero_indices(
%idx_1: index,
%idx_2: index,
@@ -417,17 +554,18 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
return
}
-// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 * 6)>
// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_zero_indices(
// CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
// CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>,
// CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]]
-// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_2]]]
+// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-DAG-SAME{LITERAL}: [[0], [1], [2, 3]] : memref<1x43x4x6xi32> into memref<1x43x24xi32>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+// CHECK: vector.transfer_write %[[SC]], %[[CS]][%[[C0]], %[[IDX_1]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x43x24xi32>
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
// CHECK-128B-NOT: memref.collapse_shape
@@ -478,12 +616,12 @@ func.func @transfer_write_leading_dynamic_dims(
// CHECK-LABEL: func @transfer_write_leading_dynamic_dims
// CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}}
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8>
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]]
-// CHECK-SAME: {in_bounds = [true]}
+// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] {in_bounds = [true]}
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
// CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
@@ -528,22 +666,21 @@ func.func @transfer_write_dynamic_dim_to_flatten(
return
}
-// CHECK: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 24 + s1 * 6)>
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 6)>
// CHECK-LABEL: func.func @transfer_write_dynamic_dim_to_flatten
// CHECK-SAME: %[[IDX_1:arg0]]: index
// CHECK-SAME: %[[IDX_2:arg1]]: index
// CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32>
// CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32>
-
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
-// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32>
-// CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
-// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[COLLAPSED_MEM:.+]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?x24xi32>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
+// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32>
+// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[IDX_1]], %[[COLLAPSED_IDX]]]
+// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?x24xi32>
// CHECK-128B-LABEL: func @transfer_write_dynamic_dim_to_flatten
// CHECK-128B-NOT: memref.collapse_shape
@@ -621,8 +758,13 @@ func.func @negative_out_of_bound_transfer_read(
memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
return %res : vector<5x4x3x2xi8>
}
-// CHECK: func.func @negative_out_of_bound_transfer_read
-// CHECK-NOT: memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
// -----
@@ -633,5 +775,10 @@ func.func @negative_out_of_bound_transfer_write(
vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
return
}
-// CHECK: func.func @negative_out_of_bound_transfer_write
-// CHECK-NOT: memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-128B-NOT: memref.collapse_shape
+// CHECK-128B-NOT: vector.shape_cast
More information about the Mlir-commits
mailing list