[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)

Ian Wood llvmlistbot at llvm.org
Sun May 11 17:44:53 PDT 2025


================
@@ -31,64 +31,128 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                          ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
-
-  ReassociationIndices currIndices;
-  int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
-      break;
-
-    int64_t currTargetShape = targetShape[targetDim];
-    while (sourceDim < (sourceShape.size() - 1) &&
-           sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
-      prodOfCollapsedDims *= sourceShape[sourceDim];
-      currIndices.push_back(sourceDim++);
+  SmallVector<ReassociationIndices, 4> reassociationMap;
+  reassociationMap.reserve(numTargetDims);
+
+  unsigned sourceDimIdx = 0, targetDimIdx = 0;
+  // Source dimensions iteration logic for static target dimensions.
+  // FIXME: Instead of lambda-capturing this function's source shape index "in
+  // place", consider refactoring this into a separate function.
+  auto collectSourceIndicesForStaticTargetDim =
+      [&](int64_t targetShape,
+          bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
+    ReassociationIndices resultIndices;
+    int64_t prodOfCollapsedDims = 1;
+    bool reachedTargetDimSize = false;
+    for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
+      // Source shape cannot be dynamic if the target dim is static.
+      if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
+        return failure();
+      prodOfCollapsedDims *= sourceShape[sourceDimIdx];
+      resultIndices.push_back(sourceDimIdx);
+      if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
+        return failure();
+      while (prodOfCollapsedDims > targetShape) {
+        assert(!resultIndices.empty());
+        auto frontOffsetIdx = resultIndices.begin();
+        prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
+        resultIndices.erase(frontOffsetIdx);
----------------
IanWood1 wrote:

Because `resultIndices` is ultimately just a contiguous sequence between an lower/upper bound, it may be simpler and more efficient to just track the bounds and then construct the array at the end. This would also avoid the need to erase from the front of the array.

https://github.com/llvm/llvm-project/pull/137963


More information about the Mlir-commits mailing list