[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)
Artem Gindinson
llvmlistbot at llvm.org
Mon May 12 00:23:42 PDT 2025
================
@@ -31,64 +31,128 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ if (numSourceDims <= numTargetDims)
return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
-
- ReassociationIndices currIndices;
- int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
- // If we have mapped all the target dimensions stop and handle the remaining
- // tail of size-1 dimensions explicitly.
- if (targetDim == targetShape.size())
- break;
-
- int64_t currTargetShape = targetShape[targetDim];
- while (sourceDim < (sourceShape.size() - 1) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
- prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
+ SmallVector<ReassociationIndices, 4> reassociationMap;
+ reassociationMap.reserve(numTargetDims);
+
+ unsigned sourceDimIdx = 0, targetDimIdx = 0;
+ // Source dimensions iteration logic for static target dimensions.
+ // FIXME: Instead of lambda-capturing this function's source shape index "in
+ // place", consider refactoring this into a separate function.
+ auto collectSourceIndicesForStaticTargetDim =
+ [&](int64_t targetShape,
+ bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
+ ReassociationIndices resultIndices;
+ int64_t prodOfCollapsedDims = 1;
+ bool reachedTargetDimSize = false;
+ for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
+ // Source shape cannot be dynamic if the target dim is static.
+ if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
+ return failure();
+ prodOfCollapsedDims *= sourceShape[sourceDimIdx];
+ resultIndices.push_back(sourceDimIdx);
+ if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
+ return failure();
+ while (prodOfCollapsedDims > targetShape) {
+ assert(!resultIndices.empty());
+ auto frontOffsetIdx = resultIndices.begin();
+ prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
+ resultIndices.erase(frontOffsetIdx);
+ }
+ if (prodOfCollapsedDims == targetShape) {
+ reachedTargetDimSize = true;
+ ++sourceDimIdx;
+ break;
+ }
+ }
+ if (!reachedTargetDimSize)
+ return failure();
+ return resultIndices;
+ };
+ // Source dimensions iteration logic for dynamic target dimensions.
+ // FIXME: Instead of lambda-capturing this function's source shape index "in
+ // place", consider refactoring this into a separate function.
+ auto collectSourceIndicesForDynamicTargetDim =
+ [&](bool allowStaticNonOnes,
+ bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
+ ReassociationIndices resultIndices;
+ bool foundFirstDynamic = false;
+ while (sourceDimIdx < numSourceDims) {
+ if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) {
+ if (foundFirstDynamic && !mapConsecutiveDynDims)
+ break;
+ foundFirstDynamic |= true;
+ } else {
+ if (foundFirstDynamic)
+ break;
+ else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes)
+ return failure();
+ }
+ resultIndices.push_back(sourceDimIdx++);
+ }
+ if (!foundFirstDynamic)
+ return failure();
+ return resultIndices;
+ };
+ // Iterate over target shape.
+ bool wasLastDimDynamic = false;
+ for (; targetDimIdx < numTargetDims; ++targetDimIdx) {
+ int64_t currTargetShape = targetShape[targetDimIdx];
+ if (currTargetShape != ShapedType::kDynamic) {
+ unsigned sourceDimAtStart = sourceDimIdx;
+ auto indices = collectSourceIndicesForStaticTargetDim(
+ currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
+ if (failed(indices))
+ return std::nullopt;
+ if (wasLastDimDynamic) {
+ assert(!reassociationMap.empty());
+ auto &previousIndices = reassociationMap.back();
----------------
AGindinson wrote:
Fair enough, I thought of this more of an "assert-as-comment" but it really is obsolete
https://github.com/llvm/llvm-project/pull/137963
More information about the Mlir-commits
mailing list