[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (1/N) (PR #73522)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Dec 3 08:47:40 PST 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/73522

>From 902ccc3b984b5a060052b87768003ef45870e08f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 25 Nov 2023 19:10:34 +0000
Subject: [PATCH 1/4] [mlir][Vector] Update patterns for flattening vector.xfer
 Ops

Updates "flatten vector" patterns to support more cases, namely Ops that
read/write vectors with leading unit dims. For example:

```mlir
%0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0] ... :
  memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
```

Currently, this `vector.transfer_read` would not be flattened. With this
change, it will be transformed as follows:
```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] :
  memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
  into memref<120xi8, strided<[1], offset: ?>>
%0 = vector.transfer_read %collapse_shape[%c0] ... :
  memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
%1 = vector.shape_cast %0 : vector<4xi8> to vector<1x1x2x2xi8>
```

`hasMatchingInnerContigousShape` is generalised and renamed as
`isContiguousSlice` to better match the updated functionality. A few
test names are updated to better highlight what case is being exercised.
---
 .../Transforms/VectorTransferOpTransforms.cpp | 79 ++++++++++++----
 .../Vector/vector-transfer-flatten.mlir       | 92 ++++++++++++++++---
 2 files changed, 140 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bb..c1c9659e7b1ab 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -487,26 +487,75 @@ class TransferWriteDropUnitDimsPattern
 
 } // namespace
 
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
-                                              ArrayRef<int64_t> targetShape) {
-  auto shape = memrefType.getShape();
-  SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+///
+/// There are two cases:
+///
+/// 1. The trailing dimensions of `memrefType` match the dimensions of
+/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
+/// not matter in this case):
+///
+///   vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+///   vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
+/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
+/// first dim of `vectorType` that does not match can be arbitrary, but the
+/// remaining leading dims have to be 1:
+///
+///   vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+///   vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+///
+/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
+/// TODO: Update
+static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+
+  ArrayRef<int64_t> targetShape = vectorType.getShape();
+  auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+  // Not used
   int64_t offset;
+  SmallVector<int64_t> strides;
   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
     return false;
+
+  // Non-unit stride in the trailing dimension means that this is memref is
+  // not contiguous.
   if (strides.back() != 1)
     return false;
-  strides.pop_back();
+
+  // Do all but the leading dim of `vectorType` and the trailing dims of
+  // `memrefType` match?
+  bool allTrailingDimsMatch = true;
+
+  // The trailing dimension of `memrefType` after collapsing/flattening the
+  // current dim. This will be a product of the leading dims, hence initialising
+  // to 1.
   int64_t flatDim = 1;
-  for (auto [targetDim, memrefDim, memrefStride] :
-       llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+  strides.pop_back();
+  for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+           targetShapeTrailingDims, memrefType.getShape(), strides))) {
     flatDim *= memrefDim;
-    if (flatDim != memrefStride || targetDim != memrefDim)
+    // If the memref stride does not match the flattened dim, then this is
+    // memref is not contiguous.
+    if (flatDim != memrefStride)
+      return false;
+
+    // If a non-matching dim was found, then the remaining dims of `VectorType`
+    // should be 1.
+    if (!allTrailingDimsMatch && (targetDim != 1))
       return false;
+
+    allTrailingDimsMatch = (targetDim == memrefDim);
   }
-  return true;
+
+  return allTrailingDimsMatch ? true : (targetShape[0] == 1);
 }
 
 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -568,9 +617,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!hasMatchingInnerContigousShape(
-            sourceType,
-            vectorType.getShape().take_back(vectorType.getRank() - 1)))
+    if (!isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
         sourceType.getRank() - vectorType.getRank();
@@ -628,9 +675,7 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!hasMatchingInnerContigousShape(
-            sourceType,
-            vectorType.getShape().take_back(vectorType.getRank() - 1)))
+    if (!isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
         sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae62a5ba43d05..08ce837be93ff 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 
-func.func @transfer_read_flattenable_with_offset(
+func.func @transfer_read_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
     %c0 = arith.constant 0 : index
     %cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
     return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,7 +18,44 @@ func.func @transfer_read_flattenable_with_offset(
 
 // -----
 
-func.func @transfer_write_flattenable_with_offset(
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
+func.func @transfer_read_dims_mismatch_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+    return %v : vector<1x1x2x2xi8>
+}
+
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
+// CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_contiguous(
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+    return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -26,7 +63,7 @@ func.func @transfer_write_flattenable_with_offset(
     return
 }
 
-// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +72,46 @@ func.func @transfer_write_flattenable_with_offset(
 
 // -----
 
+func.func @transfer_write_dims_mismatch_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+      vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    return
+}
+
+// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME:                                            %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME:                                            %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK:           vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @transfer_write_dims_mismatch_non_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+      vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    return
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
 func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
       vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
       return
 }
 
-// CHECK-LABEL: func @transfer_write_0d
-// CHECK-SAME:       %[[ARG:.+]]: memref<i8>
-// CHECK-SAME:       %[[VEC:.+]]: vector<i8>
-// CHECK:          vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
-// CHECK:          return
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // -----
 
@@ -54,11 +121,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
       return %0 : vector<i8>
 }
 
-// CHECK-LABEL: func @transfer_read_0d
-// CHECK-SAME:       %[[ARG:.+]]: memref<i8>
-// CHECK:            %[[CST:.+]] = arith.constant 0 : i8
-// CHECK:            %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
-// CHECK:            return %[[READ]]
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // -----
 

>From 9a3c60b07387164d39dc961f6227950a011579a9 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 28 Nov 2023 09:41:09 +0000
Subject: [PATCH 2/4] fixup! [mlir][Vector] Update patterns for flattening
 vector.xfer Ops

Update comments
---
 .../Transforms/VectorTransferOpTransforms.cpp | 57 ++++++++++++-------
 1 file changed, 35 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c1c9659e7b1ab..9f20f75bc7edb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -489,37 +489,43 @@ class TransferWriteDropUnitDimsPattern
 
 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
 ///
-/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
-/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+/// Compares `vectorType` against the trailing dimensions of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
+/// is implemented by iterating over the dims of `vectorType` and `memrefType`
+/// and comparing them starting from the inner-most/right-most dims.
 ///
-/// There are two cases:
+/// Note that there might be some restriction on the leading dim of
+/// `VectorType`:
+///   1. if all the trialing dims of `vectorType` match the trailing dims
+///     of `memrefType` then the leading dim of `vectorType` can be arbitrary:
 ///
-/// 1. The trailing dimensions of `memrefType` match the dimensions of
-/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
-/// not matter in this case):
+///       1.1 contiguous slice, perfect match
+///         vector<4x3x2xi32> from memref<5x4x3x2xi32>
+///       1.2 contiguous slice, all dims match except the leading dim: 2 != 4
+///         vector<2x3x2xi32> from memref<5x4x3x2xi32>
 ///
-///   vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
-///   vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///   2. if an "internal" dim of `vectorType` does not match the corresponding
+///     trailing dim in `memrefType` then the remaining leading dims of
+///     `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
 ///
-/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
-/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
-/// first dim of `vectorType` that does not match can be arbitrary, but the
-/// remaining leading dims have to be 1:
+///       2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
+///         vector<2x2x2xi32> from memref<5x4x3x2xi32>
+///       2.2  contiguous slice, 2 != 3 and the leading dim == <1>
+///         vector<1x2x2xi32> from memref<5x4x3x2xi32>
+///       2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
+///         vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+///       2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
+///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
 ///
-///   vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
-///   vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
-///
-/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// In all cases `memrefType` has to be contiguous (this is checked by looking
 /// at strides).
-///
-/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
-/// TODO: Update
 static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
 
+  // Get the shape of `vectorType`. The leading dim is treated seperately.
   ArrayRef<int64_t> targetShape = vectorType.getShape();
   auto targetShapeTrailingDims = targetShape.drop_front(1);
 
-  // Not used
+  // Get the strides of the memref.
   int64_t offset;
   SmallVector<int64_t> strides;
   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
@@ -538,6 +544,9 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   // current dim. This will be a product of the leading dims, hence initialising
   // to 1.
   int64_t flatDim = 1;
+
+  // Iterate overall all dim of `vectorType` excluding the leading dim and
+  // compare them against the trailing dims of `memrefType`.
   strides.pop_back();
   for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
            targetShapeTrailingDims, memrefType.getShape(), strides))) {
@@ -547,14 +556,18 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
     if (flatDim != memrefStride)
       return false;
 
-    // If a non-matching dim was found, then the remaining dims of `VectorType`
-    // should be 1.
+    // If a non-matching dim was found previously, then the remaining dims of
+    // `VectorType` should be 1.
     if (!allTrailingDimsMatch && (targetDim != 1))
       return false;
 
     allTrailingDimsMatch = (targetDim == memrefDim);
   }
 
+  // If all dims of `vectorType` (excluding the leading dim) match the trailing
+  // dims `memrefType`, then this is a contiguous load. If there was a
+  // mismatch, then the internal dims have already been verified to be unit
+  // dims, but the leading dim still has to be checked.
   return allTrailingDimsMatch ? true : (targetShape[0] == 1);
 }
 

>From e3dabd34666d9b7bf1e6ab964570649847adcbc1 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 30 Nov 2023 15:39:28 +0000
Subject: [PATCH 3/4] fixup! [mlir][Vector] Update patterns for flattening
 vector.xfer Ops

Update comments
---
 .../Transforms/VectorTransferOpTransforms.cpp | 26 +++++++++----------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 9f20f75bc7edb..3ba5de690daef 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -496,7 +496,7 @@ class TransferWriteDropUnitDimsPattern
 ///
 /// Note that there might be some restriction on the leading dim of
 /// `VectorType`:
-///   1. if all the trialing dims of `vectorType` match the trailing dims
+///   1. if all the trailing dims of `vectorType` match the trailing dims
 ///     of `memrefType` then the leading dim of `vectorType` can be arbitrary:
 ///
 ///       1.1 contiguous slice, perfect match
@@ -521,7 +521,7 @@ class TransferWriteDropUnitDimsPattern
 /// at strides).
 static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
 
-  // Get the shape of `vectorType`. The leading dim is treated seperately.
+  // Get the shape of `vectorType`. The leading dim is treated separately.
   ArrayRef<int64_t> targetShape = vectorType.getShape();
   auto targetShapeTrailingDims = targetShape.drop_front(1);
 
@@ -531,13 +531,12 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
     return false;
 
-  // Non-unit stride in the trailing dimension means that this is memref is
+  // Non-unit stride in the trailing dimension means this memref is
   // not contiguous.
   if (strides.back() != 1)
     return false;
 
-  // Do all but the leading dim of `vectorType` and the trailing dims of
-  // `memrefType` match?
+  // Do all but the leading dim of `vectorType` and `memrefType` match?
   bool allTrailingDimsMatch = true;
 
   // The trailing dimension of `memrefType` after collapsing/flattening the
@@ -545,11 +544,12 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   // to 1.
   int64_t flatDim = 1;
 
-  // Iterate overall all dim of `vectorType` excluding the leading dim and
-  // compare them against the trailing dims of `memrefType`.
+  // Iterate over all dim of `vectorType` (in reverse) excluding the leading dim
+  // and compare them against the trailing dims of `memrefType`.
   strides.pop_back();
-  for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
-           targetShapeTrailingDims, memrefType.getShape(), strides))) {
+  for (auto [targetDim, memrefDim, memrefStride] :
+       llvm::reverse(llvm::zip(targetShapeTrailingDims,
+                               memrefType.getShape().drop_front(1), strides))) {
     flatDim *= memrefDim;
     // If the memref stride does not match the flattened dim, then this is
     // memref is not contiguous.
@@ -564,10 +564,10 @@ static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
     allTrailingDimsMatch = (targetDim == memrefDim);
   }
 
-  // If all dims of `vectorType` (excluding the leading dim) match the trailing
-  // dims `memrefType`, then this is a contiguous load. If there was a
-  // mismatch, then the internal dims have already been verified to be unit
-  // dims, but the leading dim still has to be checked.
+  // If the trailing dims of `vectorType` and `memrefType` match, then this is a
+  // contiguous load. If there was a mismatch, then the internal dims have
+  // already been verified to be unit dims, but the leading dim still has to be
+  // checked.
   return allTrailingDimsMatch ? true : (targetShape[0] == 1);
 }
 

>From c0bf526a032455cc904d3f37d28a6ea3c0d84750 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 3 Dec 2023 13:09:12 +0000
Subject: [PATCH 4/4] fixup! [mlir][Vector] Update patterns for flattening
 vector.xfer Ops

Final refactor requested by Nicolas
---
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 33 +++++++
 .../Transforms/VectorTransferOpTransforms.cpp | 89 +------------------
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 44 +++++++++
 .../Vector/vector-transfer-flatten.mlir       |  6 +-
 4 files changed, 85 insertions(+), 87 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index fc00769a4aaa8..2ab456d4fdbf1 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -42,6 +42,39 @@ Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim);
 /// on a 2D slice. Otherwise, returns a failure.
 FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op);
 
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
+/// checked (the other dims are not relevant). Note that for `vectorType` to be
+/// a contiguous slice of `memrefType`, the trailing dims of the latter have
+/// to be contiguous - this is checked by looking at the corresponding strides.
+///
+/// There might be some restriction on the leading dim of `VectorType`:
+///
+/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
+///         of `memrefType` then the leading dim of `vectorType` can be
+///         arbitrary.
+///
+///        Ex. 1.1 contiguous slice, perfect match
+///          vector<4x3x2xi32> from memref<5x4x3x2xi32>
+///        Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
+///          vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///
+/// Case 2. If an "internal" dim of `vectorType` does not match the
+///         corresponding trailing dim in `memrefType` then the remaining
+///         leading dims of `vectorType` have to be 1 (the first non-matching
+///         dim can be arbitrary).
+///
+///        Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
+///          vector<2x2x2xi32> from memref<5x4x3x2xi32>
+///        Ex. 2.2  contiguous slice, 2 != 3 and the leading dim == <1>
+///          vector<1x2x2xi32> from memref<5x4x3x2xi32>
+///        Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
+///          vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+///        Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
+///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 3ba5de690daef..1e0d6af568f50 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils//IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
@@ -487,90 +488,6 @@ class TransferWriteDropUnitDimsPattern
 
 } // namespace
 
-/// Return true if `vectorType` is a contiguous slice of `memrefType`.
-///
-/// Compares `vectorType` against the trailing dimensions of `memrefType`
-/// to check whether `vectorType` is a contiguous slice of `memrefType`. This
-/// is implemented by iterating over the dims of `vectorType` and `memrefType`
-/// and comparing them starting from the inner-most/right-most dims.
-///
-/// Note that there might be some restriction on the leading dim of
-/// `VectorType`:
-///   1. if all the trailing dims of `vectorType` match the trailing dims
-///     of `memrefType` then the leading dim of `vectorType` can be arbitrary:
-///
-///       1.1 contiguous slice, perfect match
-///         vector<4x3x2xi32> from memref<5x4x3x2xi32>
-///       1.2 contiguous slice, all dims match except the leading dim: 2 != 4
-///         vector<2x3x2xi32> from memref<5x4x3x2xi32>
-///
-///   2. if an "internal" dim of `vectorType` does not match the corresponding
-///     trailing dim in `memrefType` then the remaining leading dims of
-///     `vectorType` have to be 1 (the first non-matching dim can be arbitrary):
-///
-///       2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
-///         vector<2x2x2xi32> from memref<5x4x3x2xi32>
-///       2.2  contiguous slice, 2 != 3 and the leading dim == <1>
-///         vector<1x2x2xi32> from memref<5x4x3x2xi32>
-///       2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
-///         vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
-///       2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
-///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
-///
-/// In all cases `memrefType` has to be contiguous (this is checked by looking
-/// at strides).
-static bool isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
-
-  // Get the shape of `vectorType`. The leading dim is treated separately.
-  ArrayRef<int64_t> targetShape = vectorType.getShape();
-  auto targetShapeTrailingDims = targetShape.drop_front(1);
-
-  // Get the strides of the memref.
-  int64_t offset;
-  SmallVector<int64_t> strides;
-  if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
-    return false;
-
-  // Non-unit stride in the trailing dimension means this memref is
-  // not contiguous.
-  if (strides.back() != 1)
-    return false;
-
-  // Do all but the leading dim of `vectorType` and `memrefType` match?
-  bool allTrailingDimsMatch = true;
-
-  // The trailing dimension of `memrefType` after collapsing/flattening the
-  // current dim. This will be a product of the leading dims, hence initialising
-  // to 1.
-  int64_t flatDim = 1;
-
-  // Iterate over all dim of `vectorType` (in reverse) excluding the leading dim
-  // and compare them against the trailing dims of `memrefType`.
-  strides.pop_back();
-  for (auto [targetDim, memrefDim, memrefStride] :
-       llvm::reverse(llvm::zip(targetShapeTrailingDims,
-                               memrefType.getShape().drop_front(1), strides))) {
-    flatDim *= memrefDim;
-    // If the memref stride does not match the flattened dim, then this is
-    // memref is not contiguous.
-    if (flatDim != memrefStride)
-      return false;
-
-    // If a non-matching dim was found previously, then the remaining dims of
-    // `VectorType` should be 1.
-    if (!allTrailingDimsMatch && (targetDim != 1))
-      return false;
-
-    allTrailingDimsMatch = (targetDim == memrefDim);
-  }
-
-  // If the trailing dims of `vectorType` and `memrefType` match, then this is a
-  // contiguous load. If there was a mismatch, then the internal dims have
-  // already been verified to be unit dims, but the leading dim still has to be
-  // checked.
-  return allTrailingDimsMatch ? true : (targetShape[0] == 1);
-}
-
 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
 /// input starting at `firstDimToCollapse`.
 static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
@@ -630,7 +547,7 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!isContiguousSlice(sourceType, vectorType))
+    if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
         sourceType.getRank() - vectorType.getRank();
@@ -688,7 +605,7 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!isContiguousSlice(sourceType, vectorType))
+    if (!vector::isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
         sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 48cd67ad86c63..ac0fe64c70cd6 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -249,3 +249,47 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
   // between parallel, reduction and possibly other cases.
   return ratio.has_value();
 }
+
+bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
+  if (vectorType.isScalable())
+    return false;
+
+  ArrayRef<int64_t> vectorShape = vectorType.getShape();
+  auto vecRank = vectorType.getRank();
+
+  // Extract the trailing dims and strides of the input memref
+  auto memrefShape = memrefType.getShape().take_back(vecRank);
+  int64_t offset;
+  SmallVector<int64_t> stridesFull;
+  if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
+    return false;
+  auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
+
+  // Cond 1: A contiguous memref will always have a unit trailing stride.
+  if (strides.back() != 1)
+    return false;
+
+  // Cond 2: Strides of a contiguous memref have to match the flattened dims.
+  strides = strides.drop_back(1);
+  SmallVector<int64_t> flattenedDims;
+  for (size_t i = 1; i < memrefShape.size(); i++)
+    flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+
+  if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
+    return false;
+
+  // Cond 3: Compare the dims of `vectorType` against `memrefType` (in reverse).
+  // In the most basic case, all dims will match.
+  auto firstNonMatchingDim =
+      std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
+                    memrefShape.rbegin(), memrefShape.rend());
+  if (firstNonMatchingDim.first == vectorShape.rend())
+    return true;
+
+  // One non-matching dim is still fine, however the remaining leading dims of
+  // `vectorType` need to be 1.
+  SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
+                                   vectorShape.rend());
+
+  return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 08ce837be93ff..2ffe85bf3bfa6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -41,7 +41,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
 
 // -----
 
-func.func @transfer_read_dims_mismatch_contiguous(
+func.func @transfer_read_dims_mismatch_non_contiguous(
     %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
     %c0 = arith.constant 0 : index
     %cst = arith.constant 0 : i8
@@ -50,6 +50,7 @@ func.func @transfer_read_dims_mismatch_contiguous(
     return %v : vector<2x1x2x2xi8>
 }
 
+// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
@@ -100,6 +101,7 @@ func.func @transfer_write_dims_mismatch_non_contiguous(
     return
 }
 
+// CHECK-LABEL: func.func @transfer_write_dims_mismatch_non_contiguous
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
@@ -110,6 +112,7 @@ func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
       return
 }
 
+// CHECK-LABEL: func.func @transfer_write_0d
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 
@@ -121,6 +124,7 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
       return %0 : vector<i8>
 }
 
+// CHECK-LABEL: func.func @transfer_read_0d
 // CHECK-NOT: memref.collapse_shape
 // CHECK-NOT: vector.shape_cast
 



More information about the Mlir-commits mailing list