[Mlir-commits] [mlir] [mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (2/N) (PR #123031)

Han-Chung Wang llvmlistbot at llvm.org
Sun Feb 2 22:36:36 PST 2025


================
@@ -2716,56 +2716,60 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   }
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
-  // 3. Generate TransferReadOp.
-  SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
+  // 3. Generate TransferReadOp + TransferWriteOp
+  ReifiedRankedShapedTypeDims reifiedSrcSizes;
+  Value maskOp;
 
-  // If vector sizes are user provided, make sure to mask xfer_read.
+  // If vector sizes are user provided, make sure to mask. First, generate the
+  // mask.
   if (!inputVectorSizes.empty()) {
     auto *srcDefOp = source.getDefiningOp();
     if (!srcDefOp) {
       LDBG("Unable to get the defining Op of " << sliceOp);
       return failure();
     }
 
-    ReifiedRankedShapedTypeDims reifiedSrcSizes;
     LogicalResult status =
         cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
             rewriter, reifiedSrcSizes);
     if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << sliceOp);
+      LDBG("Unable to reify result shapes of " << srcDefOp);
       return failure();
     }
 
     // Create the mask
-    SmallVector<int64_t> readMaskShape(
-        sliceOp.getSource().getType().getShape());
     auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    Value maskOp = rewriter.create<vector::CreateMaskOp>(
+    maskOp = rewriter.create<vector::CreateMaskOp>(
         sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-
-    // Mask the xfer_read Op
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
   }
 
-  // 4. Generate TransferWriteOp.
-  if (!inputVectorSizes.empty() &&
-      ShapedType::isDynamicShape(resultType.getShape())) {
-    LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp);
-    return failure();
+  // 3.a. TransferReadOp
+  SmallVector<Value> readIndices(
+      vecType.getRank(),
+      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
+  Operation *read = rewriter.create<vector::TransferReadOp>(
+      sliceOp.getLoc(), vecType, source, readIndices, padValue,
+      ArrayRef<bool>{readInBounds});
+
+  // Mask the xfer_read Op
----------------
hanhanW wrote:

ditto: I think the line 2732 comment already documents it. IMO, they belong to the same function/snippet, so I think we do not need the comment.

>   // If vector sizes are user provided, make sure to mask. First, generate the
>  // mask.

https://github.com/llvm/llvm-project/pull/123031


More information about the Mlir-commits mailing list