[Mlir-commits] [mlir] [mlir][vector] Drop innermost unit dims on transfer_write. (PR #78554)

Han-Chung Wang llvmlistbot at llvm.org
Fri Jan 19 01:53:16 PST 2024


================
@@ -1261,6 +1290,88 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
   }
 };
 
+/// Drop inner most contiguous unit dimensions from transfer_write operand.
+/// E.g.,
+///    vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0, %c0]
+///      {in_bounds = [true, true, true, true, true]}
+///      : vector<1x16x16x1x1xf32>, memref<1x512x16x1x1xf32>
+///
+/// will be replaced with
+///
+///    %subview = memref.subview %arg0
+///      [0, 0, 0, 0, 0] [1, 512, 16, 1, 1] [1, 1, 1, 1, 1]
+///      : memref<1x512x16x1x1xf32> to memref<1x512x16xf32>
+///    %0 = vector.shape_cast %arg1 : vector<1x16x16x1x1xf32>
+///      to vector<1x16x16xf32>
+///    vector.transfer_write %0, %subview[%c0, %arg2, %c0]
+///      {in_bounds = [true, true, true]}
+///      : vector<1x16x16xf32>, memref<1x512x16xf32>
+class DropInnerMostUnitDimsTransferWrite
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    // TODO: support 0-d corner case.
+    if (writeOp.getTransferRank() == 0)
+      return failure();
+
+    // TODO: support mask.
+    if (writeOp.getMask())
+      return failure();
+
+    auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
+    if (!srcType || !srcType.hasStaticShape())
+      return failure();
+
+    if (!writeOp.getPermutationMap().isMinorIdentity())
+      return failure();
+
+    auto targetType = writeOp.getVectorType();
+    if (targetType.getRank() <= 1)
+      return failure();
+
+    FailureOr<size_t> maybeDimsToDrop =
+        getTransferFoldableInnerUnitDims(srcType, targetType);
+    if (failed(maybeDimsToDrop))
+      return failure();
+
+    size_t dimsToDrop = maybeDimsToDrop.value();
+    if (dimsToDrop == 0)
+      return failure();
+
+    auto resultTargetVecType =
+        VectorType::get(targetType.getShape().drop_back(dimsToDrop),
+                        targetType.getElementType());
+
+    auto loc = writeOp.getLoc();
----------------
hanhanW wrote:

I moved it to right before where it is used.

https://github.com/llvm/llvm-project/pull/78554


More information about the Mlir-commits mailing list