[Mlir-commits] [mlir] [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N) (PR #73523)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 06:50:02 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Updates patterns for flattening `vector.transfer_read` by relaxing the
requirement that the "collapsed" indices are all zero. This enables
collapsing cases like this one:

```mlir
  %2 = vector.transfer_read %arg4[%c0, %arg0, %arg1, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

Previously only the following case would be consider for collapsing:

```mlir
  %2 = vector.transfer_read %arg4[%c0, %c0, %c0, %c0] ... :
    memref<1x43x4x6xi32>, vector<1x2x6xi32>
```

The pattern itself, `FlattenContiguousRowMajorTransferReadPattern`, was
a bit refactored too:
  * added comments,
  * renamed `firstContiguousInnerDim` as `firstDimToCollapse` (the
    latter better matches the meaning and is already consistently used
    in various helper methods that use it),

Similar update for `vector.transfer_write` will be implemented in a
follow-up patch.


---
Full diff: https://github.com/llvm/llvm-project/pull/73523.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+125-27) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+110-14) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..951a378b84cf0e0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -487,26 +487,76 @@ class TransferWriteDropUnitDimsPattern
 
 } // namespace
 
-/// Return true if the memref type has its inner dimension matching the given
-/// shape. Otherwise return false.
-static int64_t hasMatchingInnerContigousShape(MemRefType memrefType,
-                                              ArrayRef<int64_t> targetShape) {
-  auto shape = memrefType.getShape();
-  SmallVector<int64_t> strides;
+/// Return true if `vectorType` is a contiguous slice of `memrefType`.
+///
+/// Compares `vectorType` against the trailing dimensions (*) of `memrefType`
+/// to check whether `vectorType` is a contiguous slice of `memrefType`.
+///
+/// There are two cases:
+///
+/// 1. The trailing dimensions of `memrefType` match the dimensions of
+/// `vectorType` excluding the front dim (the leading dim of `vectorType` does
+/// not matter in this case):
+///
+///   vector<2x4x3x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+///   vector<2x4x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// 2. The trailing dimension of `memrefType` match the trailing dimensions of
+/// `vectorType` (i.e. at least 2 leading dims of `vectorType` don't match). The
+/// first dim of `vectorType` that does not match can be arbitrary, but the
+/// remaining leading dims have to be 1:
+///
+///   vector<1x1x2x2xi32> vs memref<5x4x3x2xi32> (contiguous slice)
+///   vector<2x1x2x2xi32> vs memref<5x4x3x2xi32> (non-contiguous slice)
+///
+/// In both cases `memrefType` has to be contiguous (this is checked by looking
+/// at strides).
+///
+/// (*) Only relevant in cases when the rank(vectorType) < rank(memrefType)
+/// TODO: Update
+static bool isContiguousSlice(MemRefType memrefType,
+                              VectorType vectorType) {
+
+  ArrayRef<int64_t> targetShape = vectorType.getShape();
+  auto targetShapeTrailingDims = targetShape.drop_front(1);
+
+  // Not used
   int64_t offset;
+  SmallVector<int64_t> strides;
   if (!succeeded(getStridesAndOffset(memrefType, strides, offset)))
     return false;
+
+  // Non-unit stride in the trailing dimension means that this is memref is
+  // not contiguous.
   if (strides.back() != 1)
     return false;
-  strides.pop_back();
+
+  // Do all but the leading dim of `vectorType` and the trailing dims of
+  // `memrefType` match?
+  bool allTrailingDimsMatch = true;
+
+  // The trailing dimension of `memrefType` after collapsing/flattening the
+  // current dim. This will be a product of the leading dims, hence initialising
+  // to 1.
   int64_t flatDim = 1;
-  for (auto [targetDim, memrefDim, memrefStride] :
-       llvm::reverse(llvm::zip(targetShape, shape, strides))) {
+  strides.pop_back();
+  for (auto [targetDim, memrefDim, memrefStride] : llvm::reverse(llvm::zip(
+           targetShapeTrailingDims, memrefType.getShape(), strides))) {
     flatDim *= memrefDim;
-    if (flatDim != memrefStride || targetDim != memrefDim)
+    // If the memref stride does not match the flattened dim, then this is
+    // memref is not contiguous.
+    if (flatDim != memrefStride)
+      return false;
+
+    // If a non-matching dim was found, then the remaining dims of `VectorType`
+    // should be 1.
+    if (!allTrailingDimsMatch && (targetDim != 1))
       return false;
+
+    allTrailingDimsMatch = (targetDim == memrefDim);
   }
-  return true;
+
+  return allTrailingDimsMatch ? true : (targetShape[0] == 1);
 }
 
 /// Creates a memref.collapse_shape collapsing all inner dimensions of the
@@ -529,6 +579,8 @@ static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
 /// Checks that the indices corresponding to dimensions starting at
 /// `firstDimToCollapse` are constant 0, and writes to `outIndices`
 /// the truncated indices where `firstDimToCollapse` is now the innermost dim.
+/// TODO: Extract the logic that writes to outIndices so that this method
+/// simply checks one pre-condition.
 static LogicalResult
 checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse,
                                  SmallVector<Value> &outIndices) {
@@ -562,18 +614,16 @@ class FlattenContiguousRowMajorTransferReadPattern
     VectorType vectorType = cast<VectorType>(vector.getType());
     Value source = transferReadOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
+
+    // 0. Check pre-conditions
     // Contiguity check is valid on tensors only.
     if (!sourceType)
       return failure();
+    // If this is already 0D/1D, there's nothing to do.
     if (vectorType.getRank() <= 1)
-      // Already 0D/1D, nothing to do.
       return failure();
-    if (!hasMatchingInnerContigousShape(
-            sourceType,
-            vectorType.getShape().take_back(vectorType.getRank() - 1)))
+    if (!isContiguousSlice(sourceType, vectorType))
       return failure();
-    int64_t firstContiguousInnerDim =
-        sourceType.getRank() - vectorType.getRank();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferReadOp.hasOutOfBoundsDim())
       return failure();
@@ -581,26 +631,76 @@ class FlattenContiguousRowMajorTransferReadPattern
       return failure();
     if (transferReadOp.getMask())
       return failure();
+
     SmallVector<Value> collapsedIndices;
-    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
-                                                firstContiguousInnerDim,
-                                                collapsedIndices)))
-      return failure();
+    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+
+    // 1. Collapse the source memref
     Value collapsedSource =
-        collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim);
+        collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
     MemRefType collapsedSourceType =
         dyn_cast<MemRefType>(collapsedSource.getType());
     int64_t collapsedRank = collapsedSourceType.getRank();
-    assert(collapsedRank == firstContiguousInnerDim + 1);
+    assert(collapsedRank == firstDimToCollapse + 1);
+
+    // 2. Generate input args for a new vector.transfer_read that will read
+    // from the collapsed memref.
+    // 2.1. New dim exprs + affine map
     SmallVector<AffineExpr, 1> dimExprs{
-        getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())};
+        getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
     auto collapsedMap =
         AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
+
+    // 2.2 New indices
+    // If all the collapsed indices are zero then no extra logic is needed.
+    // Otherwise, a new offset/index has to be computed.
+    if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(),
+                                                firstDimToCollapse,
+                                                collapsedIndices))) {
+      // Copy all the leading indices
+      collapsedIndices = transferReadOp.getIndices();
+      collapsedIndices.resize(firstDimToCollapse);
+
+      // Compute the remaining trailing index/offset required for reading from
+      // the collapsed memref:
+      //
+      //    offset = 0
+      //    for (i = firstDimToCollapse; i < outputRank; ++i)
+      //      offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
+      //
+      // For this example:
+      //   %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
+      //   memref<1x43x2xi32>, vector<1x2xi32>
+      // which would be collapsed to:
+      //   %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
+      //   memref<1x86xi32>, vector<2xi32>
+      // one would get the following offset:
+      //    %offset = %arg0 * 43
+      int64_t outputRank = transferReadOp.getIndices().size();
+      Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+      for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
+        Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+        auto sourceDimSize =
+            rewriter.create<memref::DimOp>(loc, source, dimIdx);
+
+        offset = rewriter.create<arith::AddIOp>(
+            loc,
+            rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
+                                           sourceDimSize),
+            offset);
+      }
+      collapsedIndices.push_back(offset);
+    }
+
+    // 3. Create new vector.transfer_read that reads from the collapsed memref
     VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
                                                 vectorType.getElementType());
     vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
         loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
     flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
+
+    // 4. Replace the old transfer_read with the new one reading from the
+    // collapsed shape
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
         transferReadOp, cast<VectorType>(vector.getType()), flatRead);
     return success();
@@ -628,9 +728,7 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (vectorType.getRank() <= 1)
       // Already 0D/1D, nothing to do.
       return failure();
-    if (!hasMatchingInnerContigousShape(
-            sourceType,
-            vectorType.getShape().take_back(vectorType.getRank() - 1)))
+    if (!isContiguousSlice(sourceType, vectorType))
       return failure();
     int64_t firstContiguousInnerDim =
         sourceType.getRank() - vectorType.getRank();
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index ae62a5ba43d055a..8369069e31ab7c6 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s
 
-func.func @transfer_read_flattenable_with_offset(
+func.func @transfer_read_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
     %c0 = arith.constant 0 : index
     %cst = arith.constant 0 : i8
@@ -9,7 +9,7 @@ func.func @transfer_read_flattenable_with_offset(
     return %v : vector<5x4x3x2xi8>
 }
 
-// CHECK-LABEL: func @transfer_read_flattenable_with_offset
+// CHECK-LABEL: func @transfer_read_dims_match_contiguous
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]
 // CHECK:         %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]]
@@ -18,7 +18,76 @@ func.func @transfer_read_flattenable_with_offset(
 
 // -----
 
-func.func @transfer_write_flattenable_with_offset(
+// The shape of the memref and the vector don't match, but the vector is a
+// contiguous subset of the memref, so "flattenable".
+
+func.func @transfer_read_dims_mismatch_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<1x1x2x2xi8>
+    return %v : vector<1x1x2x2xi8>
+}
+
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_contiguous(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
+// CHECK:           %[[VAL_1:.*]] = arith.constant 0 : i8
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<120xi8, strided<[1], offset: ?>>, vector<4xi8>
+// CHECK:           %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
+// CHECK:           return %[[VAL_5]] : vector<1x1x2x2xi8>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_non_zero_indices(
+                     %idx_1: index,
+                     %idx_2: index,
+                     %m_in: memref<1x43x4x6xi32>,
+                     %m_out: memref<1x2x6xi32>) {
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : 
+    memref<1x43x4x6xi32>, vector<1x2x6xi32>
+  vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} :
+    vector<1x2x6xi32>, memref<1x2x6xi32>
+  return
+}
+
+// CHECK-LABEL:   func.func @transfer_read_dims_mismatch_non_zero_indices(
+// CHECK-SAME:      %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
+// CHECK-SAME:      %[[VAL_2:.*]]: memref<1x43x4x6xi32>,
+// CHECK-SAME:      %[[VAL_3:.*]]: memref<1x2x6xi32>) {
+// CHECK:           %[[VAL_4:.*]] = arith.constant 43 : index
+// CHECK:           %[[VAL_5:.*]] = arith.constant 4 : index
+// CHECK:           %[[VAL_6:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.muli %[[VAL_0]], %[[VAL_4]] : index
+// CHECK:           %[[VAL_10:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
+// CHECK:           %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_9]] : index
+// CHECK:           %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_7]], %[[VAL_11]]], %[[VAL_6]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK:           %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
+// CHECK:           vector.transfer_write %[[VAL_12]], %[[VAL_13]]{{\[}}%[[VAL_7]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
+
+// -----
+
+func.func @transfer_read_dims_mismatch_contiguous(
+    %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<2x1x2x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+      memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<2x1x2x2xi8>
+    return %v : vector<2x1x2x2xi8>
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
+func.func @transfer_write_dims_match_contiguous(
       %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<5x4x3x2xi8>) {
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
@@ -26,7 +95,7 @@ func.func @transfer_write_flattenable_with_offset(
     return
 }
 
-// CHECK-LABEL: func @transfer_write_flattenable_with_offset
+// CHECK-LABEL: func @transfer_write_dims_match_contiguous
 // CHECK-SAME:      %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8
 // CHECK-SAME:      %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8>
 // CHECK-DAG:     %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}>
@@ -35,16 +104,46 @@ func.func @transfer_write_flattenable_with_offset(
 
 // -----
 
+func.func @transfer_write_dims_mismatch_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<1x1x2x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+      vector<1x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    return
+}
+
+// CHECK-LABEL:   func.func @transfer_write_dims_mismatch_contiguous
+// CHECK-SAME:                                            %[[VAL_0:.*]]: memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>,
+// CHECK-SAME:                                            %[[VAL_1:.*]]: vector<1x1x2x2xi8>) {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> into memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           %[[VAL_4:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x1x2x2xi8> to vector<4xi8>
+// CHECK:           vector.transfer_write %[[VAL_4]], %[[VAL_3]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<4xi8>, memref<120xi8, strided<[1], offset: ?>>
+// CHECK:           return
+// CHECK:         }
+
+// -----
+
+func.func @transfer_write_dims_mismatch_non_contiguous(
+      %arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, %vec : vector<2x1x2x2xi8>) {
+    %c0 = arith.constant 0 : index
+    vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+      vector<2x1x2x2xi8>, memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
+    return
+}
+
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
+
+// -----
+
 func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) {
       vector.transfer_write %vec, %arg[] : vector<i8>, memref<i8>
       return
 }
 
-// CHECK-LABEL: func @transfer_write_0d
-// CHECK-SAME:       %[[ARG:.+]]: memref<i8>
-// CHECK-SAME:       %[[VEC:.+]]: vector<i8>
-// CHECK:          vector.transfer_write %[[VEC]], %[[ARG]][] : vector<i8>, memref<i8>
-// CHECK:          return
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // -----
 
@@ -54,11 +153,8 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> {
       return %0 : vector<i8>
 }
 
-// CHECK-LABEL: func @transfer_read_0d
-// CHECK-SAME:       %[[ARG:.+]]: memref<i8>
-// CHECK:            %[[CST:.+]] = arith.constant 0 : i8
-// CHECK:            %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref<i8>
-// CHECK:            return %[[READ]]
+// CHECK-NOT: memref.collapse_shape
+// CHECK-NOT: vector.shape_cast
 
 // -----
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/73523


More information about the Mlir-commits mailing list