[Mlir-commits] [mlir] [mlir][linalg] Split GenericPadOpVectorizationPattern into two patterns (PR #111349)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Oct 27 10:57:15 PDT 2024


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/111349

>From 45318f3a950b19d567608528ad90920a0c23680e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 1 Oct 2024 16:26:22 +0100
Subject: [PATCH 1/2] [mlir][linalg] Split GenericPadOpVectorizationPattern
 into two patterns

At the moment, `GenericPadOpVectorizationPattern` implements two
orthogonal transformations:
  1. Rewrites `tensor::PadOp` into a sequence of `tensor::EmptyOp`,
    `linalg::FillOp` and `tensor::InsertSliceOp`.
  2. Vectorizes (where possible) `tensor::InsertSliceOp` (see
    `tryVectorizeCopy`).

This patch splits `GenericPadOpVectorizationPattern` into two separate
patterns:
  1. `GeneralizePadOpPattern` for the first transformation (note that
    currently `GenericPadOpVectorizationPattern` inherits from
    `GeneralizePadOpPattern`).
  2. `InsertSliceVectorizePattern` to vectorize `tensor::InsertSliceOp`.

With this change, we gain the following:
  * a clear separation between pre-processing and vectorization
    transformations/stages,
  * a path to support masked vectorisation for `tensor.insert_slice`
    (with a dedicated pattern for vectorization, it is much easier to
    specify the input vector sizes used in masking),
  * more opportunities to vectorize `tensor.insert_slice`.

Note for downstream users:
--------------------------

If you were using `populatePadOpVectorizationPatterns`, following this
change you will also have to add
`populateInsertSliceVectorizationPatterns`.

Finer implementation details:
-----------------------------

1. The majority of changes in this patch are copy & paste + some edits.

  1.1 The only functional change is that the vectorization of
    `tensor.insert_slice` is now broadly available (as opposed to being
    constrained to the pad vectorization pattern:
    `GenericPadOpVectorizationPattern`).

  1.2 Following-on from the above, `@pad_and_insert_slice_dest` is
    updated. As expected, the input `tensor.insert_slice` Op is no
    longer "preserved" and instead gets vectorized successfully.

2. The `linalg.fill` case in `getConstantPadVal` works under the
   assumption that only _scalar_ source values can be used. That's
   consistent with the definition of the Op, but it's not tested at the
   moment. Hence a test case in Linalg/invalid.mlir is added.

3. The behaviour of the two TD vectorization Ops,
   `transform.structured.vectorize_children_and_apply_patterns` and
   `transform.structured.vectorize` is preserved.
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  14 +-
 .../TransformOps/LinalgTransformOps.cpp       |   3 +
 .../Dialect/Linalg/Transforms/Transforms.cpp  |   7 +-
 .../Linalg/Transforms/Vectorization.cpp       | 289 +++++++++++-------
 mlir/test/Dialect/Linalg/invalid.mlir         |   9 +
 .../Linalg/vectorization-with-patterns.mlir   |  52 +++-
 6 files changed, 240 insertions(+), 134 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 70b086641bdc18..b5710bd78f0089 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1503,18 +1503,13 @@ using OptimizeCopyFn =
 
 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
 /// InsertSliceOp. For now, only constant padding values are supported.
-/// `OptimizeCopyFn` can be used to customize copying step optimization.
 struct GeneralizePadOpPattern : public OpRewritePattern<tensor::PadOp> {
-  GeneralizePadOpPattern(MLIRContext *context,
-                         OptimizeCopyFn optimizeCopyFn = nullptr,
-                         PatternBenefit benefit = 1)
-      : OpRewritePattern<tensor::PadOp>(context, benefit),
-        optimizeCopyFn(std::move(optimizeCopyFn)) {}
+  GeneralizePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::PadOp>(context, benefit) {}
   LogicalResult matchAndRewrite(tensor::PadOp padOp,
                                 PatternRewriter &rewriter) const override;
 
 protected:
-  OptimizeCopyFn optimizeCopyFn;
   Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp,
                                Value dest,
                                const SmallVector<Value> &dynSizes) const;
@@ -1663,6 +1658,11 @@ void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
 /// \see rewriteInIm2Col for more details.
 void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with vectorisation patterns for tensor.insert_slice.
+/// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either
+/// expand or merge with other `populate{}`.
+void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that vectorize tensor.pad.
 /// These patterns are meant to apply in a complementary fashion. Benefits
 /// are used to encode a certain ordering of pattern application. To avoid
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 3d3f0a93a3829b..7ce4860c44dc3d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3482,6 +3482,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
 
   patterns.add<CopyVectorizationPattern>(ctx);
 
+  // Add misc. vectorization patterns (e.g. for tensor.insert_slice)
+  linalg::populateInsertSliceVectorizationPatterns(patterns);
+
   if (getVectorizePadding())
     linalg::populatePadOpVectorizationPatterns(patterns);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 0fe096863d7b01..da5233049aaf69 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -973,12 +973,7 @@ GeneralizePadOpPattern::matchAndRewrite(tensor::PadOp padOp,
       padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes);
   Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes);
 
-  // Try optimize the copy of source.
-  if (optimizeCopyFn && optimizeCopyFn(rewriter, padOp, fill).succeeded())
-    return success();
-
-  // tensor::PadOps cannot be optimized. Generate a InsertSliceOp instead
-  // for copying the PadOp source.
+  // Generate a InsertSliceOp for copying the PadOp source.
   auto sourceType = padOp.getSourceType();
   // Compute size of source of tensor::PadOp.
   SmallVector<OpFoldResult> srcSizes =
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a2457176a1d47..cd4b46d8412d80 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2281,115 +2281,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
 //----------------------------------------------------------------------------//
 // Misc. vectorization patterns.
 //----------------------------------------------------------------------------//
-
-/// Helper function that retrieves the value of an IntegerAttr.
-static int64_t getIntFromAttr(Attribute attr) {
-  return cast<IntegerAttr>(attr).getInt();
-}
-
-/// Given an ArrayRef of OpFoldResults, return a vector of Values.
-/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
-/// not supported.
-static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
-                                           ArrayRef<OpFoldResult> ofrs) {
-  SmallVector<Value> result;
-  for (auto o : ofrs) {
-    if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
-      result.push_back(val);
-    } else {
-      result.push_back(rewriter.create<arith::ConstantIndexOp>(
-          loc, getIntFromAttr(o.template get<Attribute>())));
-    }
-  }
-  return result;
-}
-
-/// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
-/// InsertSliceOp. For now, only constant padding values are supported.
-/// If there is enough static type information, TransferReadOps and
-/// TransferWriteOps may be generated instead of InsertSliceOps.
-struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
-  GenericPadOpVectorizationPattern(MLIRContext *context,
-                                   PatternBenefit benefit = 1)
-      : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
-  /// Vectorize the copying of a tensor::PadOp's source. This is possible if
-  /// each dimension size is statically know in the source type or the result
-  /// type (or both).
-  static LogicalResult tryVectorizeCopy(RewriterBase &rewriter,
-                                        tensor::PadOp padOp, Value dest) {
-    auto sourceType = padOp.getSourceType();
-    auto resultType = padOp.getResultType();
-    if (!VectorType::isValidElementType(sourceType.getElementType()))
-      return failure();
-
-    // Copy cannot be vectorized if pad value is non-constant and source shape
-    // is dynamic. In case of a dynamic source shape, padding must be appended
-    // by TransferReadOp, but TransferReadOp supports only constant padding.
-    auto padValue = padOp.getConstantPaddingValue();
-    if (!padValue) {
-      if (!sourceType.hasStaticShape())
-        return failure();
-      // Create dummy padding value.
-      auto elemType = sourceType.getElementType();
-      padValue = rewriter.create<arith::ConstantOp>(
-          padOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
-    }
-
-    SmallVector<int64_t> vecShape;
-    SmallVector<bool> readInBounds;
-    SmallVector<bool> writeInBounds;
-    for (unsigned i = 0; i < sourceType.getRank(); ++i) {
-      if (!sourceType.isDynamicDim(i)) {
-        vecShape.push_back(sourceType.getDimSize(i));
-        // Source shape is statically known: Neither read nor write are
-        // out-of- bounds.
-        readInBounds.push_back(true);
-        writeInBounds.push_back(true);
-      } else if (!resultType.isDynamicDim(i)) {
-        // Source shape is not statically known, but result shape is.
-        // Vectorize with size of result shape. This may be larger than the
-        // source size.
-        vecShape.push_back(resultType.getDimSize(i));
-        // Read may be out-of-bounds because the result size could be larger
-        // than the source size.
-        readInBounds.push_back(false);
-        // Write is out-of-bounds if low padding > 0.
-        writeInBounds.push_back(
-            getConstantIntValue(padOp.getMixedLowPad()[i]) ==
-            static_cast<int64_t>(0));
-      } else {
-        // Neither source nor result dim of padOp is static. Cannot vectorize
-        // the copy.
-        return failure();
-      }
-    }
-    auto vecType = VectorType::get(vecShape, sourceType.getElementType());
-
-    // Generate TransferReadOp.
-    SmallVector<Value> readIndices(
-        vecType.getRank(),
-        rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
-    auto read = rewriter.create<vector::TransferReadOp>(
-        padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue,
-        ArrayRef<bool>{readInBounds});
-
-    // If `dest` is a FillOp and the TransferWriteOp would overwrite the
-    // entire tensor, write directly to the FillOp's operand.
-    if (llvm::equal(vecShape, resultType.getShape()) &&
-        llvm::all_of(writeInBounds, [](bool b) { return b; }))
-      if (auto fill = dest.getDefiningOp<FillOp>())
-        dest = fill.output();
-
-    // Generate TransferWriteOp.
-    auto writeIndices =
-        ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        padOp, read, dest, writeIndices, ArrayRef<bool>{writeInBounds});
-
-    return success();
-  }
-};
-
 /// Base pattern for rewriting tensor::PadOps whose result is consumed by a
 /// given operation type OpTy.
 template <typename OpTy>
@@ -2623,6 +2514,177 @@ struct PadOpVectorizationWithTransferWritePattern
   }
 };
 
+/// Given an ArrayRef of OpFoldResults, return a vector of Values.
+/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
+/// not supported.
+static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
+                                           ArrayRef<OpFoldResult> ofrs) {
+  SmallVector<Value> result;
+  for (auto o : ofrs) {
+    if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
+      result.push_back(val);
+    } else {
+      result.push_back(rewriter.create<arith::ConstantIndexOp>(
+          loc, cast<IntegerAttr>(cast<Attribute>(o)).getInt()));
+    }
+  }
+  return result;
+}
+
+/// Returns the effective Pad value for the input op, provided it's a scalar.
+///
+/// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
+/// this Op performs padding, retrieve the padding value provided that it's
+/// a scalar and static/fixed for all the padded values. Returns an empty value
+/// otherwise.
+static Value getStaticPadVl(Operation *op) {
+  if (!op)
+    return {};
+
+  // 1. vector.broadcast - return the value that's being broadcast,
+  // provided that it's a scalar.
+  if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
+    auto source = bcast.getSource();
+    if (llvm::dyn_cast<VectorType>(source.getType()))
+      return {};
+
+    return source;
+  }
+
+  // 1. linalg.fill - use the scalar input value that used to fill the output
+  // tensor.
+  if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
+    return fill.getInputs()[0];
+  }
+
+  // 2. tensor.generateOp - can't guarantee the value is fixed without
+  // analysing, bail out.
+  if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
+    return {};
+  }
+
+  // 3. vector.transfer_write - inspect the input vector that's written from. If
+  // if contains a single value that has been broadcast (e.g. via
+  // vector.broadcast), extract it, fail otherwise.
+  if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
+    return getStaticPadVl(xferWrite.getVector().getDefiningOp());
+
+  // 4. tensor.insert_slice - inspect the destination tensor. If it's larger
+  // than the input tensor, then, provided it's constant, we'll extract the
+  // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
+  // TODO: Clarify the semantics when the input tensor is larger than the
+  // destination.
+  if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
+    return getStaticPadVl(slice.getDest().getDefiningOp());
+
+  return {};
+}
+
+/// Rewrite tensor.insert.slice as a vector.transfer_read +
+/// vector.transfer_write pair. The vector size is inferred from the static
+/// dims in the input and output tensors. If a dim is dynamic in both the input
+/// and output tensors, bails out.
+///
+/// Before:
+///     !t_in_type = tensor<1x2x3xf32>
+///     !t_out_type = tensor<9x8x7x1x2x3xf32>
+///     !v_type = vector<1x2x3xf32>
+///     %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
+///     into !t_out_type
+/// After:
+///     %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
+///     %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
+///
+/// TODO: Support masking
+struct InsertSliceVectorizePattern
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp sliceOp,
+                                PatternRewriter &rewriter) const final {
+    auto sourceType = sliceOp.getSource().getType();
+    if (!VectorType::isValidElementType(sourceType.getElementType()))
+      return failure();
+
+    auto resultType = sliceOp.getResultType();
+
+    // 1. Get the pad value.
+    // TransferReadOp requires a scalar padding value. Note that:
+    //    * for in-bounds access, the value is actually irrelevant.
+    //  There are 2 cases in which xfer.read accesses are known to be in-bounds:
+    //  1. The source shape is static (output vector sizes would be based on
+    //     the source shape and hence all memory accesses would be in-bounds),
+    //  2. Masking is used (output vector sizes would be user-provided, in which
+    //     case it is assumed that all memory accesses are in-bounds). This
+    //     remains a TODO.
+    //
+    // When the value is not known and not needed, use 0. Otherwise, bail out.
+    Value padValue = getStaticPadVl(sliceOp);
+    bool isOutOfBoundsRead = !sourceType.hasStaticShape();
+
+    if (!padValue && isOutOfBoundsRead) {
+      LDBG("Failed to get a pad value for out-of-bounds read access\n");
+      return failure();
+    }
+
+    if (!padValue) {
+      auto elemType = sourceType.getElementType();
+      padValue = rewriter.create<arith::ConstantOp>(
+          sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType));
+    }
+
+    // 2. Get the vector shape and in-bounds attributes
+    SmallVector<int64_t> vecShape;
+    SmallVector<bool> readInBounds;
+    SmallVector<bool> writeInBounds;
+    for (unsigned i = 0; i < sourceType.getRank(); ++i) {
+      if (!sourceType.isDynamicDim(i)) {
+        vecShape.push_back(sourceType.getDimSize(i));
+        // Source shape is statically known: Neither read nor write are
+        // out-of-bounds.
+        readInBounds.push_back(true);
+        writeInBounds.push_back(true);
+      } else if (!resultType.isDynamicDim(i)) {
+        // Source shape is not statically known, but result shape is.
+        // Vectorize with size of result shape. This may be larger than the
+        // source size.
+        vecShape.push_back(resultType.getDimSize(i));
+        // Read may be out-of-bounds because the result size could be larger
+        // than the source size.
+        readInBounds.push_back(false);
+        // Write will in-bounds provided that the corresponding write idx is 0.
+        // To keep this logic simple, conservatively mark as out-of-bounds.
+        writeInBounds.push_back(false);
+      } else {
+        // Neither source nor result dim of padOp is static. Cannot vectorize
+        // the copy.
+        // TODO: Add support for masking
+        return failure();
+      }
+    }
+    auto vecType = VectorType::get(vecShape, sourceType.getElementType());
+
+    // 3. Generate TransferReadOp.
+    SmallVector<Value> readIndices(
+        vecType.getRank(),
+        rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
+    auto read = rewriter.create<vector::TransferReadOp>(
+        sliceOp.getLoc(), vecType, sliceOp.getSource(), readIndices, padValue,
+        ArrayRef<bool>{readInBounds});
+
+    // 4. Generate TransferWriteOp.
+    auto writeIndices =
+        ofrToIndexValues(rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
+
+    // 5. Finalize
+    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
+        sliceOp, read, sliceOp.getDest(), writeIndices,
+        ArrayRef<bool>{writeInBounds});
+
+    return success();
+  }
+};
+
 /// Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
 /// ```
 /// %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -2710,13 +2772,18 @@ struct PadOpVectorizationWithInsertSlicePattern
   }
 };
 
+void mlir::linalg::populateInsertSliceVectorizationPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<InsertSliceVectorizePattern>(patterns.getContext());
+}
+
 void mlir::linalg::populatePadOpVectorizationPatterns(
     RewritePatternSet &patterns, PatternBenefit baseBenefit) {
   // TODO: The following pattern implements "decomposition" and
   // optional "vectorization". Seperate "decomposition" into a sepereate
   // pre-processing pattern group.
-  patterns.add<GenericPadOpVectorizationPattern>(patterns.getContext(),
-                                                 baseBenefit);
+  patterns.add<GeneralizePadOpPattern>(patterns.getContext(), baseBenefit);
+
   // Try these specialized patterns first before resorting to the generic one.
   patterns.add<PadOpVectorizationWithTransferReadPattern,
                PadOpVectorizationWithTransferWritePattern,
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index c481a723c5623c..4b5a66f8fb5b92 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -352,6 +352,15 @@ func.func @illegal_fill_tensor_with_memref_return
 
 // -----
 
+func.func @illegal_fill_value_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xf32>) -> tensor<2x2xf32>
+{
+  // expected-error @+1 {{expected op with scalar input}}
+  %0 = linalg.fill ins(%arg1 : tensor<2xf32>) outs(%arg0 : tensor<2x2xf32>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
 func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
   // expected-error @+1 {{inferred input/output operand #1 has shape's dimension #0 to be 4, but found 3}}
   linalg.matmul ins(%arg0, %arg1 : memref<2x4xf32>, memref<3x4xf32>)
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 189507d97d6dc2..6ee759fe30e3d6 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -939,16 +939,20 @@ module attributes {transform.with_named_sequence} {
 
 func.func private @make_vector() -> tensor<12x13xf32>
 
-// CHECK-LABEL: func @pad_and_insert_slice_dest
-//  CHECK-SAME:     %[[ARG0:.*]]: tensor<1x5x6xf32>
-// Check the insert slice is not rewritten if the padded result is used by the destination operand.
-//   CHECK-NOT:   tensor.pad
-//       CHECK:   %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
-//       CHECK:   %[[WRITE_1:.*]] = vector.transfer_write %{{.*}}, %[[EMPTY]]{{.*}} : vector<1x12x13xf32>, tensor<1x12x13xf32>
-//       CHECK:   %[[READ:.*]]  = vector.transfer_read %[[ARG0:.*]]{{.*}} : tensor<1x5x6xf32>, vector<1x5x6xf32>
-//       CHECK:   %[[WRITE_2:.*]] = vector.transfer_write %[[READ]], %[[WRITE_1]]{{.*}} : vector<1x5x6xf32>, tensor<1x12x13xf32>
-//       CHECK:   %[[T1:.*]] = call @make_vector() : () -> tensor<12x13xf32>
-//       CHECK:   tensor.insert_slice %[[T1]] into %[[WRITE_2]]
+// CHECK-LABEL:   func.func @pad_and_insert_slice_dest(
+// CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[CST:.*]] = arith.constant dense<5.000000e+00> : vector<1x12x13xf32>
+// CHECK:           %[[C0_IDX:.*]] = arith.constant 0 : index
+// CHECK:           %[[PAD_VAL:.*]] = arith.constant 5.000000e+00 : f32
+// CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
+// CHECK:           %[[WRITE_1:.*]] = vector.transfer_write %[[CST]], %[[EMPTY]]{{\[}}%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]] {in_bounds = [true, true, true]} : vector<1x12x13xf32>, tensor<1x12x13xf32>
+// CHECK:           %[[READ_1:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]], %[[PAD_VAL]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
+// CHECK:           %[[WRITE_2:.*]] = vector.transfer_write %[[READ_1]], %[[WRITE_1]]{{\[}}%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
+// CHECK:           %[[MAKE_VEC:.*]] = call @make_vector() : () -> tensor<12x13xf32>
+// CHECK:           %[[READ_2:.*]] = vector.transfer_read %[[MAKE_VEC]]{{\[}}%[[C0_IDX]], %[[C0_IDX]]], %[[C0]] {in_bounds = [true, true]} : tensor<12x13xf32>, vector<12x13xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ_2]], %[[WRITE_2]]{{\[}}%[[C0_IDX]], %[[C0_IDX]], %[[C0_IDX]]] {in_bounds = [true, true]} : vector<12x13xf32>, tensor<1x12x13xf32>
+// CHECK:           return %[[RES]] : tensor<1x12x13xf32>
 func.func @pad_and_insert_slice_dest(
     %arg0: tensor<1x5x6xf32>) -> tensor<1x12x13xf32> {
   %c5 = arith.constant 5.0 : f32
@@ -1924,3 +1928,31 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// tensor.insert_slice
+///----------------------------------------------------------------------------------------
+
+// CHECK-LABEL: func @insert_slice
+// CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x2x3xf32>,
+// CHECK-SAME:      %[[ARG_1:.*]]: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x2x3xf32>
+// CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
+// CHECK:           return %[[WRITE]] : tensor<9x8x7x1x2x3xf32>
+func.func @insert_slice(%arg0: tensor<1x2x3xf32>, %arg1: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
+  %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, 3][1, 1, 1, 1, 1, 1] : tensor<1x2x3xf32> into tensor<9x8x7x1x2x3xf32>
+  return %0 : tensor<9x8x7x1x2x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

>From a8406b3d35dbb7aefdebe99c8c06a4745e305a66 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sun, 27 Oct 2024 17:09:23 +0000
Subject: [PATCH 2/2] fixup! [mlir][linalg] Split
 GenericPadOpVectorizationPattern into two patterns

* Incorporate suggestions from Hanhan
* Add a negative test to document when vectorization of
  tensor.insert_slice might fail
* Update `@pad_and_insert_slice_dest` that was added in #112504 (this
  change means that _all_ qualifying `tensor.insert_slice` Ops are
  vectorized).
* Added more tests to demonstrate other cases (e.g. default vs
  non-default pad value).
---
 .../TransformOps/LinalgTransformOps.cpp       |  1 +
 .../Linalg/Transforms/Vectorization.cpp       | 50 +++++--------
 .../Linalg/vectorization-pad-patterns.mlir    | 11 +--
 .../Linalg/vectorization-unsupported.mlir     | 22 ++++++
 .../Linalg/vectorization-with-patterns.mlir   | 71 +++++++++++++++++--
 5 files changed, 115 insertions(+), 40 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7ce4860c44dc3d..9c0ab4f41b855a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -256,6 +256,7 @@ void transform::ApplyFoldAddIntoDestPatternsOp::populatePatterns(
 void transform::ApplyPadVectorizationPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   linalg::populatePadOpVectorizationPatterns(patterns);
+  linalg::populateInsertSliceVectorizationPatterns(patterns);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cd4b46d8412d80..090e0b46768d7e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2514,35 +2514,18 @@ struct PadOpVectorizationWithTransferWritePattern
   }
 };
 
-/// Given an ArrayRef of OpFoldResults, return a vector of Values.
-/// IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
-/// not supported.
-static SmallVector<Value> ofrToIndexValues(RewriterBase &rewriter, Location loc,
-                                           ArrayRef<OpFoldResult> ofrs) {
-  SmallVector<Value> result;
-  for (auto o : ofrs) {
-    if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
-      result.push_back(val);
-    } else {
-      result.push_back(rewriter.create<arith::ConstantIndexOp>(
-          loc, cast<IntegerAttr>(cast<Attribute>(o)).getInt()));
-    }
-  }
-  return result;
-}
-
 /// Returns the effective Pad value for the input op, provided it's a scalar.
 ///
 /// Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
 /// this Op performs padding, retrieve the padding value provided that it's
 /// a scalar and static/fixed for all the padded values. Returns an empty value
 /// otherwise.
-static Value getStaticPadVl(Operation *op) {
+static Value getStaticPadVal(Operation *op) {
   if (!op)
     return {};
 
-  // 1. vector.broadcast - return the value that's being broadcast,
-  // provided that it's a scalar.
+  // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
+  // being broadcast, provided that it's a scalar.
   if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
     auto source = bcast.getSource();
     if (llvm::dyn_cast<VectorType>(source.getType()))
@@ -2551,31 +2534,31 @@ static Value getStaticPadVl(Operation *op) {
     return source;
   }
 
-  // 1. linalg.fill - use the scalar input value that used to fill the output
+  // 2. linalg.fill - use the scalar input value that used to fill the output
   // tensor.
   if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
     return fill.getInputs()[0];
   }
 
-  // 2. tensor.generateOp - can't guarantee the value is fixed without
+  // 3. tensor.generateOp - can't guarantee the value is fixed without
   // analysing, bail out.
   if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
     return {};
   }
 
-  // 3. vector.transfer_write - inspect the input vector that's written from. If
+  // 4. vector.transfer_write - inspect the input vector that's written from. If
   // if contains a single value that has been broadcast (e.g. via
   // vector.broadcast), extract it, fail otherwise.
   if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
-    return getStaticPadVl(xferWrite.getVector().getDefiningOp());
+    return getStaticPadVal(xferWrite.getVector().getDefiningOp());
 
-  // 4. tensor.insert_slice - inspect the destination tensor. If it's larger
+  // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
   // than the input tensor, then, provided it's constant, we'll extract the
   // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
   // TODO: Clarify the semantics when the input tensor is larger than the
   // destination.
   if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
-    return getStaticPadVl(slice.getDest().getDefiningOp());
+    return getStaticPadVal(slice.getDest().getDefiningOp());
 
   return {};
 }
@@ -2619,7 +2602,7 @@ struct InsertSliceVectorizePattern
     //     remains a TODO.
     //
     // When the value is not known and not needed, use 0. Otherwise, bail out.
-    Value padValue = getStaticPadVl(sliceOp);
+    Value padValue = getStaticPadVal(sliceOp);
     bool isOutOfBoundsRead = !sourceType.hasStaticShape();
 
     if (!padValue && isOutOfBoundsRead) {
@@ -2637,6 +2620,7 @@ struct InsertSliceVectorizePattern
     SmallVector<int64_t> vecShape;
     SmallVector<bool> readInBounds;
     SmallVector<bool> writeInBounds;
+    size_t rankDiff = resultType.getRank() - sourceType.getRank();
     for (unsigned i = 0; i < sourceType.getRank(); ++i) {
       if (!sourceType.isDynamicDim(i)) {
         vecShape.push_back(sourceType.getDimSize(i));
@@ -2648,7 +2632,9 @@ struct InsertSliceVectorizePattern
         // Source shape is not statically known, but result shape is.
         // Vectorize with size of result shape. This may be larger than the
         // source size.
-        vecShape.push_back(resultType.getDimSize(i));
+        // FIXME: Using rankDiff implies that the source tensor is inserted at
+        // the end of the destination tensor. However, that's not required.
+        vecShape.push_back(resultType.getDimSize(rankDiff + i));
         // Read may be out-of-bounds because the result size could be larger
         // than the source size.
         readInBounds.push_back(false);
@@ -2673,8 +2659,8 @@ struct InsertSliceVectorizePattern
         ArrayRef<bool>{readInBounds});
 
     // 4. Generate TransferWriteOp.
-    auto writeIndices =
-        ofrToIndexValues(rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
+    auto writeIndices = getValueOrCreateConstantIndexOp(
+        rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
 
     // 5. Finalize
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
@@ -2761,8 +2747,8 @@ struct PadOpVectorizationWithInsertSlicePattern
     // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
     // specified offsets. Write is fully in-bounds because a InsertSliceOp's
     // source must fit into the destination at the specified offsets.
-    auto writeIndices =
-        ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
+    auto writeIndices = getValueOrCreateConstantIndexOp(
+        rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
     SmallVector<bool> inBounds(vecRank, true);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
         insertOp, read, insertOp.getDest(), writeIndices,
diff --git a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
index 2aa4638af3f0f3..640de85cc5f12e 100644
--- a/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-pad-patterns.mlir
@@ -161,7 +161,8 @@ module attributes {transform.with_named_sequence} {
 
 ///----------------------------------------------------------------------------------------
 /// tensor::PadOp -> tensor::EmptyOp + linalg::FillOp/tensor::GenerateOp + tensor::InsertSliceOp
-/// [Pattern: GenericPadOpVectorizationPattern]
+/// [Pattern: GenericPadOpVectorizationPattern + InsertSliceVectorizePattern]
+/// TODO: Split the test into two, one for each pattern.
 ///----------------------------------------------------------------------------------------
 
 func.func private @make_vector() -> tensor<12x13xf32>
@@ -174,12 +175,14 @@ func.func private @make_vector() -> tensor<12x13xf32>
 //  CHECK-NOT:     tensor.pad
 //  CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 //  CHECK-DAG:     %[[PAD:.*]] = arith.constant 5.000000e+00 : f32
+//  CHECK-DAG:     %[[PAD_READ:.*]] = arith.constant 0.000000e+00 : f32
 //      CHECK:     %[[EMPTY:.*]] = tensor.empty() : tensor<1x12x13xf32>
 //      CHECK:     %[[FILL:.*]] = linalg.fill ins(%[[PAD]] : f32) outs(%[[EMPTY]] : tensor<1x12x13xf32>) -> tensor<1x12x13xf32>
-//      CHECK:     %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
-//      CHECK:     %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
+//      CHECK:     %[[READ_1:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x5x6xf32>, vector<1x5x6xf32>
+//      CHECK:     %[[WRITE_1:.*]] = vector.transfer_write %[[READ_1]], %[[FILL]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x5x6xf32>, tensor<1x12x13xf32>
 //      CHECK:     %[[VEC:.*]] = call @make_vector() : () -> tensor<12x13xf32>
-//      CHECK:     %[[RES:.*]] = tensor.insert_slice %[[VEC]] into %[[WRITE]][0, 0, 0] [1, 12, 13] [1, 1, 1] : tensor<12x13xf32> into tensor<1x12x13xf32>
+//      CHECK:     %[[READ_2:.*]] = vector.transfer_read %[[VEC]]{{\[}}%[[C0]], %[[C0]]], %[[PAD_READ]] {in_bounds = [true, true]} : tensor<12x13xf32>, vector<12x13xf32>
+//      CHECK:     %[[RES:.*]] = vector.transfer_write %[[READ_2]], %[[WRITE_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<12x13xf32>, tensor<1x12x13xf32>
 //      CHECK:     return %[[RES]] : tensor<1x12x13xf32>
 
 func.func @pad_and_insert_slice_dest(
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index e9f8e08ca0c6b4..843751ba98525b 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -253,3 +253,25 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// With dynamically shaped source, the vectorizer infers the vector size for
+// xfer Ops from the destination tensor and, conservatively, assumes
+// out-of-bounds accesses. Out-of-bounds accesses require a pad value, but
+// that's impossible to recover in this example. Hence the vectorization fails.
+
+func.func @insert_slice_default_pad(%arg0: tensor<1x?x3xf32>, %arg1: tensor<9x8x7x1x2x3xf32>, %size: index) -> tensor<9x8x7x1x2x3xf32> {
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %res = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, %size, 3][1, 1, 1, 1, 1, 1] : tensor<1x?x3xf32> into tensor<9x8x7x1x2x3xf32>
+  return %res : tensor<9x8x7x1x2x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 6ee759fe30e3d6..93fe208a9e0c0f 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -1935,7 +1935,9 @@ module attributes {transform.with_named_sequence} {
 /// tensor.insert_slice
 ///----------------------------------------------------------------------------------------
 
-// CHECK-LABEL: func @insert_slice
+// The pad value for xfer-read is neither needed nor available - use the default (0.0).
+
+// CHECK-LABEL: func @insert_static_slice_default_pad
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x2x3xf32>,
 // CHECK-SAME:      %[[ARG_1:.*]]: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
 // CHECK:           %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
@@ -1943,9 +1945,70 @@ module attributes {transform.with_named_sequence} {
 // CHECK:           %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG_1]]{{\[}}%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
 // CHECK:           return %[[WRITE]] : tensor<9x8x7x1x2x3xf32>
-func.func @insert_slice(%arg0: tensor<1x2x3xf32>, %arg1: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
-  %0 = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, 3][1, 1, 1, 1, 1, 1] : tensor<1x2x3xf32> into tensor<9x8x7x1x2x3xf32>
-  return %0 : tensor<9x8x7x1x2x3xf32>
+func.func @insert_static_slice_default_pad(%arg0: tensor<1x2x3xf32>, %arg1: tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32> {
+  %res = tensor.insert_slice %arg0 into %arg1[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, 3][1, 1, 1, 1, 1, 1] : tensor<1x2x3xf32> into tensor<9x8x7x1x2x3xf32>
+  return %res : tensor<9x8x7x1x2x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Same as above, but there's a pad value available that should be used instead of the default value.
+
+// CHECK-LABEL:   func.func @insert_static_slice_non_zero_pad
+// CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x2x3xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
+// CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
+// CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7x1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{.*}}, %[[PAD]] {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x2x3xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ]], %[[WRITE]]{{.*}} {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
+// CHECK:           return %[[RES]] : tensor<9x8x7x1x2x3xf32>
+func.func @insert_static_slice_non_zero_pad(%arg0: tensor<1x2x3xf32>, %pad : f32) -> tensor<9x8x7x1x2x3xf32> {
+  %init = tensor.empty() : tensor<9x8x7x1x2x3xf32>
+  %fill = linalg.fill ins(%pad : f32) outs(%init : tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32>
+  %res = tensor.insert_slice %arg0 into %fill[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, 3][1, 1, 1, 1, 1, 1] : tensor<1x2x3xf32> into tensor<9x8x7x1x2x3xf32>
+  return %res : tensor<9x8x7x1x2x3xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// Same as above, but the source type has is dynamically shaped. This means
+// that the pad value is now required and the vector dim corresponding to the
+// dynamic shape has to be inferred from the shape of the destination tensor.
+
+// CHECK-LABEL:   func.func @insert_dynamic_slice_non_zero_pad(
+// CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
+// CHECK-SAME:      %[[PAD:.*]]: f32,
+// CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
+// CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
+// CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7x1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read %[[ARG_0]]{{.*}}, %[[PAD]] {in_bounds = [true, false, true]} : tensor<1x?x3xf32>, vector<1x2x3xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[READ]], %[[WRITE]]{{.*}} {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
+// CHECK:           return %[[RES]] : tensor<9x8x7x1x2x3xf32>
+func.func @insert_dynamic_slice_non_zero_pad(%arg0: tensor<1x?x3xf32>, %pad : f32, %size: index) -> tensor<9x8x7x1x2x3xf32> {
+  %init = tensor.empty() : tensor<9x8x7x1x2x3xf32>
+  %fill = linalg.fill ins(%pad : f32) outs(%init : tensor<9x8x7x1x2x3xf32>) -> tensor<9x8x7x1x2x3xf32>
+  %res = tensor.insert_slice %arg0 into %fill[0, 0, 0, 0, 0, 0] [1, 1, 1, 1, %size, 3][1, 1, 1, 1, 1, 1] : tensor<1x?x3xf32> into tensor<9x8x7x1x2x3xf32>
+  return %res : tensor<9x8x7x1x2x3xf32>
 }
 
 module attributes {transform.with_named_sequence} {



More information about the Mlir-commits mailing list