[Mlir-commits] [mlir] [mlir][vector] Drop innermost unit dims on transfer_write. (PR #78554)
Han-Chung Wang
llvmlistbot at llvm.org
Fri Jan 19 02:16:28 PST 2024
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/78554
>From c9343bd38df632b98a7beabb461e8d7036fa3d1c Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 18 Jan 2024 09:10:51 +0000
Subject: [PATCH 1/4] [mlir][vector] Drop innermost unit dims on
transfer_write.
The revision renames DropInnerMostUnitDims to
DropInnerMostUnitDimsTransferRead; adds support for
vector.transfer_write.
It refactors common methods (i.e., getTransferFoldableInnerUnitDims and
getMemRefTypeWithDroppingInnerDims) and uses them in both patterns.
---
.../Vector/Transforms/VectorTransforms.cpp | 197 +++++++++++++-----
...tor-transfer-collapse-inner-most-dims.mlir | 21 ++
2 files changed, 164 insertions(+), 54 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index bd02c07981466d9..7c276ca8101221b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1152,8 +1152,71 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
}
};
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if strides and offsets can not be resolved.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+ SmallVector<int64_t> srcStrides;
+ int64_t srcOffset;
+ if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ return failure();
+
+ // According to vector.transfer_read/write semantics, the vector can be a
+ // slice. It pads the indices with `1` starting from beginning. Thus, we have
+ // to offset the check index with `rankDiff` in `srcStrides` and source dim
+ // sizes.
+ size_t result = 0;
+ int rankDiff = srcType.getRank() - vectorType.getRank();
+ for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+ // Check that the inner dim size is 1 for both memref/tensor type and
+ // vector slice. It can be folded only if they are 1 and the stride is 1.
+ int dim = vectorType.getRank() - i - 1;
+ if (srcStrides[dim + rankDiff] == 1 &&
+ srcType.getDimSize(dim + rankDiff) == 1 &&
+ vectorType.getDimSize(dim) == 1) {
+ result++;
+ } else {
+ break;
+ }
+ }
+ return result;
+}
+
+/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
+/// `srcType`.
+static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
+ MemRefType srcType,
+ size_t dimsToDrop) {
+ MemRefType resultMemrefType;
+ MemRefLayoutAttrInterface layout = srcType.getLayout();
+ if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), nullptr,
+ srcType.getMemorySpace());
+ }
+ MemRefLayoutAttrInterface updatedLayout;
+ if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
+ auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
+ updatedLayout = StridedLayoutAttr::get(strided.getContext(),
+ strided.getOffset(), strides);
+ } else {
+ AffineMap map = srcType.getLayout().getAffineMap();
+ int numSymbols = map.getNumSymbols();
+ for (size_t i = 0; i < dimsToDrop; ++i) {
+ int dim = srcType.getRank() - i - 1;
+ map = map.replace(builder.getAffineDimExpr(dim),
+ builder.getAffineConstantExpr(0), map.getNumDims() - 1,
+ numSymbols);
+ }
+ }
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), updatedLayout,
+ srcType.getMemorySpace());
+}
+
+/// Drop inner most contiguous unit dimensions from transfer_read operand.
+class DropInnerMostUnitDimsTransferRead
+ : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
@@ -1177,29 +1240,12 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
if (targetType.getRank() <= 1)
return failure();
- SmallVector<int64_t> srcStrides;
- int64_t srcOffset;
- if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
- return failure();
-
- // According to vector.transfer_read semantics, the result can be a slice.
- // It pads the indices with `1` starting from beginning. Thus, we have to
- // offset the check index with `rankDiff` in `srcStrides` and source dim
- // sizes.
- size_t dimsToDrop = 0;
- int rankDiff = srcType.getRank() - targetType.getRank();
- for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
- // Check that the inner dim size is 1 for both memref/tensor type and
- // vector slice. It can be folded only if they are 1 and the stride is 1.
- int dim = targetType.getRank() - i - 1;
- if (srcStrides[dim + rankDiff] == 1 &&
- srcType.getDimSize(dim + rankDiff) == 1 &&
- targetType.getDimSize(dim) == 1) {
- dimsToDrop++;
- } else {
- break;
- }
- }
+ FailureOr<size_t> maybeDimsToDrop =
+ getTransferFoldableInnerUnitDims(srcType, targetType);
+ if (failed(maybeDimsToDrop))
+ return failure();
+
+ size_t dimsToDrop = maybeDimsToDrop.value();
if (dimsToDrop == 0)
return failure();
@@ -1207,35 +1253,9 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());
- MemRefType resultMemrefType;
- MemRefLayoutAttrInterface layout = srcType.getLayout();
- if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
- resultMemrefType = MemRefType::get(
- srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- nullptr, srcType.getMemorySpace());
- } else {
- MemRefLayoutAttrInterface updatedLayout;
- if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
- auto strides =
- llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
- updatedLayout = StridedLayoutAttr::get(strided.getContext(),
- strided.getOffset(), strides);
- } else {
- AffineMap map = srcType.getLayout().getAffineMap();
- int numSymbols = map.getNumSymbols();
- for (size_t i = 0; i < dimsToDrop; ++i) {
- int dim = srcType.getRank() - i - 1;
- map = map.replace(rewriter.getAffineDimExpr(dim),
- rewriter.getAffineConstantExpr(0),
- map.getNumDims() - 1, numSymbols);
- }
- }
- resultMemrefType = MemRefType::get(
- srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
- updatedLayout, srcType.getMemorySpace());
- }
-
auto loc = readOp.getLoc();
+ MemRefType resultMemrefType =
+ getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);
@@ -1261,6 +1281,73 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
}
};
+/// Drop inner most contiguous unit dimensions from transfer_write operand.
+class DropInnerMostUnitDimsTransferWrite
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (writeOp.getTransferRank() == 0)
+ return failure();
+
+ // TODO: support mask.
+ if (writeOp.getMask())
+ return failure();
+
+ auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
+ if (!srcType || !srcType.hasStaticShape())
+ return failure();
+
+ if (!writeOp.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ auto targetType = writeOp.getVectorType();
+ if (targetType.getRank() <= 1)
+ return failure();
+
+ FailureOr<size_t> maybeDimsToDrop =
+ getTransferFoldableInnerUnitDims(srcType, targetType);
+ if (failed(maybeDimsToDrop))
+ return failure();
+
+ size_t dimsToDrop = maybeDimsToDrop.value();
+ if (dimsToDrop == 0)
+ return failure();
+
+ auto resultTargetVecType =
+ VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+ targetType.getElementType());
+
+ auto loc = writeOp.getLoc();
+ MemRefType resultMemrefType =
+ getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
+ SmallVector<int64_t> offsets(srcType.getRank(), 0);
+ SmallVector<int64_t> strides(srcType.getRank(), 1);
+
+ ArrayAttr inBoundsAttr =
+ writeOp.getInBounds()
+ ? rewriter.getArrayAttr(
+ writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
+ : ArrayAttr();
+ Value rankedReducedView = rewriter.create<memref::SubViewOp>(
+ loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
+ strides);
+ auto permMap = getTransferMinorIdentityMap(
+ cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
+
+ auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, resultTargetVecType, writeOp.getVector());
+ rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+ writeOp, shapeCast, rankedReducedView,
+ writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
+ // TODO: support mask.
+ /*mask=*/Value(), inBoundsAttr);
+ return success();
+ }
+};
+
/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction suitable for MMT (matrix matrix multiplication
/// with the RHS transposed) lowering.
@@ -1696,7 +1783,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
+ patterns.add<DropInnerMostUnitDimsTransferRead,
+ DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
+ benefit);
}
void mlir::vector::populateSinkVectorBroadcastPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 0d2743b9fe2e7f5..59116c19b46ec23 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -76,3 +76,24 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
// CHECK-NOT: memref.subview
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SRC]]
// CHECK: return %[[READ]] : vector<4x8xf32>
+
+// -----
+
+func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
+ {in_bounds = [true, true, true, true]}
+ : vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
+ return
+}
+// CHECK: func.func @drop_inner_most_dim_for_transfer_write
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0] [1, 512, 16, 1] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
+// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
>From 58e6571395ff88ecbb9daf7f614644b57f7df775 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 19 Jan 2024 08:41:56 +0000
Subject: [PATCH 2/4] improve comments and tests
---
.../Vector/Transforms/VectorTransforms.cpp | 56 +++++++++++++------
...tor-transfer-collapse-inner-most-dims.mlir | 34 ++++++++++-
2 files changed, 73 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7c276ca8101221b..21d855528fc07d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1153,7 +1153,12 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
};
/// Returns the number of dims can be folded away from transfer ops. It returns
-/// a failure if strides and offsets can not be resolved.
+/// a failure if it can not determine the number of dims to be folded.
+/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
+/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
+/// can be dropped by memref.subview ops.
+/// Example 2: it returns "1" if `srcType` is the same memref type with
+/// [8192, 16, 8, 1] strides.
static FailureOr<size_t>
getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
SmallVector<int64_t> srcStrides;
@@ -1162,14 +1167,13 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
return failure();
// According to vector.transfer_read/write semantics, the vector can be a
- // slice. It pads the indices with `1` starting from beginning. Thus, we have
- // to offset the check index with `rankDiff` in `srcStrides` and source dim
- // sizes.
+ // slice. Thus, we have to offset the check index with `rankDiff` in
+ // `srcStrides` and source dim sizes.
size_t result = 0;
int rankDiff = srcType.getRank() - vectorType.getRank();
for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
- // Check that the inner dim size is 1 for both memref/tensor type and
- // vector slice. It can be folded only if they are 1 and the stride is 1.
+ // Check that the inner dim size is 1 for both memref type and vector
+ // slice. It can be folded only if they are 1 and the stride is 1.
int dim = vectorType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] == 1 &&
srcType.getDimSize(dim + rankDiff) == 1 &&
@@ -1183,7 +1187,8 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
}
/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
-/// `srcType`.
+/// `srcType`. E.g., if `srcType` is memref<512x16x1x1xf32> and `dimsToDrop` is
+/// two, it returns memref<512x16x16> type.
static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
MemRefType srcType,
size_t dimsToDrop) {
@@ -1199,15 +1204,19 @@ static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
updatedLayout = StridedLayoutAttr::get(strided.getContext(),
strided.getOffset(), strides);
- } else {
- AffineMap map = srcType.getLayout().getAffineMap();
- int numSymbols = map.getNumSymbols();
- for (size_t i = 0; i < dimsToDrop; ++i) {
- int dim = srcType.getRank() - i - 1;
- map = map.replace(builder.getAffineDimExpr(dim),
- builder.getAffineConstantExpr(0), map.getNumDims() - 1,
- numSymbols);
- }
+ return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
+ srcType.getElementType(), updatedLayout,
+ srcType.getMemorySpace());
+ }
+
+ // Non-strided layout case.
+ AffineMap map = srcType.getLayout().getAffineMap();
+ int numSymbols = map.getNumSymbols();
+ for (size_t i = 0; i < dimsToDrop; ++i) {
+ int dim = srcType.getRank() - i - 1;
+ map = map.replace(builder.getAffineDimExpr(dim),
+ builder.getAffineConstantExpr(0), map.getNumDims() - 1,
+ numSymbols);
}
return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
srcType.getElementType(), updatedLayout,
@@ -1282,6 +1291,21 @@ class DropInnerMostUnitDimsTransferRead
};
/// Drop inner most contiguous unit dimensions from transfer_write operand.
+/// E.g.,
+/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+/// {in_bounds = [true, true, true, true, true]}
+/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+///
+/// will be replaced with
+///
+/// %subview = memref.subview %arg0
+/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
+/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
+/// to vector<1x16x16xf32>
+/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
+/// {in_bounds = [true, true, true]}
+/// : vector<1x16x16xf32>, memref<1x512x16xf32>
class DropInnerMostUnitDimsTransferWrite
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
index 59116c19b46ec23..d6d69c8af88508d 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir
@@ -79,6 +79,26 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
// -----
+func.func @drop_two_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1x1xf32>, %arg1: vector<1x16x16x1x1xf32>, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+ {in_bounds = [true, true, true, true, true]}
+ : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+ return
+}
+// CHECK: func.func @drop_two_inner_most_dim_for_transfer_write
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
+// CHECK-SAME: memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1x1xf32> to vector<1x16x16xf32>
+// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
+// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
@@ -92,8 +112,20 @@ func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32,
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
-// CHECK-SAME: [0, 0, 0, 0] [1, 512, 16, 1] [1, 1, 1, 1]
// CHECK-SAME: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]
+
+// -----
+
+func.func @non_unit_strides(%arg0: memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>, %arg1: vector<16x16x1xf32>, %arg2: index) {
+ %c0 = arith.constant 0 : index
+ vector.transfer_write %arg1, %arg0[%arg2, %c0, %c0]
+ {in_bounds = [true, true, true]}
+ : vector<16x16x1xf32>, memref<512x16x1xf32, strided<[8192, 16, 4], offset: ?>>
+ return
+}
+// The inner most unit dims can not be dropped if the strides are not ones.
+// CHECK: func.func @non_unit_strides
+// CHECK-NOT: memref.subview
>From feda905282482401a454a0f07d08b76a36bdc7a1 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 19 Jan 2024 09:52:50 +0000
Subject: [PATCH 3/4] address comments!
---
.../Vector/Transforms/VectorTransforms.cpp | 18 ++++++++----------
1 file changed, 8 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 21d855528fc07d9..7d5f4d471e89bff 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1172,16 +1172,14 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
size_t result = 0;
int rankDiff = srcType.getRank() - vectorType.getRank();
for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
- // Check that the inner dim size is 1 for both memref type and vector
- // slice. It can be folded only if they are 1 and the stride is 1.
+ // Check that the inner dim size is 1 for both memref type and vector slice.
+ // It can be folded only if they are 1 and the stride is 1.
int dim = vectorType.getRank() - i - 1;
- if (srcStrides[dim + rankDiff] == 1 &&
- srcType.getDimSize(dim + rankDiff) == 1 &&
- vectorType.getDimSize(dim) == 1) {
- result++;
- } else {
+ if (srcStrides[dim + rankDiff] != 1 ||
+ srcType.getDimSize(dim + rankDiff) != 1 ||
+ vectorType.getDimSize(dim) == 1)
break;
- }
+ result++;
}
return result;
}
@@ -1344,17 +1342,17 @@ class DropInnerMostUnitDimsTransferWrite
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());
- auto loc = writeOp.getLoc();
MemRefType resultMemrefType =
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);
-
ArrayAttr inBoundsAttr =
writeOp.getInBounds()
? rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();
+
+ Location loc = writeOp.getLoc();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
strides);
>From c3c2efbd3ab74fa604296a77bb021627a47f4714 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 19 Jan 2024 10:15:52 +0000
Subject: [PATCH 4/4] fix a bug that is introduced by addressing comments
---
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7d5f4d471e89bff..9c734e8cc2ad11a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1177,7 +1177,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
int dim = vectorType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] != 1 ||
srcType.getDimSize(dim + rankDiff) != 1 ||
- vectorType.getDimSize(dim) == 1)
+ vectorType.getDimSize(dim) != 1)
break;
result++;
}
More information about the Mlir-commits
mailing list