[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (PR #73522)

Cullen Rhodes llvmlistbot at llvm.org
Wed Nov 29 08:34:09 PST 2023


================
@@ -487,26 +487,88 @@ class TransferWriteDropUnitDimsPattern
 
 } // namespace
 
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
-                                              ArrayRef<int64_t> targetShape) {
-  auto shape = memrefType.getShape();
-  SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
+/// is implemented by iterating over the dims of `vectorType` and `memrefType`
+/// and comparing them starting from the inner-most/right-most dims.
+///
+/// Note that there might be some restriction on the leading dim of
+/// `VectorType`:
+///   1. if all the trialing dims of `vectorType` match the trailing dims
+///     of `memrefType` then the leading dim of `vectorType` can be arbitrary:
+///
+///       1.1 contiguous slice, perfect match
+///         vector<4x3x2xi32> from memref<5x4x3x2xi32>
+///       1.2 contiguous slice, all dims match except the leading dim: 2 != 4
+///         vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///
+///   2. if an "internal" dim of `vectorType` does not match the corresponding
+///     trailing dim in `memrefType` then the remaining leading dims of
+///     `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
+///
+///       2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
+///         vector<2x2x2xi32> from memref<5x4x3x2xi32>
+///       2.2  contiguous slice, 2 != 3 and the leading dim == <1>
+///         vector<1x2x2xi32> from memref<5x4x3x2xi32>
+///       2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
+///         vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+///       2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
+///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///
+/// In all cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+
+  // Get the shape of `vectorType`. The leading dim is treated seperately.
+  ArrayRef<int64_t> targetShape = vectorType.getShape();
+  auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+  // Get the strides of the memref.
   int64_t offset;
+  SmallVector<int64_t> strides;
   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
     return false;
+
+  // Non-unit stride in the trailing dimension means that this is memref is
+  // not contiguous.
   if (strides.back() != 1)
     return false;
-  strides.pop_back();
+
+  // Do all but the leading dim of `vectorType` and the trailing dims of
+  // `memrefType` match?
+  bool allTrailingDimsMatch = true;
+
+  // The trailing dimension of `memrefType` after collapsing/flattening the
+  // current dim. This will be a product of the leading dims, hence initialising
+  // to 1.
   int64_t flatDim = 1;
-  for (auto [targetDim, memrefDim, memrefStride] :
-       llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+
+  // Iterate overall all dim of `vectorType` excluding the leading dim and
+  // compare them against the trailing dims of `memrefType`.
+  strides.pop_back();
+  for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+           targetShapeTrailingDims, memrefType.getShape(), strides))) {
     flatDim *= memrefDim;
-    if (flatDim != memrefStride || targetDim != memrefDim)
+    // If the memref stride does not match the flattened dim, then this is
+    // memref is not contiguous.
+    if (flatDim != memrefStride)
+      return false;
+
+    // If a non-matching dim was found previously, then the remaining dims of
+    // `VectorType` should be 1.
+    if (!allTrailingDimsMatch && (targetDim != 1))
       return false;
+
+    allTrailingDimsMatch = (targetDim == memrefDim);
   }
-  return true;
+
+  // If all dims of `vectorType` (excluding the leading dim) match the trailing
+  // dims `memrefType`, then this is a contiguous load. If there was a
----------------
c-rhodes wrote:

```suggestion
  // If the trailing dims of `vectorType` and `memrefType` match, then this is a contiguous load. If there was a
```

https://github.com/llvm/llvm-project/pull/73522


More information about the Mlir-commits mailing list