[Mlir-commits] [mlir] [mlir][vector] Drop innermost unit dims on transfer_write. (PR #78554)

Han-Chung Wang llvmlistbot at llvm.org
Fri Jan 19 01:51:52 PST 2024


================
@@ -1152,8 +1152,80 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
   }
 };
 
-// Drop inner most contiguous unit dimensions from transfer_read operand.
-class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
+/// Returns the number of dims can be folded away from transfer ops. It returns
+/// a failure if it can not determine the number of dims to be folded.
+/// Example 1: it returns "2" if `srcType` is memref<512x16x1x1xf32> and
+/// `vectorType` is vector<16x16x1x1xf32>. Because there two inner most dims
+/// can be dropped by memref.subview ops.
+/// Example 2: it returns "1" if `srcType` is the same memref type with
+/// [8192, 16, 8, 1] strides.
+static FailureOr<size_t>
+getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
+  SmallVector<int64_t> srcStrides;
+  int64_t srcOffset;
+  if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
+    return failure();
+
+  // According to vector.transfer_read/write semantics, the vector can be a
+  // slice. Thus, we have to offset the check index with `rankDiff` in
+  // `srcStrides` and source dim sizes.
+  size_t result = 0;
+  int rankDiff = srcType.getRank() - vectorType.getRank();
+  for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
+    // Check that the inner dim size is 1 for both memref type and  vector
+    // slice. It can be folded only if they are 1 and the stride is 1.
+    int dim = vectorType.getRank() - i - 1;
+    if (srcStrides[dim + rankDiff] == 1 &&
+        srcType.getDimSize(dim + rankDiff) == 1 &&
+        vectorType.getDimSize(dim) == 1) {
+      result++;
+    } else {
+      break;
+    }
----------------
hanhanW wrote:

SGTM, I'm also a fan about it!

https://github.com/llvm/llvm-project/pull/78554


More information about the Mlir-commits mailing list