[Mlir-commits] [mlir] [mlir][vector] Drop innermost unit dims on transfer_write. (PR #78554)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Jan 19 01:43:43 PST 2024
================
@@ -1261,6 +1290,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
}
};
+/// Drop inner most contiguous unit dimensions from transfer_write operand.
+/// E.g.,
+/// vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+/// {in_bounds = [true, true, true, true, true]}
+/// : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+///
+/// will be replaced with
+///
+/// %subview = memref.subview %arg0
+/// [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
+/// : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+/// %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
+/// to vector<1x16x16xf32>
+/// vector.transfer_write %0, %subview[%c0, %arg2, %c0]
+/// {in_bounds = [true, true, true]}
+/// : vector<1x16x16xf32>, memref<1x512x16xf32>
+class DropInnerMostUnitDimsTransferWrite
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ // TODO: support 0-d corner case.
+ if (writeOp.getTransferRank() == 0)
+ return failure();
+
+ // TODO: support mask.
+ if (writeOp.getMask())
+ return failure();
+
+ auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
+ if (!srcType || !srcType.hasStaticShape())
+ return failure();
+
+ if (!writeOp.getPermutationMap().isMinorIdentity())
+ return failure();
+
+ auto targetType = writeOp.getVectorType();
+ if (targetType.getRank() <= 1)
+ return failure();
+
+ FailureOr<size_t> maybeDimsToDrop =
+ getTransferFoldableInnerUnitDims(srcType, targetType);
+ if (failed(maybeDimsToDrop))
+ return failure();
+
+ size_t dimsToDrop = maybeDimsToDrop.value();
+ if (dimsToDrop == 0)
+ return failure();
+
+ auto resultTargetVecType =
+ VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+ targetType.getElementType());
+
+ auto loc = writeOp.getLoc();
----------------
banach-space wrote:
[nit] Not needed until L1358
https://github.com/llvm/llvm-project/pull/78554
More information about the Mlir-commits
mailing list