[Mlir-commits] [mlir] bc9fce7 - [MLIR][Vector] Add a pattern that folds consecutive extract_strided_strided_slice ops (#175738)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 14 09:00:20 PST 2026


Author: Adam Paszke
Date: 2026-01-14T09:00:13-08:00
New Revision: bc9fce7e8b734c1a2a491c23ca247e4411b561f5

URL: https://github.com/llvm/llvm-project/commit/bc9fce7e8b734c1a2a491c23ca247e4411b561f5
DIFF: https://github.com/llvm/llvm-project/commit/bc9fce7e8b734c1a2a491c23ca247e4411b561f5.diff

LOG: [MLIR][Vector] Add a pattern that folds consecutive extract_strided_strided_slice ops (#175738)

A slice of a slice is just a slice.

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b9ea336051de6..085f879c2d0e6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4366,6 +4366,68 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
+// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
+//
+// Example:
+//
+// %0 = vector.extract_strided_slice %arg0
+//        {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
+//          : vector<4x8x16xf32> to vector<3x4x16xf32>
+// %1 = vector.extract_strided_slice %0
+//        {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
+//          : vector<3x4x16xf32> to vector<2x2x16xf32>
+//
+// to
+//
+// %1 = vector.extract_strided_slice %arg0
+//        {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
+//          : vector<4x8x16xf32> to vector<2x2x16xf32>
+class StridedSliceFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
+                                PatternRewriter &rewriter) const override {
+    auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
+    if (!firstOp)
+      return failure();
+
+    if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
+      return failure();
+
+    SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
+    SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
+    SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
+    SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
+
+    unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
+    SmallVector<int64_t> combinedOffsets(newRank, 0);
+    SmallVector<int64_t> combinedSizes(newRank);
+    ArrayRef<int64_t> firstSourceShape =
+        firstOp.getSourceVectorType().getShape();
+    for (unsigned i = 0; i < newRank; ++i) {
+      int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
+      int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
+      combinedOffsets[i] = off1 + off2;
+
+      if (i < secondSizes.size()) {
+        combinedSizes[i] = secondSizes[i];
+      } else if (i < firstSizes.size()) {
+        combinedSizes[i] = firstSizes[i];
+      } else {
+        combinedSizes[i] = firstSourceShape[i];
+      }
+    }
+
+    SmallVector<int64_t> combinedStrides(newRank, 1);
+    rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
+        secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
+        combinedStrides);
+    return success();
+  }
+};
+
 // Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
 // CreateMaskOp.
 //
@@ -4641,9 +4703,10 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
-              StridedSliceBroadcast, StridedSliceSplat,
-              ContiguousExtractStridedSliceToExtract>(context);
+  results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
+              StridedSliceConstantMaskFolder, StridedSliceBroadcast,
+              StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e17b1cfbe5e0d..a30eda1e06cf8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -4027,3 +4027,45 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
   %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
   return %v_1 : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: extract_strided_slice_of_extract_strided_slice
+//  CHECK-SAME: %[[ARG0:.*]]: vector<4x8x16xf32>
+//       CHECK: %[[RESULT:.*]] = vector.extract_strided_slice %[[ARG0]]
+//  CHECK-SAME: {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
+//  CHECK-SAME: : vector<4x8x16xf32> to vector<2x2x16xf32>
+//       CHECK: return %[[RESULT]]
+func.func @extract_strided_slice_of_extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]} : vector<4x8x16xf32> to vector<3x4x16xf32>
+  %1 = vector.extract_strided_slice %0 {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]} : vector<3x4x16xf32> to vector<2x2x16xf32>
+  return %1 : vector<2x2x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_slice_inner_shorter
+//  CHECK-SAME: %[[ARG0:.*]]: vector<4x8x16xf32>
+//       CHECK: %[[RESULT:.*]] = vector.extract_strided_slice %[[ARG0]]
+//  CHECK-SAME: {offsets = [1, 1, 2], sizes = [2, 2, 4], strides = [1, 1, 1]}
+//  CHECK-SAME: : vector<4x8x16xf32> to vector<2x2x4xf32>
+//       CHECK: return %[[RESULT]]
+func.func @extract_strided_slice_inner_shorter(%arg0: vector<4x8x16xf32>) -> vector<2x2x4xf32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [3], strides = [1]} : vector<4x8x16xf32> to vector<3x8x16xf32>
+  %1 = vector.extract_strided_slice %0 {offsets = [0, 1, 2], sizes = [2, 2, 4], strides = [1, 1, 1]} : vector<3x8x16xf32> to vector<2x2x4xf32>
+  return %1 : vector<2x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_slice_outer_shorter
+//  CHECK-SAME: %[[ARG0:.*]]: vector<4x8x16xf32>
+//       CHECK: %[[RESULT:.*]] = vector.extract_strided_slice %[[ARG0]]
+//  CHECK-SAME: {offsets = [1, 2], sizes = [2, 4], strides = [1, 1]}
+//  CHECK-SAME: : vector<4x8x16xf32> to vector<2x4x16xf32>
+//       CHECK: return %[[RESULT]]
+func.func @extract_strided_slice_outer_shorter(%arg0: vector<4x8x16xf32>) -> vector<2x4x16xf32> {
+  %0 = vector.extract_strided_slice %arg0 {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]} : vector<4x8x16xf32> to vector<3x4x16xf32>
+  %1 = vector.extract_strided_slice %0 {offsets = [0], sizes = [2], strides = [1]} : vector<3x4x16xf32> to vector<2x4x16xf32>
+  return %1 : vector<2x4x16xf32>
+}


        


More information about the Mlir-commits mailing list