[Mlir-commits] [mlir] [MLIR] Add pattern to fold insert_slice of extract_slice (PR #86328)
Jerry Wu
llvmlistbot at llvm.org
Fri Mar 22 14:09:23 PDT 2024
https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/86328
>From 2fb7b5c91d1df0b61dd7a1b1f5d3972255b40644 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Wed, 20 Mar 2024 21:26:44 +0000
Subject: [PATCH 1/3] Fold extract_slice+insert_slice
---
.../Tensor/Transforms/FoldTensorSubsetOps.cpp | 61 ++++++++++++++++++-
1 file changed, 59 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb7314..94de49b301e692 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineMap.h"
@@ -65,6 +66,16 @@ class InsertSliceOfTransferWriteOpFolder final
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
PatternRewriter &rewriter) const override;
};
+
+class InsertSliceOfExtractSliceFolder final
+ : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+ PatternRewriter &rewriter) const override;
+};
+
} // namespace
template <typename XferOp, typename ExtractOrInsertOp>
@@ -147,6 +158,52 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
return success();
}
+LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
+ tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
+ auto extractSliceOp =
+ insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+ if (!extractSliceOp)
+ return failure();
+
+ if (!extractSliceOp->hasOneUse())
+ return failure();
+
+ if (!isCastLikeInsertSliceOp(insertSliceOp))
+ return failure();
+
+ llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
+ llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+ if (extractDroppedDims.size() < insertExpandedDims.size())
+ return failure();
+
+ int64_t insertPos = 0;
+ for (int64_t extractPos = 0; extractPos < extractDroppedDims.size();
+ ++extractPos) {
+ if (insertPos == insertExpandedDims.size())
+ break;
+
+ bool isDropped = extractDroppedDims[extractPos];
+ bool isExpanded = insertExpandedDims[insertPos];
+ if (isDropped == isExpanded) {
+ insertPos += 1;
+ } else {
+ if (!isDropped && isExpanded) {
+ return failure();
+ }
+ }
+ }
+ if (insertPos != insertExpandedDims.size())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+ insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
+ extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+ extractSliceOp.getMixedStrides());
+ rewriter.eraseOp(extractSliceOp);
+
+ return success();
+}
+
template <typename OpTy>
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -224,8 +281,8 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
- InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
- patterns.getContext());
+ InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>,
+ InsertSliceOfExtractSliceFolder>(patterns.getContext());
}
void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
>From ca2607a4532f561a76f96e3a9af9824a51bbd4df Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 21 Mar 2024 20:06:13 +0000
Subject: [PATCH 2/3] Fix isCastLikeInsertSliceOp
---
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 186f85d2ce20a6..4bc966f2079d8a 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -142,11 +142,13 @@ mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
+ RankedTensorType resultType = op.getDestType();
// Source dims and destination dims (apart from dropped dims) must have the
// same size.
- for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
- ++resultDim) {
+ for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
if (droppedDims.test(resultDim)) {
+ if (resultType.getDimSize(resultDim) != 1)
+ return false;
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
>From dceb7b78f571423032364fae90e66d3670e70c8f Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 22 Mar 2024 19:03:48 +0000
Subject: [PATCH 3/3] Add pattern to fold insert_slice of extract_slice
---
.../Tensor/Transforms/FoldTensorSubsetOps.cpp | 53 +++++++++++----
mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 2 +
.../Tensor/fold-tensor-subset-ops.mlir | 65 +++++++++++++++++++
3 files changed, 108 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 94de49b301e692..e65edcb05cb7f3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -67,6 +67,7 @@ class InsertSliceOfTransferWriteOpFolder final
PatternRewriter &rewriter) const override;
};
+/// Merge insert_slice operation with extract_slice operation.
class InsertSliceOfExtractSliceFolder final
: public OpRewritePattern<tensor::InsertSliceOp> {
public:
@@ -158,6 +159,21 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
return success();
}
+/// Merge insert_slice operation with extract_slice operation.
+///
+/// This can be done when the insert_slice op purely expands ranks (adds unit
+/// dims) and the extrace_slice drops corresponding unit dims. For example:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+/// : tensor<2x8xf32> to tensor<8xf32>
+/// %inserted_slice = tensor.insert_slice %extracted_slice
+/// into %dest[0, 0] [1, 8] [1, 1]
+/// : tensor<8xf32> into tensor<1x8xf32>
+///
+/// can be folded into:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+/// : tensor<2x8xf32> to tensor<1x8xf32>
LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
auto extractSliceOp =
@@ -165,34 +181,47 @@ LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
if (!extractSliceOp)
return failure();
+ // Can't fold if the extract_slice op has other users.
if (!extractSliceOp->hasOneUse())
return failure();
+ // Check if the insert_slice op purely expands ranks (add unit dims).
if (!isCastLikeInsertSliceOp(insertSliceOp))
return failure();
llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+ // Can't fold if the insert_slice op expands to more dims.
if (extractDroppedDims.size() < insertExpandedDims.size())
return failure();
- int64_t insertPos = 0;
- for (int64_t extractPos = 0; extractPos < extractDroppedDims.size();
- ++extractPos) {
- if (insertPos == insertExpandedDims.size())
+ // Try to match the dropped unit dims to the expanded unit dims. This is done
+ // by scanning the dims of extract_slice and find the left-most one can match
+ // the dim of insert_slice. If a match is found, advance the dim of
+ // insert_slice to match the next one.
+ unsigned insertDimPos = 0;
+ for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
+ ++extractDimPos) {
+ // Matched all expanded dims.
+ if (insertDimPos == insertExpandedDims.size())
break;
- bool isDropped = extractDroppedDims[extractPos];
- bool isExpanded = insertExpandedDims[insertPos];
+ bool isDropped = extractDroppedDims[extractDimPos];
+ bool isExpanded = insertExpandedDims[insertDimPos];
+ // Match if both sides drop/keep the dim. Advance and match the next dim of
+ // insert_slice.
if (isDropped == isExpanded) {
- insertPos += 1;
- } else {
- if (!isDropped && isExpanded) {
- return failure();
- }
+ insertDimPos += 1;
+ } else if (!isDropped && isExpanded) {
+ // Not enough dropped unit dims to match the expanded unit dims.
+ return failure();
}
+ // If the dim is dropped by extract_slice and not by insert_slice, look the
+ // next dim of extract_slice to see if it can match the current dim of
+ // insert_slice.
}
- if (insertPos != insertExpandedDims.size())
+ // Can't match some expanded dims.
+ if (insertDimPos != insertExpandedDims.size())
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 4bc966f2079d8a..2dd91e2f7a1700 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -147,6 +147,8 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
// same size.
for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
if (droppedDims.test(resultDim)) {
+ // InsertSlice may expand unit dimensions that result from inserting a
+ // size-1 slice into a non-size-1 result dimension.
if (resultType.getDimSize(resultDim) != 1)
return false;
continue;
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index f2e529b4cac950..bb1df99d4c97ee 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -390,3 +390,68 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
}
return %0: tensor<12x34xf32>
}
+
+// -----
+
+func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
+ return %inserted_slice : tensor<8x1x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
+// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
+
+// -----
+
+func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
+ return %inserted_slice : tensor<1x4x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
+// CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
+// CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
+
+// -----
+
+func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
+ return %inserted_slice : tensor<1x1x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
+
+// -----
+
+func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
+ return %inserted_slice : tensor<1x4x4xf32>
+}
+// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
+
+// -----
+
+func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
+ %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
+ return %inserted_slice : tensor<2x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
More information about the Mlir-commits
mailing list