[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (PR #73522)

Cullen Rhodes llvmlistbot at llvm.org
Wed Nov 29 08:34:06 PST 2023


================
@@ -487,26 +487,88 @@ class TransferWriteDropUnitDimsPattern
 
 } // namespace
 
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
-                                              ArrayRef<int64_t> targetShape) {
-  auto shape = memrefType.getShape();
-  SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
+/// is implemented by iterating over the dims of `vectorType` and `memrefType`
+/// and comparing them starting from the inner-most/right-most dims.
+///
+/// Note that there might be some restriction on the leading dim of
+/// `VectorType`:
+///   1. if all the trialing dims of `vectorType` match the trailing dims
+///     of `memrefType` then the leading dim of `vectorType` can be arbitrary:
+///
+///       1.1 contiguous slice, perfect match
+///         vector<4x3x2xi32> from memref<5x4x3x2xi32>
+///       1.2 contiguous slice, all dims match except the leading dim: 2 != 4
+///         vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///
+///   2. if an "internal" dim of `vectorType` does not match the corresponding
+///     trailing dim in `memrefType` then the remaining leading dims of
+///     `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
+///
+///       2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
+///         vector<2x2x2xi32> from memref<5x4x3x2xi32>
+///       2.2  contiguous slice, 2 != 3 and the leading dim == <1>
+///         vector<1x2x2xi32> from memref<5x4x3x2xi32>
+///       2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
+///         vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+///       2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
+///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///
+/// In all cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+
+  // Get the shape of `vectorType`. The leading dim is treated seperately.
+  ArrayRef<int64_t> targetShape = vectorType.getShape();
+  auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+  // Get the strides of the memref.
   int64_t offset;
+  SmallVector<int64_t> strides;
   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
     return false;
+
+  // Non-unit stride in the trailing dimension means that this is memref is
+  // not contiguous.
   if (strides.back() != 1)
     return false;
-  strides.pop_back();
+
+  // Do all but the leading dim of `vectorType` and the trailing dims of
+  // `memrefType` match?
+  bool allTrailingDimsMatch = true;
+
+  // The trailing dimension of `memrefType` after collapsing/flattening the
+  // current dim. This will be a product of the leading dims, hence initialising
+  // to 1.
   int64_t flatDim = 1;
-  for (auto [targetDim, memrefDim, memrefStride] :
-       llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+
+  // Iterate overall all dim of `vectorType` excluding the leading dim and
+  // compare them against the trailing dims of `memrefType`.
+  strides.pop_back();
+  for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+           targetShapeTrailingDims, memrefType.getShape(), strides))) {
----------------
c-rhodes wrote:

`memref.getShape()` has one more element than `targetShapeTrailingDims` and `strides`, so I initially thought this was comparing the trailing dims of the vector type with all but the trailing dim of the memref type, but it seems reverse is evaluated before the zip so it is the leading dim of the memrefType being dropped (as expected). I personally found this a bit baffling (this code is quite dense as it is), `memrefType.getShape().drop_front(1)` would be more explicit.

https://github.com/llvm/llvm-project/pull/73522


More information about the Mlir-commits mailing list