[Mlir-commits] [mlir] [MLIR] Add pattern to fold insert_slice of extract_slice (PR #86328)

Jerry Wu llvmlistbot at llvm.org
Wed Mar 27 13:58:45 PDT 2024


================
@@ -134,6 +134,86 @@ struct DropRedundantInsertSliceRankExpansion
     return success();
   }
 };
+
+/// Drop redundant rank expansion of insert_slice that direclty follows
+/// extract_slice.
+///
+/// This can be done when the insert_slice op purely expands ranks (adds unit
+/// dims) and the extrace_slice drops corresponding unit dims. For example:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<8xf32>
+/// %inserted_slice = tensor.insert_slice %extracted_slice
+///     into %dest[0, 0] [1, 8] [1, 1]
+///     : tensor<8xf32> into tensor<1x8xf32>
+///
+/// can be folded into:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<1x8xf32>
+struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const {
+    auto extractSliceOp =
+        insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractSliceOp)
+      return failure();
+
+    // Can't fold if the extract_slice op has other users.
+    if (!extractSliceOp->hasOneUse())
+      return failure();
+
+    // Check if the insert_slice op purely expands ranks (add unit dims).
+    if (!isCastLikeInsertSliceOp(insertSliceOp))
+      return failure();
+
+    llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
+    llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+    // Can't fold if the insert_slice op expands to more dims.
+    if (extractDroppedDims.size() < insertExpandedDims.size())
+      return failure();
----------------
pzread wrote:

Done.

https://github.com/llvm/llvm-project/pull/86328


More information about the Mlir-commits mailing list