[Mlir-commits] [mlir] [MLIR] Add pattern to bubble up tensor.extract_slice (PR #126898)

Andrzej WarzyƄski llvmlistbot at llvm.org
Sun Feb 23 09:28:10 PST 2025


================
@@ -210,6 +213,178 @@ struct BubbleUpExpandThroughParallelCollapse
   }
 };
 
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape. If the transformation
+/// is not possible, or if the slice is rank reducting, the function returns
+/// failure.
+///
+/// Example:
+/// ```
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+///     tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// // The transformation is possible because each reassociation group has a
+/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
+/// // After the transformation:
+///
+/// %slice = tensor.extract_slice %in ...
+///     tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+///
+/// Note - this pattern could be reworked to be a swap pattern between
+/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
+/// implemented only as a bubble up pattern for `tensor.extract_slice`.
+struct BubbleUpExpandShapeThroughExtractSlice
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp =
+        sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandShapeOp) {
+      return rewriter.notifyMatchFailure(
+          sliceOp, "slice source not produced by expand_shape");
+    }
+
+    if (!sliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "unsupported: non-unit stride");
+    }
+
+    SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+
+    if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+        sizes.size()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "unimplemented: rank reducing slice");
+    }
+
+    // Helper variables and function for accumulating the new offset and length
+    // values.
+    Location loc = expandShapeOp->getLoc();
+    AffineExpr d0, d1, d2;
+    bindDims(rewriter.getContext(), d0, d1, d2);
+    // Multiply two integers.
+    auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+      auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+      return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+                                                   {v1, v2});
+    };
+
+    SmallVector<OpFoldResult> outputShape =
+        getMixedValues(expandShapeOp.getStaticOutputShape(),
+                       expandShapeOp.getOutputShape(), rewriter);
+
+    auto isZeroOffsetAndFullSize =
+        [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
+          if (!isConstantIntValue(offset, 0))
+            return false;
+          FailureOr<bool> maybeEqual =
+              ValueBoundsConstraintSet::areEqual(sliceSize, size);
+          return llvm::succeeded(maybeEqual) && maybeEqual.value();
+        };
+
+    // First verify that this is a full slice of the expanded tensor.
+    for (const ReassociationIndices &indices :
+         expandShapeOp.getReassociationIndices()) {
+      int64_t i = 0;
+      int64_t e = indices.size();
+      // Find the first expanded dim after the first dim with non-unit extracted
+      // size.
+      for (; i < e; ++i) {
+        if (!isConstantIntValue(sizes[indices[i]], 1)) {
+          // +1 to skip the first non-unit size dim.
+          i++;
+          break;
+        }
+      }
+
+      // Verify that all subsequent dimensions extract the full size of the
+      // source tensor.
+      for (; i < e; ++i) {
+        int64_t expandedDim = indices[i];
+        if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+                                     outputShape[expandedDim])) {
+          return rewriter.notifyMatchFailure(
+              sliceOp, "Not a contiguous slice of the expanded tensor.");
+        }
+      }
+    }
+
+    // Compute new offsets, lengths, and strides.
+    SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+    for (const ReassociationIndices &indices :
+         expandShapeOp.getReassociationIndices()) {
+      OpFoldResult newSize = rewriter.getIndexAttr(1);
+      SmallVector<OpFoldResult> basis, delinOffsets;
+
+      int64_t i = 0;
+      int64_t e = indices.size();
+      // Offset = cumulative product of leading unit extracted dims.
+      for (; i < e; ++i) {
+        int64_t expandedDim = indices[i];
+        if (!isConstantIntValue(sizes[expandedDim], 1))
+          break;
+
+        basis.push_back(outputShape[expandedDim]);
+        delinOffsets.push_back(offsets[expandedDim]);
+      }
+
+      if (i != e) {
+        int64_t expandedDim = indices[i];
+        basis.push_back(outputShape[expandedDim]);
+        delinOffsets.push_back(offsets[expandedDim]);
+        newSize = sizes[expandedDim];
+        i++;
+      }
+
+      for (; i < e; ++i) {
+        OpFoldResult fullSize = outputShape[indices[i]];
+        basis.push_back(fullSize);
+        delinOffsets.push_back(rewriter.getIndexAttr(0));
+        newSize = mul(newSize, fullSize);
----------------
banach-space wrote:

I didn't see this size re-calculation in [bubble-up-extract-slice-op.mlir](https://github.com/llvm/llvm-project/pull/126898/files#diff-5b706598efea999ee91a9207f5bb81fb6e566754507af09c36456b8ee7f72252). It would be good to test it as well.

https://github.com/llvm/llvm-project/pull/126898


More information about the Mlir-commits mailing list