[Mlir-commits] [mlir] [mlir][tensor] Enhance pattern to fold extract_slice(insert_slice) (PR #195045)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 30 02:22:48 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Longsheng Mou (CoTinker)

<details>
<summary>Changes</summary>

Extend the DropRedundantRankExpansionOnExtractSliceOfInsertSlice pattern to support cases  where the expanded dimensions are a subset of the dropped dimensions, rather than requiring them to be exactly equal.
For example:
```
%inserted_slice = tensor.insert_slice %src into %dest[0, 0, 0, 0] [1, 1, 128, 480] [1, 1, 1, 1] : tensor<128x480xf32> into tensor<1x1x128x480xf32>
%extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 1] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123xf32>
```
can be folded into:
```
%extracted_slice = tensor.extract_slice %src[0, 0] [123, 1] [1, 1] : tensor<128x480xf32> to tensor<123xf32>
```

---
Full diff: https://github.com/llvm/llvm-project/pull/195045.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tensor/Transforms/DropRedundantRankExpansionPatterns.cpp (+17-10) 
- (modified) mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir (+12) 


``````````diff
diff --git a/mlir/lib/Dialect/Tensor/Transforms/DropRedundantRankExpansionPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/DropRedundantRankExpansionPatterns.cpp
index 4253548d11f49..55ad6256b0f6c 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/DropRedundantRankExpansionPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/DropRedundantRankExpansionPatterns.cpp
@@ -20,9 +20,14 @@ using namespace mlir::tensor;
 namespace {
 /// Drop redundant rank expansion of insert_slice that are directly followed
 /// by extract_slice. E.g.:
-/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
+/// %0 = tensor.insert_slice %in... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
 /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
 ///     : tensor<1x1x5x10xf32> to tensor<2x2xf32>
+///
+/// can be folded into:
+///
+/// %1 = tensor.extract_slice %in[2, 3] [2, 2] [1, 1]
+///     : tensor<5x10xf32> to tensor<2x2xf32>
 struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
     : public OpRewritePattern<ExtractSliceOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -41,9 +46,8 @@ struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
       return failure();
     llvm::SmallBitVector expandedDims = insertSliceOp.getDroppedDims();
 
-    // TODO: This could be extended to support cases where the dropped dims are
-    // a subset of the expanded dims.
-    if (expandedDims != droppedDims)
+    // Support cases where the expanded dims are a subset of the droped dims.
+    if (!expandedDims.subsetOf(droppedDims))
       return failure();
 
     // The tensor.insert_slice may not be redundant if it has multiple users.
@@ -58,18 +62,21 @@ struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
     // Extract directly from the source.
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(extractSliceOp);
+    SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
     SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
     for (int64_t i = 0, e = extractSliceOp.getSourceType().getRank(); i < e;
          ++i) {
-      if (droppedDims.test(i))
+      if (expandedDims.test(i))
         continue;
-      newOffsets.push_back(extractSliceOp.getMixedOffsets()[i]);
-      newSizes.push_back(extractSliceOp.getMixedSizes()[i]);
-      newStrides.push_back(extractSliceOp.getMixedStrides()[i]);
+      newOffsets.push_back(mixedOffsets[i]);
+      newSizes.push_back(mixedSizes[i]);
+      newStrides.push_back(mixedStrides[i]);
     }
     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
-        extractSliceOp, /*source=*/insertSliceOp.getSource(), newOffsets,
-        newSizes, newStrides);
+        extractSliceOp, extractSliceOp.getResultType(),
+        /*source=*/insertSliceOp.getSource(), newOffsets, newSizes, newStrides);
     rewriter.eraseOp(insertSliceOp);
     return success();
   }
diff --git a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
index 0496b93257a9e..e21d85411b2a7 100644
--- a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
+++ b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
@@ -12,6 +12,18 @@ func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1
 
 // -----
 
+// CHECK-LABEL: func @test_drop_rank_expansion(
+//  CHECK-SAME:     %[[src:.*]]: tensor<128x480xf32>,
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[src]][0, 0] [123, 1] [1, 1] : tensor<128x480xf32> to tensor<123xf32>
+//       CHECK:   return %[[extract]]
+func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1x128x480xf32>) -> tensor<123xf32> {
+  %inserted_slice = tensor.insert_slice %src into %dest[0, 0, 0, 0] [1, 1, 128, 480] [1, 1, 1, 1] : tensor<128x480xf32> into tensor<1x1x128x480xf32>
+  %extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 1] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123xf32>
+  return %extracted_slice : tensor<123xf32>
+}
+
+// -----
+
 func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
   %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
   %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/195045


More information about the Mlir-commits mailing list