[Mlir-commits] [mlir] 8171eac - [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (#73522)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 4 02:21:37 PST 2023
Author: Andrzej WarzyĆski
Date: 2023-12-04T10:21:32Z
New Revision: 8171eac23fe7756319444c2caa27216a1e9f046a
URL: https://github.com/llvm/llvm-project/commit/8171eac23fe7756319444c2caa27216a1e9f046a
DIFF: https://github.com/llvm/llvm-project/commit/8171eac23fe7756319444c2caa27216a1e9f046a.diff
LOG: [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (#73522)
Updates "flatten vector" patterns to support more cases, namely Ops that
read/write vectors with leading unit dims. For example:
```mlir
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0] ... :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
```
Currently, the `vector.transfer_read` above would not be flattened. With
this
change, it will be rewritten as follows:
```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] :
memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
into memref<120xi8, strided<[1], offset: ?>>
%0 = vector.transfer_read %collapse_shape[%c0] ... :
memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
%1 = vector.shape_cast %0 : vector<4xi8> to vector<1x1x2x2xi8>
```
`hasMatchingInnerContigousShape` is generalised and renamed as
`isContiguousSlice` to better match the updated functionality. A few
test names are updated to better highlight what case is being exercised.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index fc00769a4aaa8..2ab456d4fdbf1 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -42,6 +42,39 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
/// on a 2D slice. Otherwise, returns a failure.
FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
+/// checked (the other dims are not relevant). Note that for `vectorType` to be
+/// a contiguous slice of `memrefType`, the trailing dims of the latter have
+/// to be contiguous - this is checked by looking at the corresponding strides.
+///
+/// There might be some restriction on the leading dim of `VectorType`:
+///
+/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
+/// of `memrefType` then the leading dim of `vectorType` can be
+/// arbitrary.
+///
+/// Ex. 1.1 contiguous slice, perfect match
+/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
+/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
+/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///
+/// Case 2. If an "internal" dim of `vectorType` does not match the
+/// corresponding trailing dim in `memrefType` then the remaining
+/// leading dims of `vectorType` have to be 1 (the first non-matching
+/// dim can be arbitrary).
+///
+/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
+/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
+/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
+/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
+/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 75e1abead973f..aab7075006031 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -491,28 +491,6 @@ class TransferWriteDropUnitDimsPattern
} // namespace
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
- ArrayRef<int64_t> targetShape) {
- auto shape = memrefType.getShape();
- SmallVector<int64_t> strides;
- int64_t offset;
- if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
- return false;
- if (strides.back() != 1)
- return false;
- strides.pop_back();
- int64_t flatDim = 1;
- for (auto [targetDim, memrefDim, memrefStride] :
- llvm::reverse(llvm::zip(targetShape, shape, strides))) {
- flatDim *= memrefDim;
- if (flatDim != memrefStride || targetDim != memrefDim)
- return false;
- }
- return true;
-}
-
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
/// input starting at `firstDimToCollapse`.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
@@ -572,9 +550,7 @@ class FlattenContiguousRowMajorTransferReadPattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
@@ -632,9 +608,7 @@ class FlattenContiguousRowMajorTransferWritePattern
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
- if (!hasMatchingInnerContigousShape(
- sourceType,
- vectorType.getShape().take_back(vectorType.getRank() - 1)))
+ if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
int64_t firstContiguousInnerDim =
sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 48cd67ad86c63..ac0fe64c70cd6 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,3 +249,47 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
// between parallel, reduction and possibly other cases.
return ratio.has_value();
}
+
+bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+ if (vectorType.isScalable())
+ return false;
+
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+ auto vecRank = vectorType.getRank();
+
+ // Extract the trailing dims and strides of the input memref
+ auto memrefShape = memrefType.getShape().take_back(vecRank);
+ int64_t offset;
+ SmallVector<int64_t> stridesFull;
+ if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
+ return false;
+ auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
+
+ // Cond 1: A contiguous memref will always have a unit trailing stride.
+ if (strides.back() != 1)
+ return false;
+
+ // Cond 2: Strides of a contiguous memref have to match the flattened dims.
+ strides = strides.drop_back(1);
+ SmallVector<int64_t> flattenedDims;
+ for (size_t i = 1; i < memrefShape.size(); i++)
+ flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+
+ if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
+ return false;
+
+ // Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
+ // In the most basic case, all dims will match.
+ auto firstNonMatchingDim =
+ std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
+ memrefShape.rbegin(), memrefShape.rend());
+ if (firstNonMatchingDim.first == vectorShape.rend())
+ return true;
+
+ // One non-matching dim is still fine, however the remaining leading dims of
+ // `vectorType` need to be 1.
+ SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
+ vectorShape.rend());
+
+ return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae62a5ba43d05..2ffe85bf3bfa6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
-func.func @transfer_read_flattenable_with_offset(
+func.func @transfer_read_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
return %v : vector<5x4x3x2xi8>
}
-// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,7 +18,45 @@ func.func @transfer_read_flattenable_with_offset(
// -----
-func.func @transfer_write_flattenable_with_offset(
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
+func.func @transfer_read_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+ return %v : vector<1x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
+// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_non_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+ return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_dims_match_contiguous(
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -26,7 +64,7 @@ func.func @transfer_write_flattenable_with_offset(
return
}
-// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +73,48 @@ func.func @transfer_write_flattenable_with_offset(
// -----
+func.func @transfer_write_dims_mismatch_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME: %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME: %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK: vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK: return
+// CHECK: }
+
+// -----
+
+func.func @transfer_write_dims_mismatch_non_contiguous(
+ %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+ return
+}
+
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
return
}
-// CHECK-LABEL: func @transfer_write_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK-SAME: %[[VEC:.+]]: vector<i8>
-// CHECK: vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
-// CHECK: return
+// CHECK-LABEL: func.func @transfer_write_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
@@ -54,11 +124,9 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
return %0 : vector<i8>
}
-// CHECK-LABEL: func @transfer_read_0d
-// CHECK-SAME: %[[ARG:.+]]: memref<i8>
-// CHECK: %[[CST:.+]] = arith.constant 0 : i8
-// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
-// CHECK: return %[[READ]]
+// CHECK-LABEL: func.func @transfer_read_0d
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
// -----
More information about the Mlir-commits
mailing list