[Mlir-commits] [mlir] [MLIR] Bubble up tensor.extract_slice through tensor.collapse_shape (PR #131982)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Mar 23 06:44:34 PDT 2025


================
@@ -428,6 +429,190 @@ struct BubbleUpExpandShapeThroughExtractSlice
   }
 };
 
+/// Converts `tensor.collapse_shape(tensor.extract_slice)` to
+/// `tensor.extract_slice(tensor.collapse_shape)`.
+///
+/// For this transformation to be possible, the slice must be representable as a
+/// contiguous slice within each reassociation group of the src.
+///
+/// In case the size and offset extracted are static then this is possible if
+/// the following conditions are met:
+/// Let T be a tensor of shape [A0, A1, ..., An], and let S = [S0, S1, ..., Sn]
+/// be the shape of a desired slice. A slice of shape S can be extracted as a
+/// contiguous block of memory if and only if there exists an index k in {0, 1,
+/// ..., n} such that:
+///      S_i = 1 for all i < k (that is, all leading dimensions are singleton),
+///      1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
+///                       one dimension),
+///      S_i = A_i for all i > k (that is, all trailing dimensions are preserved
+///      in full).
+/// In other words, the slice shape S must be of the form:
+/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
+///
+/// In case the size and/or offset extracted are dynamic then this is possible
+/// only if there is single dimension in the reassociation group that has a size
+/// not equal to 1.
+/// In other words, the tensor shape must be of the form:
+/// [ 1, 1, ..., 1, A, 1, ...,1 ]
+/// Note - it might be possible to enable this pattern for more cases when the
+/// size/offset are dynamic via performing an analysis of the possible values
+/// that could be given to the size/offset.
+///
+/// Example:
+/// The transformation is possible because each reassociation group can be
+/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
+/// [20->10]).
+/// ```
+/// BEFORE:
+/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
+///     tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
+/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
+///     tensor<128x7x20xf32> to tensor<32x?x10xf32>
+///
+/// AFTER:
+/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
+//           [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
+/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
+///     tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
+/// ```
+struct BubbleUpCollapseShapeThroughExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseShapeOp =
+        sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseShapeOp)
+      return rewriter.notifyMatchFailure(
+          sliceOp,
+          "tensor.extract_slice source not produced by tensor.collapse_shape");
+
+    if (!sliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
+                   "be supported in this transformation.");
+    }
+
+    // The tensor.extract_slice before applying the pattern works on the result
+    // of the tensor.collapse_shape, so variables (i.e. inputs for
+    // ExtractSliceOp) referring to the state before applying the pattern are
+    // named with the prefix "collapsed", and ones referring to the state after
+    // applying the pattern are named with the prefix "expanded".
+    SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+
+    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+        collapsedSizes.size())
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "unimplemented: rank reducing slice");
+
+    ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
+    SmallVector<ReassociationIndices, 4> reassociationIndices =
+        collapseShapeOp.getReassociationIndices();
+
+    // Compute new offsets, sizes, and strides for tensor.extract_slice.
+    // The new tensor.extract_slice will work on a tensor that has has a rank
+    // equal to the rank of the src of the collapse_shape. In each iteration of
+    // the loop, the offsets and sizes will be computed per reassociation group.
+    SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
+    SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
+                                              rewriter.getIndexAttr(1));
+
+    for (auto [groupIdx, reassocIndices] :
+         enumerate(collapseShapeOp.getReassociationIndices())) {
+      OpFoldResult collapsedSize = collapsedSizes[groupIdx];
+      OpFoldResult collapsedOffset = collapsedOffsets[groupIdx];
+      // Case #1 - size and/or offset are dynamic.
+      // In this case, the slice can be represented as a contiguous slice only
+      // if there is a single dimension in the reassociation group that has a
+      // size not equal to 1.
+      if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+        int nonUnitSizeCount = 0;
+        for (int64_t expandedShapeIdx : reassocIndices) {
+          if (srcShape[expandedShapeIdx] != 1) {
+            nonUnitSizeCount++;
+            expandedSizes.push_back(collapsedSize);
+            expandedOffsets.push_back(collapsedOffset);
+            continue;
+          }
+
+          expandedSizes.push_back(rewriter.getIndexAttr(1));
+          expandedOffsets.push_back(rewriter.getIndexAttr(0));
+        }
+
+        if (nonUnitSizeCount != 1) {
+          return rewriter.notifyMatchFailure(
+              sliceOp,
+              "unsupported: slice cannot be verified to be contiguous");
+        }
+        continue;
+      }
+
+      // Case #2 = size and offset are static.
+      // Verify that the slice can be represented as a contiguous slice of the
+      // src of the collapse_shape.
+      // Checking this must be done on order of most
+      // internal dimensions first, so traversal is done in reverse order of the
+      // reassociation group.
+      int64_t collapsedSizeValue = getConstantIntValue(collapsedSize).value();
+      int64_t collapsedOffsetValue =
+          getConstantIntValue(collapsedOffset).value();
+
+      SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+
+      for (int64_t expandedShapeIdx : llvm::reverse(reassocIndices)) {
+        int64_t expandedShapeSize = srcShape[expandedShapeIdx];
+
+        // This is a dimension that slicing will occur on, so need to make sure
+        // that the slice size can be set to the shape size and the offset to 0.
+        if (collapsedSizeValue >= expandedShapeSize &&
----------------
banach-space wrote:

This condition is a bit unclear to me. Say we are collapsing like this:
```mlir
  %collapse = tensor.collapse_shape %src [[0, 1], [2, 3, 4]] : tensor<5x10x1x2x20xf32> into tensor<50x40xf32>
  %extract = tensor.extract_slice %collapse[0. 0][20, 40][1, 1] : tensor<50x40xf32> to tensor<20x40xf32>
```

Lets look at the trailing reassoc group. Here, `collapsedSizeValue = 40` and `expandedShapeSize = 1 or 2 or 20`, so `collapsedSizeValue >= expandedShapeSize`. However, this is not the dimension/size/group along which we slice, is it?

https://github.com/llvm/llvm-project/pull/131982


More information about the Mlir-commits mailing list