[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (PR #73522)
Benjamin Maxwell
llvmlistbot at llvm.org
Thu Nov 30 08:12:51 PST 2023
================
@@ -487,26 +487,88 @@ class TransferWriteDropUnitDimsPattern
} // namespace
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
- ArrayRef<int64_t> targetShape) {
- auto shape = memrefType.getShape();
- SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
+/// is implemented by iterating over the dims of `vectorType` and `memrefType`
+/// and comparing them starting from the inner-most/right-most dims.
+///
+/// Note that there might be some restriction on the leading dim of
+/// `VectorType`:
+/// 1. if all the trialing dims of `vectorType` match the trailing dims
+/// of `memrefType` then the leading dim of `vectorType` can be arbitrary:
+///
+/// 1.1 contiguous slice, perfect match
+/// vector<4x3x2xi32> from memref<5x4x3x2xi32>
+/// 1.2 contiguous slice, all dims match except the leading dim: 2 != 4
+/// vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///
+/// 2. if an "internal" dim of `vectorType` does not match the corresponding
+/// trailing dim in `memrefType` then the remaining leading dims of
+/// `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
+///
+/// 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
+/// vector<2x2x2xi32> from memref<5x4x3x2xi32>
+/// 2.2 contiguous slice, 2 != 3 and the leading dim == <1>
+/// vector<1x2x2xi32> from memref<5x4x3x2xi32>
+/// 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
+/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+/// 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
+/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///
+/// In all cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+
+ // Get the shape of `vectorType`. The leading dim is treated seperately.
+ ArrayRef<int64_t> targetShape = vectorType.getShape();
+ auto targetShapeTrailingDims = targetShape.drop_front(1);
----------------
MacDue wrote:
For now, explicitly returning `false` for scalable vectors should make it clear they they're not supported here (yet)
https://github.com/llvm/llvm-project/pull/73522
More information about the Mlir-commits
mailing list