[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)
Ian Wood
llvmlistbot at llvm.org
Tue May 27 10:18:28 PDT 2025
================
@@ -28,67 +32,319 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
return std::nullopt;
}
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
- return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+ /// FIXME: Signed type is used for consistency with ReassociationIndices.
+ /// We should consider refactoring all reassociation utilities to use unsigned
+ /// types.
+ int64_t leftIdx = 0, rightIdx = 0;
+
+ /// Util for manual checks of the range's validity
+ LogicalResult verify() const {
+ return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+ }
+
+ /// Checks range's containment within another range. Treats the edges
+ /// non-exclusively.
+ bool isInRange(const ReassociationIndexRange &outerRange) const {
+ return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+ }
+
+ unsigned size() const {
+ assert(succeeded(verify()));
+ return rightIdx - leftIdx + 1;
+ }
+ bool containsSingleIndex() const { return size() == 1; }
+
+ void expandRight() { ++rightIdx; }
+ void shrinkLeft() { ++leftIdx; }
+
+ /// Implements arithmetic XOR semantics to get non-overlapping indices between
+ /// ranges.
+ ReassociationIndices operator^(ReassociationIndexRange &rhs) const {
+ ReassociationIndices result;
+ result.reserve(size() + rhs.size() / 2); // Attempt to amortize
+ for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
+ if (idx < rhs.leftIdx || idx > rhs.rightIdx)
+ result.push_back(idx);
+ }
+ for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) {
+ if (rhsIndex < leftIdx || rhsIndex > rightIdx)
+ result.push_back(rhsIndex);
+ }
----------------
IanWood1 wrote:
I think this should work
```cpp
ReassociationIndices result;
int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
llvm::append_range(result, llvm::seq(leftStart, leftEnd));
int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
if (rightStart < rightEnd)
llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
return result;
```
The first part is straightforward since `llvm::seq` already doesn't include `leftEnd` in the range. However, the second part is a bit tricky because we want the opposite (don't include `rightStart` but include `rightEnd`).
https://github.com/llvm/llvm-project/pull/137963
More information about the Mlir-commits
mailing list