[Mlir-commits] [mlir] [mlir][vector] Drop innermost unit dims on transfer_write. (PR #78554)
Andrzej Warzyński
llvmlistbot at llvm.org
Fri Jan 19 01:43:44 PST 2024
================
@@ -1152,8 +1152,80 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
}
};
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if it can not determine the number of dims to be folded.
+/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
+/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
+/// can be dropped by memref.subview ops.
+/// Example 2: it returns "1" if `srcType` is the same memref type with
+/// [8192, 16, 8, 1] strides.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+ SmallVector<int64_t> srcStrides;
+ int64_t srcOffset;
+ if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+ return failure();
+
+ // According to vector.transfer_read/write semantics, the vector can be a
+ // slice. Thus, we have to offset the check index with `rankDiff` in
+ // `srcStrides` and source dim sizes.
+ size_t result = 0;
+ int rankDiff = srcType.getRank() - vectorType.getRank();
+ for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+ // Check that the inner dim size is 1 for both memref type and vector
+ // slice. It can be folded only if they are 1 and the stride is 1.
+ int dim = vectorType.getRank() - i - 1;
+ if (srcStrides[dim + rankDiff] == 1 &&
+ srcType.getDimSize(dim + rankDiff) == 1 &&
+ vectorType.getDimSize(dim) == 1) {
+ result++;
+ } else {
+ break;
+ }
----------------
banach-space wrote:
[nit] This is really a matter of preference, so feel free to ignore. I'm just on a mission to remove `else` from C++ 😂
```suggestion
if (srcStrides[dim + rankDiff] != 1 ||
srcType.getDimSize(dim + rankDiff) != 1 ||
vectorType.getDimSize(dim) != 1)
break;
result++;
```
https://github.com/llvm/llvm-project/pull/78554
More information about the Mlir-commits
mailing list