[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)

Artem Gindinson llvmlistbot at llvm.org
Thu May 8 06:20:55 PDT 2025


================
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                          ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
-
-  ReassociationIndices currIndices;
-  int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
-      break;
+  SmallVector<ReassociationIndices, 4> reassociationMap;
+  reassociationMap.reserve(numTargetDims);
 
+  unsigned sourceDim = 0, targetDim = 0;
+  for (; targetDim < numTargetDims; ++targetDim) {
     int64_t currTargetShape = targetShape[targetDim];
-    while (sourceDim < (sourceShape.size() - 1) &&
-           sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+    ReassociationIndices currIndices;
+    // 1. Target dimension is dynamic. Source shape should contain at least
+    // one dynamic dimension.
+    if (currTargetShape == ShapedType::kDynamic) {
----------------
AGindinson wrote:

JFYI:
1. "Similar logic to `CollapseOfExpand`" that would be purely based on reassociation maps is quite obviously not possible (`expand[[1, 2, 3], ...]` | `collapse[[1, 2], ...]` is determinate, `collapse[[1, 2, 3]]` | `expand [[1, 2]]` is not)
2. I may've found a much simpler algorithm after all. We just need to move through the target shape in strides of 2 and only when dealing with `[dyn, static]` would we need to recurse for `static` a bit. Examples from my earlier comment: `?x5x8x3x2 into ?x48`, `?x8x3x1x1x1x1x5x2 into ?x48`
3. Interestingly enough, consecutive dynamic dims in the target can sometimes be determinate (enough). Consider `?x?xNxK into ?x?xK` (with N on either edge of a dynamic subexpressions), or `?x1x1x? into ?x?`. By iterating in pairs, I can also detect such cases (the pre-existing "tail" loop for source dimensions helps to isolate the logic).
4. Bonus to p. 3: always assigning the 1's to a preceding dynamic dimension seems a better policy than no policy at all (and there's a lot of arbitrary decisions about ones anyway in the original algorithm) - none of that impacts correctness, only blocks further folds down the road in complex, presumably rare dispatches.

Code update "soon" :)

https://github.com/llvm/llvm-project/pull/137963


More information about the Mlir-commits mailing list