[Mlir-commits] [mlir] [mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) (PR #138725)

James Newling llvmlistbot at llvm.org
Tue May 6 10:29:48 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/138725

Extends the set of vector operations that we can linearize to include vector.insert_strided_slice. The new pattern reuses the ideas from vector.extract_strided_slice linearization.  

>From d708a63f43cd08a9ba21fdcfacb9a26dd93734b0 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 10:12:13 -0700
Subject: [PATCH 1/3] linearize with shuffle

---
 .../Vector/Transforms/VectorLinearize.cpp     | 278 ++++++++++++------
 mlir/test/Dialect/Vector/linearize.mlir       |  55 +++-
 2 files changed, 241 insertions(+), 92 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..58e00e357f6cb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,17 +109,103 @@ struct LinearizeVectorizable final
   }
 };
 
-/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
-/// on a linearized vector.
-/// Following,
+template <typename TOp>
+static bool stridesAllOne(TOp op) {
+  static_assert(
+      std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
+          std::is_same_v<TOp, vector::InsertStridedSliceOp>,
+      "expected vector.extract_strided_slice or vector.insert_strided_slice");
+  ArrayAttr strides = op.getStrides();
+  return llvm::all_of(
+      strides, [](auto stride) { return isConstantIntValue(stride, 1); });
+}
+
+/// Convert an array of attributes into a vector of integers, if possible.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+  if (!attrs)
+    return failure();
+  SmallVector<int64_t> ints;
+  ints.reserve(attrs.size());
+  for (auto attr : attrs) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+      ints.push_back(intAttr.getInt());
+    } else {
+      return failure();
+    }
+  }
+  return ints;
+}
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+SmallVector<int64_t> static getFlattenedStridedSliceIndices(
+    ArrayRef<int64_t> small, ArrayRef<int64_t> large,
+    ArrayRef<int64_t> offsets) {
+
+  // Example of alignment between, `large`, `small` and `offsets`:
+  //    large  =  4, 5, 6, 7, 8
+  //    small  =     1, 6, 7, 8
+  //  offsets  =  2, 3, 0
+  //
+  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+  assert(large.size() >= small.size());
+  assert(large.size() >= offsets.size());
+  unsigned delta = large.size() - small.size();
+  unsigned nOffsets = offsets.size();
+  auto getSmall = [&](int64_t i) { return i >= delta ? small[i - delta] : 1; };
+  auto getOffset = [&](int64_t i) { return i < nOffsets ? offsets[i] : 0; };
+
+  // Using 2 vectors of indices, at each iteration populate the updated set of
+  // indices based on the old set of indices, and the size of the small vector
+  // in the current iteration.
+  SmallVector<int64_t> indices{0};
+  SmallVector<int64_t> nextIndices;
+  int64_t stride = 1;
+  for (int i = large.size() - 1; i >= 0; --i) {
+    auto currentSize = indices.size();
+    auto smallSize = getSmall(i);
+    auto nextSize = currentSize * smallSize;
+    nextIndices.resize(nextSize);
+    int64_t *base = nextIndices.begin();
+    int64_t offset = getOffset(i) * stride;
+    for (int j = 0; j < smallSize; ++j) {
+      for (uint64_t k = 0; k < currentSize; ++k) {
+        base[k] = indices[k] + offset;
+      }
+      offset += stride;
+      base += currentSize;
+    }
+    stride *= large[i];
+    std::swap(indices, nextIndices);
+    nextIndices.clear();
+  }
+  return indices;
+}
+
+/// This pattern converts a vector.extract_strided_slice operation into a
+/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
+///
+/// For example, the following:
+///
+/// ```
 ///   vector.extract_strided_slice %source
 ///         { offsets = [..], strides = [..], sizes = [..] }
+/// ```
+///
 /// is converted to :
+/// ```
 ///   %source_1d = vector.shape_cast %source
-///   %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
-///   %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the offsets and sizes of the
-/// extraction.
+///   %out_1d    = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
+///   %out_nd    = vector.shape_cast %out_1d
+/// ```
+///
+/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
+/// vector.extract_strided_slice operation.
 struct LinearizeVectorExtractStridedSlice final
     : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -129,88 +215,109 @@ struct LinearizeVectorExtractStridedSlice final
       : OpConversionPattern(typeConverter, context, benefit) {}
 
   LogicalResult
-  matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
+  matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
+                  OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    VectorType dstType =
-        getTypeConverter()->convertType<VectorType>(extractOp.getType());
-    assert(dstType && "vector type destination expected.");
-    if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
-      return rewriter.notifyMatchFailure(extractOp,
-                                         "scalable vectors are not supported.");
 
-    ArrayAttr offsets = extractOp.getOffsets();
-    ArrayAttr sizes = extractOp.getSizes();
-    ArrayAttr strides = extractOp.getStrides();
-    if (!isConstantIntValue(strides[0], 1))
-      return rewriter.notifyMatchFailure(
-          extractOp, "Strided slice with stride != 1 is not supported.");
-    Value srcVector = adaptor.getVector();
-    // If kD offsets are specified for nD source vector (n > k), the granularity
-    // of the extraction is greater than 1. In this case last (n-k) dimensions
-    // form the extraction granularity.
-    // Example :
-    //  vector.extract_strided_slice %src {
-    //      offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
-    //      vector<4x8x8xf32> to vector<2x2x8xf32>
-    // Here, extraction granularity is 8.
-    int64_t extractGranularitySize = 1;
-    int64_t nD = extractOp.getSourceVectorType().getRank();
-    int64_t kD = (int64_t)offsets.size();
-    int64_t k = kD;
-    while (k < nD) {
-      extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
-      ++k;
+    VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
+        extractStridedSliceOp.getType());
+    assert(flatOutputType && "vector type expected");
+
+    if (!stridesAllOne(extractStridedSliceOp)) {
+      return rewriter.notifyMatchFailure(extractStridedSliceOp,
+                                         "strides other than 1 not supported");
     }
-    // Get total number of extracted slices.
-    int64_t nExtractedSlices = 1;
-    for (Attribute size : sizes) {
-      nExtractedSlices *= cast<IntegerAttr>(size).getInt();
+
+    ArrayRef<int64_t> inputShape =
+        extractStridedSliceOp.getSourceVectorType().getShape();
+
+    ArrayRef<int64_t> outputType = extractStridedSliceOp.getType().getShape();
+
+    auto maybeIntOffsets =
+        intsFromArrayAttr(extractStridedSliceOp.getOffsets());
+    if (failed(maybeIntOffsets)) {
+      return rewriter.notifyMatchFailure(extractStridedSliceOp,
+                                         "failed to get integer offsets");
     }
-    // Compute the strides of the source vector considering first k dimensions.
-    llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
-    for (int i = kD - 2; i >= 0; --i) {
-      sourceStrides[i] = sourceStrides[i + 1] *
-                         extractOp.getSourceVectorType().getShape()[i + 1];
+
+    SmallVector<int64_t> indices = getFlattenedStridedSliceIndices(
+        outputType, inputShape, maybeIntOffsets.value());
+
+    Value srcVector = adaptor.getVector();
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
+        extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
+    return success();
+  }
+};
+
+/// This pattern converts a vector.insert_strided_slice operation into a
+/// vector.shuffle operation that has rank-1 (linearized) operands and result.
+///
+/// For example, the following:
+/// ```
+///  %0 = vector.insert_strided_slice %to_store, %into
+///             {offsets = [1, 0, 0, 0], strides = [1, 1]}
+///                  : vector<2x2xi8> into vector<2x1x3x2xi8>
+/// ```
+///
+/// is converted to
+/// ```
+///  %to_store_1d
+///           = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
+///  %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
+///  %out_1d  = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
+///  %out_nd  = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
+/// ```
+///
+/// where shuffle_indices_1d in this case is
+///     [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
+///                        ^^^^^^^^^^^^^^
+///                          to_store_1d
+///
+struct LinearizeVectorInsertStridedSlice final
+    : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
+                                    MLIRContext *context,
+                                    PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!stridesAllOne(insertStridedSliceOp)) {
+      return rewriter.notifyMatchFailure(insertStridedSliceOp,
+                                         "strides other than 1 not supported");
     }
-    // Final shuffle indices has nExtractedSlices * extractGranularitySize
-    // elements.
-    llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
-                                          extractGranularitySize);
-    // Compute the strides of the extracted kD vector.
-    llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
-    // Compute extractedStrides.
-    for (int i = kD - 2; i >= 0; --i) {
-      extractedStrides[i] =
-          extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
+
+    VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
+    ArrayRef<int64_t> inputShape = inputType.getShape();
+
+    VectorType outputType = insertStridedSliceOp.getType();
+    ArrayRef<int64_t> outputShape = outputType.getShape();
+    int64_t nOutputElements = outputType.getNumElements();
+
+    auto maybeIntOffsets = intsFromArrayAttr(insertStridedSliceOp.getOffsets());
+    if (failed(maybeIntOffsets)) {
+      return rewriter.notifyMatchFailure(insertStridedSliceOp,
+                                         "failed to get integer offsets");
     }
-    // Iterate over all extracted slices from 0 to nExtractedSlices - 1
-    // and compute the multi-dimensional index and the corresponding linearized
-    // index within the source vector.
-    for (int64_t i = 0; i < nExtractedSlices; ++i) {
-      int64_t index = i;
-      // Compute the corresponding multi-dimensional index.
-      llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
-      for (int64_t j = 0; j < kD; ++j) {
-        multiDimIndex[j] = (index / extractedStrides[j]);
-        index -= multiDimIndex[j] * extractedStrides[j];
-      }
-      // Compute the corresponding linearized index in the source vector
-      // i.e. shift the multiDimIndex by the offsets.
-      int64_t linearizedIndex = 0;
-      for (int64_t j = 0; j < kD; ++j) {
-        linearizedIndex +=
-            (cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
-            sourceStrides[j];
-      }
-      // Fill the indices array form linearizedIndex to linearizedIndex +
-      // extractGranularitySize.
-      for (int64_t j = 0; j < extractGranularitySize; ++j) {
-        indices[i * extractGranularitySize + j] = linearizedIndex + j;
-      }
+    SmallVector<int64_t> sliceIndices = getFlattenedStridedSliceIndices(
+        inputShape, outputShape, maybeIntOffsets.value());
+
+    SmallVector<int64_t> indices(nOutputElements, 0);
+    std::iota(indices.begin(), indices.end(), 0);
+    for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
+      indices[sliceIndex] = index + nOutputElements;
     }
-    // Perform a shuffle to extract the kD vector.
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractOp, dstType, srcVector, srcVector, indices);
+
+    Value flatToStore = adaptor.getValueToStore();
+    Value flatDest = adaptor.getDest();
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
+                                                   flatDest.getType(), flatDest,
+                                                   flatToStore, indices);
     return success();
   }
 };
@@ -296,7 +403,7 @@ struct LinearizeVectorExtract final
     // Skip if result is not a vector type
     if (!isa<VectorType>(extractOp.getType()))
       return rewriter.notifyMatchFailure(extractOp,
-                                         "scalar extract is not supported.");
+                                         "scalar extract not supported");
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
     assert(dstTy && "expected 1-D vector type");
 
@@ -539,6 +646,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
     const TypeConverter &typeConverter, const ConversionTarget &target,
     RewritePatternSet &patterns) {
   patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
-               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
-      typeConverter, patterns.getContext());
+               LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
+               LinearizeVectorInsertStridedSlice>(typeConverter,
+                                                  patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..5e9877cae4b18 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -131,9 +131,9 @@ func.func @test_0d_vector() -> vector<f32> {
 
 // -----
 
-// CHECK-LABEL: test_extract_strided_slice_1
+// CHECK-LABEL: test_extract_strided_slice_2D
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
-func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
+func.func @test_extract_strided_slice_2D(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
 
   // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
   // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
@@ -147,13 +147,13 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
 
 // -----
 
-// CHECK-LABEL:   func.func @test_extract_strided_slice_1_scalable(
+// CHECK-LABEL:   func.func @test_extract_strided_slice_2D_scalable(
 // CHECK-SAME:    %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_2D_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
 
   // CHECK-NOT: vector.shuffle
   // CHECK-NOT: vector.shape_cast
-  // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
+  // CHECK: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] 
   %0 = vector.extract_strided_slice %arg0 { sizes = [2, 8], strides = [1, 1], offsets = [1, 0] } : vector<4x[8]xf32> to vector<2x[8]xf32>
 
   // CHECK: return %[[RES]] : vector<2x[8]xf32>
@@ -162,9 +162,9 @@ func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> ve
 
 // -----
 
-// CHECK-LABEL: test_extract_strided_slice_2
+// CHECK-LABEL: test_extract_strided_slice_3D
 // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
-func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
+func.func @test_extract_strided_slice_3D(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
 
   // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
   // CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
@@ -176,6 +176,46 @@ func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4
   return %0 : vector<1x4x2xf32>
 }
 
+// -----
+
+// Test of insert_strided_slice -> shuffle.
+// This is a contiguous insertion of 4 elements at offset 6 into a vector of 12 elements. 
+// CHECK-LABEL: insert_strided_slice_2D_into_4D
+func.func @insert_strided_slice_2D_into_4D(%arg0 : vector<2x2xi8>, %arg1 : vector<2x1x3x2xi8>) -> vector<2x1x3x2xi8> {
+
+//   CHECK-DAG:    %[[ARG0:.*]] = vector.shape_cast {{.*}}  to vector<4xi8>
+//   CHECK-DAG:    %[[ARG1:.*]] = vector.shape_cast {{.*}}  to vector<12xi8>
+//       CHECK:    vector.shuffle %[[ARG1]], %[[ARG0]]
+//  CHECK-SAME:      [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11] : vector<12xi8>, vector<4xi8>
+  %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [1, 0, 0, 0], strides = [1, 1]} : vector<2x2xi8> into vector<2x1x3x2xi8>
+
+//       CHECK:    %[[RES:.*]] = vector.shape_cast {{.*}} to vector<2x1x3x2xi8>
+//       CHECK:    return %[[RES]] : vector<2x1x3x2xi8>
+  return %0 : vector<2x1x3x2xi8>
+}
+
+// -----
+
+// Test of insert_strided_slice -> shuffle. 
+// [[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]], [[12, 13], [14, 15]], [[16, 17]]]
+//                                         ^         ^
+//                                         |         |
+//                          where the 2 elements are inserted into the 3x3x2 vector
+// CHECK-LABEL: insert_strided_slice_3D
+func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg2 : vector<3x3x2xi8>) -> vector<3x3x2xi8> {
+
+//   CHECK-DAG:     %[[ARG0:.*]] = vector.shape_cast {{.*}}  to vector<2xi8>
+//   CHECK-DAG:     %[[ARG1:.*]] = vector.shape_cast {{.*}}  to vector<18xi8>
+//       CHECK:     vector.shuffle %[[ARG1]], %[[ARG0]]
+//  CHECK-SAME:       [0, 1, 2, 3, 4, 5, 6, 7, 8, 18, 10, 19, 12, 13, 14, 15, 16, 17] : vector<18xi8>, vector<2xi8>
+  %0 = vector.insert_strided_slice %arg0, %arg2 {offsets = [1, 1, 1], sizes = [1, 2, 1], strides = [1, 1, 1]} : vector<1x2x1xi8> into vector<3x3x2xi8>
+
+//       CHECK:     %[[RES:.*]] = vector.shape_cast {{.*}} to vector<3x3x2xi8>
+//       CHECK:     return %[[RES]] : vector<3x3x2xi8>
+  return %0 : vector<3x3x2xi8>
+}
+
+
 // -----
 
 // CHECK-LABEL: test_vector_shuffle
@@ -345,3 +385,4 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
 }
+

>From c578d7b0355e973e9c7b3827a73de35641f7f460 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 10:23:11 -0700
Subject: [PATCH 2/3] name improvement

---
 .../Vector/Transforms/VectorLinearize.cpp     | 23 ++++++++++---------
 1 file changed, 12 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 58e00e357f6cb..6daaeea66526a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -228,20 +228,20 @@ struct LinearizeVectorExtractStridedSlice final
                                          "strides other than 1 not supported");
     }
 
-    ArrayRef<int64_t> inputShape =
-        extractStridedSliceOp.getSourceVectorType().getShape();
-
-    ArrayRef<int64_t> outputType = extractStridedSliceOp.getType().getShape();
-
-    auto maybeIntOffsets =
+    FailureOr<SmallVector<int64_t>> offsets =
         intsFromArrayAttr(extractStridedSliceOp.getOffsets());
-    if (failed(maybeIntOffsets)) {
+    if (failed(offsets)) {
       return rewriter.notifyMatchFailure(extractStridedSliceOp,
                                          "failed to get integer offsets");
     }
 
+    ArrayRef<int64_t> inputShape =
+        extractStridedSliceOp.getSourceVectorType().getShape();
+
+    ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
+
     SmallVector<int64_t> indices = getFlattenedStridedSliceIndices(
-        outputType, inputShape, maybeIntOffsets.value());
+        outputShape, inputShape, offsets.value());
 
     Value srcVector = adaptor.getVector();
     rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
@@ -299,13 +299,14 @@ struct LinearizeVectorInsertStridedSlice final
     ArrayRef<int64_t> outputShape = outputType.getShape();
     int64_t nOutputElements = outputType.getNumElements();
 
-    auto maybeIntOffsets = intsFromArrayAttr(insertStridedSliceOp.getOffsets());
-    if (failed(maybeIntOffsets)) {
+    FailureOr<SmallVector<int64_t>> offsets =
+        intsFromArrayAttr(insertStridedSliceOp.getOffsets());
+    if (failed(offsets)) {
       return rewriter.notifyMatchFailure(insertStridedSliceOp,
                                          "failed to get integer offsets");
     }
     SmallVector<int64_t> sliceIndices = getFlattenedStridedSliceIndices(
-        inputShape, outputShape, maybeIntOffsets.value());
+        inputShape, outputShape, offsets.value());
 
     SmallVector<int64_t> indices(nOutputElements, 0);
     std::iota(indices.begin(), indices.end(), 0);

>From cd894d8c1e547a42b37ee402a8647cb9db46180b Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 6 May 2025 10:29:18 -0700
Subject: [PATCH 3/3] scalable blacklist

---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 4 ++--
 mlir/test/Dialect/Vector/linearize.mlir                | 1 -
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 6daaeea66526a..8ffb3b0cb2c42 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -561,8 +561,8 @@ struct LinearizeVectorSplat final
 static bool isNotLinearizableBecauseScalable(Operation *op) {
 
   bool unsupported =
-      isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
-          op);
+      isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
+          vector::ExtractOp, vector::InsertOp>(op);
   if (!unsupported)
     return false;
 
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 5e9877cae4b18..508bce689b14e 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -215,7 +215,6 @@ func.func @insert_strided_slice_3D(%arg0 : vector<1x2x1xi8>, %arg2 : vector<3x3x
   return %0 : vector<3x3x2xi8>
 }
 
-
 // -----
 
 // CHECK-LABEL: test_vector_shuffle



More information about the Mlir-commits mailing list