[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)
Artem Gindinson
llvmlistbot at llvm.org
Fri May 23 07:55:10 PDT 2025
================
@@ -28,67 +32,319 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return std::nullopt;
}
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
- return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+ /// FIXME: Signed type is used for consistency with ReassociationIndices.
+ /// We should consider refactoring all reassociation utilities to use unsigned
+ /// types.
+ int64_t leftIdx = 0, rightIdx = 0;
+
+ /// Util for manual checks of the range's validity
+ LogicalResult verify() const {
+ return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+ }
+
+ /// Checks range's containment within another range. Treats the edges
+ /// non-exclusively.
+ bool isInRange(const ReassociationIndexRange &outerRange) const {
+ return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+ }
+
+ unsigned size() const {
+ assert(succeeded(verify()));
+ return rightIdx - leftIdx + 1;
+ }
+ bool containsSingleIndex() const { return size() == 1; }
+
+ void expandRight() { ++rightIdx; }
+ void shrinkLeft() { ++leftIdx; }
+
+ /// Implements arithmetic XOR semantics to get non-overlapping indices between
+ /// ranges.
+ ReassociationIndices operator^(ReassociationIndexRange &rhs) const {
+ ReassociationIndices result;
+ result.reserve(size() + rhs.size() / 2); // Attempt to amortize
+ for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
+ if (idx < rhs.leftIdx || idx > rhs.rightIdx)
+ result.push_back(idx);
+ }
+ for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) {
+ if (rhsIndex < leftIdx || rhsIndex > rightIdx)
+ result.push_back(rhsIndex);
+ }
+ return result;
+ }
+
+ /// Converts the range into ReassociationIndices.
+ ReassociationIndices getFullIndices() const {
+ ReassociationIndices result;
+ for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
+ result.push_back(idx);
+ }
+ return result;
+ }
+};
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence that can be collapsed into a dynamic dimension (at least one must
+/// be present in the source).
+/// By default, lazily returns once the first dynamic dimension has been found.
+/// Setting `matchGreedily` as `true` will also mark all subsequent
+/// source dimensions for collapsing into the target.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
+ int64_t sourceStartIdx,
+ bool matchGreedily = false) {
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+ const unsigned numSourceDims = sourceShape.size();
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+ if (!iterationRange.isInRange(sourceShapeAsRange))
+ return failure();
+ auto resultRange = iterationRange;
+
+ bool foundDynamic = false;
+ for (; iterationRange.isInRange(sourceShapeAsRange);
+ iterationRange.expandRight()) {
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+ if (foundDynamic && !matchGreedily)
+ break;
+ if (sourceSize == ShapedType::kDynamic)
+ foundDynamic = true;
+ resultRange = iterationRange;
+ }
+ if (!foundDynamic)
+ return failure();
+ return resultRange;
+}
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence of static dimensions such that their product matches `targetSize`.
+/// By default, lazily returns once the product matches the target size. Setting
+/// `matchGreedily` as `true` will append all neighboring unit dimensions
+/// (dimensions of 1) to the match.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
+ int64_t sourceStartIdx, int64_t targetSize,
+ bool matchGreedily = false) {
+ ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+ const unsigned numSourceDims = sourceShape.size();
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+ if (!iterationRange.isInRange(sourceShapeAsRange))
+ return failure();
+ auto resultRange = iterationRange;
- ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
- // If we have mapped all the target dimensions stop and handle the remaining
- // tail of size-1 dimensions explicitly.
- if (targetDim == targetShape.size())
+ bool reachedTargetDimSize = false;
+ while (iterationRange.isInRange(sourceShapeAsRange)) {
+ int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+ if (reachedTargetDimSize && !matchGreedily)
+ break;
+ if (sourceSize == ShapedType::kDynamic) {
+ if (reachedTargetDimSize)
+ break;
+ // Reassociation for a static dim cannot include a dynamic dim. Reset
+ // induction variables to essentially restart the loop from the next
+ // source dimension.
+ prodOfCollapsedDims = 1;
+ resultRange = {iterationRange.rightIdx + 1, iterationRange.rightIdx + 1};
+ iterationRange = resultRange;
+ continue;
+ }
+ prodOfCollapsedDims *= sourceSize;
+ if (prodOfCollapsedDims > targetSize && reachedTargetDimSize)
break;
+ // If the target size has been exceeded without matching, we need to shift
+ // the range start right. From the start of the range, roll back the
+ // multiplication until the target size exceeds the product again.
+ while (prodOfCollapsedDims > targetSize &&
+ !iterationRange.containsSingleIndex()) {
+ int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
+ prodOfCollapsedDims /= frontSourceSize;
+ iterationRange.shrinkLeft();
+ }
+ resultRange = iterationRange;
+ // We could've reached the target size with the current dimension,
+ // also as a result of the above shift to right.
+ if (prodOfCollapsedDims == targetSize)
+ reachedTargetDimSize = true;
+ // Increment the iteration range
+ iterationRange.expandRight();
+ }
+ if (!reachedTargetDimSize)
+ return failure();
+ return resultRange;
+}
+
+/// Attempts to find a valid collapsing reassociation of `sourceShape` into
+/// `targetShape` through a simple traversal. If successful, an array of source
+/// index ranges is returned, correspondingly to each dimension in the target
+/// shape. The resulting indices shall fully cover the `sourceShape` without
+/// overlaps.
+///
+/// The algorithm is essentially a lazy one, searching for non-greedy matches -
+/// it will only yield a greedy match for the last target dimension.
+/// FIXME: The algorithm can only backtrack when it needs to append an offset
+/// for a static target dimension to the preceding dynamic one (this retains the
+/// linear complexity). As feasible, consider adding further backtracking
+/// routines to enable more reassociations, e.g.:
+/// - ?x2x?x2 into ?x2
+FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape) {
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ assert(numSourceDims > numTargetDims);
+ ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+
+ SmallVector<ReassociationIndexRange> reassocRanges;
+ reassocRanges.reserve(numTargetDims);
+ // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
+ // cases, e.g.:
+ // - ?x2x3x5 into ?x15
+ std::optional<int64_t> prevTargetSize = std::nullopt;
+ for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
+ targetDimIdx < numTargetDims; ++targetDimIdx) {
+ int64_t targetSize = targetShape[targetDimIdx];
+ // Simply check if there are any subsequent target dimensions left - if not,
+ // the match must be made greedily.
+ bool isLastTargetDim = targetDimIdx == numTargetDims - 1;
+ bool shouldMatchGreedily = isLastTargetDim;
+ FailureOr<ReassociationIndexRange> sourceRange;
+ if (targetSize == ShapedType::kDynamic) {
+ sourceRange = findReassociationRangeForDynamicDim(
+ sourceShape, sourceDimIdx, shouldMatchGreedily);
+ } else {
+ sourceRange = findReassociationRangeForSize(
+ sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
+ }
- int64_t currTargetShape = targetShape[targetDim];
- while (sourceDim < (sourceShape.size() - 1) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
- prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
+ // Run sanity checks on the returned index range.
+ if (failed(sourceRange) || failed(sourceRange->verify()) ||
+ !sourceRange->isInRange(sourceShapeAsRange))
+ return failure();
+ if (sourceRange->leftIdx > sourceDimIdx) {
+ // If some source dimensions had to be skipped in order to find a match,
+ // they must be collapsed into the directly preceding dynamic dimension.
+ if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
+ return failure();
+ reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
}
- // If the current expanded dimension is dynamic, then the collapsed
- // dimensions should also be dynamic and product of all previous unprocessed
- // dimensions of the expanded shape should be 1.
- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
- return std::nullopt;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamic &&
- sourceShape[sourceDim] != ShapedType::kDynamic)
- return std::nullopt;
-
- // For static shapes, if the product of dimensions of the expanded shape
- // should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
- return std::nullopt;
-
- currIndices.push_back(sourceDim++);
- reassociationMap.emplace_back(ReassociationIndices{});
- std::swap(reassociationMap.back(), currIndices);
- prodOfCollapsedDims = 1;
+ // Store the gathered information as required for the next iteration.
+ prevTargetSize = targetSize;
+ sourceDimIdx = sourceRange->rightIdx + 1;
+ reassocRanges.emplace_back(std::move(*sourceRange));
+ }
+ // Fail if the source shape wasn't a full match for the target shape. We only
+ // need to check the last recorded index - any other gaps should have been
+ // mended by the main loop.
+ if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
+ return failure();
+ return reassocRanges;
+}
+
+/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
+/// the shapes right-to-left.
+FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape,
+ bool iterateRightToLeft) {
+ if (!iterateRightToLeft)
+ return findReassociationRangesForCollapse(sourceShape, targetShape);
+ // FIXME: It would be preferable to avoid the expensive copies. At the moment,
+ // this approach is chosen for readability of the main implementation.
+ auto sourceToReverse = sourceShape.vec(), targetToReverse = targetShape.vec();
+ std::reverse(sourceToReverse.begin(), sourceToReverse.end());
+ std::reverse(targetToReverse.begin(), targetToReverse.end());
+ auto invertedRanges =
+ findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
+ if (failed(invertedRanges))
+ return failure();
+ auto rangesToInvert = *invertedRanges;
+ unsigned numSourceDims = sourceShape.size();
+ // We have received the ranges for inverted shapes. Now we have to invert
+ // the ranges back to correspond with the original source shape.
+ for (auto &range : rangesToInvert) {
+ if (failed(range.verify()))
+ return failure();
----------------
AGindinson wrote:
Good point, to be frank I'd even skip the assert
https://github.com/llvm/llvm-project/pull/137963
More information about the Mlir-commits
mailing list