[Mlir-commits] [mlir] bc9fce7 - [MLIR][Vector] Add a pattern that folds consecutive extract_strided_strided_slice ops (#175738)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 14 09:00:20 PST 2026
Author: Adam Paszke
Date: 2026-01-14T09:00:13-08:00
New Revision: bc9fce7e8b734c1a2a491c23ca247e4411b561f5
URL: https://github.com/llvm/llvm-project/commit/bc9fce7e8b734c1a2a491c23ca247e4411b561f5
DIFF: https://github.com/llvm/llvm-project/commit/bc9fce7e8b734c1a2a491c23ca247e4411b561f5.diff
LOG: [MLIR][Vector] Add a pattern that folds consecutive extract_strided_strided_slice ops (#175738)
A slice of a slice is just a slice.
Added:
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b9ea336051de6..085f879c2d0e6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4366,6 +4366,68 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
+// Pattern to rewrite nested ExtractStridedSliceOp into a single one.
+//
+// Example:
+//
+// %0 = vector.extract_strided_slice %arg0
+// {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]}
+// : vector<4x8x16xf32> to vector<3x4x16xf32>
+// %1 = vector.extract_strided_slice %0
+// {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]}
+// : vector<3x4x16xf32> to vector<2x2x16xf32>
+//
+// to
+//
+// %1 = vector.extract_strided_slice %arg0
+// {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
+// : vector<4x8x16xf32> to vector<2x2x16xf32>
+class StridedSliceFolder final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp secondOp,
+ PatternRewriter &rewriter) const override {
+ auto firstOp = secondOp.getSource().getDefiningOp<ExtractStridedSliceOp>();
+ if (!firstOp)
+ return failure();
+
+ if (secondOp.hasNonUnitStrides() || firstOp.hasNonUnitStrides())
+ return failure();
+
+ SmallVector<int64_t> firstOffsets = getI64SubArray(firstOp.getOffsets());
+ SmallVector<int64_t> firstSizes = getI64SubArray(firstOp.getSizes());
+ SmallVector<int64_t> secondOffsets = getI64SubArray(secondOp.getOffsets());
+ SmallVector<int64_t> secondSizes = getI64SubArray(secondOp.getSizes());
+
+ unsigned newRank = std::max(firstOffsets.size(), secondOffsets.size());
+ SmallVector<int64_t> combinedOffsets(newRank, 0);
+ SmallVector<int64_t> combinedSizes(newRank);
+ ArrayRef<int64_t> firstSourceShape =
+ firstOp.getSourceVectorType().getShape();
+ for (unsigned i = 0; i < newRank; ++i) {
+ int64_t off1 = (i < firstOffsets.size()) ? firstOffsets[i] : 0;
+ int64_t off2 = (i < secondOffsets.size()) ? secondOffsets[i] : 0;
+ combinedOffsets[i] = off1 + off2;
+
+ if (i < secondSizes.size()) {
+ combinedSizes[i] = secondSizes[i];
+ } else if (i < firstSizes.size()) {
+ combinedSizes[i] = firstSizes[i];
+ } else {
+ combinedSizes[i] = firstSourceShape[i];
+ }
+ }
+
+ SmallVector<int64_t> combinedStrides(newRank, 1);
+ rewriter.replaceOpWithNewOp<ExtractStridedSliceOp>(
+ secondOp, firstOp.getSource(), combinedOffsets, combinedSizes,
+ combinedStrides);
+ return success();
+ }
+};
+
// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
// CreateMaskOp.
//
@@ -4641,9 +4703,10 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
- results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
- StridedSliceBroadcast, StridedSliceSplat,
- ContiguousExtractStridedSliceToExtract>(context);
+ results.add<StridedSliceFolder, StridedSliceCreateMaskFolder,
+ StridedSliceConstantMaskFolder, StridedSliceBroadcast,
+ StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index e17b1cfbe5e0d..a30eda1e06cf8 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -4027,3 +4027,45 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
%v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
return %v_1 : vector<4xf32>
}
+
+// -----
+
+// CHECK-LABEL: extract_strided_slice_of_extract_strided_slice
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x8x16xf32>
+// CHECK: %[[RESULT:.*]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1, 3], sizes = [2, 2], strides = [1, 1]}
+// CHECK-SAME: : vector<4x8x16xf32> to vector<2x2x16xf32>
+// CHECK: return %[[RESULT]]
+func.func @extract_strided_slice_of_extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]} : vector<4x8x16xf32> to vector<3x4x16xf32>
+ %1 = vector.extract_strided_slice %0 {offsets = [0, 1], sizes = [2, 2], strides = [1, 1]} : vector<3x4x16xf32> to vector<2x2x16xf32>
+ return %1 : vector<2x2x16xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_slice_inner_shorter
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x8x16xf32>
+// CHECK: %[[RESULT:.*]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1, 1, 2], sizes = [2, 2, 4], strides = [1, 1, 1]}
+// CHECK-SAME: : vector<4x8x16xf32> to vector<2x2x4xf32>
+// CHECK: return %[[RESULT]]
+func.func @extract_strided_slice_inner_shorter(%arg0: vector<4x8x16xf32>) -> vector<2x2x4xf32> {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [1], sizes = [3], strides = [1]} : vector<4x8x16xf32> to vector<3x8x16xf32>
+ %1 = vector.extract_strided_slice %0 {offsets = [0, 1, 2], sizes = [2, 2, 4], strides = [1, 1, 1]} : vector<3x8x16xf32> to vector<2x2x4xf32>
+ return %1 : vector<2x2x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: extract_strided_slice_outer_shorter
+// CHECK-SAME: %[[ARG0:.*]]: vector<4x8x16xf32>
+// CHECK: %[[RESULT:.*]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1, 2], sizes = [2, 4], strides = [1, 1]}
+// CHECK-SAME: : vector<4x8x16xf32> to vector<2x4x16xf32>
+// CHECK: return %[[RESULT]]
+func.func @extract_strided_slice_outer_shorter(%arg0: vector<4x8x16xf32>) -> vector<2x4x16xf32> {
+ %0 = vector.extract_strided_slice %arg0 {offsets = [1, 2], sizes = [3, 4], strides = [1, 1]} : vector<4x8x16xf32> to vector<3x4x16xf32>
+ %1 = vector.extract_strided_slice %0 {offsets = [0], sizes = [2], strides = [1]} : vector<3x4x16xf32> to vector<2x4x16xf32>
+ return %1 : vector<2x4x16xf32>
+}
More information about the Mlir-commits
mailing list