[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)

Artem Gindinson llvmlistbot at llvm.org
Mon May 12 00:23:42 PDT 2025


================
@@ -31,64 +31,128 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                          ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
-
-  ReassociationIndices currIndices;
-  int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
-      break;
-
-    int64_t currTargetShape = targetShape[targetDim];
-    while (sourceDim < (sourceShape.size() - 1) &&
-           sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
-      prodOfCollapsedDims *= sourceShape[sourceDim];
-      currIndices.push_back(sourceDim++);
+  SmallVector<ReassociationIndices, 4> reassociationMap;
+  reassociationMap.reserve(numTargetDims);
+
+  unsigned sourceDimIdx = 0, targetDimIdx = 0;
+  // Source dimensions iteration logic for static target dimensions.
+  // FIXME: Instead of lambda-capturing this function's source shape index "in
+  // place", consider refactoring this into a separate function.
+  auto collectSourceIndicesForStaticTargetDim =
+      [&](int64_t targetShape,
+          bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
+    ReassociationIndices resultIndices;
+    int64_t prodOfCollapsedDims = 1;
+    bool reachedTargetDimSize = false;
+    for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
+      // Source shape cannot be dynamic if the target dim is static.
+      if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
+        return failure();
+      prodOfCollapsedDims *= sourceShape[sourceDimIdx];
+      resultIndices.push_back(sourceDimIdx);
+      if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
+        return failure();
+      while (prodOfCollapsedDims > targetShape) {
+        assert(!resultIndices.empty());
+        auto frontOffsetIdx = resultIndices.begin();
+        prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
+        resultIndices.erase(frontOffsetIdx);
+      }
+      if (prodOfCollapsedDims == targetShape) {
+        reachedTargetDimSize = true;
+        ++sourceDimIdx;
+        break;
+      }
+    }
+    if (!reachedTargetDimSize)
+      return failure();
+    return resultIndices;
+  };
+  // Source dimensions iteration logic for dynamic target dimensions.
+  // FIXME: Instead of lambda-capturing this function's source shape index "in
+  // place", consider refactoring this into a separate function.
+  auto collectSourceIndicesForDynamicTargetDim =
+      [&](bool allowStaticNonOnes,
+          bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
+    ReassociationIndices resultIndices;
+    bool foundFirstDynamic = false;
+    while (sourceDimIdx < numSourceDims) {
+      if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) {
+        if (foundFirstDynamic && !mapConsecutiveDynDims)
+          break;
+        foundFirstDynamic |= true;
+      } else {
+        if (foundFirstDynamic)
+          break;
+        else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes)
+          return failure();
+      }
+      resultIndices.push_back(sourceDimIdx++);
+    }
+    if (!foundFirstDynamic)
+      return failure();
+    return resultIndices;
+  };
+  // Iterate over target shape.
+  bool wasLastDimDynamic = false;
+  for (; targetDimIdx < numTargetDims; ++targetDimIdx) {
+    int64_t currTargetShape = targetShape[targetDimIdx];
+    if (currTargetShape != ShapedType::kDynamic) {
+      unsigned sourceDimAtStart = sourceDimIdx;
+      auto indices = collectSourceIndicesForStaticTargetDim(
+          currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
+      if (failed(indices))
+        return std::nullopt;
+      if (wasLastDimDynamic) {
+        assert(!reassociationMap.empty());
+        auto &previousIndices = reassociationMap.back();
----------------
AGindinson wrote:

Fair enough, I thought of this more of an "assert-as-comment" but it really is obsolete

https://github.com/llvm/llvm-project/pull/137963


More information about the Mlir-commits mailing list