[Mlir-commits] [mlir] 12b676d - [mlir][vector] Drop innermost unit dims on transfer_write. (#78554)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 19 03:15:18 PST 2024


Author: Han-Chung Wang
Date: 2024-01-19T03:15:13-08:00
New Revision: 12b676de728ee9046ac5fea49e27b9bf1cde4a70

URL: https://github.com/llvm/llvm-project/commit/12b676de728ee9046ac5fea49e27b9bf1cde4a70
DIFF: https://github.com/llvm/llvm-project/commit/12b676de728ee9046ac5fea49e27b9bf1cde4a70.diff

LOG: [mlir][vector] Drop innermost unit dims on transfer_write. (#78554)

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
    mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bd02c07981466d9..9c734e8cc2ad11a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1152,8 +1152,78 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
   }
 };
 
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if it can not determine the number of dims to be folded.
+/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
+/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
+/// can be dropped by memref.subview ops.
+/// Example 2: it returns "1" if `srcType` is the same memref type with
+/// [8192, 16, 8, 1] strides.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+  SmallVector<int64_t> srcStrides;
+  int64_t srcOffset;
+  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+    return failure();
+
+  // According to vector.transfer_read/write semantics, the vector can be a
+  // slice. Thus, we have to offset the check index with `rankDiff` in
+  // `srcStrides` and source dim sizes.
+  size_t result = 0;
+  int rankDiff = srcType.getRank() - vectorType.getRank();
+  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+    // Check that the inner dim size is 1 for both memref type and vector slice.
+    // It can be folded only if they are 1 and the stride is 1.
+    int dim = vectorType.getRank() - i - 1;
+    if (srcStrides[dim + rankDiff] != 1 ||
+        srcType.getDimSize(dim + rankDiff) != 1 ||
+        vectorType.getDimSize(dim) != 1)
+      break;
+    result++;
+  }
+  return result;
+}
+
+/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
+/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
+/// two, it returns memref<512x16x16> type.
+static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
+                                                     MemRefType srcType,
+                                                     size_t dimsToDrop) {
+  MemRefType resultMemrefType;
+  MemRefLayoutAttrInterface layout = srcType.getLayout();
+  if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
+    return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+                           srcType.getElementType(), nullptr,
+                           srcType.getMemorySpace());
+  }
+  MemRefLayoutAttrInterface updatedLayout;
+  if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
+    auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+    updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+                                           strided.getOffset(), strides);
+    return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+                           srcType.getElementType(), updatedLayout,
+                           srcType.getMemorySpace());
+  }
+
+  // Non-strided layout case.
+  AffineMap map = srcType.getLayout().getAffineMap();
+  int numSymbols = map.getNumSymbols();
+  for (size_t i = 0; i < dimsToDrop; ++i) {
+    int dim = srcType.getRank() - i - 1;
+    map = map.replace(builder.getAffineDimExpr(dim),
+                      builder.getAffineConstantExpr(0), map.getNumDims() - 1,
+                      numSymbols);
+  }
+  return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+                         srcType.getElementType(), updatedLayout,
+                         srcType.getMemorySpace());
+}
+
+/// Drop inner most contiguous unit dimensions from transfer_read operand.
+class DropInnerMostUnitDimsTransferRead
+    : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
@@ -1177,29 +1247,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
     if (targetType.getRank() <= 1)
       return failure();
 
-    SmallVector<int64_t> srcStrides;
-    int64_t srcOffset;
-    if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
-      return failure();
-
-    // According to vector.transfer_read semantics, the result can be a slice.
-    // It pads the indices with `1` starting from beginning. Thus, we have to
-    // offset the check index with `rankDiff` in `srcStrides` and source dim
-    // sizes.
-    size_t dimsToDrop = 0;
-    int rankDiff = srcType.getRank() - targetType.getRank();
-    for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
-      // Check that the inner dim size is 1 for both memref/tensor type and
-      // vector slice. It can be folded only if they are 1 and the stride is 1.
-      int dim = targetType.getRank() - i - 1;
-      if (srcStrides[dim + rankDiff] == 1 &&
-          srcType.getDimSize(dim + rankDiff) == 1 &&
-          targetType.getDimSize(dim) == 1) {
-        dimsToDrop++;
-      } else {
-        break;
-      }
-    }
+    FailureOr<size_t> maybeDimsToDrop =
+        getTransferFoldableInnerUnitDims(srcType, targetType);
+    if (failed(maybeDimsToDrop))
+      return failure();
+
+    size_t dimsToDrop = maybeDimsToDrop.value();
     if (dimsToDrop == 0)
       return failure();
 
@@ -1207,35 +1260,9 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
         VectorType::get(targetType.getShape().drop_back(dimsToDrop),
                         targetType.getElementType());
 
-    MemRefType resultMemrefType;
-    MemRefLayoutAttrInterface layout = srcType.getLayout();
-    if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
-      resultMemrefType = MemRefType::get(
-          srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
-          nullptr, srcType.getMemorySpace());
-    } else {
-      MemRefLayoutAttrInterface updatedLayout;
-      if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
-        auto strides =
-            llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
-        updatedLayout = StridedLayoutAttr::get(strided.getContext(),
-                                               strided.getOffset(), strides);
-      } else {
-        AffineMap map = srcType.getLayout().getAffineMap();
-        int numSymbols = map.getNumSymbols();
-        for (size_t i = 0; i < dimsToDrop; ++i) {
-          int dim = srcType.getRank() - i - 1;
-          map = map.replace(rewriter.getAffineDimExpr(dim),
-                            rewriter.getAffineConstantExpr(0),
-                            map.getNumDims() - 1, numSymbols);
-        }
-      }
-      resultMemrefType = MemRefType::get(
-          srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
-          updatedLayout, srcType.getMemorySpace());
-    }
-
     auto loc = readOp.getLoc();
+    MemRefType resultMemrefType =
+        getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
     SmallVector<int64_t> offsets(srcType.getRank(), 0);
     SmallVector<int64_t> strides(srcType.getRank(), 1);
 
@@ -1261,6 +1288,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
+/// Drop inner most contiguous unit dimensions from transfer_write operand.
+/// E.g.,
+///    vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+///      {in_bounds = [true, true, true, true, true]}
+///      : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+///
+/// will be replaced with
+///
+///    %subview = memref.subview %arg0
+///      [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
+///      : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+///    %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
+///      to vector<1x16x16xf32>
+///    vector.transfer_write %0, %subview[%c0, %arg2, %c0]
+///      {in_bounds = [true, true, true]}
+///      : vector<1x16x16xf32>, memref<1x512x16xf32>
+class DropInnerMostUnitDimsTransferWrite
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (writeOp.getTransferRank() == 0)
+      return failure();
+
+    // TODO: support mask.
+    if (writeOp.getMask())
+      return failure();
+
+    auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
+    if (!srcType || !srcType.hasStaticShape())
+      return failure();
+
+    if (!writeOp.getPermutationMap().isMinorIdentity())
+      return failure();
+
+    auto targetType = writeOp.getVectorType();
+    if (targetType.getRank() <= 1)
+      return failure();
+
+    FailureOr<size_t> maybeDimsToDrop =
+        getTransferFoldableInnerUnitDims(srcType, targetType);
+    if (failed(maybeDimsToDrop))
+      return failure();
+
+    size_t dimsToDrop = maybeDimsToDrop.value();
+    if (dimsToDrop == 0)
+      return failure();
+
+    auto resultTargetVecType =
+        VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+                        targetType.getElementType());
+
+    MemRefType resultMemrefType =
+        getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+    SmallVector<int64_t> offsets(srcType.getRank(), 0);
+    SmallVector<int64_t> strides(srcType.getRank(), 1);
+    ArrayAttr inBoundsAttr =
+        writeOp.getInBounds()
+            ? rewriter.getArrayAttr(
+                  writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
+            : ArrayAttr();
+
+    Location loc = writeOp.getLoc();
+    Value rankedReducedView = rewriter.create<memref::SubViewOp>(
+        loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
+        strides);
+    auto permMap = getTransferMinorIdentityMap(
+        cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+        loc, resultTargetVecType, writeOp.getVector());
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        writeOp, shapeCast, rankedReducedView,
+        writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+        // TODO: support mask.
+        /*mask=*/Value(), inBoundsAttr);
+    return success();
+  }
+};
+
 /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
 /// semantics to a contraction suitable for MMT (matrix matrix multiplication
 /// with the RHS transposed) lowering.
@@ -1696,7 +1805,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
 void mlir::vector::
     populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
         RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
+  patterns.add<DropInnerMostUnitDimsTransferRead,
+               DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
+                                                   benefit);
 }
 
 void mlir::vector::populateSinkVectorBroadcastPatterns(

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 0d2743b9fe2e7f5..d6d69c8af88508d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -76,3 +76,56 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
 //  CHECK-NOT:   memref.subview
 //      CHECK:   %[[READ:.+]] = vector.transfer_read %[[SRC]]
 //      CHECK:   return %[[READ]] : vector<4x8xf32>
+
+// -----
+
+func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+    {in_bounds = [true, true, true, true, true]}
+    : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+  return
+}
+// CHECK:      func.func @drop_two_inner_most_dim_for_transfer_write
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME:     memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1x1xf32> to vector<1x16x16xf32>
+// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME:     [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
+func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
+    {in_bounds = [true, true, true, true]}
+    : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
+  return
+}
+// CHECK:      func.func @drop_inner_most_dim_for_transfer_write
+// CHECK-SAME:   %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME:   %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// CHECK:        %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME:     memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
+// CHECK:        %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
+// CHECK:        vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME:     [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
+func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
+    {in_bounds = [true, true, true]}
+    : vector<16x16x1xf32>, memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>
+  return
+}
+// The inner most unit dims can not be dropped if the strides are not ones.
+// CHECK:     func.func @non_unit_strides
+// CHECK-NOT:   memref.subview


        


More information about the Mlir-commits mailing list