[Mlir-commits] [mlir] [MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape (PR #131982)
ofri frishman
llvmlistbot at llvm.org
Tue Mar 25 07:10:48 PDT 2025
================
@@ -428,6 +429,239 @@ struct BubbleUpExpandShapeThroughExtractSlice
}
};
+/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
+/// `tensor.collapse_shape(tensor.extract_slice)`.
+///
+/// For this transformation to be possible - after bubbling up, the extraction
+/// of the contiguous slice must be representable as a single slice obtained via
+/// tensor.extract_slice within each reassociation group of the src.
+///
+/// In case the size and offset extracted are static then this is possible if
+/// the following conditions are met within each reassociation group:
+/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
+/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
+/// shape of a desired slice. A slice of shape S can be extracted as a
+/// contiguous span of elements if and only if there exists an index k in {0, 1,
+/// ..., n} such that:
+/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
+/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
+/// one dimension),
+/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
+/// in full).
+/// In other words, the slice shape S must be of the form:
+/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
+///
+/// In case the size and/or offset extracted are dynamic then this is possible
+/// only if there is single dimension in the reassociation group that has a size
+/// not equal to 1.
+/// In other words, the tensor shape must be of the form:
+/// [ 1, 1, ..., 1, A, 1, ...,1 ]
+/// Note - it might be possible to enable this pattern for more cases when the
+/// size/offset are dynamic via performing an analysis of the possible values
+/// that could be given to the size/offset.
+///
+/// Example:
+/// The transformation is possible because each reassociation group can be
+/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
+/// [20->10]).
+/// ```
+/// BEFORE:
+/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
+/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
+/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
+/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
+///
+/// AFTER:
+/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
+// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
+/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
+/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
+/// ```
+///
+/// Negative example:
+/// The transformation is not possible because we cannot use a single slice to
+/// represent the reassociation group [2x3x10->???]. If we would want the
+/// collapse to be after the extraction, we would need to extract multiple
+/// slices and concat them together.
+/// ```
+/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
+/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
+/// tensor<60xf32> to tensor<15xf32>
+/// ```
+/// If we would want the collapse to be after the extraction, a possible
+/// alternate transformation could be to extract multiple slices and concat them
+/// together:
+/// ```
+/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
+/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
+/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
+/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
+/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
+/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
+/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
+/// to tensor<15xf32>
+/// ```
+/// But this is not the intended purpose of the transformation.
+struct BubbleUpCollapseShapeThroughExtractSlice
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto collapseShapeOp =
+ sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+ if (!collapseShapeOp)
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "tensor.extract_slice source not produced by tensor.collapse_shape");
+
+ if (!sliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
+ "be supported in this transformation.");
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.collapse_shape, so variables (i.e. inputs for
+ // ExtractSliceOp) referring to the state before applying the pattern are
+ // named with the prefix "collapsed", and ones referring to the state after
+ // applying the pattern are named with the prefix "expanded".
+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+ collapsedSizes.size())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "unimplemented: rank reducing slice");
+
+ ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
+ SmallVector<ReassociationIndices, 4> reassociationIndices =
+ collapseShapeOp.getReassociationIndices();
+
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank
+ // equal to the rank of the src of the collapse_shape. In each iteration of
+ // the loop, the offsets and sizes will be computed per reassociation group.
+ SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
+ SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
+ rewriter.getIndexAttr(1));
+
+ for (auto [groupIdx, reassocIndices] :
+ enumerate(collapseShapeOp.getReassociationIndices())) {
+ OpFoldResult collapsedSize = collapsedSizes[groupIdx];
+ OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
+ // Case #1 - size and/or offset are dynamic.
+ // In this case, the slice can be represented as a contiguous slice only
+ // if there is a single dimension in the reassociation group that has a
+ // size not equal to 1.
+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+ int nonUnitSizeCount = 0;
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (srcShape[expandedShapeIdx] != 1) {
+ nonUnitSizeCount++;
+ expandedSizes.push_back(collapsedSize);
+ expandedOffsets.push_back(collapsedOffset);
+ continue;
+ }
+
+ expandedSizes.push_back(rewriter.getIndexAttr(1));
+ expandedOffsets.push_back(rewriter.getIndexAttr(0));
+ }
+
+ if (nonUnitSizeCount != 1) {
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "unsupported: slice cannot be verified to be contiguous");
+ }
+ continue;
+ }
+
+ // Case #2 = size and offset are static.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
+ int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
+ int64_t collapsedOffsetValue =
+ getConstantIntValue(collapsedOffset).value();
+
+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
+
+ if (collapsedSizeValue < expandedShapeSize)
----------------
ofri-frishman wrote:
Added documentation to the variables and changed their name
https://github.com/llvm/llvm-project/pull/131982
More information about the Mlir-commits
mailing list