[Mlir-commits] [mlir] [mlir][tensor]-Handle Dynamic Offset in BubbleUpSliceOpThroughCollapse (PR #178921)
Ian Wood
llvmlistbot at llvm.org
Mon Feb 9 11:55:09 PST 2026
================
@@ -579,10 +580,174 @@ LogicalResult mlir::tensor::getCollapsedExtractSliceInfo(
return success();
}
+// Checks if the `ofr` is a multiple of the `factor`.
+// Handles both static integer and dynamic values
+// where the value is the result of an affine.apply.
+static bool isMultipleOf(OpFoldResult ofr, int64_t factor) {
+ std::optional<int64_t> staticValue = getConstantIntValue(ofr);
+ if (staticValue.has_value())
+ return staticValue.value() % factor == 0;
+
+ Value value = dyn_cast<Value>(ofr);
+ if (!value)
+ return false;
+ auto applyOp = value.getDefiningOp<affine::AffineApplyOp>();
+ if (!applyOp)
+ return false;
+ AffineMap map = applyOp.getAffineMap();
+ SmallVector<Value> operands(applyOp.getOperands());
+ affine::fullyComposeAffineMapAndOperands(&map, &operands);
+ map = simplifyAffineMap(map);
+ if (map.getNumResults() != 1)
+ return false;
+ return map.getResult(0).isMultipleOf(factor);
+}
+
+/// Given a `collapsedOffset` and `collapsedSize`, this function
+/// validates that the slice is representable as a contiguous slice
+/// in the `expandedShape` and computes the corresponding expanded sizes.
+/// Returns failure if the slice cannot be guaranteed to be contiguous.
+/// On success, populates `groupSizes` with the expanded sizes for each
+/// dimension in the reassociation group.
+static LogicalResult computeExpandedSliceInfoForReassocGroup(
+ OpBuilder &b, OpFoldResult collapsedSize, OpFoldResult collapsedOffset,
+ const ReassociationIndices &reassocIndices, ArrayRef<int64_t> expandedShape,
+ SmallVectorImpl<OpFoldResult> &groupSizes) {
+ assert(groupSizes.empty() && "Group sizes must be empty");
+ // The first case is when there's only one non-unit dimension in the
+ // reassociation group.
+ // When there's only one non-unit dimension, the slice is trivially
+ // contiguous - offset and size go directly on that dimension.
+ // This works for both dynamic size and dynamic offset.
+ int nonUnitSizeCount = llvm::count_if(
+ reassocIndices, [&expandedShape](int64_t expandedShapeIdx) {
+ return expandedShape[expandedShapeIdx] != 1;
+ });
+ if (nonUnitSizeCount == 1) {
+ for (int64_t expandedShapeIdx : reassocIndices) {
+ if (expandedShape[expandedShapeIdx] != 1)
+ groupSizes.push_back(collapsedSize);
+ else
+ groupSizes.push_back(b.getIndexAttr(1));
+ }
+ return success();
+ }
+
+ // Having dynamic extracted size requires additional complex
+ // analysis to guarantee contiguous slicing.
+ if (isa<Value>(collapsedSize))
+ return failure();
+
+ std::optional<int64_t> staticSize = getConstantIntValue(collapsedSize);
+ assert(staticSize.has_value() && "Expected static size");
+
+ // The extracted size is only one element, offset may be static
+ // or dynamic, It's a trivial case where we always can guarantee
+ // contiguous slicing.
+ if (staticSize.value() == 1) {
+ SmallVector<int64_t> basis;
+ for (size_t i = 0; i < reassocIndices.size(); ++i)
+ groupSizes.push_back(b.getIndexAttr(1));
+
+ return success();
+ }
+
+ // Size is static and greater than 1, offset may be static or dynamic.
+ // Use traversal to find dimension k where slicing occurs.
+ // Verify that the slice can be represented as a contiguous slice of the
+ // src of the collapse_shape.
+ // Checking this is done on order of most internal dimensions first,
+ // so traversal is done in reverse order of the reassociation group.
+ // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
+ // ...,An] then we first find the size and offset for n...k+1 then for k
+ // and then for k-1...0.
+
+ // currentCollapsedsize is initialized with the original collapsed size
+ // and divided by the expanded shape size in each dimension as we go along
+ // the reassociation group. In essence we are spreading the original
+ // collapsed size over the various expanded slice dimensions.
+ // currentOffsetDivisor is initialized with 1 and multiplied by the expanded
+ // shape size in each dimension as we go along the reassociation group.
+ // These variables are used both to check the validity of the slice and to
+ // compute the expanded sizes and offsets.
+ assert(staticSize.value() > 1 && "Expected size to be greater than 1");
+ int64_t currentCollapsedsize = staticSize.value();
+ int64_t currentOffsetDivisor = 1;
+
+ ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
+ reassocIndices.rend());
+ int64_t idx = 0;
+ int64_t reassocGroupSize = reassocIndices.size();
+
+ SmallVector<OpFoldResult> groupExpandedSizes;
+
+ // First handle the trailing dimensions where the slice size should be
+ // equal to the tensor shape and the offset should be 0 (n...k+1).
+ for (; idx < reassocGroupSize; ++idx) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+
+ if (currentCollapsedsize < expandedShapeSize)
+ break;
+
+ // Check size divisibility.
+ if ((currentCollapsedsize % expandedShapeSize) != 0)
+ return failure();
+
+ // Check dynamic/static offset divisibility.
+ currentOffsetDivisor *= expandedShapeSize;
+ if (!isMultipleOf(collapsedOffset, currentOffsetDivisor))
+ return failure();
+
+ // Trailing dims get full shape and zero offset.
+ groupSizes.push_back(b.getIndexAttr(expandedShapeSize));
+ currentCollapsedsize /= expandedShapeSize;
+ }
+
+ // Now handle the first dim where slicing occurs on (k).
+ if (idx < reassocGroupSize) {
+ int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]];
+ std::optional<int64_t> staticOffset = getConstantIntValue(collapsedOffset);
+
+ if (staticOffset.has_value()) {
+ // Static offset: check that offset + size doesn't exceed dimension.
+ int64_t offsetInDim =
+ (staticOffset.value() / currentOffsetDivisor) % expandedShapeSize;
+ if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize)
+ return failure();
----------------
IanWood1 wrote:
```suggestion
if ((currentCollapsedsize + offsetInDim) > expandedShapeSize)
return failure();
```
We should be able to apply the transform in cases like:
```mlir
func.func @bubble_up_extract_slice_through_collapse_shape_boundary_offset(%arg0: tensor<3x10xf32>) -> tensor<5xf32> {
%collapsed = tensor.collapse_shape %arg0 [[0, 1]] : tensor<3x10xf32> into tensor<30xf32>
%extracted_slice = tensor.extract_slice %collapsed[5] [5] [1] : tensor<30xf32> to tensor<5xf32>
return %extracted_slice : tensor<5xf32>
}
```
https://github.com/llvm/llvm-project/pull/178921
More information about the Mlir-commits
mailing list