[Mlir-commits] [mlir] [MLIR] Add pattern to fold insert_slice of extract_slice (PR #86328)

Jerry Wu llvmlistbot at llvm.org
Tue Mar 26 10:49:17 PDT 2024


https://github.com/pzread updated https://github.com/llvm/llvm-project/pull/86328

>From 2fb7b5c91d1df0b61dd7a1b1f5d3972255b40644 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Wed, 20 Mar 2024 21:26:44 +0000
Subject: [PATCH 1/4] Fold extract_slice+insert_slice

---
 .../Tensor/Transforms/FoldTensorSubsetOps.cpp | 61 ++++++++++++++++++-
 1 file changed, 59 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 3b8d3708bb7314..94de49b301e692 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Tensor/Transforms/Passes.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/AffineMap.h"
@@ -65,6 +66,16 @@ class InsertSliceOfTransferWriteOpFolder final
   LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
                                 PatternRewriter &rewriter) const override;
 };
+
+class InsertSliceOfExtractSliceFolder final
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+public:
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 template <typename XferOp, typename ExtractOrInsertOp>
@@ -147,6 +158,52 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   return success();
 }
 
+LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
+    tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
+  auto extractSliceOp =
+      insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+  if (!extractSliceOp)
+    return failure();
+
+  if (!extractSliceOp->hasOneUse())
+    return failure();
+
+  if (!isCastLikeInsertSliceOp(insertSliceOp))
+    return failure();
+
+  llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
+  llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+  if (extractDroppedDims.size() < insertExpandedDims.size())
+    return failure();
+
+  int64_t insertPos = 0;
+  for (int64_t extractPos = 0; extractPos < extractDroppedDims.size();
+       ++extractPos) {
+    if (insertPos == insertExpandedDims.size())
+      break;
+
+    bool isDropped = extractDroppedDims[extractPos];
+    bool isExpanded = insertExpandedDims[insertPos];
+    if (isDropped == isExpanded) {
+      insertPos += 1;
+    } else {
+      if (!isDropped && isExpanded) {
+        return failure();
+      }
+    }
+  }
+  if (insertPos != insertExpandedDims.size())
+    return failure();
+
+  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+      insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
+      extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+      extractSliceOp.getMixedStrides());
+  rewriter.eraseOp(extractSliceOp);
+
+  return success();
+}
+
 template <typename OpTy>
 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -224,8 +281,8 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
   populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
   patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
-               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
-      patterns.getContext());
+               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>,
+               InsertSliceOfExtractSliceFolder>(patterns.getContext());
 }
 
 void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(

>From ca2607a4532f561a76f96e3a9af9824a51bbd4df Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Thu, 21 Mar 2024 20:06:13 +0000
Subject: [PATCH 2/4] Fix isCastLikeInsertSliceOp

---
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 186f85d2ce20a6..4bc966f2079d8a 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -142,11 +142,13 @@ mlir::tensor::getUnPackInverseSrcPerm(UnPackOp unpackOp,
 bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
   llvm::SmallBitVector droppedDims = op.getDroppedDims();
   int64_t srcDim = 0;
+  RankedTensorType resultType = op.getDestType();
   // Source dims and destination dims (apart from dropped dims) must have the
   // same size.
-  for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
-       ++resultDim) {
+  for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
     if (droppedDims.test(resultDim)) {
+      if (resultType.getDimSize(resultDim) != 1)
+        return false;
       continue;
     }
     FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(

>From dceb7b78f571423032364fae90e66d3670e70c8f Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Fri, 22 Mar 2024 19:03:48 +0000
Subject: [PATCH 3/4] Add pattern to fold insert_slice of extract_slice

---
 .../Tensor/Transforms/FoldTensorSubsetOps.cpp | 53 +++++++++++----
 mlir/lib/Dialect/Tensor/Utils/Utils.cpp       |  2 +
 .../Tensor/fold-tensor-subset-ops.mlir        | 65 +++++++++++++++++++
 3 files changed, 108 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index 94de49b301e692..e65edcb05cb7f3 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -67,6 +67,7 @@ class InsertSliceOfTransferWriteOpFolder final
                                 PatternRewriter &rewriter) const override;
 };
 
+/// Merge insert_slice operation with extract_slice operation.
 class InsertSliceOfExtractSliceFolder final
     : public OpRewritePattern<tensor::InsertSliceOp> {
 public:
@@ -158,6 +159,21 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   return success();
 }
 
+/// Merge insert_slice operation with extract_slice operation.
+///
+/// This can be done when the insert_slice op purely expands ranks (adds unit
+/// dims) and the extrace_slice drops corresponding unit dims. For example:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<8xf32>
+/// %inserted_slice = tensor.insert_slice %extracted_slice
+///     into %dest[0, 0] [1, 8] [1, 1]
+///     : tensor<8xf32> into tensor<1x8xf32>
+///
+/// can be folded into:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<1x8xf32>
 LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
     tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
   auto extractSliceOp =
@@ -165,34 +181,47 @@ LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
   if (!extractSliceOp)
     return failure();
 
+  // Can't fold if the extract_slice op has other users.
   if (!extractSliceOp->hasOneUse())
     return failure();
 
+  // Check if the insert_slice op purely expands ranks (add unit dims).
   if (!isCastLikeInsertSliceOp(insertSliceOp))
     return failure();
 
   llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
   llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+  // Can't fold if the insert_slice op expands to more dims.
   if (extractDroppedDims.size() < insertExpandedDims.size())
     return failure();
 
-  int64_t insertPos = 0;
-  for (int64_t extractPos = 0; extractPos < extractDroppedDims.size();
-       ++extractPos) {
-    if (insertPos == insertExpandedDims.size())
+  // Try to match the dropped unit dims to the expanded unit dims. This is done
+  // by scanning the dims of extract_slice and find the left-most one can match
+  // the dim of insert_slice. If a match is found, advance the dim of
+  // insert_slice to match the next one.
+  unsigned insertDimPos = 0;
+  for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
+       ++extractDimPos) {
+    // Matched all expanded dims.
+    if (insertDimPos == insertExpandedDims.size())
       break;
 
-    bool isDropped = extractDroppedDims[extractPos];
-    bool isExpanded = insertExpandedDims[insertPos];
+    bool isDropped = extractDroppedDims[extractDimPos];
+    bool isExpanded = insertExpandedDims[insertDimPos];
+    // Match if both sides drop/keep the dim. Advance and match the next dim of
+    // insert_slice.
     if (isDropped == isExpanded) {
-      insertPos += 1;
-    } else {
-      if (!isDropped && isExpanded) {
-        return failure();
-      }
+      insertDimPos += 1;
+    } else if (!isDropped && isExpanded) {
+      // Not enough dropped unit dims to match the expanded unit dims.
+      return failure();
     }
+    // If the dim is dropped by extract_slice and not by insert_slice, look the
+    // next dim of extract_slice to see if it can match the current dim of
+    // insert_slice.
   }
-  if (insertPos != insertExpandedDims.size())
+  // Can't match some expanded dims.
+  if (insertDimPos != insertExpandedDims.size())
     return failure();
 
   rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 4bc966f2079d8a..2dd91e2f7a1700 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -147,6 +147,8 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
   // same size.
   for (int64_t resultDim = 0; resultDim < resultType.getRank(); ++resultDim) {
     if (droppedDims.test(resultDim)) {
+      // InsertSlice may expand unit dimensions that result from inserting a
+      // size-1 slice into a non-size-1 result dimension.
       if (resultType.getDimSize(resultDim) != 1)
         return false;
       continue;
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index f2e529b4cac950..bb1df99d4c97ee 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -390,3 +390,68 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
   }
   return %0: tensor<12x34xf32>
 }
+
+// -----
+
+func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
+  return %inserted_slice : tensor<8x1x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
+// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
+// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
+
+// -----
+
+func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
+  return %inserted_slice : tensor<1x4x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
+// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
+// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
+
+// -----
+
+func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
+  return %inserted_slice : tensor<1x1x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
+
+// -----
+
+func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
+  return %inserted_slice : tensor<1x4x4xf32>
+}
+// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
+
+// -----
+
+func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
+  return %inserted_slice : tensor<2x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>

>From 669c1bd973e3a342c29ebcfd52f17a9d2f3d9cb6 Mon Sep 17 00:00:00 2001
From: Jerry Wu <cheyuw at google.com>
Date: Tue, 26 Mar 2024 17:49:01 +0000
Subject: [PATCH 4/4] Move code

---
 .../Tensor/Transforms/FoldTensorSubsetOps.cpp | 88 +-----------------
 ...eConsecutiveInsertExtractSlicePatterns.cpp | 90 ++++++++++++++++++-
 ...redundant-insert-slice-rank-expansion.mlir | 65 ++++++++++++++
 .../Tensor/fold-tensor-subset-ops.mlir        | 67 +-------------
 4 files changed, 154 insertions(+), 156 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index e65edcb05cb7f3..6bb66f9ce4f971 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -67,16 +67,6 @@ class InsertSliceOfTransferWriteOpFolder final
                                 PatternRewriter &rewriter) const override;
 };
 
-/// Merge insert_slice operation with extract_slice operation.
-class InsertSliceOfExtractSliceFolder final
-    : public OpRewritePattern<tensor::InsertSliceOp> {
-public:
-  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
-                                PatternRewriter &rewriter) const override;
-};
-
 } // namespace
 
 template <typename XferOp, typename ExtractOrInsertOp>
@@ -159,80 +149,6 @@ LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
   return success();
 }
 
-/// Merge insert_slice operation with extract_slice operation.
-///
-/// This can be done when the insert_slice op purely expands ranks (adds unit
-/// dims) and the extrace_slice drops corresponding unit dims. For example:
-///
-/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
-///     : tensor<2x8xf32> to tensor<8xf32>
-/// %inserted_slice = tensor.insert_slice %extracted_slice
-///     into %dest[0, 0] [1, 8] [1, 1]
-///     : tensor<8xf32> into tensor<1x8xf32>
-///
-/// can be folded into:
-///
-/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
-///     : tensor<2x8xf32> to tensor<1x8xf32>
-LogicalResult InsertSliceOfExtractSliceFolder::matchAndRewrite(
-    tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
-  auto extractSliceOp =
-      insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
-  if (!extractSliceOp)
-    return failure();
-
-  // Can't fold if the extract_slice op has other users.
-  if (!extractSliceOp->hasOneUse())
-    return failure();
-
-  // Check if the insert_slice op purely expands ranks (add unit dims).
-  if (!isCastLikeInsertSliceOp(insertSliceOp))
-    return failure();
-
-  llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
-  llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
-  // Can't fold if the insert_slice op expands to more dims.
-  if (extractDroppedDims.size() < insertExpandedDims.size())
-    return failure();
-
-  // Try to match the dropped unit dims to the expanded unit dims. This is done
-  // by scanning the dims of extract_slice and find the left-most one can match
-  // the dim of insert_slice. If a match is found, advance the dim of
-  // insert_slice to match the next one.
-  unsigned insertDimPos = 0;
-  for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
-       ++extractDimPos) {
-    // Matched all expanded dims.
-    if (insertDimPos == insertExpandedDims.size())
-      break;
-
-    bool isDropped = extractDroppedDims[extractDimPos];
-    bool isExpanded = insertExpandedDims[insertDimPos];
-    // Match if both sides drop/keep the dim. Advance and match the next dim of
-    // insert_slice.
-    if (isDropped == isExpanded) {
-      insertDimPos += 1;
-    } else if (!isDropped && isExpanded) {
-      // Not enough dropped unit dims to match the expanded unit dims.
-      return failure();
-    }
-    // If the dim is dropped by extract_slice and not by insert_slice, look the
-    // next dim of extract_slice to see if it can match the current dim of
-    // insert_slice.
-  }
-  // Can't match some expanded dims.
-  if (insertDimPos != insertExpandedDims.size())
-    return failure();
-
-  rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
-      insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
-      extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
-      extractSliceOp.getMixedStrides());
-  rewriter.eraseOp(extractSliceOp);
-
-  return success();
-}
-
 template <typename OpTy>
 struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
   using OpRewritePattern<OpTy>::OpRewritePattern;
@@ -310,8 +226,8 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
 void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
   populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
   patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
-               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>,
-               InsertSliceOfExtractSliceFolder>(patterns.getContext());
+               InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
+      patterns.getContext());
 }
 
 void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
index 5257310f5b005b..b66f9197057278 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -78,12 +78,12 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
   }
 };
 
-/// Drop redundant rank expansion. I.e., rank expansions that are directly
-/// followed by rank reductions. E.g.:
+/// Drop redundant rank expansion of insert_slice that are directly followed
+/// by extract_slice. E.g.:
 /// %0 = tensor.insert_slice ... : tensor<5x10xf32> into tensor<1x1x5x10xf32>
 /// %1 = tensor.extract_slice %0[0, 0, 2, 3] [1, 1, 2, 2] [1, 1, 1, 1]
 ///     : tensor<1x1x5x10xf32> to tensor<2x2xf32>
-struct DropRedundantInsertSliceRankExpansion
+struct DropRedundantRankExpansionOnExtractSliceOfInsertSlice
     : public OpRewritePattern<ExtractSliceOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -134,6 +134,86 @@ struct DropRedundantInsertSliceRankExpansion
     return success();
   }
 };
+
+/// Drop redundant rank expansion of insert_slice that direclty follows
+/// extract_slice.
+///
+/// This can be done when the insert_slice op purely expands ranks (adds unit
+/// dims) and the extrace_slice drops corresponding unit dims. For example:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<8xf32>
+/// %inserted_slice = tensor.insert_slice %extracted_slice
+///     into %dest[0, 0] [1, 8] [1, 1]
+///     : tensor<8xf32> into tensor<1x8xf32>
+///
+/// can be folded into:
+///
+/// %extracted_slice = tensor.extract_slice %in[0, 0] [1, 8] [1, 1]
+///     : tensor<2x8xf32> to tensor<1x8xf32>
+struct DropRedundantRankExpansionOnInsertSliceOfExtractSlice final
+    : public OpRewritePattern<tensor::InsertSliceOp> {
+  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
+                                PatternRewriter &rewriter) const {
+    auto extractSliceOp =
+        insertSliceOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractSliceOp)
+      return failure();
+
+    // Can't fold if the extract_slice op has other users.
+    if (!extractSliceOp->hasOneUse())
+      return failure();
+
+    // Check if the insert_slice op purely expands ranks (add unit dims).
+    if (!isCastLikeInsertSliceOp(insertSliceOp))
+      return failure();
+
+    llvm::SmallBitVector extractDroppedDims = extractSliceOp.getDroppedDims();
+    llvm::SmallBitVector insertExpandedDims = insertSliceOp.getDroppedDims();
+    // Can't fold if the insert_slice op expands to more dims.
+    if (extractDroppedDims.size() < insertExpandedDims.size())
+      return failure();
+
+    // Try to match the dropped unit dims to the expanded unit dims. This is
+    // done by scanning the dims of extract_slice and find the left-most one can
+    // match the dim of insert_slice. If a match is found, advance the dim of
+    // insert_slice to match the next one.
+    unsigned insertDimPos = 0;
+    for (unsigned extractDimPos = 0; extractDimPos < extractDroppedDims.size();
+         ++extractDimPos) {
+      // Matched all expanded dims.
+      if (insertDimPos == insertExpandedDims.size())
+        break;
+
+      bool isDropped = extractDroppedDims[extractDimPos];
+      bool isExpanded = insertExpandedDims[insertDimPos];
+      // Match if both sides drop/keep the dim. Advance and match the next dim
+      // of insert_slice.
+      if (isDropped == isExpanded) {
+        insertDimPos += 1;
+      } else if (!isDropped && isExpanded) {
+        // Not enough dropped unit dims to match the expanded unit dims.
+        return failure();
+      }
+      // If the dim is dropped by extract_slice and not by insert_slice, look
+      // the next dim of extract_slice to see if it can match the current dim of
+      // insert_slice.
+    }
+    // Can't match some expanded dims.
+    if (insertDimPos != insertExpandedDims.size())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+        insertSliceOp, insertSliceOp.getType(), extractSliceOp.getSource(),
+        extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(),
+        extractSliceOp.getMixedStrides());
+    rewriter.eraseOp(extractSliceOp);
+
+    return success();
+  }
+};
 } // namespace
 
 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
@@ -146,5 +226,7 @@ void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
 
 void mlir::tensor::populateDropRedundantInsertSliceRankExpansionPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<DropRedundantInsertSliceRankExpansion>(patterns.getContext());
+  patterns.add<DropRedundantRankExpansionOnExtractSliceOfInsertSlice,
+               DropRedundantRankExpansionOnInsertSliceOfExtractSlice>(
+      patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
index e337fdd9321424..88e55062f47702 100644
--- a/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
+++ b/mlir/test/Dialect/Tensor/drop-redundant-insert-slice-rank-expansion.mlir
@@ -9,3 +9,68 @@ func.func @test_drop_rank_expansion(%src: tensor<128x480xf32>, %dest: tensor<1x1
   %extracted_slice = tensor.extract_slice %inserted_slice[0, 0, 0, 0] [1, 1, 123, 456] [1, 1, 1, 1] : tensor<1x1x128x480xf32> to tensor<123x456xf32>
   return %extracted_slice : tensor<123x456xf32>
 }
+
+// -----
+
+func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
+  return %inserted_slice : tensor<8x1x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
+// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
+// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
+
+// -----
+
+func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
+  return %inserted_slice : tensor<1x4x8xf32>
+}
+// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
+// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
+// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
+
+// -----
+
+func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
+  return %inserted_slice : tensor<1x1x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
+
+// -----
+
+func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
+  return %inserted_slice : tensor<1x4x4xf32>
+}
+// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
+
+// -----
+
+func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
+  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
+  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
+  return %inserted_slice : tensor<2x8x8xf32>
+}
+// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
+// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
+// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
+// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
+// CHECK:         return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
diff --git a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
index bb1df99d4c97ee..f66f443eb91f19 100644
--- a/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
+++ b/mlir/test/Dialect/Tensor/fold-tensor-subset-ops.mlir
@@ -389,69 +389,4 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
     }
   }
   return %0: tensor<12x34xf32>
-}
-
-// -----
-
-func.func @fold_casting_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<8x1x8xf32>) -> tensor<8x1x8xf32> {
-  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
-  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [8, 1, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<8x1x8xf32>
-  return %inserted_slice : tensor<8x1x8xf32>
-}
-// CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
-// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
-// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
-// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
-// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
-
-// -----
-
-func.func @fold_casting_insert_slice_of_strided_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x8xf32>) -> tensor<1x4x8xf32> {
-  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1] : tensor<?x8x2x8xf32> to tensor<4x8xf32>
-  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 4, 8] [1, 1, 1] : tensor<4x8xf32> into tensor<1x4x8xf32>
-  return %inserted_slice : tensor<1x4x8xf32>
-}
-// CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
-// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
-// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
-// CHECK-SAME:      : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
-// CHECK:         return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
-
-// -----
-
-func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(%in : tensor<?x8x8xf32>, %dest : tensor<1x1x8x8xf32>) -> tensor<1x1x8x8xf32> {
-  %extracted_slice = tensor.extract_slice %in[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<?x8x8xf32> to tensor<8x8xf32>
-  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : tensor<8x8xf32> into tensor<1x1x8x8xf32>
-  return %inserted_slice : tensor<1x1x8x8xf32>
-}
-// CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
-// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x8xf32>
-// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
-// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
-// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
-
-// -----
-
-func.func @no_fold_strided_insert_slice_of_extract_slice(%in : tensor<?x8x2x8xf32>, %dest : tensor<1x4x4xf32>) -> tensor<1x4x4xf32> {
-  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1] : tensor<?x8x2x8xf32> to tensor<8x8xf32>
-  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 2, 2] : tensor<8x8xf32> into tensor<1x4x4xf32>
-  return %inserted_slice : tensor<1x4x4xf32>
-}
-// CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
-// CHECK-SAME:      %[[ARG0:.*]]: tensor<?x8x2x8xf32>
-// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
-// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
-// CHECK:         return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
-
-// -----
-
-func.func @no_fold_non_casting_insert_slice_of_extract_slice(%in : tensor<1x1x1x8x8xf32>, %dest : tensor<2x8x8xf32>) -> tensor<2x8x8xf32> {
-  %extracted_slice = tensor.extract_slice %in[0, 0, 0, 0, 0] [1, 1, 1, 8, 8] [1, 1, 1, 1, 1] : tensor<1x1x1x8x8xf32> to tensor<8x8xf32>
-  %inserted_slice = tensor.insert_slice %extracted_slice into %dest[0, 0, 0] [1, 8, 8] [1, 1, 1] : tensor<8x8xf32> into tensor<2x8x8xf32>
-  return %inserted_slice : tensor<2x8x8xf32>
-}
-// CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
-// CHECK-SAME:      %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
-// CHECK:         %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
-// CHECK:         %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
-// CHECK:         return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
+}
\ No newline at end of file



More information about the Mlir-commits mailing list