[Mlir-commits] [mlir] f566b07 - [MLIR] Add pattern to fold insert_slice of extract_slice (#86328)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 28 08:18:50 PDT 2024
Author: Jerry Wu
Date: 2024-03-28T11:18:47-04:00
New Revision: f566b079f171f28366a66b8afa4a975bc4005529
URL: https://github.com/llvm/llvm-project/commit/f566b079f171f28366a66b8afa4a975bc4005529
DIFF: https://github.com/llvm/llvm-project/commit/f566b079f171f28366a66b8afa4a975bc4005529.diff
LOG: [MLIR] Add pattern to fold insert_slice of extract_slice (#86328)
Fold the `tensor.insert_slice` of `tensor.extract_slice` into
`tensor_extract_slice` when the `insert_slice` simply expand some unit
dims dropped by the `extract_slice`.
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
mlir/lib/Dialect/Tensor/Utils/Utils.cpp
mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
index 5257310f5b005b..59aa4322217583 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -78,12 +78,12 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
}
};
-/// Drop redundant rank expansion. I.e., rank expansions that are directly
-/// followed by rank reductions. E.g.:
+/// Drop redundant rank expansion of insert_slice that are directly followed
+/// by extract_slice. E.g.:
/// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
/// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
/// : tensor<1x1x5x10xf32> to tensor<2x2xf32>
-struct DropRedundantInsertSliceRankExpansion
+struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
: public OpRewritePattern<ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
@@ -134,6 +134,97 @@ struct DropRedundantInsertSliceRankExpansion
return success();
}
};
+
+/// Drop redundant rank expansion of insert_slice that direclty follows
+/// extract_slice.
+///
+/// This can be done when the insert_slice op purely expands ranks (adds unit
+/// dims) and the extrace_slice drops corresponding unit dims. For example:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+/// : tensor<2x8xf32> to tensor<8xf32>
+/// %inserted_slice = tensor.insert_slice %extracted_slice
+/// into %dest[0, 0] [1, 8] [1, 1]
+/// : tensor<8xf32> into tensor<1x8xf32>
+///
+/// can be folded into:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+/// : tensor<2x8xf32> to tensor<1x8xf32>
+struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const {
+ auto extractSliceOp =
+ insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "source is not extract_slice");
+ }
+
+ // Can't fold if the extract_slice op has other users.
+ if (!extractSliceOp->hasOneUse()) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "source has multi-uses");
+ }
+
+ // Check if the insert_slice op purely expands ranks (add unit dims).
+ if (!isCastLikeInsertSliceOp(insertSliceOp)) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "insert_slice is not cast-like");
+ }
+
+ llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
+ llvm::SmallBitVector insertDroppedDims = insertSliceOp.getDroppedDims();
+ // Can't fold if the insert_slice op expands to more dims.
+ if (extractDroppedDims.size() < insertDroppedDims.size()) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "insert_slice expands more dims");
+ }
+
+ // Try to match the extract dropped dims to the insert dropped dims. This is
+ // done by scanning the dims of extract_slice and find the left-most one can
+ // match the dim of insert_slice. If a match is found, advance the dim of
+ // insert_slice to match the next one.
+ unsigned insertDimPos = 0;
+ for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
+ ++extractDimPos) {
+ // Matched all dims.
+ if (insertDimPos == insertDroppedDims.size())
+ break;
+
+ bool isExtractDropped = extractDroppedDims[extractDimPos];
+ bool isInsertDropped = insertDroppedDims[insertDimPos];
+ // Match if both sides drop/keep the dim. Advance and match the next dim
+ // of insert_slice.
+ if (isExtractDropped == isInsertDropped) {
+ insertDimPos += 1;
+ } else if (!isExtractDropped && isInsertDropped) {
+ // Not enough extract dropped dims to match the insert dropped dims.
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "insert_slice drops more unit dims");
+ }
+ // If the dim is dropped by extract_slice and not by insert_slice, look
+ // the next dim of extract_slice to see if it can match the current dim of
+ // insert_slice.
+ }
+ // Can't match some insert dims.
+ if (insertDimPos != insertDroppedDims.size()) {
+ return rewriter.notifyMatchFailure(insertSliceOp,
+ "insert_slice has unmatched dims");
+ }
+
+ rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+ insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
+ extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+ extractSliceOp.getMixedStrides());
+ rewriter.eraseOp(extractSliceOp);
+
+ return success();
+ }
+};
} // namespace
void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
@@ -146,5 +237,7 @@ void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
RewritePatternSet &patterns) {
- patterns.add<DropRedundantInsertSliceRankExpansion>(patterns.getContext());
+ patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
+ DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
+ patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 186f85d2ce20a6..2dd91e2f7a1700 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -142,11 +142,15 @@ mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
+ RankedTensorType resultType = op.getDestType();
// Source dims and destination dims (apart from dropped dims) must have the
// same size.
- for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
- ++resultDim) {
+ for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
if (droppedDims.test(resultDim)) {
+ // InsertSlice may expand unit dimensions that result from inserting a
+ // size-1 slice into a non-size-1 result dimension.
+ if (resultType.getDimSize(resultDim) != 1)
+ return false;
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
diff --git a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
index e337fdd9321424..88e55062f47702 100644
--- a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
+++ b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
@@ -9,3 +9,68 @@ func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1
%extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32>
return %extracted_slice : tensor<123x456xf32>
}
+
+// -----
+
+func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
+ return %inserted_slice : tensor<8x1x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
+// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
+
+// -----
+
+func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
+ return %inserted_slice : tensor<1x4x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
+// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
+
+// -----
+
+func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
+ return %inserted_slice : tensor<1x1x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
+
+// -----
+
+func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
+ return %inserted_slice : tensor<1x4x4xf32>
+}
+// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
+
+// -----
+
+func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
+ return %inserted_slice : tensor<2x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
More information about the Mlir-commits
mailing list