[Mlir-commits] [mlir] [mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (PR #122927)

Andrzej WarzyƄski llvmlistbot at llvm.org
Wed Jan 15 00:55:43 PST 2025


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/122927

>From e499a2bf861a52ca46fd3757870500b34c918dac Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Thu, 9 Jan 2025 18:43:57 +0000
Subject: [PATCH] [mlir][linalg] Add support for masked vectorization of
 `tensor.insert_slice`

This PR refactors the `InsertSliceVectorizePattern` to enable masked
vectorization of `tensor.insert_slice`.

Note, `tensor.insert_slice` is vectorised using the
`vector.transfer_read` + `vector.transfer_write` pair. ATM, only
`vector.transfer_read` is masked. If `vector.transfer_write` also
requires masking, the vectorizer will bail out. This will be addressed
in a sub-sequent PR.

Summary of changes:
  * Added an argument to specify vector sizes (behavior remains
    unchanged if vector sizes are not specified).
  * Renamed `InsertSliceVectorizePattern` to `vectorizeAsInsertSliceOp`
    and integrated into (alongside other hooks for vectorization) in
    `linalg::vectorize`.
  * Removed `populateInsertSliceVectorizationPatterns`, as
    `InsertSliceVectorizePattern` was its only pattern.
  * Updated `vectorizeAsInsertSliceOp` to support masking for the
    "read" operation.
  * Updated `@pad_and_insert_slice_dest` in
    "vectorization-pad-patterns.mlir" to reflect the removal of
    `populateInsertSliceVectorizationPatterns` from
    `ApplyPadVectorizationPatternsOps`.
---
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 -
 .../TransformOps/LinalgTransformOps.cpp       |   4 -
 .../Linalg/Transforms/Vectorization.cpp       | 270 +++++++++++-------
 .../Linalg/vectorization-pad-patterns.mlir    |  32 +--
 .../Linalg/vectorization-unsupported.mlir     |  23 ++
 mlir/test/Dialect/Linalg/vectorization.mlir   | 100 ++++++-
 6 files changed, 284 insertions(+), 150 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc700f22c2027..726ce22ac70dc3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1723,11 +1723,6 @@ void populateDecomposePadPatterns(RewritePatternSet &patterns);
 /// \see rewriteInIm2Col for more details.
 void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
 
-/// Populates `patterns` with vectorisation patterns for tensor.insert_slice.
-/// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either
-/// expand or merge with other `populate{}`.
-void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns);
-
 /// Populates `patterns` with patterns that vectorize tensor.pad.
 /// These patterns are meant to apply in a complementary fashion. Benefits
 /// are used to encode a certain ordering of pattern application. To avoid
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 67dd21aafe4fe0..73a52ebc46f15b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -265,7 +265,6 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::populatePadOpVectorizationPatterns(patterns);
-  linalg::populateInsertSliceVectorizationPatterns(patterns);
 }
 
 //===----------------------------------------------------------------------===//
@@ -3504,9 +3503,6 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
 
   patterns.add<CopyVectorizationPattern>(ctx);
 
-  // Add misc. vectorization patterns (e.g. for tensor.insert_slice)
-  linalg::populateInsertSliceVectorizationPatterns(patterns);
-
   if (getVectorizePadding()) {
     linalg::populatePadOpVectorizationPatterns(patterns);
     // This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 863f2280e46ce6..dce513a13c8491 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -59,6 +59,30 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
                      ArrayRef<bool> inputVecScalableFlags = {},
                      bool flatten1DDepthwiseConv = false);
 
+/// Vectorize tensor::InsertSliceOp with:
+///   * vector::TransferReadOp + vector::TransferWriteOp
+/// The vector sizes are either:
+///   * user-provided in `inputVectorSizes`, or
+///   * inferred from the static dims in the input and output tensors.
+/// Bails out if:
+///   * vector sizes are not user-provided, and
+///   * at least one dim is dynamic (in both the input and output tensors),
+///   bails out.
+///
+/// Before:
+///     !t_in_type = tensor<1x2x3xf32>
+///     !t_out_type = tensor<9x8x7x1x2x3xf32>
+///     !v_type = vector<1x2x3xf32>
+///     %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
+///     into !t_out_type
+/// After:
+///     %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
+///     %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
+static LogicalResult
+vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
+                         ArrayRef<int64_t> inputVectorSizes,
+                         SmallVectorImpl<Value> &newResults);
+
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
 template <typename OpType>
@@ -1557,6 +1581,7 @@ static LogicalResult
 vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
                         ArrayRef<int64_t> inputVectorSizes,
                         SmallVectorImpl<Value> &newResults) {
+  // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(packOp);
 
@@ -1633,6 +1658,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
                           ArrayRef<int64_t> inputVectorSizes,
                           SmallVectorImpl<Value> &newResults) {
 
+  // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(unpackOp);
 
@@ -1763,7 +1789,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   auto padValue = padOp.getConstantPaddingValue();
   Location loc = padOp.getLoc();
 
-  // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
+  // TODO: Introduce a parent class that will handle the insertion point update.
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPoint(padOp);
 
@@ -1874,6 +1900,15 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
   return success();
 }
 
+/// Need to check if the inner-tiles are static/constant.
+static LogicalResult
+vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
+                                   ArrayRef<int64_t> inputVectorSizes) {
+
+  // TODO: Move pre-conditions from the vectorization logic
+  return success();
+}
+
 static LogicalResult vectorizeLinalgOpPrecondition(
     LinalgOp linalgOp, ArrayRef<int64_t> inputVectorSizes,
     bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -2144,6 +2179,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
       .Case<tensor::UnPackOp>([&](auto unpackOp) {
         return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes);
       })
+      .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
+        return vectorizeInsertSliceOpPrecondition(sliceOp, inputVectorSizes);
+      })
       .Default([](auto) { return failure(); });
 }
 
@@ -2163,8 +2201,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 }
 
 bool mlir::linalg::hasVectorizationImpl(Operation *op) {
-  return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
-      op);
+  return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
+             tensor::InsertSliceOp>(op);
 }
 
 /// Emit a suitable vector form for an operation. If provided,
@@ -2178,6 +2216,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<bool> inputScalableVecDims,
                                       bool vectorizeNDExtract,
                                       bool flatten1DDepthwiseConv) {
+  rewriter.getInsertionPoint();
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -2244,6 +2283,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             return vectorizeAsTensorPackOp(rewriter, packOp, inputVectorSizes,
                                            results);
           })
+          .Case<tensor::InsertSliceOp>([&](auto sliceOp) {
+            return vectorizeAsInsertSliceOp(rewriter, sliceOp, inputVectorSizes,
+                                            results);
+          })
           .Case<tensor::UnPackOp>([&](auto unpackOp) {
             return vectorizeAsTensorUnpackOp(rewriter, unpackOp,
                                              inputVectorSizes, results);
@@ -2583,113 +2626,139 @@ static Value getStaticPadVal(Operation *op) {
   return {};
 }
 
-/// Rewrite tensor.insert.slice as a vector.transfer_read +
-/// vector.transfer_write pair. The vector size is inferred from the static
-/// dims in the input and output tensors. If a dim is dynamic in both the input
-/// and output tensors, bails out.
-///
-/// Before:
-///     !t_in_type = tensor<1x2x3xf32>
-///     !t_out_type = tensor<9x8x7x1x2x3xf32>
-///     !v_type = vector<1x2x3xf32>
-///     %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
-///     into !t_out_type
-/// After:
-///     %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
-///     %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
-///
-/// TODO: Support masking
-struct InsertSliceVectorizePattern
-    : public OpRewritePattern<tensor::InsertSliceOp> {
-  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+static LogicalResult
+vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
+                         ArrayRef<int64_t> inputVectorSizes,
+                         SmallVectorImpl<Value> &newResults) {
+  // TODO: Introduce a parent class that will handle the insertion point update.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(sliceOp);
 
-  LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
-                                PatternRewriter &rewriter) const final {
-    auto sourceType = sliceOp.getSource().getType();
-    if (!VectorType::isValidElementType(sourceType.getElementType()))
-      return failure();
+  TypedValue<RankedTensorType> source = sliceOp.getSource();
+  auto sourceType = source.getType();
+  if (!VectorType::isValidElementType(sourceType.getElementType()))
+    return failure();
 
-    auto resultType = sliceOp.getResultType();
-
-    // 1. Get the pad value.
-    // TransferReadOp requires a scalar padding value. Note that:
-    //    * for in-bounds access, the value is actually irrelevant.
-    //  There are 2 cases in which xfer.read accesses are known to be in-bounds:
-    //  1. The source shape is static (output vector sizes would be based on
-    //     the source shape and hence all memory accesses would be in-bounds),
-    //  2. Masking is used (output vector sizes would be user-provided, in which
-    //     case it is assumed that all memory accesses are in-bounds). This
-    //     remains a TODO.
-    //
-    // When the value is not known and not needed, use 0. Otherwise, bail out.
-    Value padValue = getStaticPadVal(sliceOp);
-    bool isOutOfBoundsRead = !sourceType.hasStaticShape();
-
-    if (!padValue && isOutOfBoundsRead) {
-      LDBG("Failed to get a pad value for out-of-bounds read access\n");
+  auto resultType = sliceOp.getResultType();
+
+  // 1. Get the pad value.
+  // TransferReadOp requires a scalar padding value. Note that:
+  //    * for in-bounds access, the value is actually irrelevant.
+  // There are 2 cases in which xfer.read accesses are known to be in-bounds:
+  //  1. The source shape is static (output vector sizes would be based on
+  //     the source shape and hence all memory accesses would be in-bounds),
+  //  2. Masking is used (output vector sizes would be user-provided, in which
+  //     case it is assumed that all memory accesses are in-bounds). This
+  //     remains a TODO.
+  //
+  // When the value is not known and not needed, use 0. Otherwise, bail out.
+  Value padValue = getStaticPadVal(sliceOp);
+  bool isOutOfBoundsRead =
+      !sourceType.hasStaticShape() && inputVectorSizes.empty();
+
+  if (!padValue && isOutOfBoundsRead) {
+    LDBG("Failed to get a pad value for out-of-bounds read access\n");
+    return failure();
+  }
+
+  if (!padValue) {
+    auto elemType = sourceType.getElementType();
+    padValue = rewriter.create<arith::ConstantOp>(
+        sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
+  }
+
+  // 2. Get the vector shape and in-bounds attributes
+  SmallVector<int64_t> vecShape;
+  SmallVector<bool> readInBounds;
+  SmallVector<bool> writeInBounds;
+  size_t rankDiff = resultType.getRank() - sourceType.getRank();
+  for (unsigned i = 0; i < sourceType.getRank(); ++i) {
+    if (!inputVectorSizes.empty()) {
+      vecShape.push_back(inputVectorSizes[i]);
+      readInBounds.push_back(false);
+      writeInBounds.push_back(false);
+    } else if (!sourceType.isDynamicDim(i)) {
+      vecShape.push_back(sourceType.getDimSize(i));
+      // Source shape is statically known: Neither read nor write are
+      // out-of-bounds.
+      readInBounds.push_back(true);
+      writeInBounds.push_back(true);
+    } else if (!resultType.isDynamicDim(i)) {
+      // Source shape is not statically known, but result shape is.
+      // Vectorize with size of result shape. This may be larger than the
+      // source size.
+      // FIXME: Using rankDiff implies that the source tensor is inserted at
+      // the end of the destination tensor. However, that's not required.
+      vecShape.push_back(resultType.getDimSize(rankDiff + i));
+      // Read may be out-of-bounds because the result size could be larger
+      // than the source size.
+      readInBounds.push_back(false);
+      // Write will be in-bounds provided that the corresponding write idx is 0.
+      // To keep this logic simple, conservatively mark as out-of-bounds.
+      writeInBounds.push_back(false);
+    } else {
+      // Neither source nor result dim of padOp is static. Cannot vectorize
+      // the copy.
+      // TODO: Add support for masking
       return failure();
     }
+  }
+  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
-    if (!padValue) {
-      auto elemType = sourceType.getElementType();
-      padValue = rewriter.create<arith::ConstantOp>(
-          sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
-    }
+  // 3. Generate TransferReadOp.
+  SmallVector<Value> readIndices(
+      vecType.getRank(),
+      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
+  Operation *read = rewriter.create<vector::TransferReadOp>(
+      sliceOp.getLoc(), vecType, source, readIndices, padValue,
+      ArrayRef<bool>{readInBounds});
 
-    // 2. Get the vector shape and in-bounds attributes
-    SmallVector<int64_t> vecShape;
-    SmallVector<bool> readInBounds;
-    SmallVector<bool> writeInBounds;
-    size_t rankDiff = resultType.getRank() - sourceType.getRank();
-    for (unsigned i = 0; i < sourceType.getRank(); ++i) {
-      if (!sourceType.isDynamicDim(i)) {
-        vecShape.push_back(sourceType.getDimSize(i));
-        // Source shape is statically known: Neither read nor write are
-        // out-of-bounds.
-        readInBounds.push_back(true);
-        writeInBounds.push_back(true);
-      } else if (!resultType.isDynamicDim(i)) {
-        // Source shape is not statically known, but result shape is.
-        // Vectorize with size of result shape. This may be larger than the
-        // source size.
-        // FIXME: Using rankDiff implies that the source tensor is inserted at
-        // the end of the destination tensor. However, that's not required.
-        vecShape.push_back(resultType.getDimSize(rankDiff + i));
-        // Read may be out-of-bounds because the result size could be larger
-        // than the source size.
-        readInBounds.push_back(false);
-        // Write will in-bounds provided that the corresponding write idx is 0.
-        // To keep this logic simple, conservatively mark as out-of-bounds.
-        writeInBounds.push_back(false);
-      } else {
-        // Neither source nor result dim of padOp is static. Cannot vectorize
-        // the copy.
-        // TODO: Add support for masking
-        return failure();
-      }
+  // If vector sizes are user provided, make sure to mask xfer_read.
+  if (!inputVectorSizes.empty()) {
+    auto *srcDefOp = source.getDefiningOp();
+    if (!srcDefOp) {
+      LDBG("Unable to get the defining Op of " << sliceOp);
+      return failure();
     }
-    auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
-    // 3. Generate TransferReadOp.
-    SmallVector<Value> readIndices(
-        vecType.getRank(),
-        rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-    auto read = rewriter.create<vector::TransferReadOp>(
-        sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
-        ArrayRef<bool>{readInBounds});
+    ReifiedRankedShapedTypeDims reifiedSrcSizes;
+    LogicalResult status =
+        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp)
+            .reifyResultShapes(rewriter, reifiedSrcSizes);
+    if (status.failed()) {
+      LDBG("Unable to reify result shapes of " << sliceOp);
+      return failure();
+    }
 
-    // 4. Generate TransferWriteOp.
-    auto writeIndices = getValueOrCreateConstantIndexOp(
-        rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
+    // Create the mask
+    SmallVector<int64_t> readMaskShape(
+        sliceOp.getSource().getType().getShape());
+    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
+    Value maskOp = rewriter.create<vector::CreateMaskOp>(
+        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
 
-    // 5. Finalize
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        sliceOp, read, sliceOp.getDest(), writeIndices,
-        ArrayRef<bool>{writeInBounds});
+    // Mask the xfer_read Op
+    read = mlir::vector::maskOperation(rewriter, read, maskOp);
+  }
 
-    return success();
+  // 4. Generate TransferWriteOp.
+  if (!inputVectorSizes.empty() &&
+      ShapedType::isDynamicShape(resultType.getShape())) {
+    LDBG("TODO: Masking of xfer_write when vectorising " << sliceOp);
+    return failure();
   }
-};
+
+  auto writeIndices = getValueOrCreateConstantIndexOp(
+      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
+
+  // 5. Finalize
+  Operation *write = rewriter.create<vector::TransferWriteOp>(
+      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
+      ArrayRef<bool>{writeInBounds});
+  newResults.push_back(write->getResult(0));
+
+  return success();
+}
 
 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
 /// ```
@@ -2778,11 +2847,6 @@ struct PadOpVectorizationWithInsertSlicePattern
   }
 };
 
-void mlir::linalg::populateInsertSliceVectorizationPatterns(
-    RewritePatternSet &patterns) {
-  patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
-}
-
 void mlir::linalg::populatePadOpVectorizationPatterns(
     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
   patterns.add<PadOpVectorizationWithTransferReadPattern,
diff --git a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
index 08a3bbbb301c87..747b6f6d90cc7f 100644
--- a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
@@ -224,34 +224,16 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-
 // -----
 
-///----------------------------------------------------------------------------------------
-/// tensor::PadOp -> tensor::EmptyOp + linalg::FillOp/tensor::GenerateOp + tensor::InsertSliceOp
-/// [Pattern: GenericPadOpVectorizationPattern + InsertSliceVectorizePattern]
-/// TODO: Split the test into two, one for each pattern.
-///----------------------------------------------------------------------------------------
-
 func.func private @make_vector() -> tensor<12x13xf32>
 
-// Same as @pad_and_insert_slice_dest in vectorization-with-patterns.mlir, but
-// over here linalg::fill is not vectorized (patterns for linalg.fill are not
-// included here)
-// CHECK-LABEL:   func.func @pad_and_insert_slice_dest(
-// CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
-//  CHECK-NOT:     tensor.pad
-//  CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
-//  CHECK-DAG:     %[[PAD:.*]] = arith.constant 5.000000e+00 : f32
-//  CHECK-DAG:     %[[PAD_READ:.*]] = arith.constant 0.000000e+00 : f32
-//      CHECK:     %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
-//      CHECK:     %[[FILL:.*]] = linalg.fill ins(%[[PAD]] : f32) outs(%[[EMPTY]] : tensor<1x12x13xf32>) -> tensor<1x12x13xf32>
-//      CHECK:     %[[READ_1:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
-//      CHECK:     %[[WRITE_1:.*]] = vector.transfer_write %[[READ_1]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
-//      CHECK:     %[[VEC:.*]] = call @make_vector() : () -> tensor<12x13xf32>
-//      CHECK:     %[[READ_2:.*]] = vector.transfer_read %[[VEC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD_READ]] {in_bounds = [true, true]} : tensor<12x13xf32>, vector<12x13xf32>
-//      CHECK:     %[[RES:.*]] = vector.transfer_write %[[READ_2]], %[[WRITE_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<12x13xf32>, tensor<1x12x13xf32>
-//      CHECK:     return %[[RES]] : tensor<1x12x13xf32>
+// the destination of tensor.insert_slice matches the result of tensor.pad -
+// not supported.
+
+// check-label:   func.func @pad_and_insert_slice_dest(
+// check-not:     vector.transfer_read
+// check-not:     vector.transfer_write
 
 func.func @pad_and_insert_slice_dest(
     %arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
@@ -270,8 +252,6 @@ module attributes {transform.with_named_sequence} {
     %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
 
     transform.apply_patterns to %func_op {
-      // TODO: Split into two tests, one for each pattern
-      transform.apply_patterns.linalg.decompose_pad
       transform.apply_patterns.linalg.pad_vectorization
     } : !transform.op<"func.func">
     transform.yield
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8fbc74ec345c6b..b1e3a69df203e1 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -280,3 +280,26 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// One of the _write_ dimensions is dynamic (but _read_ dimensions are static).
+
+func.func private @insert_slice_dynamic_write_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<?x3xi32> {
+  %c2 = arith.constant 2 : index
+  %init = tensor.empty(%size) : tensor<?x3xi32>
+
+  %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<?x3xi32>
+
+  return %res : tensor<?x3xi32>
+}
+
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
+    transform.yield
+  } 
+ }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 0f2abe06569d64..9bd4a76354fe25 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -66,7 +66,7 @@ func.func @vectorize_dynamic_identity_with_constant(%arg0: tensor<?xf32>,
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    %size = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op 
+    %size = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [%size] : !transform.any_op, !transform.any_op
     transform.yield
   }
@@ -690,7 +690,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [4, 1, 32] : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -727,7 +727,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [32, 4, 1] : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -768,7 +768,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize %0 vector_sizes [4, 1] : !transform.any_op
-    transform.yield 
+    transform.yield
   }
 }
 
@@ -933,7 +933,7 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
     %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [512, 128] : !transform.any_op
     transform.yield
-  } 
+  }
 }
 
 // -----
@@ -957,7 +957,7 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
     %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
     transform.yield
-  } 
+  }
  }
 
   // -----
@@ -981,7 +981,7 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
     %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 vector_sizes [256, 128] : !transform.any_op
     transform.yield
-  } 
+  }
 }
 
   // -----
@@ -1022,7 +1022,7 @@ func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>
 //  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
 //  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
 //      CHECK: %[[transfer_read:.*]] =  vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
-// CHECK-SAME:   {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32> 
+// CHECK-SAME:   {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
 //      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[transfer_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
 //      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
 //  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
@@ -1059,7 +1059,7 @@ func.func @test_vectorize_unpack_no_vector_sizes(%source: tensor<8x8x32x16xf32>,
     %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 : !transform.any_op
     transform.yield
-  } 
+  }
  }
 
   // -----
@@ -1083,10 +1083,10 @@ func.func @test_vectorize_unpack_no_vector_sizes_slice_output(%source: tensor<8x
     %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 : !transform.any_op
     transform.yield
-  } 
+  }
  }
 
-  // -----
+// -----
 
 func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf32>, %dest: tensor<7x16xf32>) -> tensor<7x16xf32> {
    %0 = tensor.unpack %source outer_dims_perm=[1, 0] inner_dims_pos = [1] inner_tiles = [4] into %dest : tensor<4x7x4xf32> -> tensor<7x16xf32>
@@ -1106,5 +1106,81 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     %0 = transform.structured.match ops{["tensor.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    transform.structured.vectorize %0 : !transform.any_op
     transform.yield
-  } 
+  }
+ }
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// tensor.insert_slice
+///----------------------------------------------------------------------------------------
+
+func.func private @insert_slice_static_sizes(%source: tensor<?x3x?x1xi32>) -> tensor<5x3xi32> {
+  %c2 = arith.constant 2 : index
+  %init = tensor.empty() : tensor<5x3xi32>
+
+  %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
+  %res = tensor.insert_slice %source_slice into %init[0, %c2] [5, 1] [1, 1] : tensor<5x1xi32> into tensor<5x3xi32>
+
+  return %res : tensor<5x3xi32>
+}
+
+// CHECK-LABEL:   func.func private @insert_slice_static_sizes(
+// CHECK-SAME:      %[[SEC:.*]]: tensor<?x3x?x1xi32>) -> tensor<5x3xi32> {
+// CHECK:           %[[C_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[INIT:.*]] = tensor.empty() : tensor<5x3xi32>
+// CHECK:           %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SEC]][0, %[[C_2]], 0, 0] [1, 1, 5, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<5x1xi32>
+// CHECK:           %[[PAD:.*]] = arith.constant 0 : i32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C_5:.*]] = arith.constant 5 : index
+// CHECK:           %[[C_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[C_5]], %[[C_1]] : vector<8x1xi1>
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C0]], %[[C0]]], %[[PAD]] : tensor<5x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
+// CHECK:           %[[C_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32>
+// CHECK:           return %[[RES]] : tensor<5x3xi32>
+
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
+    transform.yield
+  }
+ }
+
+// -----
+
+// One of the _read_ dimensions is dynamic (but _write_ dimensions are static).
+
+func.func private @insert_slice_dynamic_read_dim(%source: tensor<?x3x?x1xi32>, %size: index) -> tensor<5x3xi32> {
+  %c2 = arith.constant 2 : index
+  %init = tensor.empty() : tensor<5x3xi32>
+
+  %source_slice = tensor.extract_slice %source[0, %c2, 0, 0] [1, 1, %size, 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
+  %res = tensor.insert_slice %source_slice into %init[0, %c2] [%size, 1] [1, 1] : tensor<?x1xi32> into tensor<5x3xi32>
+
+  return %res : tensor<5x3xi32>
+}
+
+// CHECK-LABEL:   func.func private @insert_slice_dynamic_read_dim(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<?x3x?x1xi32>,
+// CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<5x3xi32> {
+// CHECK:           %[[C_2:.*]] = arith.constant 2 : index
+// CHECK:           %[[INIT:.*]] = tensor.empty() : tensor<5x3xi32>
+// CHECK:           %[[SRC_SLICE:.*]] = tensor.extract_slice %[[SRC]][0, %[[C_2]], 0, 0] [1, 1, %[[SIZE]], 1] [1, 1, 1, 1] : tensor<?x3x?x1xi32> to tensor<?x1xi32>
+// CHECK-DAG:       %[[PAD:.*]] = arith.constant 0 : i32
+// CHECK-DAG:       %[[C_1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C_0:.*]] = arith.constant 0 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[SIZE]], %[[C_1]] : vector<8x1xi1>
+// CHECK:           %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC_SLICE]][%[[C_0]], %[[C_0]]], %[[PAD]] : tensor<?x1xi32>, vector<8x1xi32> } : vector<8x1xi1> -> vector<8x1xi32>
+// CHECK:           %[[C_0_1:.*]] = arith.constant 0 : index
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C_0_1]], %[[C_2]]] : vector<8x1xi32>, tensor<5x3xi32>
+// CHECK:           return %[[RES]] : tensor<5x3xi32>
+
+ module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 vector_sizes [8, 1] : !transform.any_op
+    transform.yield
+  }
  }



More information about the Mlir-commits mailing list