[Mlir-commits] [mlir] Revert "[mlir][tensor] Loosen restrictions on folding dynamic reshapes" (PR #142639)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jun 3 13:54:19 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
Author: Ian Wood (IanWood1)
<details>
<summary>Changes</summary>
Reverts llvm/llvm-project#<!-- -->137963
---
Patch is 32.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142639.diff
5 Files Affected:
- (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+53-319)
- (modified) mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir (+2-2)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+4-35)
- (modified) mlir/unittests/Dialect/Utils/CMakeLists.txt (-1)
- (removed) mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp (-203)
``````````diff
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 3b1fdb69e8ef1..1a04d702e0559 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -10,10 +10,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/LogicalResult.h"
#include <numeric>
#include <optional>
@@ -32,329 +28,67 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return std::nullopt;
}
-namespace {
-/// A simple struct to represent ReassociationIndices as an inclusive interval.
-/// It's designed to be feasibly minimal, so the call sites should manage the
-/// validity of the range manually.
-struct ReassociationIndexRange {
- /// FIXME: Signed type is used for consistency with ReassociationIndices.
- /// We should consider refactoring all reassociation utilities to use unsigned
- /// types.
- int64_t leftIdx = 0, rightIdx = 0;
-
- /// Util for manual checks of the range's validity
- LogicalResult verify() const {
- return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
- }
-
- /// Checks range's containment within another range. Treats the edges
- /// non-exclusively.
- bool isInRange(const ReassociationIndexRange &outerRange) const {
- return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
- }
-
- unsigned size() const {
- assert(succeeded(verify()));
- return rightIdx - leftIdx + 1;
- }
- bool containsSingleIndex() const { return size() == 1; }
-
- /// Collects indices that do not overlap between this and another range.
- ReassociationIndices
- getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
- if (rightIdx < rhs.leftIdx) {
- // The intervals do not overlap - concatenate the indices from both.
- auto jointFullIndices = getFullIndices();
- jointFullIndices.append(rhs.getFullIndices());
- return jointFullIndices;
- }
- ReassociationIndices result;
- // Handle the chunk left of the overlapping range.
- int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
- int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
- llvm::append_range(result, llvm::seq(leftStart, leftEnd));
- // Handle the chunk right of the overlapping range. Symmetrically, we should
- // skip the edge of the overlap AND include the rightmost index.
- int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
- int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
- if (rightStart < rightEnd)
- llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
- return result;
- }
-
- /// Converts the range into ReassociationIndices.
- ReassociationIndices getFullIndices() const {
- ReassociationIndices result;
- for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
- result.push_back(idx);
- }
- return result;
- }
-};
-} // namespace
-
-/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
-/// sequence that can be collapsed into a dynamic dimension (at least one must
-/// be present in the source).
-/// By default, lazily returns once the first dynamic dimension has been found.
-/// Setting `matchGreedily` as `true` will also mark all subsequent
-/// source dimensions for collapsing into the target.
-static FailureOr<ReassociationIndexRange>
-findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
- int64_t sourceStartIdx,
- bool matchGreedily = false) {
- const unsigned numSourceDims = sourceShape.size();
- ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
- std::optional<ReassociationIndexRange> resultRange = std::nullopt;
-
- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
- for (; iterationRange.isInRange(sourceShapeAsRange);
- iterationRange.rightIdx++) {
- int64_t sourceSize = sourceShape[iterationRange.rightIdx];
- if (sourceSize == ShapedType::kDynamic) {
- resultRange = iterationRange;
- break;
- }
- }
- if (!resultRange)
- return failure();
- if (matchGreedily)
- resultRange->rightIdx = sourceShapeAsRange.rightIdx;
- return *resultRange;
-}
+std::optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> targetShape) {
+ if (sourceShape.size() <= targetShape.size())
+ return std::nullopt;
+ unsigned sourceDim = 0;
+ SmallVector<ReassociationIndices> reassociationMap;
+ reassociationMap.reserve(targetShape.size());
-/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
-/// sequence of static dimensions such that their product matches `targetSize`.
-/// By default, lazily returns once the product matches the target size. Setting
-/// `matchGreedily` as `true` will append all neighboring unit dimensions
-/// (dimensions of 1) to the match.
-static FailureOr<ReassociationIndexRange>
-findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
- int64_t sourceStartIdx, int64_t targetSize,
- bool matchGreedily = false) {
- const unsigned numSourceDims = sourceShape.size();
- ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
- std::optional<ReassociationIndexRange> resultRange = std::nullopt;
-
- ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+ ReassociationIndices currIndices;
int64_t prodOfCollapsedDims = 1;
- while (iterationRange.isInRange(sourceShapeAsRange)) {
- int64_t sourceSize = sourceShape[iterationRange.rightIdx];
- if (sourceSize == ShapedType::kDynamic) {
- // Reassociation for a static dim cannot include a dynamic dim. Reset
- // induction variables to essentially restart the loop from the next
- // source dimension.
- prodOfCollapsedDims = 1;
- iterationRange = {iterationRange.rightIdx + 1,
- iterationRange.rightIdx + 1};
- continue;
- }
- prodOfCollapsedDims *= sourceSize;
- // If the target size has been exceeded without matching, we need to shift
- // the range start right. From the start of the range, roll back the
- // multiplication until the target size exceeds the product again.
- while (prodOfCollapsedDims > targetSize &&
- !iterationRange.containsSingleIndex()) {
- int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
- prodOfCollapsedDims /= frontSourceSize;
- // Shrink the range rightwards
- iterationRange.leftIdx++;
- }
- // We could've reached the target size with the current dimension,
- // also as a result of the above shift to right.
- if (prodOfCollapsedDims == targetSize) {
- resultRange = iterationRange;
+ while (sourceDim < sourceShape.size()) {
+ unsigned targetDim = reassociationMap.size();
+ // If we have mapped all the target dimensions stop and handle the remaining
+ // tail of size-1 dimensions explicitly.
+ if (targetDim == targetShape.size())
break;
- }
- // Increment the iteration range
- iterationRange.rightIdx++;
- }
- if (!resultRange)
- return failure();
- if (matchGreedily) {
- // We now want to collect all unit dimensions directly after the target
- // product match. Advance the iterator to avoid OOB when the product match
- // happens at the last element.
- iterationRange.rightIdx++;
- while (iterationRange.isInRange(sourceShapeAsRange) &&
- sourceShape[iterationRange.rightIdx] == 1) {
- resultRange = iterationRange;
- iterationRange.rightIdx++;
- }
- }
- return *resultRange;
-}
-/// Attempts to find a valid collapsing reassociation of `sourceShape` into
-/// `targetShape` through a simple traversal. If successful, an array of source
-/// index ranges is returned, correspondingly to each dimension in the target
-/// shape. The resulting indices shall fully cover the `sourceShape` without
-/// overlaps.
-///
-/// The algorithm is essentially a lazy one, searching for non-greedy matches -
-/// it will only yield a greedy match for the last target dimension.
-/// FIXME: The algorithm can only backtrack when it needs to append an offset
-/// for a static target dimension to the preceding dynamic one (this retains the
-/// linear complexity). As feasible, consider adding further backtracking
-/// routines to enable more reassociations, e.g.:
-/// - ?x2x?x2 into ?x2
-static FailureOr<SmallVector<ReassociationIndexRange>>
-findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape) {
- unsigned numSourceDims = sourceShape.size(),
- numTargetDims = targetShape.size();
- assert(numSourceDims > numTargetDims);
- ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
-
- SmallVector<ReassociationIndexRange> reassocRanges;
- reassocRanges.reserve(numTargetDims);
- // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
- // cases, e.g.:
- // - ?x2x3x5 into ?x15
- std::optional<int64_t> prevTargetSize = std::nullopt;
- for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
- targetDimIdx < numTargetDims; ++targetDimIdx) {
- int64_t targetSize = targetShape[targetDimIdx];
- // Simply check if there are any subsequent target dimensions left - if not,
- // the match must be made greedily.
- bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
- FailureOr<ReassociationIndexRange> sourceRange;
- if (targetSize == ShapedType::kDynamic) {
- sourceRange = findReassociationRangeForDynamicDim(
- sourceShape, sourceDimIdx, shouldMatchGreedily);
- } else {
- sourceRange = findReassociationRangeForSize(
- sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
+ int64_t currTargetShape = targetShape[targetDim];
+ while (sourceDim < (sourceShape.size() - 1) &&
+ sourceShape[sourceDim] != ShapedType::kDynamic &&
+ prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+ prodOfCollapsedDims *= sourceShape[sourceDim];
+ currIndices.push_back(sourceDim++);
}
- // Run sanity checks on the returned index range.
- if (failed(sourceRange) || failed(sourceRange->verify()) ||
- !sourceRange->isInRange(sourceShapeAsRange))
- return failure();
- if (sourceRange->leftIdx > sourceDimIdx) {
- // If some source dimensions had to be skipped in order to find a match,
- // they must be collapsed into the directly preceding dynamic dimension.
- if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
- return failure();
- reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
- }
-
- // Store the gathered information as required for the next iteration.
- prevTargetSize = targetSize;
- sourceDimIdx = sourceRange->rightIdx + 1;
- reassocRanges.push_back(*sourceRange);
+ // If the current expanded dimension is dynamic, then the collapsed
+ // dimensions should also be dynamic and product of all previous unprocessed
+ // dimensions of the expanded shape should be 1.
+ if (sourceShape[sourceDim] == ShapedType::kDynamic &&
+ (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+ return std::nullopt;
+
+ // If the collapsed dim is dynamic, the current expanded dim should also
+ // be dynamic.
+ if (currTargetShape == ShapedType::kDynamic &&
+ sourceShape[sourceDim] != ShapedType::kDynamic)
+ return std::nullopt;
+
+ // For static shapes, if the product of dimensions of the expanded shape
+ // should match the collapsed dimension shape.
+ if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
+ return std::nullopt;
+
+ currIndices.push_back(sourceDim++);
+ reassociationMap.emplace_back(ReassociationIndices{});
+ std::swap(reassociationMap.back(), currIndices);
+ prodOfCollapsedDims = 1;
}
- // Fail if the source shape wasn't a full match for the target shape. We only
- // need to check the last recorded index - any other gaps should have been
- // mended by the main loop.
- if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
- return failure();
- return reassocRanges;
-}
-
-/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
-/// the shapes right-to-left.
-static FailureOr<SmallVector<ReassociationIndexRange>>
-findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape,
- bool iterateRightToLeft) {
- if (!iterateRightToLeft)
- return findReassociationRangesForCollapse(sourceShape, targetShape);
- // NB: To iterate right-to-left, we currently reverse the shapes and then
- // reverse the result back. The reversed shapes must not be temporary, as
- // we're passing through an ArrayRef.
- // FIXME: It would be preferable to avoid the expensive copies. At the moment,
- // this approach is chosen for readability of the main implementation.
- std::vector<int64_t> sourceToReverse = sourceShape.vec(),
- targetToReverse = targetShape.vec();
- std::reverse(sourceToReverse.begin(), sourceToReverse.end());
- std::reverse(targetToReverse.begin(), targetToReverse.end());
- auto invertedRanges =
- findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
- if (failed(invertedRanges))
- return failure();
- SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
- unsigned numSourceDims = sourceShape.size();
- // We have received the ranges for inverted shapes. Now we have to invert
- // the ranges back to correspond with the original source shape.
- for (auto &range : rangesToInvert) {
- int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
- range.leftIdx = numSourceDims - 1 - invRightIdx;
- range.rightIdx = numSourceDims - 1 - invLeftIdx;
- }
- // Also invert the ordering of the ranges to correspond with the original
- // target shape.
- std::reverse(rangesToInvert.begin(), rangesToInvert.end());
- return rangesToInvert;
-}
-
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape) {
- unsigned numSourceDims = sourceShape.size(),
- numTargetDims = targetShape.size();
- // We're supposed to search for a collapsing reassociation. If the sizes
- // match, there's no actual collapsing taking place - it's either a no-op or a
- // `tensor.reshape`-style reassociation (that would be beyond the scope of
- // this utility).
- if (numSourceDims <= numTargetDims)
- return std::nullopt;
- // Early handling for scalar target types.
- if (numTargetDims == 0) {
- ReassociationIndices allSourceIndices;
- allSourceIndices.reserve(numSourceDims);
- for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
- ++sourceDimIdx) {
- int64_t sourceSize = sourceShape[sourceDimIdx];
- // All source dimensions must be unit or dynamic.
- if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
- return std::nullopt;
- allSourceIndices.push_back(sourceDimIdx);
- }
- return SmallVector<ReassociationIndices>{allSourceIndices};
- }
-
- // Collect source ranges by iterating over the target shape left-to-right.
- FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
- findReassociationRangesForCollapse(sourceShape, targetShape);
- if (failed(maybeForwardRanges))
- return std::nullopt;
- auto &ranges = *maybeForwardRanges;
- // Now do the same in reverse. We need to get another valid reassociation
- // through some other strategy, and then compare the results in order to
- // disambiguate mixed subshapes, such as:
- // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
- // This leads us to lose some of the reassociation opportunities that can only
- // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
- // backtracking, the algorithm will fail right-to-left. However, this is the
- // best way to preserve correctness.
- FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
- findReassociationRangesForCollapse(sourceShape, targetShape,
- /*iterateRightToLeft=*/true);
- if (failed(maybeReverseRanges))
- return std::nullopt;
- auto &reverseRanges = *maybeReverseRanges;
-
- if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
+ // All the dimensions in the target must have been processed.
+ if (reassociationMap.size() != targetShape.size())
return std::nullopt;
- // Now we can check for ambiguity of each target dimension's reassociation. If
- // successful, we put the full indices into our result map for the target
- // shape.
- SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
- for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
- ++targetDimIdx) {
- ReassociationIndexRange &range = ranges[targetDimIdx];
- ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
- // Get non-overlapping indices between the ranges
- ReassociationIndices nonMatchingIndices =
- range.getNonOverlappingIndicesWith(reverseRange);
- // Unit dimensions can be collapsed wherever - this is the only ambiguity
- // that we allow.
- for (int64_t sourceDimIdx : nonMatchingIndices) {
- if (sourceShape[sourceDimIdx] != 1)
- return std::nullopt;
- }
- reassociationMap[targetDimIdx] = range.getFullIndices();
+ // Process any remaining entries in the source shape. They all need to be
+ // 1 or dynamic.
+ for (; sourceDim < sourceShape.size(); sourceDim++) {
+ if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+ sourceShape[sourceDim] != 1)
+ return std::nullopt;
+ // The map is empty when the target type is a scalar.
+ if (!reassociationMap.empty())
+ reassociationMap.back().push_back(sourceDim);
}
return reassociationMap;
}
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 6979770154bab..51350e5bc8498 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK: tensor.collapse
-// CHECK-NOT: linalg.unpack
+// CHECK-NOT: tensor.collapse
+// CHECK: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..0abec7e01d184 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1117,7 +1117,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
// -----
-func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1125,28 +1125,12 @@ func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %ar
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
-// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
// CHECK-NOT: tensor.{{.*}}_shape
// -----
-func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
- -> tensor<?x4x?xf32> {
- %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
- : tensor<?x4x?x2xf32> into tensor<?x?xf32>
- %1 = tensor.expand_shape %0 [...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/142639
More information about the Mlir-commits
mailing list