[Mlir-commits] [mlir] [MLIR] Add pattern to bubble up tensor.extract_slice (PR #126898)

ofri frishman llvmlistbot at llvm.org
Fri Feb 28 01:18:04 PST 2025


================
@@ -210,6 +214,214 @@ struct BubbleUpExpandThroughParallelCollapse
   }
 };
 
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+///
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape. A slice is defined as
+/// fully contiguous within a reassociation group if after flattening the
+/// reassociation group to a single 1D range, then the slice taken out of the
+/// group could be defined as a single contiguous subrange within that range.
+///
+/// Rank reducing slices are not supported.
+///
+/// Example:
+/// The transformation is possible because each reassociation group has a
+/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
+/// ```
+/// BEFORE:
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+///     tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// AFTER:
+/// %slice = tensor.extract_slice %in ...
+///     tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+///
+/// Note - this pattern could be reworked to be a swap pattern between
+/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
+/// implemented only as a bubble up pattern for `tensor.extract_slice`.
+struct BubbleUpExpandShapeThroughExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp =
+        sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+
+    if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
+                                                 rewriter)
+            .failed())
+      return failure();
+
+    // The tensor.extract_slice before applying the pattern works on the result
+    // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+    // referring to the state before applying the pattern are named with the
+    // prefix "expanded", and ones referring to the state after applying the
+    // pattern are named with the prefix "collapsed".
+    SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> expandedShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+
+    // Helper variables and function for accumulating the size values.
+    Location loc = expandShapeOp->getLoc();
+    AffineExpr d0, d1, d2;
+    bindDims(rewriter.getContext(), d0, d1, d2);
+    // Multiply two integers.
+    auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+      auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+      return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+                                                   {v1, v2});
+    };
+
+    // Compute new offsets, sizes, and strides for tensor.extract_slice.
+    // The new tensor.extract_slice will work on a tensor that has has a rank of
+    // ReassociationIndices.size(). In the loop a single offset, size, and
+    // stride value is computed per reassociation group.
+    SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
+        collapsedStrides;
+    for (const ReassociationIndices &indices :
+         expandShapeOp.getReassociationIndices()) {
+      // collapsedSize will hold the size of the single dim that represents the
+      // reassociation group in the non expanded tensor.
+      OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
+      // The reassocGroupSizes and reassocGroupOffsets are used to create an
+      // affine.linearize_index op to linearize the single offset value required
+      // for this reassociation group.
+      SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
----------------
ofri-frishman wrote:

I updated the names as you suggested and removed the comment about them. 
If we can have better names and less comments to explain them then that seems preferable to me

https://github.com/llvm/llvm-project/pull/126898


More information about the Mlir-commits mailing list