[Mlir-commits] [mlir] [mlir][tensor] Move extract_slice reshaping into two functions (PR #153675)
Ian Wood
llvmlistbot at llvm.org
Thu Aug 14 13:43:19 PDT 2025
https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/153675
Exposes the `tensor.extract_slice` reshaping logic in `BubbleUpExpandShapeThroughExtractSlice` and `BubbleUpCollapseShapeThroughExtractSlice` through two corresponding utility functions. These compute the offsets/sizes/strides of an extract slice after either collapsing or expanding.
This should also make it easier to implement the two other bubbling cases: (1) the `collapse_shape` is a consumer or (2) the `expand_shape` is a consumer.
>From bee87f64b5fef6c35e5c578c44fa7a90f2e87d57 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood at u.northwestern.edu>
Date: Thu, 14 Aug 2025 11:55:21 -0700
Subject: [PATCH] [mlir][tensor] Refactor extract_slice reshaping into util
funcs
Signed-off-by: Ian Wood <ianwood at u.northwestern.edu>
---
.../Dialect/Tensor/Transforms/Transforms.h | 28 +
.../Tensor/Transforms/ReshapePatterns.cpp | 569 +++++++++---------
2 files changed, 298 insertions(+), 299 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 87deef9ca7466..2602252916388 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -142,6 +142,34 @@ FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
ValueRange independencies);
+/// Computes the offsets, sizes, and strides needed to build a collapsed
+/// `sliceOp`. The dimensions to collapse are specified by `reassociation`.
+///
+/// This fails when the specified collapse cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getCollapsedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides,
+ OpBuilder &b);
+
+/// Computes the offsets, sizes, and strides needed to build an expanded
+/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
+/// `expandedShape`.
+///
+/// This fails when the specified expansion cannot be represented by a valid
+/// ExtractSliceOp.
+LogicalResult
+getExpandedExtractSliceInfo(tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &expandedOffsets,
+ SmallVectorImpl<OpFoldResult> &expandedSizes,
+ SmallVectorImpl<OpFoldResult> &expandedStrides,
+ OpBuilder &b);
+
} // namespace tensor
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 2ec23e1fb35ce..a93681b1fce92 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice
PatternRewriter &rewriter) const override {
auto expandShapeOp =
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandShapeOp) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "tensor.extract_slice source not produced by expand_shape");
+ }
+ SmallVector<ReassociationIndices> reassociation =
+ expandShapeOp.getReassociationIndices();
- if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
- rewriter)
- .failed())
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getCollapsedExtractSliceInfo(sliceOp, reassociation, offsets,
+ sizes, strides, rewriter)))
return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
- // referring to the state before applying the pattern are named with the
- // prefix "expanded", and ones referring to the state after applying the
- // pattern are named with the prefix "collapsed".
- SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
- SmallVector<OpFoldResult> expandedShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- // Helper variables and function for accumulating the size values.
- Location loc = expandShapeOp->getLoc();
- AffineExpr d0, d1, d2;
- bindDims(rewriter.getContext(), d0, d1, d2);
- // Multiply two integers.
- auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
- auto mulMap = AffineMap::get(2, 0, {d0 * d1});
- return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
- {v1, v2});
- };
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank of
- // ReassociationIndices.size(). In the loop a single offset, size, and
- // stride value is computed per reassociation group.
- SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
- collapsedStrides;
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- // collapsedSize will hold the size of the single dim that represents the
- // reassociation group in the non expanded tensor.
- OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
- // The reassocGroupSizes and reassocGroupOffsets are used to create an
- // affine.linearize_index op to linearize the single offset value required
- // for this reassociation group.
- SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
-
- for (long expandedDim : indices) {
- // reassocGroupSizes and reassocGroupOffsets can be obtained directly
- // from the expanded state, but the collapsed size requires calculation
- // as it did not previously exist.
- reassocGroupSizes.push_back(expandedShape[expandedDim]);
- reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
- collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
- }
-
- SmallVector<Value> offsetVals =
- llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
- });
- OpFoldResult collapsedOffset =
- affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals,
- reassocGroupSizes,
- /*disjoint=*/true)
- .getResult();
- collapsedOffsets.push_back(collapsedOffset);
- collapsedSizes.push_back(collapsedSize);
-
- // Only unit stride is supported.
- collapsedStrides.push_back(rewriter.getIndexAttr(1));
- }
-
// The shape of the result can be obtained from the sizes passed in.
- SmallVector<Value> dynDims;
- SmallVector<int64_t> shape;
- dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
- RankedTensorType resultType = RankedTensorType::get(
- shape, expandShapeOp.getResultType().getElementType());
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ RankedTensorType resultType = sliceOp.getResultType();
// Create a new ExtractSliceOp and ExpandShapeOp.
+ Location loc = sliceOp.getLoc();
Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
- collapsedStrides);
+ rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
sliceOp, resultType, newSliceOp,
expandShapeOp.getReassociationIndices(), expandedSizes);
return success();
}
-
- // Helper function to check if all the required conditions for the
- // tensor.extract_slice to be bubbled up through the tensor.expand_shape are
- // met.
- LogicalResult
- checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
- tensor::ExpandShapeOp expandShapeOp,
- PatternRewriter &rewriter) const {
-
- if (!expandShapeOp) {
- return rewriter.notifyMatchFailure(
- sliceOp, "tensor.extract_slice source not produced by expand_shape");
- }
-
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
-
- SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- sizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
-
- SmallVector<OpFoldResult> outputShape =
- getMixedValues(expandShapeOp.getStaticOutputShape(),
- expandShapeOp.getOutputShape(), rewriter);
-
- std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
- isZeroOffsetAndFullSize =
- [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isZeroInteger(offset))
- return false;
- FailureOr<bool> maybeEqual =
- ValueBoundsConstraintSet::areEqual(sliceSize, size);
- return llvm::succeeded(maybeEqual) && maybeEqual.value();
- };
-
- // Check that the slice is contiguous within each reassociation group.
- // The slice is contiguous only if after the first dimension where a non
- // unit slice is taken, the slice size on all subsequent dimensions of the
- // group is equal to the entire size of the dimension.
- // Examples of contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
- // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
- // Examples of non contiguous slices:
- // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
- // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
- for (const ReassociationIndices &indices :
- expandShapeOp.getReassociationIndices()) {
- int64_t i = 0;
- int64_t e = indices.size();
- // Find the first expanded dim after the first dim with non-unit extracted
- // size.
- for (; i < e; ++i) {
- if (!isOneInteger(sizes[indices[i]])) {
- // +1 to skip the first non-unit size dim.
- i++;
- break;
- }
- }
-
- // Verify that all subsequent dimensions extract the full size of the
- // source tensor.
- for (; i < e; ++i) {
- int64_t expandedDim = indices[i];
- if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
- outputShape[expandedDim])) {
- return rewriter.notifyMatchFailure(
- sliceOp, "Not a contiguous slice of the expanded tensor.");
- }
- }
- }
-
- return success();
- }
};
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
@@ -582,170 +441,282 @@ struct BubbleUpCollapseShapeThroughExtractSlice
"tensor.extract_slice source not produced by tensor.collapse_shape");
}
- if (!sliceOp.hasUnitStride()) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
- "be supported in this transformation.");
- }
+ SmallVector<OpFoldResult> offsets, sizes, strides;
+ if (failed(getExpandedExtractSliceInfo(
+ sliceOp, collapseShapeOp.getReassociationIndices(),
+ collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides,
+ rewriter)))
+ return failure();
- // The tensor.extract_slice before applying the pattern works on the result
- // of the tensor.collapse_shape, so variables (i.e. inputs for
- // ExtractSliceOp) referring to the state before applying the pattern are
- // named with the prefix "collapsed", and ones referring to the state after
- // applying the pattern are named with the prefix "expanded".
- SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
-
- if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
- collapsedSizes.size()) {
- return rewriter.notifyMatchFailure(sliceOp,
- "unimplemented: rank reducing slice");
- }
+ Value newSliceOp = tensor::ExtractSliceOp::create(
+ rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets,
+ sizes, strides);
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ sliceOp, sliceOp.getResultType(), newSliceOp,
+ collapseShapeOp.getReassociationIndices());
- ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
- SmallVector<ReassociationIndices, 4> reassociationIndices =
- collapseShapeOp.getReassociationIndices();
-
- // Compute new offsets, sizes, and strides for tensor.extract_slice.
- // The new tensor.extract_slice will work on a tensor that has has a rank
- // equal to the rank of the src of the collapse_shape. In each iteration of
- // the loop, the offsets and sizes will be computed per reassociation group.
- SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
- SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
- rewriter.getIndexAttr(1));
-
- for (auto [collapsedSize, collapsedOffset, reassocIndices] :
- llvm::zip_equal(collapsedSizes, collapsedOffsets,
- collapseShapeOp.getReassociationIndices())) {
- // CASE #1 - size and/or offset are dynamic.
- // In this case, the slice can be represented as a contiguous slice only
- // if there is a single dimension in the reassociation group that has a
- // size not equal to 1.
- if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
- int nonUnitSizeCount = 0;
- for (int64_t expandedShapeIdx : reassocIndices) {
- if (srcShape[expandedShapeIdx] != 1) {
- nonUnitSizeCount++;
- expandedSizes.push_back(collapsedSize);
- expandedOffsets.push_back(collapsedOffset);
- continue;
- }
-
- expandedSizes.push_back(rewriter.getIndexAttr(1));
- expandedOffsets.push_back(rewriter.getIndexAttr(0));
- }
+ return success();
+ }
+};
- if (nonUnitSizeCount != 1) {
- return rewriter.notifyMatchFailure(
- sliceOp,
- "unsupported: slice cannot be verified to be contiguous");
- }
- continue;
- }
+} // namespace
- // CASE #2 = size and offset are static.
- // Verify that the slice can be represented as a contiguous slice of the
- // src of the collapse_shape.
- // Checking this is done on order of most internal dimensions first,
- // so traversal is done in reverse order of the reassociation group.
- // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
- // ...,An] then we first find the size and offset for n...k+1 then for k
- // and then for k-1...0.
-
- // currentCollapsedsize and currentCollapsedOffset are initialized with
- // the original collapsed size and offset and divided by the expanded
- // shape size in each dimension as we go along the reassociation group.
- // In essence we are spreading the original collapsed size and offset over
- // the various expanded slice dimensions.
- // The variables are used both to check the validity of the slice and to
- // compute the expanded sizes and offsets.
- int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
- int64_t currentCollapsedOffset =
- getConstantIntValue(collapsedOffset).value();
-
- SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
-
- ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
- reassocIndices.rend());
- int64_t idx = 0;
- int64_t reassocGroupSize = reassocIndices.size();
-
- // First handle the trailing dimensions where the slice size should be
- // equal to the tensor shape and the offset should be 0 (n...k+1).
- for (; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
-
- if (currentCollapsedsize < expandedShapeSize)
- break;
-
- // We need to make sure that the slice size can be set to the shape size
- // and the offset to 0.
- if ((currentCollapsedsize % expandedShapeSize) != 0 ||
- (currentCollapsedOffset % expandedShapeSize) != 0) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: cannot be extracted as a contiguous slice "
- "of the src of the collapse_shape");
- }
+LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
+ tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ SmallVectorImpl<OpFoldResult> &collapsedOffsets,
+ SmallVectorImpl<OpFoldResult> &collapsedSizes,
+ SmallVectorImpl<OpFoldResult> &collapsedStrides, OpBuilder &b) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
- groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+ return failure();
+ }
- currentCollapsedsize /= expandedShapeSize;
- currentCollapsedOffset /= expandedShapeSize;
+ auto isZeroOffsetAndFullSize = [&](OpFoldResult offset,
+ OpFoldResult sliceSize, int64_t inputDim) {
+ if (!isZeroInteger(offset))
+ return false;
+ ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim);
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sliceSize, inputSize);
+ return llvm::succeeded(maybeEqual) && maybeEqual.value();
+ };
+
+ // Check that the slice is contiguous within each reassociation group.
+ // The slice is contiguous only if after the first dimension where a non
+ // unit slice is taken, the slice size on all subsequent dimensions of the
+ // group is equal to the entire size of the dimension.
+ // Examples of contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
+ // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
+ // Examples of non contiguous slices:
+ // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
+ // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
+ for (const ReassociationIndices &indices : reassociation) {
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Find the first expanded dim after the first dim with non-unit extracted
+ // size.
+ for (; i < e; ++i) {
+ if (!isOneInteger(sizes[indices[i]])) {
+ // +1 to skip the first non-unit size dim.
+ i++;
+ break;
}
+ }
+
+ // Verify that all subsequent dimensions extract the full size of the
+ // source tensor.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+ expandedDim)) {
+ return failure();
+ }
+ }
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
+ // referring to the state before applying the pattern are named with the
+ // prefix "expanded", and ones referring to the state after applying the
+ // pattern are named with the prefix "collapsed".
+ Location loc = sliceOp.getLoc();
+ SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> expandedShape =
+ getMixedSizes(b, loc, sliceOp.getSource());
+
+ // Helper variables and function for accumulating the size values.
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ // Multiply two integers.
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+ auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2});
+ };
+
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank of
+ // ReassociationIndices.size(). In the loop a single offset, size, and
+ // stride value is computed per reassociation group.
+ for (const ReassociationIndices &indices : reassociation) {
+ // collapsedSize will hold the size of the single dim that represents the
+ // reassociation group in the non expanded tensor.
+ OpFoldResult collapsedSize = b.getIndexAttr(1);
+ // The reassocGroupSizes and reassocGroupOffsets are used to create an
+ // affine.linearize_index op to linearize the single offset value required
+ // for this reassociation group.
+ SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
+
+ for (long expandedDim : indices) {
+ // reassocGroupSizes and reassocGroupOffsets can be obtained directly
+ // from the expanded state, but the collapsed size requires calculation
+ // as it did not previously exist.
+ reassocGroupSizes.push_back(expandedShape[expandedDim]);
+ reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
+ collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
+ }
+
+ SmallVector<Value> offsetVals =
+ llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ });
+ OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create(
+ b, loc, offsetVals, reassocGroupSizes,
+ /*disjoint=*/true)
+ .getResult();
+ collapsedOffsets.push_back(collapsedOffset);
+ collapsedSizes.push_back(collapsedSize);
+
+ // Only unit stride is supported.
+ collapsedStrides.push_back(b.getIndexAttr(1));
+ }
+ return success();
+}
+
+LogicalResult mlir::tensor::getExpandedExtractSliceInfo(
+ tensor::ExtractSliceOp sliceOp,
+ ArrayRef<ReassociationIndices> reassociation,
+ ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &expandedOffsets,
+ SmallVectorImpl<OpFoldResult> &expandedSizes,
+ SmallVectorImpl<OpFoldResult> &expandedStrides, OpBuilder &b) {
+ if (!sliceOp.hasUnitStride()) {
+ return failure();
+ }
+
+ // The tensor.extract_slice before applying the pattern works on the result
+ // of the tensor.collapse_shape, so variables (i.e. inputs for
+ // ExtractSliceOp) referring to the state before applying the pattern are
+ // named with the prefix "collapsed", and ones referring to the state after
+ // applying the pattern are named with the prefix "expanded".
+ SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
+ collapsedSizes.size()) {
+ return failure();
+ }
- // Now handle the first dim where slicing occurs on (k).
- if (idx < reassocGroupSize) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- // We need to make sure that the slice size in this dim + offset will
- // not exceed the shape size.
- if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
- return rewriter.notifyMatchFailure(
- sliceOp, "unsupported: slice cannot be extracted as a contiguous "
- "slice of the src of the collapse_shape");
+ // Compute new offsets, sizes, and strides for tensor.extract_slice.
+ // The new tensor.extract_slice will work on a tensor that has has a rank
+ // equal to the rank of the src of the collapse_shape. In each iteration of
+ // the loop, the offsets and sizes will be computed per reassociation group.
+ expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1));
+ for (auto [collapsedSize, collapsedOffset, reassocIndices] :
+ llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) {
+ // CASE #1 - size and/or offset are dynamic.
+ // In this case, the slice can be represented as a contiguous slice only
+ // if there is a single dimension in the reassociation group that has a
+ // size not equal to 1.
+ if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
+ int nonUnitSizeCount = 0;
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (expandedShape[expandedShapeIdx] != 1) {
+ nonUnitSizeCount++;
+ expandedSizes.push_back(collapsedSize);
+ expandedOffsets.push_back(collapsedOffset);
+ continue;
}
- groupExpandedSizes.push_back(
- rewriter.getIndexAttr(currentCollapsedsize));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
+ expandedSizes.push_back(b.getIndexAttr(1));
+ expandedOffsets.push_back(b.getIndexAttr(0));
+ }
- currentCollapsedOffset /= expandedShapeSize;
+ if (nonUnitSizeCount != 1) {
+ return failure();
}
+ continue;
+ }
- // Now handle the leading dimensions where the slice size is equal to 1
- // (k-1...0).
- // The size for these dimensions must be 1 because of how we constructed
- // the slice size of the expanded shape. We spread the original collapsed
- // size over the expanded shape sizes until we reached dimension k where
- // the remaining size was smaller than the expanded shape size, and spread
- // the remaining size on it. So, now we are left with only 1s.
- for (idx++; idx < reassocGroupSize; ++idx) {
- int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
- int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
- groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
- groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
- currentCollapsedOffset /= expandedShapeSize;
+ // CASE #2 = size and offset are static.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
+
+ // currentCollapsedsize and currentCollapsedOffset are initialized with
+ // the original collapsed size and offset and divided by the expanded
+ // shape size in each dimension as we go along the reassociation group.
+ // In essence we are spreading the original collapsed size and offset over
+ // the various expanded slice dimensions.
+ // The variables are used both to check the validity of the slice and to
+ // compute the expanded sizes and offsets.
+ int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
+ int64_t currentCollapsedOffset =
+ getConstantIntValue(collapsedOffset).value();
+ SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+
+ if (currentCollapsedsize < expandedShapeSize)
+ break;
+
+ // We need to make sure that the slice size can be set to the shape size
+ // and the offset to 0.
+ if ((currentCollapsedsize % expandedShapeSize) != 0 ||
+ (currentCollapsedOffset % expandedShapeSize) != 0) {
+ return failure();
}
- expandedSizes.append(groupExpandedSizes.rbegin(),
- groupExpandedSizes.rend());
- expandedOffsets.append(groupExpandedOffsets.rbegin(),
- groupExpandedOffsets.rend());
+ groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(0));
+
+ currentCollapsedsize /= expandedShapeSize;
+ currentCollapsedOffset /= expandedShapeSize;
}
- Value newSliceOp = tensor::ExtractSliceOp::create(
- rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(),
- expandedOffsets, expandedSizes, expandedStrides);
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- sliceOp, sliceOp.getResultType(), newSliceOp,
- collapseShapeOp.getReassociationIndices());
+ // Now handle the first dim where slicing occurs on (k).
+ if (idx < reassocGroupSize) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ // We need to make sure that the slice size in this dim + offset will
+ // not exceed the shape size.
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
+ return failure();
+ }
+ groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
- return success();
+ // Now handle the leading dimensions where the slice size is equal to 1
+ // (k-1...0).
+ // The size for these dimensions must be 1 because of how we constructed
+ // the slice size of the expanded shape. We spread the original collapsed
+ // size over the expanded shape sizes until we reached dimension k where
+ // the remaining size was smaller than the expanded shape size, and spread
+ // the remaining size on it. So, now we are left with only 1s.
+ for (idx++; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
+ groupExpandedSizes.push_back(b.getIndexAttr(1));
+ groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim));
+ currentCollapsedOffset /= expandedShapeSize;
+ }
+ expandedSizes.append(groupExpandedSizes.rbegin(),
+ groupExpandedSizes.rend());
+ expandedOffsets.append(groupExpandedOffsets.rbegin(),
+ groupExpandedOffsets.rend());
}
-};
-
-} // namespace
+ return success();
+}
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
RewritePatternSet &patterns) {
More information about the Mlir-commits
mailing list