[Mlir-commits] [mlir] cb4a407 - [mlir][tensor] Loosen restrictions on folding dynamic reshapes (#137963)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 3 09:09:06 PDT 2025


Author: Artem Gindinson
Date: 2025-06-03T09:09:01-07:00
New Revision: cb4a407e5c2a8a5972781d2a3be362f437602fae

URL: https://github.com/llvm/llvm-project/commit/cb4a407e5c2a8a5972781d2a3be362f437602fae
DIFF: https://github.com/llvm/llvm-project/commit/cb4a407e5c2a8a5972781d2a3be362f437602fae.diff

LOG: [mlir][tensor] Loosen restrictions on folding dynamic reshapes (#137963)

The main idea behind the change is to allow expand-of-collapse folds for
reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the
expand op must have a coherent index/affine expression specified in its
`output_shape` argument (see example below), and if it doesn't, the IR
has already been invalidated at an earlier stage:
```
%c32 = arith.constant 32 : index
%div = arith.divsi %<some_index>, %c32 : index
%collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]]
	         : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32>
%affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div]
%expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine]
		: tensor<9x?x?xf32> into tensor<9x?x32x?xf32>
```

On the above assumption, adjust the routine in
`getReassociationIndicesForCollapse()` to allow dynamic reshapes beyond
just `?x..?x1x1x..x1` -> `?`. Dynamic subshapes introduce two kinds of
issues:
1. n>2 consecutive dynamic dimensions in the source shape cannot be
collapsed together into 1<k<n neighboring dynamic dimensions in the
target shape, since there'd be more than one suitable reassociation
(example: `?x?x10x? into ?x?`)
2. When figuring out static subshape reassociations based on products,
there are cases where a static dimension is collapsed with a dynamic
one, and should therefore be skipped when comparing products of source &
target dimensions (e.g. `?x2x3x4 into ?x12`)

To address 1, we should detect such sequences in the target shape before
assigning multiple dynamic dimensions into the same index set. For 2, we
take note that a static target dimension was preceded by a dynamic one
and allow an "offset" subshape of source static dimensions, as long as
there's an exact sequence for the target size later in the source shape.

This PR aims to address all reshapes that can be determined based purely
on shapes (and original reassociation
maps, as done in
`ComposeExpandOfCollapseOp::findCollapsingReassociation)`. It doesn't
seem possible to fold all qualifying dynamic shape patterns in a
deterministic way without looking into affine expressions
simultaneously. That would be difficult to maintain in a single general
utility, so a path forward would be to provide dialect-specific
implementations for Linalg/Tensor.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>

---------

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Co-authored-by: Ian Wood <ianwood2024 at u.northwestern.edu>

Added: 
    mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp

Modified: 
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/unittests/Dialect/Utils/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index ed40a080441bc..3b1fdb69e8ef1 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -10,6 +10,10 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 
 #include <numeric>
 #include <optional>
@@ -28,67 +32,329 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
   return std::nullopt;
 }
 
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
-                                         ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
-    return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+  /// FIXME: Signed type is used for consistency with ReassociationIndices.
+  /// We should consider refactoring all reassociation utilities to use unsigned
+  /// types.
+  int64_t leftIdx = 0, rightIdx = 0;
+
+  /// Util for manual checks of the range's validity
+  LogicalResult verify() const {
+    return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+  }
+
+  /// Checks range's containment within another range. Treats the edges
+  /// non-exclusively.
+  bool isInRange(const ReassociationIndexRange &outerRange) const {
+    return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+  }
 
-  ReassociationIndices currIndices;
+  unsigned size() const {
+    assert(succeeded(verify()));
+    return rightIdx - leftIdx + 1;
+  }
+  bool containsSingleIndex() const { return size() == 1; }
+
+  /// Collects indices that do not overlap between this and another range.
+  ReassociationIndices
+  getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
+    if (rightIdx < rhs.leftIdx) {
+      // The intervals do not overlap - concatenate the indices from both.
+      auto jointFullIndices = getFullIndices();
+      jointFullIndices.append(rhs.getFullIndices());
+      return jointFullIndices;
+    }
+    ReassociationIndices result;
+    // Handle the chunk left of the overlapping range.
+    int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
+    int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
+    llvm::append_range(result, llvm::seq(leftStart, leftEnd));
+    // Handle the chunk right of the overlapping range. Symmetrically, we should
+    // skip the edge of the overlap AND include the rightmost index.
+    int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
+    int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
+    if (rightStart < rightEnd)
+      llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
+    return result;
+  }
+
+  /// Converts the range into ReassociationIndices.
+  ReassociationIndices getFullIndices() const {
+    ReassociationIndices result;
+    for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
+      result.push_back(idx);
+    }
+    return result;
+  }
+};
+} // namespace
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence that can be collapsed into a dynamic dimension (at least one must
+/// be present in the source).
+/// By default, lazily returns once the first dynamic dimension has been found.
+/// Setting `matchGreedily` as `true` will also mark all subsequent
+/// source dimensions for collapsing into the target.
+static FailureOr<ReassociationIndexRange>
+findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
+                                    int64_t sourceStartIdx,
+                                    bool matchGreedily = false) {
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
+
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+  for (; iterationRange.isInRange(sourceShapeAsRange);
+       iterationRange.rightIdx++) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (sourceSize == ShapedType::kDynamic) {
+      resultRange = iterationRange;
+      break;
+    }
+  }
+  if (!resultRange)
+    return failure();
+  if (matchGreedily)
+    resultRange->rightIdx = sourceShapeAsRange.rightIdx;
+  return *resultRange;
+}
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence of static dimensions such that their product matches `targetSize`.
+/// By default, lazily returns once the product matches the target size. Setting
+/// `matchGreedily` as `true` will append all neighboring unit dimensions
+/// (dimensions of 1) to the match.
+static FailureOr<ReassociationIndexRange>
+findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
+                              int64_t sourceStartIdx, int64_t targetSize,
+                              bool matchGreedily = false) {
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
+
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
   int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
+  while (iterationRange.isInRange(sourceShapeAsRange)) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (sourceSize == ShapedType::kDynamic) {
+      // Reassociation for a static dim cannot include a dynamic dim. Reset
+      // induction variables to essentially restart the loop from the next
+      // source dimension.
+      prodOfCollapsedDims = 1;
+      iterationRange = {iterationRange.rightIdx + 1,
+                        iterationRange.rightIdx + 1};
+      continue;
+    }
+    prodOfCollapsedDims *= sourceSize;
+    // If the target size has been exceeded without matching, we need to shift
+    // the range start right. From the start of the range, roll back the
+    // multiplication until the target size exceeds the product again.
+    while (prodOfCollapsedDims > targetSize &&
+           !iterationRange.containsSingleIndex()) {
+      int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
+      prodOfCollapsedDims /= frontSourceSize;
+      // Shrink the range rightwards
+      iterationRange.leftIdx++;
+    }
+    // We could've reached the target size with the current dimension,
+    // also as a result of the above shift to right.
+    if (prodOfCollapsedDims == targetSize) {
+      resultRange = iterationRange;
       break;
+    }
+    // Increment the iteration range
+    iterationRange.rightIdx++;
+  }
+  if (!resultRange)
+    return failure();
+  if (matchGreedily) {
+    // We now want to collect all unit dimensions directly after the target
+    // product match. Advance the iterator to avoid OOB when the product match
+    // happens at the last element.
+    iterationRange.rightIdx++;
+    while (iterationRange.isInRange(sourceShapeAsRange) &&
+           sourceShape[iterationRange.rightIdx] == 1) {
+      resultRange = iterationRange;
+      iterationRange.rightIdx++;
+    }
+  }
+  return *resultRange;
+}
+
+/// Attempts to find a valid collapsing reassociation of `sourceShape` into
+/// `targetShape` through a simple traversal. If successful, an array of source
+/// index ranges is returned, correspondingly to each dimension in the target
+/// shape. The resulting indices shall fully cover the `sourceShape` without
+/// overlaps.
+///
+/// The algorithm is essentially a lazy one, searching for non-greedy matches -
+/// it will only yield a greedy match for the last target dimension.
+/// FIXME: The algorithm can only backtrack when it needs to append an offset
+/// for a static target dimension to the preceding dynamic one (this retains the
+/// linear complexity). As feasible, consider adding further backtracking
+/// routines to enable more reassociations, e.g.:
+/// - ?x2x?x2 into ?x2
+static FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+                                   ArrayRef<int64_t> targetShape) {
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  assert(numSourceDims > numTargetDims);
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+
+  SmallVector<ReassociationIndexRange> reassocRanges;
+  reassocRanges.reserve(numTargetDims);
+  // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
+  // cases, e.g.:
+  // - ?x2x3x5 into ?x15
+  std::optional<int64_t> prevTargetSize = std::nullopt;
+  for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
+       targetDimIdx < numTargetDims; ++targetDimIdx) {
+    int64_t targetSize = targetShape[targetDimIdx];
+    // Simply check if there are any subsequent target dimensions left - if not,
+    // the match must be made greedily.
+    bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
+    FailureOr<ReassociationIndexRange> sourceRange;
+    if (targetSize == ShapedType::kDynamic) {
+      sourceRange = findReassociationRangeForDynamicDim(
+          sourceShape, sourceDimIdx, shouldMatchGreedily);
+    } else {
+      sourceRange = findReassociationRangeForSize(
+          sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
+    }
 
-    int64_t currTargetShape = targetShape[targetDim];
-    while (sourceDim < (sourceShape.size() - 1) &&
-           sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
-      prodOfCollapsedDims *= sourceShape[sourceDim];
-      currIndices.push_back(sourceDim++);
+    // Run sanity checks on the returned index range.
+    if (failed(sourceRange) || failed(sourceRange->verify()) ||
+        !sourceRange->isInRange(sourceShapeAsRange))
+      return failure();
+    if (sourceRange->leftIdx > sourceDimIdx) {
+      // If some source dimensions had to be skipped in order to find a match,
+      // they must be collapsed into the directly preceding dynamic dimension.
+      if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
+        return failure();
+      reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
     }
 
-    // If the current expanded dimension is dynamic, then the collapsed
-    // dimensions should also be dynamic and product of all previous unprocessed
-    // dimensions of the expanded shape should be 1.
-    if (sourceShape[sourceDim] == ShapedType::kDynamic &&
-        (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
-      return std::nullopt;
-
-    // If the collapsed dim is dynamic, the current expanded dim should also
-    // be dynamic.
-    if (currTargetShape == ShapedType::kDynamic &&
-        sourceShape[sourceDim] != ShapedType::kDynamic)
-      return std::nullopt;
-
-    // For static shapes, if the product of dimensions of the expanded shape
-    // should match the collapsed dimension shape.
-    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
-      return std::nullopt;
-
-    currIndices.push_back(sourceDim++);
-    reassociationMap.emplace_back(ReassociationIndices{});
-    std::swap(reassociationMap.back(), currIndices);
-    prodOfCollapsedDims = 1;
+    // Store the gathered information as required for the next iteration.
+    prevTargetSize = targetSize;
+    sourceDimIdx = sourceRange->rightIdx + 1;
+    reassocRanges.push_back(*sourceRange);
+  }
+  // Fail if the source shape wasn't a full match for the target shape. We only
+  // need to check the last recorded index - any other gaps should have been
+  // mended by the main loop.
+  if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
+    return failure();
+  return reassocRanges;
+}
+
+/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
+/// the shapes right-to-left.
+static FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+                                   ArrayRef<int64_t> targetShape,
+                                   bool iterateRightToLeft) {
+  if (!iterateRightToLeft)
+    return findReassociationRangesForCollapse(sourceShape, targetShape);
+  // NB: To iterate right-to-left, we currently reverse the shapes and then
+  // reverse the result back. The reversed shapes must not be temporary, as
+  // we're passing through an ArrayRef.
+  // FIXME: It would be preferable to avoid the expensive copies. At the moment,
+  // this approach is chosen for readability of the main implementation.
+  std::vector<int64_t> sourceToReverse = sourceShape.vec(),
+                       targetToReverse = targetShape.vec();
+  std::reverse(sourceToReverse.begin(), sourceToReverse.end());
+  std::reverse(targetToReverse.begin(), targetToReverse.end());
+  auto invertedRanges =
+      findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
+  if (failed(invertedRanges))
+    return failure();
+  SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
+  unsigned numSourceDims = sourceShape.size();
+  // We have received the ranges for inverted shapes. Now we have to invert
+  // the ranges back to correspond with the original source shape.
+  for (auto &range : rangesToInvert) {
+    int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
+    range.leftIdx = numSourceDims - 1 - invRightIdx;
+    range.rightIdx = numSourceDims - 1 - invLeftIdx;
   }
-  // All the dimensions in the target must have been processed.
-  if (reassociationMap.size() != targetShape.size())
+  // Also invert the ordering of the ranges to correspond with the original
+  // target shape.
+  std::reverse(rangesToInvert.begin(), rangesToInvert.end());
+  return rangesToInvert;
+}
+
+std::optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+                                         ArrayRef<int64_t> targetShape) {
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  // We're supposed to search for a collapsing reassociation. If the sizes
+  // match, there's no actual collapsing taking place - it's either a no-op or a
+  // `tensor.reshape`-style reassociation (that would be beyond the scope of
+  // this utility).
+  if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  // Process any remaining entries in the source shape. They all need to be
-  // 1 or dynamic.
-  for (; sourceDim < sourceShape.size(); sourceDim++) {
-    if (sourceShape[sourceDim] != ShapedType::kDynamic &&
-        sourceShape[sourceDim] != 1)
-      return std::nullopt;
-    // The map is empty when the target type is a scalar.
-    if (!reassociationMap.empty())
-      reassociationMap.back().push_back(sourceDim);
+  // Early handling for scalar target types.
+  if (numTargetDims == 0) {
+    ReassociationIndices allSourceIndices;
+    allSourceIndices.reserve(numSourceDims);
+    for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
+         ++sourceDimIdx) {
+      int64_t sourceSize = sourceShape[sourceDimIdx];
+      // All source dimensions must be unit or dynamic.
+      if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
+        return std::nullopt;
+      allSourceIndices.push_back(sourceDimIdx);
+    }
+    return SmallVector<ReassociationIndices>{allSourceIndices};
+  }
+
+  // Collect source ranges by iterating over the target shape left-to-right.
+  FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
+      findReassociationRangesForCollapse(sourceShape, targetShape);
+  if (failed(maybeForwardRanges))
+    return std::nullopt;
+  auto &ranges = *maybeForwardRanges;
+  // Now do the same in reverse. We need to get another valid reassociation
+  // through some other strategy, and then compare the results in order to
+  // disambiguate mixed subshapes, such as:
+  // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
+  // This leads us to lose some of the reassociation opportunities that can only
+  // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
+  // backtracking, the algorithm will fail right-to-left. However, this is the
+  // best way to preserve correctness.
+  FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
+      findReassociationRangesForCollapse(sourceShape, targetShape,
+                                         /*iterateRightToLeft=*/true);
+  if (failed(maybeReverseRanges))
+    return std::nullopt;
+  auto &reverseRanges = *maybeReverseRanges;
+
+  if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
+    return std::nullopt;
+  // Now we can check for ambiguity of each target dimension's reassociation. If
+  // successful, we put the full indices into our result map for the target
+  // shape.
+  SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
+  for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
+       ++targetDimIdx) {
+    ReassociationIndexRange &range = ranges[targetDimIdx];
+    ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
+    // Get non-overlapping indices between the ranges
+    ReassociationIndices nonMatchingIndices =
+        range.getNonOverlappingIndicesWith(reverseRange);
+    // Unit dimensions can be collapsed wherever - this is the only ambiguity
+    // that we allow.
+    for (int64_t sourceDimIdx : nonMatchingIndices) {
+      if (sourceShape[sourceDimIdx] != 1)
+        return std::nullopt;
+    }
+    reassociationMap[targetDimIdx] = range.getFullIndices();
   }
   return reassociationMap;
 }
@@ -315,11 +581,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 

diff  --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
 // -----
 
 // CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT:     tensor.collapse
-// CHECK:         linalg.unpack
+// CHECK:     tensor.collapse
+// CHECK-NOT:         linalg.unpack
 func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
   %c32 = arith.constant 32 : index
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index cdcd7f305d2d9..3eaf824b99115 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1094,7 +1094,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
 
 // -----
 
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
     -> tensor<?x4x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1102,12 +1102,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   return %1 : tensor<?x4x?xf32>
 }
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
 //   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+      : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+//   CHECK-NOT:   tensor.expand_shape
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+//  CHECK-SAME:     : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+//  CHECK-NEXT:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
     -> tensor<?x?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1115,7 +1131,22 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
       : tensor<?x?xf32> into tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
+//       CHECK:   tensor.collapse_shape
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
+//       CHECK:   return %[[EXPAND]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
+      : tensor<?x?x?xf32> into tensor<?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2]
+      : tensor<?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic
 //       CHECK:   tensor.collapse_shape
 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
 //       CHECK:   return %[[EXPAND]]

diff  --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt
index 61b9cdcb3b8f3..e921c8bcfb4e5 100644
--- a/mlir/unittests/Dialect/Utils/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRDialectUtilsTests
   StructuredOpsUtilsTest.cpp
+  ReshapeOpsUtilsTest.cpp
   IndexingUtilsTest.cpp
 )
 mlir_target_link_libraries(MLIRDialectUtilsTests

diff  --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
new file mode 100644
index 0000000000000..db1a87a4de2d5
--- /dev/null
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -0,0 +1,203 @@
+//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "llvm/ADT/STLExtras.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace mlir;
+
+/// Helper to make constructing
+/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
+static std::optional<SmallVector<ReassociationIndices>>
+makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
+  return std::optional<SmallVector<ReassociationIndices>>(list);
+}
+
+TEST(ReassociationIndicesForCollapse, ScalarTest) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
+            makeOptionalIndices({{0}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
+            makeOptionalIndices({{0}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
+                                                ShapedType::kDynamic, 1,
+                                                ShapedType::kDynamic},
+                                               {}),
+            makeOptionalIndices({{0, 1, 2, 3, 4}}));
+}
+
+TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
+  EXPECT_EQ(
+      getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
+      std::nullopt);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTest) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
+            makeOptionalIndices({{0}, {1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
+            makeOptionalIndices({{0, 1}, {2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
+  // No-op reassociation
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
+            std::nullopt);
+  // Invalid static reassociations
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
+            std::nullopt);
+  // Non-collapsing (expanding) reassociation
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
+            std::nullopt);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}),
+            makeOptionalIndices({{0}, {1}, {2, 3}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTest) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}, {2, 3, 4}}));
+  EXPECT_EQ(
+      getReassociationIndicesForCollapse(
+          {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
+      makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, ShapedType::kDynamic},
+                {1, ShapedType::kDynamic}),
+            makeOptionalIndices({{0}, {1, 2}}));
+
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, ShapedType::kDynamic},
+                {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10},
+                {ShapedType::kDynamic, 10}),
+            makeOptionalIndices({{0, 1, 2, 3}, {4}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+                                               {ShapedType::kDynamic, 20}),
+            makeOptionalIndices({{0, 1}, {2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
+                                               {ShapedType::kDynamic, 20}),
+            makeOptionalIndices({{0, 1}, {2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}),
+            makeOptionalIndices({{0, 1}, {2, 3, 4}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
+                {ShapedType::kDynamic, 20, ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, ShapedType::kDynamic, 1},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0}, {1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}, {2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 1, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0}, {1, 2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+                                               {ShapedType::kDynamic, 10}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}),
+            std::nullopt);
+  EXPECT_EQ(
+      getReassociationIndicesForCollapse(
+          {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
+          {ShapedType::kDynamic, ShapedType::kDynamic}),
+      std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
+                 ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 10, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 12, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 32, ShapedType::kDynamic}),
+            std::nullopt);
+
+  //===----------------------------------------------------------------------===//
+  // TODO: Reassociation for the following examples can be computed, but isn't
+  // supported by `getReassociationIndicesForCollapse`.
+  //===----------------------------------------------------------------------===//
+
+  // TODO: Fails because there's no backtracking when some source dimensions
+  // remain unmatched at either edge.
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10},
+                {ShapedType::kDynamic, 10}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2},
+                                               {1, ShapedType::kDynamic, 2}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1},
+                                               {2, ShapedType::kDynamic}),
+            std::nullopt);
+}


        


More information about the Mlir-commits mailing list