[Mlir-commits] [mlir] [MLIR] Add pattern to bubble up tensor.extract_slice (PR #126898)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Feb 25 08:13:35 PST 2025


================
@@ -210,6 +214,200 @@ struct BubbleUpExpandThroughParallelCollapse
   }
 };
 
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape. If the transformation
+/// is not possible, or if the slice is rank reducing, the function returns
+/// failure.
+///
+/// Example:
+/// ```
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+///     tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// // The transformation is possible because each reassociation group has a
+/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
+/// // After the transformation:
+///
+/// %slice = tensor.extract_slice %in ...
+///     tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+///
+/// Note - this pattern could be reworked to be a swap pattern between
+/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
+/// implemented only as a bubble up pattern for `tensor.extract_slice`.
+struct BubbleUpExpandShapeThroughExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp =
+        sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+
+    if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
+                                                 rewriter)
+            .failed())
+      return failure();
+
+    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> outputShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+
+    // Helper variables and function for accumulating the new offset and length
+    // values.
+    Location loc = expandShapeOp->getLoc();
+    AffineExpr d0, d1, d2;
+    bindDims(rewriter.getContext(), d0, d1, d2);
+    // Multiply two integers.
+    auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+      auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+      return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+                                                   {v1, v2});
+    };
+
+    // Compute new offsets, lengths, and strides.
+    SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+    for (const ReassociationIndices &indices :
+         expandShapeOp.getReassociationIndices()) {
+      OpFoldResult newSize = rewriter.getIndexAttr(1);
+      SmallVector<OpFoldResult> basis, delinOffsets;
+
+      int64_t i = 0;
+      int64_t e = indices.size();
+      // Offset = cumulative product of leading unit extracted dims.
----------------
banach-space wrote:

This comment seems to apply specifically to "offsets", but this loop computes more than just offsets?

Perhaps document every loop so that it's clear what it's dealing with. Currently that's rather unclear.

https://github.com/llvm/llvm-project/pull/126898


More information about the Mlir-commits mailing list