[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)

Artem Gindinson llvmlistbot at llvm.org
Fri May 23 11:30:10 PDT 2025


================
@@ -28,67 +32,319 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
   return std::nullopt;
 }
 
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
-                                         ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
-    return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+  /// FIXME: Signed type is used for consistency with ReassociationIndices.
+  /// We should consider refactoring all reassociation utilities to use unsigned
+  /// types.
+  int64_t leftIdx = 0, rightIdx = 0;
+
+  /// Util for manual checks of the range's validity
+  LogicalResult verify() const {
+    return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+  }
+
+  /// Checks range's containment within another range. Treats the edges
+  /// non-exclusively.
+  bool isInRange(const ReassociationIndexRange &outerRange) const {
+    return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+  }
+
+  unsigned size() const {
+    assert(succeeded(verify()));
+    return rightIdx - leftIdx + 1;
+  }
+  bool containsSingleIndex() const { return size() == 1; }
+
+  void expandRight() { ++rightIdx; }
+  void shrinkLeft() { ++leftIdx; }
+
+  /// Implements arithmetic XOR semantics to get non-overlapping indices between
+  /// ranges.
+  ReassociationIndices operator^(ReassociationIndexRange &rhs) const {
+    ReassociationIndices result;
+    result.reserve(size() + rhs.size() / 2); // Attempt to amortize
+    for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
+      if (idx < rhs.leftIdx || idx > rhs.rightIdx)
+        result.push_back(idx);
+    }
+    for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) {
+      if (rhsIndex < leftIdx || rhsIndex > rightIdx)
+        result.push_back(rhsIndex);
+    }
----------------
AGindinson wrote:

With:
```cpp
    llvm::append_range(result, llvm::seq(leftIdx, rhs.leftIdx));
    llvm::append_range(result, llvm::seq_inclusive(rightIdx + 1, rhs.rightIdx));
    llvm::append_range(result, llvm::seq(rhs.leftIdx, leftIdx));
    llvm::append_range(result, llvm::seq_inclusive(rhs.rightIdx + 1, rightIdx));
```
nothing gets appended, and other attempts yield very weird behavior which I assume is somehow caused by `llvm::iota_range` :( Not sure what's the correct improvement here.

https://github.com/llvm/llvm-project/pull/137963


More information about the Mlir-commits mailing list