[Mlir-commits] [mlir] [mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (1/N) (PR #122927)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jan 22 08:36:57 PST 2025


================
@@ -2583,113 +2626,139 @@ static Value getStaticPadVal(Operation *op) {
   return {};
 }
 
-/// Rewrite tensor.insert.slice as a vector.transfer_read +
-/// vector.transfer_write pair. The vector size is inferred from the static
-/// dims in the input and output tensors. If a dim is dynamic in both the input
-/// and output tensors, bails out.
-///
-/// Before:
-///     !t_in_type = tensor<1x2x3xf32>
-///     !t_out_type = tensor<9x8x7x1x2x3xf32>
-///     !v_type = vector<1x2x3xf32>
-///     %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
-///     into !t_out_type
-/// After:
-///     %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
-///     %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
-///
-/// TODO: Support masking
-struct InsertSliceVectorizePattern
-    : public OpRewritePattern<tensor::InsertSliceOp> {
-  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+static LogicalResult
+vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
+                         ArrayRef<int64_t> inputVectorSizes,
+                         SmallVectorImpl<Value> &newResults) {
+  // TODO: Introduce a parent class that will handle the insertion point update.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(sliceOp);
 
-  LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
-                                PatternRewriter &rewriter) const final {
-    auto sourceType = sliceOp.getSource().getType();
-    if (!VectorType::isValidElementType(sourceType.getElementType()))
-      return failure();
+  TypedValue<RankedTensorType> source = sliceOp.getSource();
+  auto sourceType = source.getType();
+  if (!VectorType::isValidElementType(sourceType.getElementType()))
+    return failure();
 
-    auto resultType = sliceOp.getResultType();
-
-    // 1. Get the pad value.
-    // TransferReadOp requires a scalar padding value. Note that:
-    //    * for in-bounds access, the value is actually irrelevant.
-    //  There are 2 cases in which xfer.read accesses are known to be in-bounds:
-    //  1. The source shape is static (output vector sizes would be based on
-    //     the source shape and hence all memory accesses would be in-bounds),
-    //  2. Masking is used (output vector sizes would be user-provided, in which
-    //     case it is assumed that all memory accesses are in-bounds). This
-    //     remains a TODO.
-    //
-    // When the value is not known and not needed, use 0. Otherwise, bail out.
-    Value padValue = getStaticPadVal(sliceOp);
-    bool isOutOfBoundsRead = !sourceType.hasStaticShape();
-
-    if (!padValue && isOutOfBoundsRead) {
-      LDBG("Failed to get a pad value for out-of-bounds read access\n");
+  auto resultType = sliceOp.getResultType();
+
+  // 1. Get the pad value.
+  // TransferReadOp requires a scalar padding value. Note that:
+  //    * for in-bounds access, the value is actually irrelevant.
+  // There are 2 cases in which xfer.read accesses are known to be in-bounds:
+  //  1. The source shape is static (output vector sizes would be based on
+  //     the source shape and hence all memory accesses would be in-bounds),
+  //  2. Masking is used (output vector sizes would be user-provided, in which
+  //     case it is assumed that all memory accesses are in-bounds). This
+  //     remains a TODO.
+  //
+  // When the value is not known and not needed, use 0. Otherwise, bail out.
+  Value padValue = getStaticPadVal(sliceOp);
+  bool isOutOfBoundsRead =
+      !sourceType.hasStaticShape() && inputVectorSizes.empty();
+
+  if (!padValue && isOutOfBoundsRead) {
+    LDBG("Failed to get a pad value for out-of-bounds read access\n");
+    return failure();
+  }
+
+  if (!padValue) {
+    auto elemType = sourceType.getElementType();
+    padValue = rewriter.create<arith::ConstantOp>(
+        sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
+  }
+
+  // 2. Get the vector shape and in-bounds attributes
+  SmallVector<int64_t> vecShape;
+  SmallVector<bool> readInBounds;
+  SmallVector<bool> writeInBounds;
+  size_t rankDiff = resultType.getRank() - sourceType.getRank();
+  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
----------------
banach-space wrote:

Done.

https://github.com/llvm/llvm-project/pull/122927


More information about the Mlir-commits mailing list