[Mlir-commits] [mlir] [MLIR] Add pattern to bubble up tensor.extract_slice (PR #126898)
ofri frishman
llvmlistbot at llvm.org
Fri Feb 28 01:11:55 PST 2025
================
@@ -210,6 +214,217 @@ struct BubbleUpExpandThroughParallelCollapse
}
};
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape.
+/// A slice is defined as fully contiguous within a reassociation group if after
+/// flattening the reassociation group to a single 1D range, then the slice
+/// taken out of the group could be defined as a single contiguous subrange
+/// within that range.
+/// If the transformation is not possible, or if the slice is rank reducing, the
+/// function returns failure.
+///
+/// Example:
+/// ```
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// // The transformation is possible because each reassociation group has a
+/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
+/// // After the transformation:
+///
+/// %slice = tensor.extract_slice %in ...
+/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+///
+/// Note - this pattern could be reworked to be a swap pattern between
+/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
+/// implemented only as a bubble up pattern for `tensor.extract_slice`.
+struct BubbleUpExpandShapeThroughExtractSlice
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto expandShapeOp =
+ sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+
+ if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
+ rewriter)
+ .failed())
+ return failure();
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.expand_shape, so variables referring to the state before
+ // applying the pattern are named with the prefix "expanded", and ones
+ // referring to the state after applying the pattern are named with the
+ // prefix "collapsed".
+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> expandedShape =
+ getMixedValues(expandShapeOp.getStaticOutputShape(),
+ expandShapeOp.getOutputShape(), rewriter);
+
+ // Helper variables and function for accumulating the size values.
+ Location loc = expandShapeOp->getLoc();
+ AffineExpr d0, d1, d2;
+ bindDims(rewriter.getContext(), d0, d1, d2);
+ // Multiply two integers.
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+ auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+ {v1, v2});
+ };
+
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank of
+ // ReassociationIndices.size(). In the loop a single offset, size, and
+ // stride value is computed per reassociation group.
+ SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
+ collapsedStrides;
+ for (const ReassociationIndices &indices :
+ expandShapeOp.getReassociationIndices()) {
+ // collapsedSize will hold the size of the single dim that represents the
+ // reassociation group in the non expanded tensor.
+ OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
+ // The basis and delinOffsets are used to create an affine.linearize_index
+ // op to linearize the single offset value required for this reassociation
+ // group.
+ // basis holds the full sizes of the reassociation group dimensions
+ // of the expanded tensor.
+ // delinOffsets as in "delinearized offsets", holds the offsets within the
+ // reassociation group dimensions of the expanded tensor.
+ SmallVector<OpFoldResult> basis, delinOffsets;
+
+ for (long expandedDim : indices) {
+ // basis and delinOffsets can be obtained directly from the expanded
+ // state, but the collapsed size requires calculation as it did not
+ // previously exist.
+ basis.push_back(expandedShape[expandedDim]);
+ delinOffsets.push_back(expandedOffsets[expandedDim]);
+ collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
+ }
+
+ SmallVector<Value> offsetVals =
+ llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+ });
+ OpFoldResult collapsedOffset =
+ rewriter
+ .create<affine::AffineLinearizeIndexOp>(loc, offsetVals, basis,
+ /*disjoint=*/true)
+ .getResult();
+ collapsedOffsets.push_back(collapsedOffset);
+ collapsedSizes.push_back(collapsedSize);
+
+ // Only unit stride supported.
----------------
ofri-frishman wrote:
updated
https://github.com/llvm/llvm-project/pull/126898
More information about the Mlir-commits
mailing list