[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)

Ian Wood llvmlistbot at llvm.org
Wed May 21 15:13:27 PDT 2025


================
@@ -28,67 +32,319 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
   return std::nullopt;
 }
 
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
-                                         ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
-    return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+  /// FIXME: Signed type is used for consistency with ReassociationIndices.
+  /// We should consider refactoring all reassociation utilities to use unsigned
+  /// types.
+  int64_t leftIdx = 0, rightIdx = 0;
+
+  /// Util for manual checks of the range's validity
+  LogicalResult verify() const {
+    return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+  }
+
+  /// Checks range's containment within another range. Treats the edges
+  /// non-exclusively.
+  bool isInRange(const ReassociationIndexRange &outerRange) const {
+    return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+  }
+
+  unsigned size() const {
+    assert(succeeded(verify()));
+    return rightIdx - leftIdx + 1;
+  }
+  bool containsSingleIndex() const { return size() == 1; }
+
+  void expandRight() { ++rightIdx; }
+  void shrinkLeft() { ++leftIdx; }
+
+  /// Implements arithmetic XOR semantics to get non-overlapping indices between
+  /// ranges.
+  ReassociationIndices operator^(ReassociationIndexRange &rhs) const {
+    ReassociationIndices result;
+    result.reserve(size() + rhs.size() / 2); // Attempt to amortize
+    for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
+      if (idx < rhs.leftIdx || idx > rhs.rightIdx)
+        result.push_back(idx);
+    }
+    for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) {
+      if (rhsIndex < leftIdx || rhsIndex > rightIdx)
+        result.push_back(rhsIndex);
+    }
+    return result;
+  }
+
+  /// Converts the range into ReassociationIndices.
+  ReassociationIndices getFullIndices() const {
+    ReassociationIndices result;
+    for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
+      result.push_back(idx);
+    }
+    return result;
+  }
+};
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence that can be collapsed into a dynamic dimension (at least one must
+/// be present in the source).
+/// By default, lazily returns once the first dynamic dimension has been found.
+/// Setting `matchGreedily` as `true` will also mark all subsequent
+/// source dimensions for collapsing into the target.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
+                                    int64_t sourceStartIdx,
+                                    bool matchGreedily = false) {
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  if (!iterationRange.isInRange(sourceShapeAsRange))
+    return failure();
+  auto resultRange = iterationRange;
+
+  bool foundDynamic = false;
+  for (; iterationRange.isInRange(sourceShapeAsRange);
+       iterationRange.expandRight()) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (foundDynamic && !matchGreedily)
+      break;
+    if (sourceSize == ShapedType::kDynamic)
+      foundDynamic = true;
+    resultRange = iterationRange;
+  }
+  if (!foundDynamic)
+    return failure();
+  return resultRange;
+}
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence of static dimensions such that their product matches `targetSize`.
+/// By default, lazily returns once the product matches the target size. Setting
+/// `matchGreedily` as `true` will append all neighboring unit dimensions
+/// (dimensions of 1) to the match.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
+                              int64_t sourceStartIdx, int64_t targetSize,
+                              bool matchGreedily = false) {
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  if (!iterationRange.isInRange(sourceShapeAsRange))
+    return failure();
+  auto resultRange = iterationRange;
 
-  ReassociationIndices currIndices;
   int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
+  bool reachedTargetDimSize = false;
+  while (iterationRange.isInRange(sourceShapeAsRange)) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (reachedTargetDimSize && !matchGreedily)
+      break;
+    if (sourceSize == ShapedType::kDynamic) {
+      if (reachedTargetDimSize)
+        break;
+      // Reassociation for a static dim cannot include a dynamic dim. Reset
+      // induction variables to essentially restart the loop from the next
+      // source dimension.
+      prodOfCollapsedDims = 1;
+      resultRange = {iterationRange.rightIdx + 1, iterationRange.rightIdx + 1};
+      iterationRange = resultRange;
+      continue;
+    }
+    prodOfCollapsedDims *= sourceSize;
+    if (prodOfCollapsedDims > targetSize && reachedTargetDimSize)
       break;
+    // If the target size has been exceeded without matching, we need to shift
+    // the range start right. From the start of the range, roll back the
+    // multiplication until the target size exceeds the product again.
+    while (prodOfCollapsedDims > targetSize &&
+           !iterationRange.containsSingleIndex()) {
+      int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
+      prodOfCollapsedDims /= frontSourceSize;
+      iterationRange.shrinkLeft();
+    }
+    resultRange = iterationRange;
+    // We could've reached the target size with the current dimension,
+    // also as a result of the above shift to right.
+    if (prodOfCollapsedDims == targetSize)
+      reachedTargetDimSize = true;
+    // Increment the iteration range
+    iterationRange.expandRight();
+  }
+  if (!reachedTargetDimSize)
+    return failure();
+  return resultRange;
+}
+
+/// Attempts to find a valid collapsing reassociation of `sourceShape` into
+/// `targetShape` through a simple traversal. If successful, an array of source
+/// index ranges is returned, correspondingly to each dimension in the target
+/// shape. The resulting indices shall fully cover the `sourceShape` without
+/// overlaps.
+///
+/// The algorithm is essentially a lazy one, searching for non-greedy matches -
+/// it will only yield a greedy match for the last target dimension.
+/// FIXME: The algorithm can only backtrack when it needs to append an offset
+/// for a static target dimension to the preceding dynamic one (this retains the
+/// linear complexity). As feasible, consider adding further backtracking
+/// routines to enable more reassociations, e.g.:
+/// - ?x2x?x2 into ?x2
+FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
----------------
IanWood1 wrote:

These functions can be made static



https://github.com/llvm/llvm-project/pull/137963


More information about the Mlir-commits mailing list