[Mlir-commits] [mlir] [MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape (PR #131982)
ofri frishman
llvmlistbot at llvm.org
Tue Mar 25 07:08:49 PDT 2025
https://github.com/ofri-frishman updated https://github.com/llvm/llvm-project/pull/131982
>From 72fbf710523e18d544b713652a6ee020f3063161 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Wed, 19 Mar 2025 09:02:42 +0200
Subject: [PATCH 1/5] [MLIR] Bubble up tensor.extract_slice through
tensor.collapse_shape
Add a pattern that bubbles up tensor.extract_slice through
tensor.collapse_shape.
The pattern is registered in a pattern population function
that is used by the transform op
transform.apply_patterns.tensor.bubble_up_extract_slice
and by the tranform op transform.structured.fuse as a
cleanup pattern.
This pattern enables tiling and fusing op chains which contain
tensor.collapse_shape if added as a cleanup pattern of tile and fuse
utility.
Without this pattern that would not be possible, as
tensor.collapse_shape does not implement the tiling interface.
This is an additional pattern to the one added in PR #126898
---
.../Tensor/Transforms/ReshapePatterns.cpp | 189 +++++++++++++++++-
.../Dialect/Linalg/transform-op-fuse.mlir | 50 +++++
.../Tensor/bubble-up-extract-slice-op.mlir | 153 ++++++++++++++
3 files changed, 391 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index acedf51d0e240..efa4d10817e39 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -12,8 +12,10 @@
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
+#include <algorithm>
using namespace mlir;
using namespace mlir::tensor;
@@ -428,6 +430,190 @@ struct BubbleUpExpandShapeThroughExtractSlice
}
};
+/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
+/// `tensor.extract_slice(tensor.collapse_shape)`.
+///
+/// For this transformation to be possible, the slice must be representable as a
+/// contiguous slice within each reassociation group of the src.
+///
+/// In case the size and offset extracted are static then this is possible if
+/// the following conditions are met:
+/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
+/// be the shape of a desired slice. A slice of shape S can be extracted as a
+/// contiguous block of memory if and only if there exists an index k in {0, 1,
+/// ..., n} such that:
+/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
+/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
+/// one dimension),
+/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
+/// in full).
+/// In other words, the slice shape S must be of the form:
+/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
+///
+/// In case the size and/or offset extracted are dynamic then this is possible
+/// only if there is single dimension in the reassociation group that has a size
+/// not equal to 1.
+/// In other words, the tensor shape must be of the form:
+/// [ 1, 1, ..., 1, A, 1, ...,1 ]
+/// Note - it might be possible to enable this pattern for more cases when the
+/// size/offset are dynamic via performing an analysis of the possible values
+/// that could be given to the size/offset.
+///
+/// Example:
+/// The transformation is possible because each reassociation group can be
+/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
+/// [20->10]).
+/// ```
+/// BEFORE:
+/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
+/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
+/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
+/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
+///
+/// AFTER:
+/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
+// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
+/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
+/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
+/// ```
+struct BubbleUpCollapseShapeThroughExtractSlice
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp =
+ sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseShapeOp)
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "tensor.extract_slice source not produced by tensor.collapse_shape");
+
+ if (!sliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
+ "be supported in this transformation.");
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.collapse_shape, so variables (i.e. inputs for
+ // ExtractSliceOp) referring to the state before applying the pattern are
+ // named with the prefix "collapsed", and ones referring to the state after
+ // applying the pattern are named with the prefix "expanded".
+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+ collapsedSizes.size())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "unimplemented: rank reducing slice");
+
+ ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
+ SmallVector<ReassociationIndices, 4> reassociationIndices =
+ collapseShapeOp.getReassociationIndices();
+
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank
+ // equal to the rank of the src of the collapse_shape. In each iteration of
+ // the loop, the offsets and sizes will be computed per reassociation group.
+ SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
+ SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
+ rewriter.getIndexAttr(1));
+
+ for (auto [groupIdx, reassocIndices] :
+ enumerate(collapseShapeOp.getReassociationIndices())) {
+ OpFoldResult collapsedSize = collapsedSizes[groupIdx];
+ OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
+ // Case #1 - size and/or offset are dynamic.
+ // In this case, the slice can be represented as a contiguous slice only
+ // if there is a single dimension in the reassociation group that has a
+ // size not equal to 1.
+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+ int nonUnitSizeCount = 0;
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (srcShape[expandedShapeIdx] != 1) {
+ nonUnitSizeCount++;
+ expandedSizes.emplace_back(collapsedSize);
+ expandedOffsets.emplace_back(collapsedOffset);
+ continue;
+ }
+
+ expandedSizes.emplace_back(rewriter.getIndexAttr(1));
+ expandedOffsets.emplace_back(rewriter.getIndexAttr(0));
+ }
+
+ if (nonUnitSizeCount != 1) {
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "unsupported: slice cannot be verified to be contiguous");
+ }
+ continue;
+ }
+
+ // Case #2 = size and offset are static.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this must be done on order of most
+ // internal dimensions first, so traversal is done in reverse order of the
+ // reassociation group.
+ int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
+ int64_t collapsedOffsetValue =
+ getConstantIntValue(collapsedOffset).value();
+
+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+
+ for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
+ int64_t expandedShapeSize = srcShape[expandedShapeIdx];
+
+ // This is a dimension that slicing will occur on, so need to make sure
+ // that the slice size can be set to the shape size and the offset to 0.
+ if (collapsedSizeValue >= expandedShapeSize &&
+ (collapsedSizeValue % expandedShapeSize != 0 ||
+ collapsedOffsetValue % expandedShapeSize != 0)) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "unsupported: cannot be extracted as a contiguous slice "
+ "of the src of the collapse_shape");
+ }
+
+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+
+ // This is the dimension that slicing will occur along, so need to make
+ // sure that the slice size + offset will not exceed the shape size.
+ if (collapsedSizeValue < expandedShapeSize &&
+ (collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "unsupported: slice cannot be extracted as a contiguous "
+ "slice of the src of the collapse_shape");
+ }
+
+ groupExpandedSizes.emplace_back(rewriter.getIndexAttr(
+ std::min(collapsedSizeValue, expandedShapeSize)));
+ groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim));
+
+ // Remove the size and offset of trailing dimensions from the size and
+ // offset of the slice.
+ collapsedSizeValue /= expandedShapeSize;
+ collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1);
+ collapsedOffsetValue /= expandedShapeSize;
+ }
+
+ expandedSizes.append(groupExpandedSizes.rbegin(),
+ groupExpandedSizes.rend());
+ expandedOffsets.append(groupExpandedOffsets.rbegin(),
+ groupExpandedOffsets.rend());
+ }
+
+ Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
+ expandedSizes, expandedStrides);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ sliceOp, sliceOp.getResultType(), newSliceOp,
+ collapseShapeOp.getReassociationIndices());
+
+ return success();
+ }
+};
+
} // namespace
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -448,5 +634,6 @@ void mlir::tensor::populateBubbleUpExpandShapePatterns(
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
RewritePatternSet &patterns) {
- patterns.add<BubbleUpExpandShapeThroughExtractSlice>(patterns.getContext());
+ patterns.add<BubbleUpExpandShapeThroughExtractSlice,
+ BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 9bcc125ce1ba9..441020f1cddfc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -438,3 +438,53 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape(
+// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) {
+// CHECK: %[[EXTRACT1:.*]] = tensor.extract_slice
+// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[EXTRACT1]]
+// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE1]]
+func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
+ %expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
+ %empty = tensor.empty() : tensor<8x1800x32xf32>
+ %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
+ return %exp : tensor<8x1800x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
+ (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+ transform.yield
+ }
+}
+
+
+// -----
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
+// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: %[[VAL_9:.*]] = tensor.extract_slice
+// CHECK: %[[VAL_11:.*]] = linalg.abs ins(%[[VAL_9]]
+// CHECK: %[[VAL_12:.*]] = tensor.collapse_shape %[[VAL_11]]
+// CHECK: %[[VAL_14:.*]] = linalg.exp ins(%[[VAL_12]]
+func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
+ %empty1 = tensor.empty() : tensor<1x8x1800x32xf32>
+ %abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32>
+ %expand = tensor.collapse_shape %abs [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
+ %empty2 = tensor.empty() : tensor<8x1800x32xf32>
+ %exp = linalg.exp ins(%expand : tensor<8x1800x32xf32>) outs(%empty2 : tensor<8x1800x32xf32>) -> tensor<8x1800x32xf32>
+ return %exp : tensor<8x1800x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %transformed, %loops:1 = transform.structured.fuse %0 [1, 0, 0] interchange [0, 1, 2] apply_cleanup = true :
+ (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index 3900bc56f433d..d05bf1bf76f29 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -113,6 +113,159 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>,
return %extract : tensor<?x5x2xf32>
}
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(
+// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%src: tensor<6x5x2xf32>) -> tensor<1xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32>
+ %extract = tensor.extract_slice %collapse[0][1][1] : tensor<60xf32> to tensor<1xf32>
+ return %extract : tensor<1xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
+// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1]
+// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3]]
+// CHECK: return %[[VAL_2]]
+func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<6x5x3x10xf32> into tensor<30x30xf32>
+ %extract = tensor.extract_slice %collapse[5, 10][15, 10][1, 1] : tensor<30x30xf32> to tensor<15x10xf32>
+ return %extract : tensor<15x10xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> {
+// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][2, 0, 0] [1, 2, 2] [1, 1, 1]
+// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[VAL_2]]
+func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32>
+ %extract = tensor.extract_slice %collapse[20][4][1] : tensor<60xf32> to tensor<4xf32>
+ return %extract : tensor<4xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>,
+// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size(%src: tensor<1x5x1xf32>, %size : index) -> tensor<?xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32>
+ %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<5xf32> to tensor<?xf32>
+ return %extract : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x?x1xf32>,
+// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, %[[SIZE]], 1] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_size_and_src(%src: tensor<1x?x1xf32>, %size : index) -> tensor<?xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x?x1xf32> into tensor<?xf32>
+ %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<?xf32> to tensor<?xf32>
+ return %extract : tensor<?xf32>
+}
+
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x1xf32>,
+// CHECK-SAME: %[[OFFSET:.*]]: index) -> tensor<3xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, %[[OFFSET]], 0] [1, 3, 1] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset(%src: tensor<1x5x1xf32>, %offset : index) -> tensor<3xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x1xf32> into tensor<5xf32>
+ %extract = tensor.extract_slice %collapse[%offset][3][1] : tensor<5xf32> to tensor<3xf32>
+ return %extract : tensor<3xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size(
+// CHECK-SAME: %[[SRC:.*]]: tensor<14x1xf32>,
+// CHECK-SAME: %[[OFFSET:.*]]: index,
+// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]]{{\[}}%[[OFFSET]], 0] {{\[}}%[[SIZE]], 1] [1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_offset_and_size(%src: tensor<14x1xf32>, %offset : index, %size : index) -> tensor<?xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1]] : tensor<14x1xf32> into tensor<14xf32>
+ %extract = tensor.extract_slice %collapse[%offset][%size][1] : tensor<14xf32> to tensor<?xf32>
+ return %extract : tensor<?xf32>
+}
+
+// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(
+// CHECK-SAME: %[[SRC:.*]]: tensor<5x10x1x1x40xf32>,
+// CHECK-SAME: %[[OFFSET:.*]]: index,
+// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<20x?xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 0, 0, %[[OFFSET]]] [2, 10, 1, 1, %[[SIZE]]] [1, 1, 1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3, 4]]
+// CHECK: return %[[COLLAPSE]]
+func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_groups(%src: tensor<5x10x1x1x40xf32>, %offset : index, %size : index) -> tensor<20x?xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1], [2, 3, 4]] : tensor<5x10x1x1x40xf32> into tensor<50x40xf32>
+ %extract = tensor.extract_slice %collapse[10, %offset][20, %size][1, 1] : tensor<50x40xf32> to tensor<20x?xf32>
+ return %extract : tensor<20x?xf32>
+}
+
+// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
+// shape cannot be defined as a contiguous size in the expanded shape due to size extracted not being suited
+// for the expanded shape.
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<15xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(%src: tensor<2x3x10xf32>) -> tensor<15xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ %extract = tensor.extract_slice %collapse[0][15][1] : tensor<60xf32> to tensor<15xf32>
+ return %extract : tensor<15xf32>
+}
+
+// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
+// shape cannot be defined as a contiguous size in the expanded shape due to an unsuitable offset even though
+// the size extracted is suited for the expanded shape.
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<20xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(%src: tensor<2x3x10xf32>) -> tensor<20xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ %extract = tensor.extract_slice %collapse[20][20][1] : tensor<60xf32> to tensor<20xf32>
+ return %extract : tensor<20xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride(
+// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<5xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_stride(%src: tensor<2x3x10xf32>) -> tensor<5xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into tensor<60xf32>
+ %extract = tensor.extract_slice %collapse[0][5][2] : tensor<60xf32> to tensor<5xf32>
+ return %extract : tensor<5xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing(
+// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2x1xf32>) -> tensor<1xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_rank_reducing(%src: tensor<6x5x2x1xf32>) -> tensor<1xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2], [3]] : tensor<6x5x2x1xf32> into tensor<60x1xf32>
+ %extract = tensor.extract_slice %collapse[0, 0][1, 1][1, 1] : tensor<60x1xf32> to tensor<1xf32>
+ return %extract : tensor<1xf32>
+}
+
+// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic(
+// CHECK-SAME: %[[SRC:.*]]: tensor<1x5x2xf32>,
+// CHECK-SAME: %[[SIZE:.*]]: index) -> tensor<?xf32> {
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+func.func @no_bubble_up_extract_slice_through_collapse_shape_on_unsupported_dynamic(%src: tensor<1x5x2xf32>, %size : index) -> tensor<?xf32> {
+ %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<1x5x2xf32> into tensor<10xf32>
+ %extract = tensor.extract_slice %collapse[0][%size][1] : tensor<10xf32> to tensor<?xf32>
+ return %extract : tensor<?xf32>
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
>From c0291d068b66e6b450521813ebb2ad9a3a4c75e2 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Sun, 23 Mar 2025 12:08:23 +0200
Subject: [PATCH 2/5] Updates for code review
---
.../Tensor/Transforms/ReshapePatterns.cpp | 13 ++++----
.../Dialect/Linalg/transform-op-fuse.mlir | 14 ++++----
.../Tensor/bubble-up-extract-slice-op.mlir | 33 +++++++++++++++----
3 files changed, 40 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index efa4d10817e39..cc73011151b17 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -15,7 +15,6 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
-#include <algorithm>
using namespace mlir;
using namespace mlir::tensor;
@@ -533,13 +532,13 @@ struct BubbleUpCollapseShapeThroughExtractSlice
for (int64_t expandedShapeIdx : reassocIndices) {
if (srcShape[expandedShapeIdx] != 1) {
nonUnitSizeCount++;
- expandedSizes.emplace_back(collapsedSize);
- expandedOffsets.emplace_back(collapsedOffset);
+ expandedSizes.push_back(collapsedSize);
+ expandedOffsets.push_back(collapsedOffset);
continue;
}
- expandedSizes.emplace_back(rewriter.getIndexAttr(1));
- expandedOffsets.emplace_back(rewriter.getIndexAttr(0));
+ expandedSizes.push_back(rewriter.getIndexAttr(1));
+ expandedOffsets.push_back(rewriter.getIndexAttr(0));
}
if (nonUnitSizeCount != 1) {
@@ -586,9 +585,9 @@ struct BubbleUpCollapseShapeThroughExtractSlice
"slice of the src of the collapse_shape");
}
- groupExpandedSizes.emplace_back(rewriter.getIndexAttr(
+ groupExpandedSizes.push_back(rewriter.getIndexAttr(
std::min(collapsedSizeValue, expandedShapeSize)));
- groupExpandedOffsets.emplace_back(rewriter.getIndexAttr(offsetInDim));
+ groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
// Remove the size and offset of trailing dimensions from the size and
// offset of the slice.
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index 441020f1cddfc..d7339fa3c0be4 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -443,9 +443,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape(
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}} -> (tensor<8x1800x32xf32>) {
-// CHECK: %[[EXTRACT1:.*]] = tensor.extract_slice
-// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %[[EXTRACT1]]
-// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE1]]
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]]
+// CHECK: %[[EXP1:.*]] = linalg.exp ins(%[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
%expand = tensor.collapse_shape %0 [[0, 1], [2], [3]] : tensor<1x8x1800x32xf32> into tensor<8x1800x32xf32>
%empty = tensor.empty() : tensor<8x1800x32xf32>
@@ -467,10 +467,10 @@ module attributes {transform.with_named_sequence} {
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
-// CHECK: %[[VAL_9:.*]] = tensor.extract_slice
-// CHECK: %[[VAL_11:.*]] = linalg.abs ins(%[[VAL_9]]
-// CHECK: %[[VAL_12:.*]] = tensor.collapse_shape %[[VAL_11]]
-// CHECK: %[[VAL_14:.*]] = linalg.exp ins(%[[VAL_12]]
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice
+// CHECK: %[[ABS:.*]] = linalg.abs ins(%[[EXTRACT]]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[ABS]]
+// CHECK: %[[EXP:.*]] = linalg.exp ins(%[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(%0: tensor<1x8x1800x32xf32>) -> tensor<8x1800x32xf32> {
%empty1 = tensor.empty() : tensor<1x8x1800x32xf32>
%abs = linalg.abs ins(%0 : tensor<1x8x1800x32xf32>) outs(%empty1 : tensor<1x8x1800x32xf32>) -> tensor<1x8x1800x32xf32>
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index d05bf1bf76f29..c0755d7125091 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -1,5 +1,15 @@
// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+///----------------------------------------------------------------------------------------
+/// [Pattern: BubbleUpExpandShapeThroughExtractSlice]
+///
+/// IN: tensor.expand_shape(tensor.extract_slice)
+/// OUT:tensor.extract_slice(tensor.expand_shape)
+///
+/// Note: tensor.extract_slice is bubbled up to be before tensor.expand_shape.
+/// Some tests are negative tests for cases where the pattern cannot be applied.
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_expand_shape(
// CHECK-SAME: %[[SRC:.*]]: tensor<60xf32>) -> tensor<1x1x5xf32> {
// CHECK: %[[C1:.+]] = arith.constant 5 : index
@@ -113,6 +123,16 @@ func.func @bubble_up_extract_slice_affine_apply_not_folded(%src: tensor<60xf32>,
return %extract : tensor<?x5x2xf32>
}
+///----------------------------------------------------------------------------------------
+/// [Pattern: BubbleUpCollapseShapeThroughExtractSlice]
+///
+/// IN: tensor.collapse_shape(tensor.extract_slice)
+/// OUT:tensor.extract_slice(tensor.collapse_shape)
+///
+/// Note: tensor.extract_slice is bubbled up to be before tensor.collapse_shape.
+/// Some tests are negative tests for cases where the pattern cannot be applied.
+///----------------------------------------------------------------------------------------
+
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(
// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<1xf32> {
// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][0, 0, 0] [1, 1, 1] [1, 1, 1]
@@ -209,9 +229,13 @@ func.func @bubble_up_extract_slice_through_collapse_shape_dynamic_and_static_gro
return %extract : tensor<20x?xf32>
}
-// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
-// shape cannot be defined as a contiguous size in the expanded shape due to size extracted not being suited
-// for the expanded shape.
+/// The 2 following tests are cases where the bubble up cannot occur because the contiguous size extracted
+/// from the collapsed shape cannot be expressed via a single extract_slice op.
+/// In the first test it is because the size extracted cannot be expressed as a slice
+/// of the form [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] (see the pattern documentation for more details).
+/// In the second test, the size can be expressed as the required form, but the offset is such that the pattern
+/// cannot be applied.
+
// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1(
// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<15xf32> {
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
@@ -222,9 +246,6 @@ func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_1
return %extract : tensor<15xf32>
}
-// This is a case where the bubble up cannot occur because the contiguous size extracted from the collapsed
-// shape cannot be defined as a contiguous size in the expanded shape due to an unsuitable offset even though
-// the size extracted is suited for the expanded shape.
// CHECK-LABEL: func.func @no_bubble_up_extract_slice_through_collapse_shape_on_non_contiguous_2(
// CHECK-SAME: %[[SRC:.*]]: tensor<2x3x10xf32>) -> tensor<20xf32> {
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape
>From 72b0be3bd808f94bcf062cdcfee684be553bf507 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Mon, 24 Mar 2025 12:13:56 +0200
Subject: [PATCH 3/5] Updates for CR
---
.../Tensor/Transforms/ReshapePatterns.cpp | 111 +++++++++++++-----
.../Dialect/Linalg/transform-op-fuse.mlir | 1 -
2 files changed, 80 insertions(+), 32 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index cc73011151b17..6ec7c58f85b4e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -429,17 +429,19 @@ struct BubbleUpExpandShapeThroughExtractSlice
}
};
-/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
-/// `tensor.extract_slice(tensor.collapse_shape)`.
+/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
+/// `tensor.collapse_shape(tensor.extract_slice)`.
///
-/// For this transformation to be possible, the slice must be representable as a
-/// contiguous slice within each reassociation group of the src.
+/// For this transformation to be possible - after bubbling up, the extraction
+/// of the contiguous slice must be representable as a single slice obtained via
+/// tensor.extract_slice within each reassociation group of the src.
///
/// In case the size and offset extracted are static then this is possible if
-/// the following conditions are met:
-/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
-/// be the shape of a desired slice. A slice of shape S can be extracted as a
-/// contiguous block of memory if and only if there exists an index k in {0, 1,
+/// the following conditions are met within each reassociation group:
+/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
+/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
+/// shape of a desired slice. A slice of shape S can be extracted as a
+/// contiguous span of elements if and only if there exists an index k in {0, 1,
/// ..., n} such that:
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
@@ -475,6 +477,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
/// ```
+///
+/// Negative example:
+/// The transformation is not possible because we cannot use a single slice to
+/// represent the reassociation group [2x3x10->???]. If we would want the
+/// collapse to be after the extraction, we would need to extract multiple
+/// slices and concat them together.
+/// ```
+/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
+/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
+/// tensor<60xf32> to tensor<15xf32>
+/// ```
+/// If we would want the collapse to be after the extraction, a possible
+/// alternate transformation could be to extract multiple slices and concat them
+/// together:
+/// ```
+/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
+/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
+/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
+/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
+/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
+/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
+/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
+/// to tensor<15xf32>
+/// ```
+/// But this is not the intended purpose of the transformation.
struct BubbleUpCollapseShapeThroughExtractSlice
: public OpRewritePattern<tensor::ExtractSliceOp> {
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
@@ -552,47 +579,69 @@ struct BubbleUpCollapseShapeThroughExtractSlice
// Case #2 = size and offset are static.
// Verify that the slice can be represented as a contiguous slice of the
// src of the collapse_shape.
- // Checking this must be done on order of most
- // internal dimensions first, so traversal is done in reverse order of the
- // reassociation group.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
int64_t collapsedOffsetValue =
getConstantIntValue(collapsedOffset).value();
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
- for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
- int64_t expandedShapeSize = srcShape[expandedShapeIdx];
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- // This is a dimension that slicing will occur on, so need to make sure
- // that the slice size can be set to the shape size and the offset to 0.
- if (collapsedSizeValue >= expandedShapeSize &&
- (collapsedSizeValue % expandedShapeSize != 0 ||
- collapsedOffsetValue % expandedShapeSize != 0)) {
+ if (collapsedSizeValue < expandedShapeSize)
+ break;
+
+ // We need to make sure that the slice size can be set to the shape size
+ // and the offset to 0.
+ if ((collapsedSizeValue % expandedShapeSize) != 0 ||
+ (collapsedOffsetValue % expandedShapeSize) != 0)
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
"of the src of the collapse_shape");
- }
- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+ groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
+ groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+
+ collapsedSizeValue /= expandedShapeSize;
+ collapsedOffsetValue /= expandedShapeSize;
+ }
- // This is the dimension that slicing will occur along, so need to make
- // sure that the slice size + offset will not exceed the shape size.
- if (collapsedSizeValue < expandedShapeSize &&
- (collapsedSizeValue + offsetInDim) >= expandedShapeSize) {
+ // Now handle the first dim where slicing occurs on (k).
+ if (idx < reassocGroupSize) {
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+ // We need to make sure that the slice size in this dim + offset will
+ // not exceed the shape size.
+ if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
"slice of the src of the collapse_shape");
- }
- groupExpandedSizes.push_back(rewriter.getIndexAttr(
- std::min(collapsedSizeValue, expandedShapeSize)));
+ groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue));
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
- // Remove the size and offset of trailing dimensions from the size and
- // offset of the slice.
- collapsedSizeValue /= expandedShapeSize;
- collapsedSizeValue = std::max<int64_t>(collapsedSizeValue, 1);
+ collapsedOffsetValue /= expandedShapeSize;
+ }
+
+ // Now handle the leading dimensions where the slice size is equal to 1
+ // (k-1...0).
+ for (idx++; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+ groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
+ groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
collapsedOffsetValue /= expandedShapeSize;
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index d7339fa3c0be4..962858076db93 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -462,7 +462,6 @@ module attributes {transform.with_named_sequence} {
}
}
-
// -----
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_with_collapse_producer(
>From 1aaf3c907272a1bf1d4a2a870e93565079390d07 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Tue, 25 Mar 2025 10:45:32 +0200
Subject: [PATCH 4/5] Fix lit test checker variable names
---
.../Tensor/bubble-up-extract-slice-op.mlir | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
index c0755d7125091..34128d6a5ec8b 100644
--- a/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-up-extract-slice-op.mlir
@@ -145,10 +145,10 @@ func.func @bubble_up_extract_slice_through_collapse_shape_single_reassoc_group(%
}
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
-// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1]
-// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1], [2, 3]]
-// CHECK: return %[[VAL_2]]
+// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][1, 0, 1, 0] [3, 5, 1, 10] [1, 1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1], [2, 3]]
+// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group(%src: tensor<6x5x3x10xf32>) -> tensor<15x10xf32> {
%collapse = tensor.collapse_shape %src [[0, 1], [2, 3]] : tensor<6x5x3x10xf32> into tensor<30x30xf32>
%extract = tensor.extract_slice %collapse[5, 10][15, 10][1, 1] : tensor<30x30xf32> to tensor<15x10xf32>
@@ -156,10 +156,10 @@ func.func @bubble_up_extract_slice_through_collapse_shape_multiple_reassoc_group
}
// CHECK-LABEL: func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> {
-// CHECK: %[[VAL_1:.*]] = tensor.extract_slice %[[VAL_0]][2, 0, 0] [1, 2, 2] [1, 1, 1]
-// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2]]
-// CHECK: return %[[VAL_2]]
+// CHECK-SAME: %[[SRC:.*]]: tensor<6x5x2xf32>) -> tensor<4xf32> {
+// CHECK: %[[EXTRACT:.*]] = tensor.extract_slice %[[SRC]][2, 0, 0] [1, 2, 2] [1, 1, 1]
+// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[EXTRACT]] {{\[\[}}0, 1, 2]]
+// CHECK: return %[[COLLAPSE]]
func.func @bubble_up_extract_slice_through_collapse_shape_offset_on_leading_dim(%src: tensor<6x5x2xf32>) -> tensor<4xf32> {
%collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<6x5x2xf32> into tensor<60xf32>
%extract = tensor.extract_slice %collapse[20][4][1] : tensor<60xf32> to tensor<4xf32>
>From 5845db6a8de787a5df0c684f85286ae3a4ddf4f4 Mon Sep 17 00:00:00 2001
From: Ofri Frishman <ofri4321 at gmail.com>
Date: Tue, 25 Mar 2025 16:08:09 +0200
Subject: [PATCH 5/5] Additional clarifications
---
.../Tensor/Transforms/ReshapePatterns.cpp | 44 ++++++++++++-------
1 file changed, 29 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 6ec7c58f85b4e..0a7fcba7a71cd 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -550,7 +550,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice
enumerate(collapseShapeOp.getReassociationIndices())) {
OpFoldResult collapsedSize = collapsedSizes[groupIdx];
OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
- // Case #1 - size and/or offset are dynamic.
+ // CASE #1 - size and/or offset are dynamic.
// In this case, the slice can be represented as a contiguous slice only
// if there is a single dimension in the reassociation group that has a
// size not equal to 1.
@@ -576,7 +576,7 @@ struct BubbleUpCollapseShapeThroughExtractSlice
continue;
}
- // Case #2 = size and offset are static.
+ // CASE #2 = size and offset are static.
// Verify that the slice can be represented as a contiguous slice of the
// src of the collapse_shape.
// Checking this is done on order of most internal dimensions first,
@@ -584,8 +584,16 @@ struct BubbleUpCollapseShapeThroughExtractSlice
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
// ...,An] then we first find the size and offset for n...k+1 then for k
// and then for k-1...0.
- int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
- int64_t collapsedOffsetValue =
+
+ // currentCollapsedsize and currentCollapsedOffset are initialized with
+ // the original collapsed size and offset and divided by the expanded
+ // shape size in each dimension as we go along the reassociation group.
+ // In essence we are spreading the original collapsed size and offset over
+ // the various expanded slice dimensions.
+ // The variables are used both to check the validity of the slice and to
+ // compute the expanded sizes and offsets.
+ int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
+ int64_t currentCollapsedOffset =
getConstantIntValue(collapsedOffset).value();
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
@@ -600,13 +608,13 @@ struct BubbleUpCollapseShapeThroughExtractSlice
for (; idx < reassocGroupSize; ++idx) {
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- if (collapsedSizeValue < expandedShapeSize)
+ if (currentCollapsedsize < expandedShapeSize)
break;
// We need to make sure that the slice size can be set to the shape size
// and the offset to 0.
- if ((collapsedSizeValue % expandedShapeSize) != 0 ||
- (collapsedOffsetValue % expandedShapeSize) != 0)
+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
+ (currentCollapsedOffset % expandedShapeSize) != 0)
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
"of the src of the collapse_shape");
@@ -614,35 +622,41 @@ struct BubbleUpCollapseShapeThroughExtractSlice
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
- collapsedSizeValue /= expandedShapeSize;
- collapsedOffsetValue /= expandedShapeSize;
+ currentCollapsedsize /= expandedShapeSize;
+ currentCollapsedOffset /= expandedShapeSize;
}
// Now handle the first dim where slicing occurs on (k).
if (idx < reassocGroupSize) {
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
// We need to make sure that the slice size in this dim + offset will
// not exceed the shape size.
- if ((collapsedSizeValue + offsetInDim) >= expandedShapeSize)
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
return rewriter.notifyMatchFailure(
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
"slice of the src of the collapse_shape");
- groupExpandedSizes.push_back(rewriter.getIndexAttr(collapsedSizeValue));
+ groupExpandedSizes.push_back(
+ rewriter.getIndexAttr(currentCollapsedsize));
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
- collapsedOffsetValue /= expandedShapeSize;
+ currentCollapsedOffset /= expandedShapeSize;
}
// Now handle the leading dimensions where the slice size is equal to 1
// (k-1...0).
+ // The size for these dimensions must be 1 because of how we constructed
+ // the slice size of the expanded shape. We spread the original collapsed
+ // size over the expanded shape sizes until we reached dimension k where
+ // the remaining size was smaller than the expanded shape size, and spread
+ // the remaining size on it. So, now we are left with only 1s.
for (idx++; idx < reassocGroupSize; ++idx) {
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = collapsedOffsetValue % expandedShapeSize;
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
- collapsedOffsetValue /= expandedShapeSize;
+ currentCollapsedOffset /= expandedShapeSize;
}
expandedSizes.append(groupExpandedSizes.rbegin(),
More information about the Mlir-commits
mailing list