[Mlir-commits] [mlir] [MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape (PR #131982)

ofri frishman llvmlistbot at llvm.org
Tue Mar 25 07:08:49 PDT 2025


https://github.com/ofri-frishman updated https://github.com/llvm/llvm-project/pull/131982

>From 72fbf710523e18d544b713652a6ee020f3063161 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Wed, 19 Mar 2025 09:02:42 +0200
Subject: [PATCH 1/5] [MLIR] Bubble up tensor.extract_slice through
 tensor.collapse_shape

Add a pattern that bubbles up tensor.extract_slice through
tensor.collapse_shape.
The pattern is registered in a pattern population function
that is used by the transform op
transform.apply_patterns.tensor.bubble_up_extract_slice
and by the tranform op transform.structured.fuse as a
cleanup pattern.
This pattern enables tiling and fusing op chains which contain
tensor.collapse_shape if added as a cleanup pattern of tile and fuse
utility.
Without this pattern that would not be possible, as
tensor.collapse_shape does not implement the tiling interface.
This is an additional pattern to the one added in PR #126898
---
 .../Tensor/Transforms/ReshapePatterns.cpp     | 189 +++++++++++++++++-
 .../Dialect/Linalg/transform-op-fuse.mlir     |  50 +++++
 .../Tensor/bubble-up-extract-slice-op.mlir    | 153 ++++++++++++++
 3 files changed, 391 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index acedf51d0e240..efa4d10817e39 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -12,8 +12,10 @@
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/LogicalResult.h"
+#include <algorithm>
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -428,6 +430,190 @@ struct BubbleUpExpandShapeThroughExtractSlice
   }
 };
 
+/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
+/// `tensor.extract_slice(tensor.collapse_shape)`.
+///
+/// For this transformation to be possible, the slice must be representable as a
+/// contiguous slice within each reassociation group of the src.
+///
+/// In case the size and offset extracted are static then this is possible if
+/// the following conditions are met:
+/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
+/// be the shape of a desired slice. A slice of shape S can be extracted as a
+/// contiguous block of memory if and only if there exists an index k in {0, 1,
+/// ..., n} such that:
+///      S_i = 1 for all i < k (that is, all leading dimensions are singleton),
+///      1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
+///                       one dimension),
+///      S_i = A_i for all i > k (that is, all trailing dimensions are preserved
+///      in full).
+/// In other words, the slice shape S must be of the form:
+/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
+///
+/// In case the size and/or offset extracted are dynamic then this is possible
+/// only if there is single dimension in the reassociation group that has a size
+/// not equal to 1.
+/// In other words, the tensor shape must be of the form:
+/// [ 1, 1, ..., 1, A, 1, ...,1 ]
+/// Note - it might be possible to enable this pattern for more cases when the
+/// size/offset are dynamic via performing an analysis of the possible values
+/// that could be given to the size/offset.
+///
+/// Example:
+/// The transformation is possible because each reassociation group can be
+/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
+/// [20->10]).
+/// ```
+/// BEFORE:
+/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
+///     tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
+/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
+///     tensor<128x7x20xf32> to tensor<32x?x10xf32>
+///
+/// AFTER:
+/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
+//           [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
+/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
+///     tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
+/// ```
+struct BubbleUpCollapseShapeThroughExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseShapeOp)
+      return rewriter.notifyMatchFailure(
+          sliceOp,
+          "tensor.extract_slice source not produced by tensor.collapse_shape");
+
+    if (!sliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
+                   "be supported in this transformation.");
+    }
+
+    // The tensor.extract_slice before applying the pattern works on the result
+    // of the tensor.collapse_shape, so variables (i.e. inputs for
+    // ExtractSliceOp) referring to the state before applying the pattern are
+    // named with the prefix "collapsed", and ones referring to the state after
+    // applying the pattern are named with the prefix "expanded".
+    SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+
+    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+        collapsedSizes.size())
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "unimplemented: rank reducing slice");
+
+    ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
+    SmallVector<ReassociationIndices, 4> reassociationIndices =
+        collapseShapeOp.getReassociationIndices();
+
+    // Compute new offsets, sizes, and strides for tensor.extract_slice.
+    // The new tensor.extract_slice will work on a tensor that has has a rank
+    // equal to the rank of the src of the collapse_shape. In each iteration of
+    // the loop, the offsets and sizes will be computed per reassociation group.
+    SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
+    SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
+                                              rewriter.getIndexAttr(1));
+
+    for (auto [groupIdx, reassocIndices] :
+         enumerate(collapseShapeOp.getReassociationIndices())) {
+      OpFoldResult collapsedSize = collapsedSizes[groupIdx];
+      OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
+      // Case #1 - size and/or offset are dynamic.
+      // In this case, the slice can be represented as a contiguous slice only
+      // if there is a single dimension in the reassociation group that has a
+      // size not equal to 1.
+      if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+        int nonUnitSizeCount = 0;
+        for (int64_t expandedShapeIdx : reassocIndices) {
+          if (srcShape[expandedShapeIdx] != 1) {
+            nonUnitSizeCount++;
+            expandedSizes.emplace_back(collapsedSize);
+            expandedOffsets.emplace_back(collapsedOffset);
+            continue;
+          }
+
+          expandedSizes.emplace_back(rewriter.getIndexAttr(1));
+          expandedOffsets.emplace_back(rewriter.getIndexAttr(0));
+        }
+
+        if (nonUnitSizeCount != 1) {
+          return rewriter.notifyMatchFailure(
+              sliceOp,
+              "unsupported: slice cannot be verified to be contiguous");
+        }
+        continue;
+      }
+
+      // Case #2 = size and offset are static.
+      // Verify that the slice can be represented as a contiguous slice of the
+      // src of the collapse_shape.
+      // Checking this must be done on order of most
+      // internal dimensions first, so traversal is done in reverse order of the
+      // reassociation group.
+      int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
+      int64_t collapsedOffsetValue =
+          getConstantIntValue(collapsedOffset).value();
+
+      SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+
+      for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
+        int64_t expandedShapeSize = srcShape[expandedShapeIdx];
+
+        // This is a dimension that slicing will occur on, so need to make sure
+        // that the slice size can be set to the shape size and the offset to 0.
+        if (collapsedSizeValue >= expandedShapeSize &&
+            (collapsedSizeValue % expandedShapeSize != 0 ||
+             collapsedOffsetValue % expandedShapeSize != 0)) {
+          return rewriter.notifyMatchFailure(
+              sliceOp, "unsupported: cannot be extracted as a contiguous slice "
+                       "of the src of the collapse_shape");
+        }
+
+        int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+
+        // This is the dimension that slicing will occur along, so need to make
+        // sure that the slice size + offset will not exceed the shape size.
+        if (collapsedSizeValue < expandedShapeSize &&
+            (collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
+          return rewriter.notifyMatchFailure(
+              sliceOp, "unsupported: slice cannot be extracted as a contiguous "
+                       "slice of the src of the collapse_shape");
+        }
+
+        groupExpandedSizes.emplace_back(rewriter.getIndexAttr(
+            std::min(collapsedSizeValue, expandedShapeSize)));
+        groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim));
+
+        // Remove the size and offset of trailing dimensions from the size and
+        // offset of the slice.
+        collapsedSizeValue /= expandedShapeSize;
+        collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1);
+        collapsedOffsetValue /= expandedShapeSize;
+      }
+
+      expandedSizes.append(groupExpandedSizes.rbegin(),
+                           groupExpandedSizes.rend());
+      expandedOffsets.append(groupExpandedOffsets.rbegin(),
+                             groupExpandedOffsets.rend());
+    }
+
+    Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+        collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
+        expandedSizes, expandedStrides);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        sliceOp, sliceOp.getResultType(), newSliceOp,
+        collapseShapeOp.getReassociationIndices());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -448,5 +634,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
 
 void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
+  patterns.add<BubbleUpExpandShapeThroughExtractSlice,
+               BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 9bcc125ce1ba9..441020f1cddfc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -438,3 +438,53 @@ module attributes {transform.with_named_sequence} {
     transform.yield 
   }
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape(
+// CHECK:      scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) {
+// CHECK:             %[[EXTRACT1:.*]] = tensor.extract_slice
+// CHECK:             %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[EXTRACT1]]
+// CHECK:             %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE1]]
+func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
+  %expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
+  %empty = tensor.empty() : tensor<8x1800x32xf32>
+  %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
+  return %exp : tensor<8x1800x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+    transform.yield 
+  }
+}
+
+
+// -----
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
+// CHECK:           scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+// CHECK:             %[[VAL_9:.*]] = tensor.extract_slice
+// CHECK:             %[[VAL_11:.*]] = linalg.abs ins(%[[VAL_9]]
+// CHECK:             %[[VAL_12:.*]] = tensor.collapse_shape %[[VAL_11]]
+// CHECK:             %[[VAL_14:.*]] = linalg.exp ins(%[[VAL_12]]
+func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
+  %empty1 = tensor.empty() : tensor<1x8x1800x32xf32>
+  %abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32>
+  %expand = tensor.collapse_shape %abs [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
+  %empty2 = tensor.empty() : tensor<8x1800x32xf32>
+  %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty2 : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
+  return %exp : tensor<8x1800x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+    transform.yield 
+  }
+}
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index 3900bc56f433d..d05bf1bf76f29 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -113,6 +113,159 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>,
   return %extract : tensor<?x5x2xf32>
 }
 
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(
+// CHECK-SAME:                   %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1]
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK:           return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%src: tensor<6x5x2xf32>) -> tensor<1xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32>
+  %extract = tensor.extract_slice %collapse[0][1][1] : tensor<60xf32> to tensor<1xf32>
+  return %extract : tensor<1xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(
+// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
+// CHECK:           %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1]
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3]]
+// CHECK:           return %[[VAL_2]]
+func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<6x5x3x10xf32> into tensor<30x30xf32>
+  %extract = tensor.extract_slice %collapse[5, 10][15, 10][1, 1] : tensor<30x30xf32> to tensor<15x10xf32>
+  return %extract : tensor<15x10xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(
+// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> {
+// CHECK:           %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][2, 0, 0] [1, 2, 2] [1, 1, 1] 
+// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2]]
+// CHECK:           return %[[VAL_2]]
+func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32>
+  %extract = tensor.extract_slice %collapse[20][4][1] : tensor<60xf32> to tensor<4xf32>
+  return %extract : tensor<4xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(
+// CHECK-SAME:                    %[[SRC:.*]]: tensor<1x5x1xf32>,
+// CHECK-SAME:                    %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1]
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK:           return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(%src: tensor<1x5x1xf32>, %size : index) -> tensor<?xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32>
+  %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<5xf32> to tensor<?xf32>
+  return %extract : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(
+// CHECK-SAME:                    %[[SRC:.*]]: tensor<1x?x1xf32>,
+// CHECK-SAME:                    %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1]
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK:           return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(%src: tensor<1x?x1xf32>, %size : index) -> tensor<?xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x?x1xf32> into tensor<?xf32>
+  %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<?xf32> to tensor<?xf32>
+  return %extract : tensor<?xf32>
+}
+
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(
+// CHECK-SAME:                       %[[SRC:.*]]: tensor<1x5x1xf32>,
+// CHECK-SAME:                       %[[OFFSET:.*]]: index) -> tensor<3xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, %[[OFFSET]], 0] [1, 3, 1] [1, 1, 1]
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK:           return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(%src: tensor<1x5x1xf32>, %offset : index) -> tensor<3xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32>
+  %extract = tensor.extract_slice %collapse[%offset][3][1] : tensor<5xf32> to tensor<3xf32>
+  return %extract : tensor<3xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size(
+// CHECK-SAME:                     %[[SRC:.*]]: tensor<14x1xf32>,
+// CHECK-SAME:                     %[[OFFSET:.*]]: index,
+// CHECK-SAME:                     %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[OFFSET]], 0] {{\[}}%[[SIZE]], 1] [1, 1]
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1]]
+// CHECK:           return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size(%src: tensor<14x1xf32>, %offset : index, %size : index) -> tensor<?xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<14x1xf32> into tensor<14xf32>
+  %extract = tensor.extract_slice %collapse[%offset][%size][1] : tensor<14xf32> to tensor<?xf32>
+  return %extract : tensor<?xf32>
+}
+
+// CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(
+// CHECK-SAME:                      %[[SRC:.*]]: tensor<5x10x1x1x40xf32>,
+// CHECK-SAME:                      %[[OFFSET:.*]]: index,
+// CHECK-SAME:                      %[[SIZE:.*]]: index) -> tensor<20x?xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 0, 0, %[[OFFSET]]] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1]
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3, 4]]
+// CHECK:           return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(%src: tensor<5x10x1x1x40xf32>, %offset : index, %size : index) -> tensor<20x?xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1], [2, 3, 4]] : tensor<5x10x1x1x40xf32> into tensor<50x40xf32>
+  %extract = tensor.extract_slice %collapse[10, %offset][20, %size][1, 1] : tensor<50x40xf32> to tensor<20x?xf32>
+  return %extract : tensor<20x?xf32>
+}
+
+// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
+// shape cannot be defined as a contiguous size in the expanded shape due to size extracted not being suited
+// for the expanded shape.
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(
+// CHECK-SAME:                              %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<15xf32> {
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(%src: tensor<2x3x10xf32>) -> tensor<15xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+  %extract = tensor.extract_slice %collapse[0][15][1] : tensor<60xf32> to tensor<15xf32>
+  return %extract : tensor<15xf32>
+}
+
+// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
+// shape cannot be defined as a contiguous size in the expanded shape due to an unsuitable offset even though
+// the size extracted is suited for the expanded shape.
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(
+// CHECK-SAME:                              %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<20xf32> {
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(%src: tensor<2x3x10xf32>) -> tensor<20xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+  %extract = tensor.extract_slice %collapse[20][20][1] : tensor<60xf32> to tensor<20xf32>
+  return %extract : tensor<20xf32>
+}
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride(
+// CHECK-SAME:                              %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<5xf32> {
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride(%src: tensor<2x3x10xf32>) -> tensor<5xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+  %extract = tensor.extract_slice %collapse[0][5][2] : tensor<60xf32> to tensor<5xf32>
+  return %extract : tensor<5xf32>
+}
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing(
+// CHECK-SAME:                              %[[SRC:.*]]: tensor<6x5x2x1xf32>) -> tensor<1xf32> {
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing(%src: tensor<6x5x2x1xf32>) -> tensor<1xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2], [3]] : tensor<6x5x2x1xf32> into tensor<60x1xf32>
+  %extract = tensor.extract_slice %collapse[0, 0][1, 1][1, 1] : tensor<60x1xf32> to tensor<1xf32>
+  return %extract : tensor<1xf32>
+}
+
+// CHECK-LABEL:   func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic(
+// CHECK-SAME:                              %[[SRC:.*]]: tensor<1x5x2xf32>,
+// CHECK-SAME:                              %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic(%src: tensor<1x5x2xf32>, %size : index) -> tensor<?xf32> {
+  %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x2xf32> into tensor<10xf32>
+  %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<10xf32> to tensor<?xf32>
+  return %extract : tensor<?xf32>
+}
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

>From c0291d068b66e6b450521813ebb2ad9a3a4c75e2 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Sun, 23 Mar 2025 12:08:23 +0200
Subject: [PATCH 2/5] Updates for code review

---
 .../Tensor/Transforms/ReshapePatterns.cpp     | 13 ++++----
 .../Dialect/Linalg/transform-op-fuse.mlir     | 14 ++++----
 .../Tensor/bubble-up-extract-slice-op.mlir    | 33 +++++++++++++++----
 3 files changed, 40 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index efa4d10817e39..cc73011151b17 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -15,7 +15,6 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/LogicalResult.h"
-#include <algorithm>
 
 using namespace mlir;
 using namespace mlir::tensor;
@@ -533,13 +532,13 @@ struct BubbleUpCollapseShapeThroughExtractSlice
         for (int64_t expandedShapeIdx : reassocIndices) {
           if (srcShape[expandedShapeIdx] != 1) {
             nonUnitSizeCount++;
-            expandedSizes.emplace_back(collapsedSize);
-            expandedOffsets.emplace_back(collapsedOffset);
+            expandedSizes.push_back(collapsedSize);
+            expandedOffsets.push_back(collapsedOffset);
             continue;
           }
 
-          expandedSizes.emplace_back(rewriter.getIndexAttr(1));
-          expandedOffsets.emplace_back(rewriter.getIndexAttr(0));
+          expandedSizes.push_back(rewriter.getIndexAttr(1));
+          expandedOffsets.push_back(rewriter.getIndexAttr(0));
         }
 
         if (nonUnitSizeCount != 1) {
@@ -586,9 +585,9 @@ struct BubbleUpCollapseShapeThroughExtractSlice
                        "slice of the src of the collapse_shape");
         }
 
-        groupExpandedSizes.emplace_back(rewriter.getIndexAttr(
+        groupExpandedSizes.push_back(rewriter.getIndexAttr(
             std::min(collapsedSizeValue, expandedShapeSize)));
-        groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim));
+        groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
 
         // Remove the size and offset of trailing dimensions from the size and
         // offset of the slice.
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 441020f1cddfc..d7339fa3c0be4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -443,9 +443,9 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape(
 // CHECK:      scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) {
-// CHECK:             %[[EXTRACT1:.*]] = tensor.extract_slice
-// CHECK:             %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[EXTRACT1]]
-// CHECK:             %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE1]]
+// CHECK:             %[[EXTRACT:.*]] = tensor.extract_slice
+// CHECK:             %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]]
+// CHECK:             %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE]]
 func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
   %expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
   %empty = tensor.empty() : tensor<8x1800x32xf32>
@@ -467,10 +467,10 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
 // CHECK:           scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
-// CHECK:             %[[VAL_9:.*]] = tensor.extract_slice
-// CHECK:             %[[VAL_11:.*]] = linalg.abs ins(%[[VAL_9]]
-// CHECK:             %[[VAL_12:.*]] = tensor.collapse_shape %[[VAL_11]]
-// CHECK:             %[[VAL_14:.*]] = linalg.exp ins(%[[VAL_12]]
+// CHECK:             %[[EXTRACT:.*]] = tensor.extract_slice
+// CHECK:             %[[ABS:.*]] = linalg.abs ins(%[[EXTRACT]]
+// CHECK:             %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]]
+// CHECK:             %[[EXP:.*]] = linalg.exp ins(%[[COLLAPSE]]
 func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
   %empty1 = tensor.empty() : tensor<1x8x1800x32xf32>
   %abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32>
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index d05bf1bf76f29..c0755d7125091 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -1,5 +1,15 @@
 // RUN: mlir-opt -split-input-file -transform-interpreter  %s | FileCheck %s
 
+///----------------------------------------------------------------------------------------
+/// [Pattern: BubbleUpExpandShapeThroughExtractSlice]
+///
+/// IN: tensor.expand_shape(tensor.extract_slice)
+/// OUT:tensor.extract_slice(tensor.expand_shape)
+///
+/// Note: tensor.extract_slice is bubbled up to be before tensor.expand_shape.
+///       Some tests are negative tests for cases where the pattern cannot be applied.
+///----------------------------------------------------------------------------------------
+
 // CHECK-LABEL:   func.func @bubble_up_extract_slice_through_expand_shape(
 // CHECK-SAME:                %[[SRC:.*]]: tensor<60xf32>) -> tensor<1x1x5xf32> {
 // CHECK:           %[[C1:.+]] = arith.constant 5 : index
@@ -113,6 +123,16 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>,
   return %extract : tensor<?x5x2xf32>
 }
 
+///----------------------------------------------------------------------------------------
+/// [Pattern: BubbleUpCollapseShapeThroughExtractSlice]
+///
+/// IN: tensor.collapse_shape(tensor.extract_slice)
+/// OUT:tensor.extract_slice(tensor.collapse_shape)
+///
+/// Note: tensor.extract_slice is bubbled up to be before tensor.collapse_shape.
+///       Some tests are negative tests for cases where the pattern cannot be applied.
+///----------------------------------------------------------------------------------------
+
 // CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(
 // CHECK-SAME:                   %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> {
 // CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1]
@@ -209,9 +229,13 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_gro
   return %extract : tensor<20x?xf32>
 }
 
-// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
-// shape cannot be defined as a contiguous size in the expanded shape due to size extracted not being suited
-// for the expanded shape.
+/// The 2 following tests are cases where the bubble up cannot occur because the contiguous size extracted 
+/// from the collapsed shape cannot be expressed via a single extract_slice op.
+/// In the first test it is because the size extracted cannot be expressed as a slice
+/// of the form [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] (see the pattern documentation for more details).
+/// In the second test, the size can be expressed as the required form, but the offset is such that the pattern
+/// cannot be applied.
+
 // CHECK-LABEL:   func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(
 // CHECK-SAME:                              %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<15xf32> {
 // CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape
@@ -222,9 +246,6 @@ func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1
   return %extract : tensor<15xf32>
 }
 
-// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
-// shape cannot be defined as a contiguous size in the expanded shape due to an unsuitable offset even though
-// the size extracted is suited for the expanded shape.
 // CHECK-LABEL:   func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(
 // CHECK-SAME:                              %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<20xf32> {
 // CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape

>From 72b0be3bd808f94bcf062cdcfee684be553bf507 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Mon, 24 Mar 2025 12:13:56 +0200
Subject: [PATCH 3/5] Updates for CR

---
 .../Tensor/Transforms/ReshapePatterns.cpp     | 111 +++++++++++++-----
 .../Dialect/Linalg/transform-op-fuse.mlir     |   1 -
 2 files changed, 80 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index cc73011151b17..6ec7c58f85b4e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -429,17 +429,19 @@ struct BubbleUpExpandShapeThroughExtractSlice
   }
 };
 
-/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
-/// `tensor.extract_slice(tensor.collapse_shape)`.
+/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
+///          `tensor.collapse_shape(tensor.extract_slice)`.
 ///
-/// For this transformation to be possible, the slice must be representable as a
-/// contiguous slice within each reassociation group of the src.
+/// For this transformation to be possible - after bubbling up, the extraction
+/// of the contiguous slice must be representable as a single slice obtained via
+/// tensor.extract_slice within each reassociation group of the src.
 ///
 /// In case the size and offset extracted are static then this is possible if
-/// the following conditions are met:
-/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
-/// be the shape of a desired slice. A slice of shape S can be extracted as a
-/// contiguous block of memory if and only if there exists an index k in {0, 1,
+/// the following conditions are met within each reassociation group:
+/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
+/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
+/// shape of a desired slice. A slice of shape S can be extracted as a
+/// contiguous span of elements if and only if there exists an index k in {0, 1,
 /// ..., n} such that:
 ///      S_i = 1 for all i < k (that is, all leading dimensions are singleton),
 ///      1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
@@ -475,6 +477,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
 /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
 ///     tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
 /// ```
+///
+/// Negative example:
+/// The transformation is not possible because we cannot use a single slice to
+/// represent the reassociation group [2x3x10->???]. If we would want the
+/// collapse to be after the extraction, we would need to extract multiple
+/// slices and concat them together.
+/// ```
+/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
+/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
+///                                      tensor<60xf32> to tensor<15xf32>
+/// ```
+/// If we would want the collapse to be after the extraction, a possible
+/// alternate transformation could be to extract multiple slices and concat them
+/// together:
+/// ```
+/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
+///                               tensor<2x3x10xf32> to tensor <1x1x10xf32>
+/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
+///                               tensor<2x3x10xf32> to tensor <1x1x5xf32>
+/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
+///                    (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
+/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
+///                                                       to tensor<15xf32>
+/// ```
+/// But this is not the intended purpose of the transformation.
 struct BubbleUpCollapseShapeThroughExtractSlice
     : public OpRewritePattern<tensor::ExtractSliceOp> {
   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -552,47 +579,69 @@ struct BubbleUpCollapseShapeThroughExtractSlice
       // Case #2 = size and offset are static.
       // Verify that the slice can be represented as a contiguous slice of the
       // src of the collapse_shape.
-      // Checking this must be done on order of most
-      // internal dimensions first, so traversal is done in reverse order of the
-      // reassociation group.
+      // Checking this is done on order of most internal dimensions first,
+      // so traversal is done in reverse order of the reassociation group.
+      // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+      // ...,An] then we first find the size and offset for n...k+1 then for k
+      // and then for k-1...0.
       int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
       int64_t collapsedOffsetValue =
           getConstantIntValue(collapsedOffset).value();
 
       SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
 
-      for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
-        int64_t expandedShapeSize = srcShape[expandedShapeIdx];
+      ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+                                                  reassocIndices.rend());
+      int64_t idx = 0;
+      int64_t reassocGroupSize = reassocIndices.size();
+
+      // First handle the trailing dimensions where the slice size should be
+      // equal to the tensor shape and the offset should be 0 (n...k+1).
+      for (; idx < reassocGroupSize; ++idx) {
+        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
 
-        // This is a dimension that slicing will occur on, so need to make sure
-        // that the slice size can be set to the shape size and the offset to 0.
-        if (collapsedSizeValue >= expandedShapeSize &&
-            (collapsedSizeValue % expandedShapeSize != 0 ||
-             collapsedOffsetValue % expandedShapeSize != 0)) {
+        if (collapsedSizeValue < expandedShapeSize)
+          break;
+
+        // We need to make sure that the slice size can be set to the shape size
+        // and the offset to 0.
+        if ((collapsedSizeValue % expandedShapeSize) != 0 ||
+            (collapsedOffsetValue % expandedShapeSize) != 0)
           return rewriter.notifyMatchFailure(
               sliceOp, "unsupported: cannot be extracted as a contiguous slice "
                        "of the src of the collapse_shape");
-        }
 
-        int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+        groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
+        groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+
+        collapsedSizeValue /= expandedShapeSize;
+        collapsedOffsetValue /= expandedShapeSize;
+      }
 
-        // This is the dimension that slicing will occur along, so need to make
-        // sure that the slice size + offset will not exceed the shape size.
-        if (collapsedSizeValue < expandedShapeSize &&
-            (collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
+      // Now handle the first dim where slicing occurs on (k).
+      if (idx < reassocGroupSize) {
+        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
+        int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+        // We need to make sure that the slice size in this dim + offset will
+        // not exceed the shape size.
+        if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
           return rewriter.notifyMatchFailure(
               sliceOp, "unsupported: slice cannot be extracted as a contiguous "
                        "slice of the src of the collapse_shape");
-        }
 
-        groupExpandedSizes.push_back(rewriter.getIndexAttr(
-            std::min(collapsedSizeValue, expandedShapeSize)));
+        groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue));
         groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
 
-        // Remove the size and offset of trailing dimensions from the size and
-        // offset of the slice.
-        collapsedSizeValue /= expandedShapeSize;
-        collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1);
+        collapsedOffsetValue /= expandedShapeSize;
+      }
+
+      // Now handle the leading dimensions where the slice size is equal to 1
+      // (k-1...0).
+      for (idx++; idx < reassocGroupSize; ++idx) {
+        int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
+        int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+        groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
+        groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
         collapsedOffsetValue /= expandedShapeSize;
       }
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index d7339fa3c0be4..962858076db93 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -462,7 +462,6 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
-
 // -----
 
 // CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(

>From 1aaf3c907272a1bf1d4a2a870e93565079390d07 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Tue, 25 Mar 2025 10:45:32 +0200
Subject: [PATCH 4/5] Fix lit test checker variable names

---
 .../Tensor/bubble-up-extract-slice-op.mlir       | 16 ++++++++--------
 1 file changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index c0755d7125091..34128d6a5ec8b 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -145,10 +145,10 @@ func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%
 }
 
 // CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(
-// CHECK-SAME:                      %[[VAL_0:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
-// CHECK:           %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1]
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3]]
-// CHECK:           return %[[VAL_2]]
+// CHECK-SAME:                      %[[SRC:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1]
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3]]
+// CHECK:           return %[[COLLAPSE]]
 func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
   %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<6x5x3x10xf32> into tensor<30x30xf32>
   %extract = tensor.extract_slice %collapse[5, 10][15, 10][1, 1] : tensor<30x30xf32> to tensor<15x10xf32>
@@ -156,10 +156,10 @@ func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group
 }
 
 // CHECK-LABEL:   func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(
-// CHECK-SAME:                         %[[VAL_0:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> {
-// CHECK:           %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][2, 0, 0] [1, 2, 2] [1, 1, 1] 
-// CHECK:           %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2]]
-// CHECK:           return %[[VAL_2]]
+// CHECK-SAME:                         %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> {
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][2, 0, 0] [1, 2, 2] [1, 1, 1] 
+// CHECK:           %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK:           return %[[COLLAPSE]]
 func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> {
   %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32>
   %extract = tensor.extract_slice %collapse[20][4][1] : tensor<60xf32> to tensor<4xf32>

>From 5845db6a8de787a5df0c684f85286ae3a4ddf4f4 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Tue, 25 Mar 2025 16:08:09 +0200
Subject: [PATCH 5/5] Additional clarifications

---
 .../Tensor/Transforms/ReshapePatterns.cpp     | 44 ++++++++++++-------
 1 file changed, 29 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 6ec7c58f85b4e..0a7fcba7a71cd 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -550,7 +550,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice
          enumerate(collapseShapeOp.getReassociationIndices())) {
       OpFoldResult collapsedSize = collapsedSizes[groupIdx];
       OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
-      // Case #1 - size and/or offset are dynamic.
+      // CASE #1 - size and/or offset are dynamic.
       // In this case, the slice can be represented as a contiguous slice only
       // if there is a single dimension in the reassociation group that has a
       // size not equal to 1.
@@ -576,7 +576,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice
         continue;
       }
 
-      // Case #2 = size and offset are static.
+      // CASE #2 = size and offset are static.
       // Verify that the slice can be represented as a contiguous slice of the
       // src of the collapse_shape.
       // Checking this is done on order of most internal dimensions first,
@@ -584,8 +584,16 @@ struct BubbleUpCollapseShapeThroughExtractSlice
       // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
       // ...,An] then we first find the size and offset for n...k+1 then for k
       // and then for k-1...0.
-      int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
-      int64_t collapsedOffsetValue =
+
+      // currentCollapsedsize and currentCollapsedOffset are initialized with
+      // the original collapsed size and offset and divided by the expanded
+      // shape size in each dimension as we go along the reassociation group.
+      // In essence we are spreading the original collapsed size and offset over
+      // the various expanded slice dimensions.
+      // The variables are used both to check the validity of the slice and to
+      // compute the expanded sizes and offsets.
+      int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
+      int64_t currentCollapsedOffset =
           getConstantIntValue(collapsedOffset).value();
 
       SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
@@ -600,13 +608,13 @@ struct BubbleUpCollapseShapeThroughExtractSlice
       for (; idx < reassocGroupSize; ++idx) {
         int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
 
-        if (collapsedSizeValue < expandedShapeSize)
+        if (currentCollapsedsize < expandedShapeSize)
           break;
 
         // We need to make sure that the slice size can be set to the shape size
         // and the offset to 0.
-        if ((collapsedSizeValue % expandedShapeSize) != 0 ||
-            (collapsedOffsetValue % expandedShapeSize) != 0)
+        if ((currentCollapsedsize % expandedShapeSize) != 0 ||
+            (currentCollapsedOffset % expandedShapeSize) != 0)
           return rewriter.notifyMatchFailure(
               sliceOp, "unsupported: cannot be extracted as a contiguous slice "
                        "of the src of the collapse_shape");
@@ -614,35 +622,41 @@ struct BubbleUpCollapseShapeThroughExtractSlice
         groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
         groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
 
-        collapsedSizeValue /= expandedShapeSize;
-        collapsedOffsetValue /= expandedShapeSize;
+        currentCollapsedsize /= expandedShapeSize;
+        currentCollapsedOffset /= expandedShapeSize;
       }
 
       // Now handle the first dim where slicing occurs on (k).
       if (idx < reassocGroupSize) {
         int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-        int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+        int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
         // We need to make sure that the slice size in this dim + offset will
         // not exceed the shape size.
-        if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
+        if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
           return rewriter.notifyMatchFailure(
               sliceOp, "unsupported: slice cannot be extracted as a contiguous "
                        "slice of the src of the collapse_shape");
 
-        groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue));
+        groupExpandedSizes.push_back(
+            rewriter.getIndexAttr(currentCollapsedsize));
         groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
 
-        collapsedOffsetValue /= expandedShapeSize;
+        currentCollapsedOffset /= expandedShapeSize;
       }
 
       // Now handle the leading dimensions where the slice size is equal to 1
       // (k-1...0).
+      // The size for these dimensions must be 1 because of how we constructed
+      // the slice size of the expanded shape. We spread the original collapsed
+      // size over the expanded shape sizes until we reached dimension k where
+      // the remaining size was smaller than the expanded shape size, and spread
+      // the remaining size on it. So, now we are left with only 1s.
       for (idx++; idx < reassocGroupSize; ++idx) {
         int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-        int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+        int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
         groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
         groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
-        collapsedOffsetValue /= expandedShapeSize;
+        currentCollapsedOffset /= expandedShapeSize;
       }
 
       expandedSizes.append(groupExpandedSizes.rbegin(),



More information about the Mlir-commits mailing list