[Mlir-commits] [mlir] [mlir][vector] Drop innermost unit dims on transfer_write. (PR #78554)

Han-Chung Wang llvmlistbot at llvm.org
Fri Jan 19 01:53:07 PST 2024


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/78554

>From c9343bd38df632b98a7beabb461e8d7036fa3d1c Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 18 Jan 2024 09:10:51 +0000
Subject: [PATCH 1/3] [mlir][vector] Drop innermost unit dims on
 transfer_write.

The revision renames DropInnerMostUnitDims to
DropInnerMostUnitDimsTransferRead; adds support for
vector.transfer_write.

It refactors common methods (i.e., getTransferFoldableInnerUnitDims and
getMemRefTypeWithDroppingInnerDims) and uses them in both patterns.
---
 .../Vector/Transforms/VectorTransforms.cpp    | 197 +++++++++++++-----
 ...tor-transfer-collapse-inner-most-dims.mlir |  21 ++
 2 files changed, 164 insertions(+), 54 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bd02c07981466d9..7c276ca8101221b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1152,8 +1152,71 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
   }
 };
 
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if strides and offsets can not be resolved.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+  SmallVector<int64_t> srcStrides;
+  int64_t srcOffset;
+  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+    return failure();
+
+  // According to vector.transfer_read/write semantics, the vector can be a
+  // slice. It pads the indices with `1` starting from beginning. Thus, we have
+  // to offset the check index with `rankDiff` in `srcStrides` and source dim
+  // sizes.
+  size_t result = 0;
+  int rankDiff = srcType.getRank() - vectorType.getRank();
+  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+    // Check that the inner dim size is 1 for both memref/tensor type and
+    // vector slice. It can be folded only if they are 1 and the stride is 1.
+    int dim = vectorType.getRank() - i - 1;
+    if (srcStrides[dim + rankDiff] == 1 &&
+        srcType.getDimSize(dim + rankDiff) == 1 &&
+        vectorType.getDimSize(dim) == 1) {
+      result++;
+    } else {
+      break;
+    }
+  }
+  return result;
+}
+
+/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
+/// `srcType`.
+static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
+                                                     MemRefType srcType,
+                                                     size_t dimsToDrop) {
+  MemRefType resultMemrefType;
+  MemRefLayoutAttrInterface layout = srcType.getLayout();
+  if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
+    return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+                           srcType.getElementType(), nullptr,
+                           srcType.getMemorySpace());
+  }
+  MemRefLayoutAttrInterface updatedLayout;
+  if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
+    auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+    updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+                                           strided.getOffset(), strides);
+  } else {
+    AffineMap map = srcType.getLayout().getAffineMap();
+    int numSymbols = map.getNumSymbols();
+    for (size_t i = 0; i < dimsToDrop; ++i) {
+      int dim = srcType.getRank() - i - 1;
+      map = map.replace(builder.getAffineDimExpr(dim),
+                        builder.getAffineConstantExpr(0), map.getNumDims() - 1,
+                        numSymbols);
+    }
+  }
+  return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+                         srcType.getElementType(), updatedLayout,
+                         srcType.getMemorySpace());
+}
+
+/// Drop inner most contiguous unit dimensions from transfer_read operand.
+class DropInnerMostUnitDimsTransferRead
+    : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
@@ -1177,29 +1240,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
     if (targetType.getRank() <= 1)
       return failure();
 
-    SmallVector<int64_t> srcStrides;
-    int64_t srcOffset;
-    if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
-      return failure();
-
-    // According to vector.transfer_read semantics, the result can be a slice.
-    // It pads the indices with `1` starting from beginning. Thus, we have to
-    // offset the check index with `rankDiff` in `srcStrides` and source dim
-    // sizes.
-    size_t dimsToDrop = 0;
-    int rankDiff = srcType.getRank() - targetType.getRank();
-    for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
-      // Check that the inner dim size is 1 for both memref/tensor type and
-      // vector slice. It can be folded only if they are 1 and the stride is 1.
-      int dim = targetType.getRank() - i - 1;
-      if (srcStrides[dim + rankDiff] == 1 &&
-          srcType.getDimSize(dim + rankDiff) == 1 &&
-          targetType.getDimSize(dim) == 1) {
-        dimsToDrop++;
-      } else {
-        break;
-      }
-    }
+    FailureOr<size_t> maybeDimsToDrop =
+        getTransferFoldableInnerUnitDims(srcType, targetType);
+    if (failed(maybeDimsToDrop))
+      return failure();
+
+    size_t dimsToDrop = maybeDimsToDrop.value();
     if (dimsToDrop == 0)
       return failure();
 
@@ -1207,35 +1253,9 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
                         targetType.getElementType());
 
-    MemRefType resultMemrefType;
-    MemRefLayoutAttrInterface layout = srcType.getLayout();
-    if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
-      resultMemrefType = MemRefType::get(
-          srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
-          nullptr, srcType.getMemorySpace());
-    } else {
-      MemRefLayoutAttrInterface updatedLayout;
-      if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
-        auto strides =
-            llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
-        updatedLayout = StridedLayoutAttr::get(strided.getContext(),
-                                               strided.getOffset(), strides);
-      } else {
-        AffineMap map = srcType.getLayout().getAffineMap();
-        int numSymbols = map.getNumSymbols();
-        for (size_t i = 0; i < dimsToDrop; ++i) {
-          int dim = srcType.getRank() - i - 1;
-          map = map.replace(rewriter.getAffineDimExpr(dim),
-                            rewriter.getAffineConstantExpr(0),
-                            map.getNumDims() - 1, numSymbols);
-        }
-      }
-      resultMemrefType = MemRefType::get(
-          srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
-          updatedLayout, srcType.getMemorySpace());
-    }
-
     auto loc = readOp.getLoc();
+    MemRefType resultMemrefType =
+        getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
     SmallVector<int64_t> offsets(srcType.getRank(), 0);
     SmallVector<int64_t> strides(srcType.getRank(), 1);
 
@@ -1261,6 +1281,73 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
+/// Drop inner most contiguous unit dimensions from transfer_write operand.
+class DropInnerMostUnitDimsTransferWrite
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (writeOp.getTransferRank() == 0)
+      return failure();
+
+    // TODO: support mask.
+    if (writeOp.getMask())
+      return failure();
+
+    auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
+    if (!srcType || !srcType.hasStaticShape())
+      return failure();
+
+    if (!writeOp.getPermutationMap().isMinorIdentity())
+      return failure();
+
+    auto targetType = writeOp.getVectorType();
+    if (targetType.getRank() <= 1)
+      return failure();
+
+    FailureOr<size_t> maybeDimsToDrop =
+        getTransferFoldableInnerUnitDims(srcType, targetType);
+    if (failed(maybeDimsToDrop))
+      return failure();
+
+    size_t dimsToDrop = maybeDimsToDrop.value();
+    if (dimsToDrop == 0)
+      return failure();
+
+    auto resultTargetVecType =
+        VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+                        targetType.getElementType());
+
+    auto loc = writeOp.getLoc();
+    MemRefType resultMemrefType =
+        getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+    SmallVector<int64_t> offsets(srcType.getRank(), 0);
+    SmallVector<int64_t> strides(srcType.getRank(), 1);
+
+    ArrayAttr inBoundsAttr =
+        writeOp.getInBounds()
+            ? rewriter.getArrayAttr(
+                  writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
+            : ArrayAttr();
+    Value rankedReducedView = rewriter.create<memref::SubViewOp>(
+        loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
+        strides);
+    auto permMap = getTransferMinorIdentityMap(
+        cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+        loc, resultTargetVecType, writeOp.getVector());
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        writeOp, shapeCast, rankedReducedView,
+        writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+        // TODO: support mask.
+        /*mask=*/Value(), inBoundsAttr);
+    return success();
+  }
+};
+
 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
 /// with the RHS transposed) lowering.
@@ -1696,7 +1783,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
 void mlir::vector::
     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
         RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
+  patterns.add<DropInnerMostUnitDimsTransferRead,
+               DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
+                                                   benefit);
 }
 
 void mlir::vector::populateSinkVectorBroadcastPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 0d2743b9fe2e7f5..59116c19b46ec23 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -76,3 +76,24 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
 //  CHECK-NOT:   memref.subview
 //      CHECK:   %[[READ:.+]] = vector.transfer_read %[[SRC]]
 //      CHECK:   return %[[READ]] : vector<4x8xf32>
+
+// -----
+
+func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
+    {in_bounds = [true, true, true, true]}
+    : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
+  return
+}
+// CHECK:      func.func @drop_inner_most_dim_for_transfer_write
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME:     [0, 0, 0, 0] [1, 512, 16, 1] [1, 1, 1, 1]
+// CHECK-SAME:     memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
+// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME:     [%[[C0]], %[[IDX]], %[[C0]]]

>From 58e6571395ff88ecbb9daf7f614644b57f7df775 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 19 Jan 2024 08:41:56 +0000
Subject: [PATCH 2/3] improve comments and tests

---
 .../Vector/Transforms/VectorTransforms.cpp    | 56 +++++++++++++------
 ...tor-transfer-collapse-inner-most-dims.mlir | 34 ++++++++++-
 2 files changed, 73 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7c276ca8101221b..21d855528fc07d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1153,7 +1153,12 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
 };
 
 /// Returns the number of dims can be folded away from transfer ops. It returns
-/// a failure if strides and offsets can not be resolved.
+/// a failure if it can not determine the number of dims to be folded.
+/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
+/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
+/// can be dropped by memref.subview ops.
+/// Example 2: it returns "1" if `srcType` is the same memref type with
+/// [8192, 16, 8, 1] strides.
 static FailureOr<size_t>
 getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
   SmallVector<int64_t> srcStrides;
@@ -1162,14 +1167,13 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
     return failure();
 
   // According to vector.transfer_read/write semantics, the vector can be a
-  // slice. It pads the indices with `1` starting from beginning. Thus, we have
-  // to offset the check index with `rankDiff` in `srcStrides` and source dim
-  // sizes.
+  // slice. Thus, we have to offset the check index with `rankDiff` in
+  // `srcStrides` and source dim sizes.
   size_t result = 0;
   int rankDiff = srcType.getRank() - vectorType.getRank();
   for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
-    // Check that the inner dim size is 1 for both memref/tensor type and
-    // vector slice. It can be folded only if they are 1 and the stride is 1.
+    // Check that the inner dim size is 1 for both memref type and  vector
+    // slice. It can be folded only if they are 1 and the stride is 1.
     int dim = vectorType.getRank() - i - 1;
     if (srcStrides[dim + rankDiff] == 1 &&
         srcType.getDimSize(dim + rankDiff) == 1 &&
@@ -1183,7 +1187,8 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
 }
 
 /// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
-/// `srcType`.
+/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
+/// two, it returns memref<512x16x16> type.
 static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
                                                      MemRefType srcType,
                                                      size_t dimsToDrop) {
@@ -1199,15 +1204,19 @@ static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
     auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
     updatedLayout = StridedLayoutAttr::get(strided.getContext(),
                                            strided.getOffset(), strides);
-  } else {
-    AffineMap map = srcType.getLayout().getAffineMap();
-    int numSymbols = map.getNumSymbols();
-    for (size_t i = 0; i < dimsToDrop; ++i) {
-      int dim = srcType.getRank() - i - 1;
-      map = map.replace(builder.getAffineDimExpr(dim),
-                        builder.getAffineConstantExpr(0), map.getNumDims() - 1,
-                        numSymbols);
-    }
+    return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+                           srcType.getElementType(), updatedLayout,
+                           srcType.getMemorySpace());
+  }
+
+  // Non-strided layout case.
+  AffineMap map = srcType.getLayout().getAffineMap();
+  int numSymbols = map.getNumSymbols();
+  for (size_t i = 0; i < dimsToDrop; ++i) {
+    int dim = srcType.getRank() - i - 1;
+    map = map.replace(builder.getAffineDimExpr(dim),
+                      builder.getAffineConstantExpr(0), map.getNumDims() - 1,
+                      numSymbols);
   }
   return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
                          srcType.getElementType(), updatedLayout,
@@ -1282,6 +1291,21 @@ class DropInnerMostUnitDimsTransferRead
 };
 
 /// Drop inner most contiguous unit dimensions from transfer_write operand.
+/// E.g.,
+///    vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+///      {in_bounds = [true, true, true, true, true]}
+///      : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+///
+/// will be replaced with
+///
+///    %subview = memref.subview %arg0
+///      [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
+///      : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+///    %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
+///      to vector<1x16x16xf32>
+///    vector.transfer_write %0, %subview[%c0, %arg2, %c0]
+///      {in_bounds = [true, true, true]}
+///      : vector<1x16x16xf32>, memref<1x512x16xf32>
 class DropInnerMostUnitDimsTransferWrite
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 59116c19b46ec23..d6d69c8af88508d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -79,6 +79,26 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
 
 // -----
 
+func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+    {in_bounds = [true, true, true, true, true]}
+    : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+  return
+}
+// CHECK:      func.func @drop_two_inner_most_dim_for_transfer_write
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME:     memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1x1xf32> to vector<1x16x16xf32>
+// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME:     [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
 func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
@@ -92,8 +112,20 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
 // CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
 // CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
 // CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
-// CHECK-SAME:     [0, 0, 0, 0] [1, 512, 16, 1] [1, 1, 1, 1]
 // CHECK-SAME:     memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
 // CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
 // CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
 // CHECK-SAME:     [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
+func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
+    {in_bounds = [true, true, true]}
+    : vector<16x16x1xf32>, memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>
+  return
+}
+// The inner most unit dims can not be dropped if the strides are not ones.
+// CHECK:     func.func @non_unit_strides
+// CHECK-NOT:   memref.subview

>From feda905282482401a454a0f07d08b76a36bdc7a1 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 19 Jan 2024 09:52:50 +0000
Subject: [PATCH 3/3] address comments!

---
 .../Vector/Transforms/VectorTransforms.cpp     | 18 ++++++++----------
 1 file changed, 8 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 21d855528fc07d9..7d5f4d471e89bff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1172,16 +1172,14 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
   size_t result = 0;
   int rankDiff = srcType.getRank() - vectorType.getRank();
   for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
-    // Check that the inner dim size is 1 for both memref type and  vector
-    // slice. It can be folded only if they are 1 and the stride is 1.
+    // Check that the inner dim size is 1 for both memref type and vector slice.
+    // It can be folded only if they are 1 and the stride is 1.
     int dim = vectorType.getRank() - i - 1;
-    if (srcStrides[dim + rankDiff] == 1 &&
-        srcType.getDimSize(dim + rankDiff) == 1 &&
-        vectorType.getDimSize(dim) == 1) {
-      result++;
-    } else {
+    if (srcStrides[dim + rankDiff] != 1 ||
+        srcType.getDimSize(dim + rankDiff) != 1 ||
+        vectorType.getDimSize(dim) == 1)
       break;
-    }
+    result++;
   }
   return result;
 }
@@ -1344,17 +1342,17 @@ class DropInnerMostUnitDimsTransferWrite
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
                         targetType.getElementType());
 
-    auto loc = writeOp.getLoc();
     MemRefType resultMemrefType =
         getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
     SmallVector<int64_t> offsets(srcType.getRank(), 0);
     SmallVector<int64_t> strides(srcType.getRank(), 1);
-
     ArrayAttr inBoundsAttr =
         writeOp.getInBounds()
             ? rewriter.getArrayAttr(
                   writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
             : ArrayAttr();
+
+    Location loc = writeOp.getLoc();
     Value rankedReducedView = rewriter.create<memref::SubViewOp>(
         loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
         strides);



More information about the Mlir-commits mailing list