[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (PR #73522)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Nov 28 01:53:19 PST 2023
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/73522
>From 902ccc3b984b5a060052b87768003ef45870e08f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 25 Nov 2023 19:10:34 +0000
Subject: [PATCH 1/2] [mlir][Vector] Update patterns for flattening vector.xfer
Ops
Updates "flatten vector" patterns to support more cases, namely Ops that
read/write vectors with leading unit dims. For example:
```mlir
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0] ... :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
```
Currently, this `vector.transfer_read` would not be flattened. With this
change, it will be transformed as follows:
```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
into memref<120xi8, strided<[1], offset: ?>>
%0 = vector.transfer_read %collapse_shape[%c0] ... :
memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
%1 = vector.shape_cast %0 : vector<4xi8> to vector<1x1x2x2xi8>
```
`hasMatchingInnerContigousShape` is generalised and renamed as
`isContiguousSlice` to better match the updated functionality. A few
test names are updated to better highlight what case is being exercised.
---
.../Transforms/VectorTransferOpTransforms.cpp | 79 ++++++++++++----
.../Vector/vector-transfer-flatten.mlir | 92 ++++++++++++++++---
2 files changed, 140 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..c1c9659e7b1ab29 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -487,26 +487,75 @@ class TransferWriteDropUnitDimsPattern
} // namespace
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
- ArrayRef<int64_t> targetShape) {
- auto shape = memrefType.getShape();
- SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+///
+/// There are two cases:
+///
+/// 1. The trailing dimensions of `memrefType` match the dimensions of
+/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
+/// not matter in this case):
+///
+/// vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+/// vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
+/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
+/// first dim of `vectorType` that does not match can be arbitrary, but the
+/// remaining leading dims have to be 1:
+///
+/// vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+/// vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+///
+/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
+/// TODO: Update
+static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+
+ ArrayRef<int64_t> targetShape = vectorType.getShape();
+ auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+ // Not used
int64_t offset;
+ SmallVector<int64_t> strides;
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
return false;
+
+ // Non-unit stride in the trailing dimension means that this is memref is
+ // not contiguous.
if (strides.back() != 1)
return false;
- strides.pop_back();
+
+ // Do all but the leading dim of `vectorType` and the trailing dims of
+ // `memrefType` match?
+ bool allTrailingDimsMatch = true;
+
+ // The trailing dimension of `memrefType` after collapsing/flattening the
+ // current dim. This will be a product of the leading dims, hence initialising
+ // to 1.
int64_t flatDim = 1;
- for (auto [targetDim, memrefDim, memrefStride] :
- llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+ strides.pop_back();
+ for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+ targetShapeTrailingDims, memrefType.getShape(), strides))) {
flatDim *= memrefDim;
- if (flatDim != memrefStride || targetDim != memrefDim)
+ // If the memref stride does not match the flattened dim, then this is
+ // memref is not contiguous.
+ if (flatDim != memrefStride)
+ return false;
+
+ // If a non-matching dim was found, then the remaining dims of `VectorType`
+ // should be 1.
+ if (!allTrailingDimsMatch && (targetDim != 1))
return false;
+
+ allTrailingDimsMatch = (targetDim == memrefDim);
}
- return true;
+
+ return allTrailingDimsMatch ? true : (targetShape[0] == 1);
}
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -568,9 +617,7 @@ class FlattenContiguousRowMajorTransferReadPattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
@@ -628,9 +675,7 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae62a5ba43d055a..08ce837be93ffd3 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
-func.func @transfer_read_flattenable_with_offset(
+func.func @transfer_read_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
return %v : vector<5x4x3x2xi8>
}
-// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,7 +18,44 @@ func.func @transfer_read_flattenable_with_offset(
// -----
-func.func @transfer_write_flattenable_with_offset(
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
+func.func @transfer_read_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+ return %v : vector<1x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
+// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -26,7 +63,7 @@ func.func @transfer_write_flattenable_with_offset(
return
}
-// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +72,46 @@ func.func @transfer_write_flattenable_with_offset(
// -----
+func.func @transfer_write_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @transfer_write_dims_mismatch_non_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
return
}
-// CHECK-LABEL: func @transfer_write_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK-SAME: %[[VEC:.+]]: vector<i8>
-// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
-// CHECK: return
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
@@ -54,11 +121,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
return %0 : vector<i8>
}
-// CHECK-LABEL: func @transfer_read_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK: %[[CST:.+]] = arith.constant 0 : i8
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
-// CHECK: return %[[READ]]
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
>From 9a3c60b07387164d39dc961f6227950a011579a9 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 28 Nov 2023 09:41:09 +0000
Subject: [PATCH 2/2] fixup! [mlir][Vector] Update patterns for flattening
vector.xfer Ops
Update comments
---
.../Transforms/VectorTransferOpTransforms.cpp | 57 ++++++++++++-------
1 file changed, 35 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c1c9659e7b1ab29..9f20f75bc7edbdd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -489,37 +489,43 @@ class TransferWriteDropUnitDimsPattern
/// Return true if `vectorType` is a contiguous slice of `memrefType`.
///
-/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
-/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+/// Compares `vectorType` against the trailing dimensions of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
+/// is implemented by iterating over the dims of `vectorType` and `memrefType`
+/// and comparing them starting from the inner-most/right-most dims.
///
-/// There are two cases:
+/// Note that there might be some restriction on the leading dim of
+/// `VectorType`:
+/// 1. if all the trialing dims of `vectorType` match the trailing dims
+/// of `memrefType` then the leading dim of `vectorType` can be arbitrary:
///
-/// 1. The trailing dimensions of `memrefType` match the dimensions of
-/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
-/// not matter in this case):
+/// 1.1 contiguous slice, perfect match
+/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
+/// 1.2 contiguous slice, all dims match except the leading dim: 2 != 4
+/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
///
-/// vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
-/// vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+/// 2. if an "internal" dim of `vectorType` does not match the corresponding
+/// trailing dim in `memrefType` then the remaining leading dims of
+/// `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
///
-/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
-/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
-/// first dim of `vectorType` that does not match can be arbitrary, but the
-/// remaining leading dims have to be 1:
+/// 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
+/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
+/// 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
+/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
+/// 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
+/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+/// 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
+/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
///
-/// vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
-/// vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
-///
-/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// In all cases `memrefType` has to be contiguous (this is checked by looking
/// at strides).
-///
-/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
-/// TODO: Update
static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+ // Get the shape of `vectorType`. The leading dim is treated seperately.
ArrayRef<int64_t> targetShape = vectorType.getShape();
auto targetShapeTrailingDims = targetShape.drop_front(1);
- // Not used
+ // Get the strides of the memref.
int64_t offset;
SmallVector<int64_t> strides;
if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
@@ -538,6 +544,9 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
// current dim. This will be a product of the leading dims, hence initialising
// to 1.
int64_t flatDim = 1;
+
+ // Iterate overall all dim of `vectorType` excluding the leading dim and
+ // compare them against the trailing dims of `memrefType`.
strides.pop_back();
for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
targetShapeTrailingDims, memrefType.getShape(), strides))) {
@@ -547,14 +556,18 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
if (flatDim != memrefStride)
return false;
- // If a non-matching dim was found, then the remaining dims of `VectorType`
- // should be 1.
+ // If a non-matching dim was found previously, then the remaining dims of
+ // `VectorType` should be 1.
if (!allTrailingDimsMatch && (targetDim != 1))
return false;
allTrailingDimsMatch = (targetDim == memrefDim);
}
+ // If all dims of `vectorType` (excluding the leading dim) match the trailing
+ // dims `memrefType`, then this is a contiguous load. If there was a
+ // mismatch, then the internal dims have already been verified to be unit
+ // dims, but the leading dim still has to be checked.
return allTrailingDimsMatch ? true : (targetShape[0] == 1);
}
More information about the Mlir-commits
mailing list