[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)

Artem Gindinson llvmlistbot at llvm.org
Tue May 20 15:46:11 PDT 2025


https://github.com/AGindinson updated https://github.com/llvm/llvm-project/pull/137963

>From 2ae44808cd573c31331d143d687afcdab6986a72 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Mon, 28 Apr 2025 17:24:11 +0000
Subject: [PATCH 1/8] [mlir][tensor] Loosen restrictions on folding dynamic
 reshapes

The main idea behind the change is to allow expand-of-collapse folds
for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the
expand op must have a coherent index/affine expression specified in its
`output_shape` argument (see example below), and if it doesn't, the IR
has already been invalidated at an earlier stage:
```
%c32 = arith.constant 32 : index
%div = arith.divsi %<some_index>, %c32 : index
%collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]]
	         : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32>
%affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div]
%expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine]
		: tensor<9x?x?xf32> into tensor<9x?x32x?xf32>
```

On the above assumption, adjust the routine in
`getReassociationIndicesForCollapse()` to allow dynamic reshapes
beyond just `?x..?x1x1x..x1` -> `?`.

Moreover, the reassociation util was refactored to clearly distinguish
between dynamic and static subshapes. A few known caveats were noted as
a comment; it doesn't seem possible to fold all qualifying dynamic shape
patterns in a deterministic way without looking into affine expressions
simultaneously. That would be difficult to maintain in a single general
utility. Other implementation ideas/larger refactoring could include:
- abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern,
  employing similar logic to `ComposeCollapseOfExpandOp`;
- providing dialect-specific implementations for Linalg/Tensor.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    | 103 ++++++++++--------
 .../Dialect/Linalg/simplify-pack-unpack.mlir  |   4 +-
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  24 +++-
 3 files changed, 79 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index ed40a080441bc..694783849198a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                          ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
-
-  ReassociationIndices currIndices;
-  int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
-      break;
+  SmallVector<ReassociationIndices, 4> reassociationMap;
+  reassociationMap.reserve(numTargetDims);
 
+  unsigned sourceDim = 0, targetDim = 0;
+  for (; targetDim < numTargetDims; ++targetDim) {
     int64_t currTargetShape = targetShape[targetDim];
-    while (sourceDim < (sourceShape.size() - 1) &&
-           sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+    ReassociationIndices currIndices;
+    // 1. Target dimension is dynamic. Source shape should contain at least
+    // one dynamic dimension.
+    if (currTargetShape == ShapedType::kDynamic) {
+      // FIXME: We stop the search with the first dynamic dimension, while in
+      // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
+      // indeterministic altogether when we have neighboring dynamic dimensions
+      // in the target shape. Most of these patterns will be safely rejected,
+      // however we might achieve more correct folds by taking affine
+      // expressions into account, if these can be passed on by the call sites.
+      bool foundDynamic = false;
+      while (sourceDim < numSourceDims) {
+        currIndices.push_back(sourceDim);
+        if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
+          foundDynamic = true;
+          break;
+        }
+      }
+      if (!foundDynamic)
+        return std::nullopt;
+
+      reassociationMap.push_back(currIndices);
+      continue;
+    }
+    // 2. Target dimension is static. The product of dimensions of the expanded
+    // shape should match the collapsed dimension shape.
+    int64_t prodOfCollapsedDims = 1;
+    bool reachedTargetDimSize = false;
+    while (sourceDim < numSourceDims) {
+      // Source shape cannot be dynamic if the target dim is static.
+      if (sourceShape[sourceDim] == ShapedType::kDynamic)
+        return std::nullopt;
       prodOfCollapsedDims *= sourceShape[sourceDim];
-      currIndices.push_back(sourceDim++);
+      if (prodOfCollapsedDims > currTargetShape)
+        break;
+      else if (prodOfCollapsedDims == currTargetShape) {
+        currIndices.push_back(sourceDim++);
+        reachedTargetDimSize = true;
+        break;
+      } else // prodOfCollapsedDims < currTargetShape
+        currIndices.push_back(sourceDim++);
     }
-
-    // If the current expanded dimension is dynamic, then the collapsed
-    // dimensions should also be dynamic and product of all previous unprocessed
-    // dimensions of the expanded shape should be 1.
-    if (sourceShape[sourceDim] == ShapedType::kDynamic &&
-        (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+    if (!reachedTargetDimSize)
       return std::nullopt;
-
-    // If the collapsed dim is dynamic, the current expanded dim should also
-    // be dynamic.
-    if (currTargetShape == ShapedType::kDynamic &&
-        sourceShape[sourceDim] != ShapedType::kDynamic)
-      return std::nullopt;
-
-    // For static shapes, if the product of dimensions of the expanded shape
-    // should match the collapsed dimension shape.
-    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
-      return std::nullopt;
-
-    currIndices.push_back(sourceDim++);
-    reassociationMap.emplace_back(ReassociationIndices{});
-    std::swap(reassociationMap.back(), currIndices);
-    prodOfCollapsedDims = 1;
+    reassociationMap.push_back(currIndices);
   }
-  // All the dimensions in the target must have been processed.
-  if (reassociationMap.size() != targetShape.size())
-    return std::nullopt;
-  // Process any remaining entries in the source shape. They all need to be
-  // 1 or dynamic.
-  for (; sourceDim < sourceShape.size(); sourceDim++) {
-    if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+  // Now that we've mapped all the target dimensions, process any remaining
+  // entries in the source shape explicitly. Either the last target dimension
+  // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
+  // applies when target shape is empty (can be the case for subshape
+  // reassociations).
+  for (; sourceDim < numSourceDims; sourceDim++) {
+    if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
+        sourceShape[sourceDim] != ShapedType::kDynamic &&
         sourceShape[sourceDim] != 1)
       return std::nullopt;
     // The map is empty when the target type is a scalar.
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
 // -----
 
 // CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT:     tensor.collapse
-// CHECK:         linalg.unpack
+// CHECK:     tensor.collapse
+// CHECK-NOT:         linalg.unpack
 func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
   %c32 = arith.constant 32 : index
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..443f931745557 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1068,7 +1068,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
 
 // -----
 
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
     -> tensor<?x4x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1076,12 +1076,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   return %1 : tensor<?x4x?xf32>
 }
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
 //   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+      : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+//   CHECK-NOT:   tensor.expand_shape
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+//  CHECK-SAME:     : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+//  CHECK-NEXT:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
     -> tensor<?x?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1089,7 +1105,7 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
       : tensor<?x?xf32> into tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
 //       CHECK:   tensor.collapse_shape
 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
 //       CHECK:   return %[[EXPAND]]

>From 52ff4e0a1e81e13282975076d730e741a1da1cae Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:12:21 +0000
Subject: [PATCH 2/8] [fixup] Algorithm rewrite

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 140 ++++++++++++++-------
 1 file changed, 93 insertions(+), 47 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 694783849198a..1cd06a2757363 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -39,67 +39,113 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
   reassociationMap.reserve(numTargetDims);
 
   unsigned sourceDim = 0, targetDim = 0;
-  for (; targetDim < numTargetDims; ++targetDim) {
-    int64_t currTargetShape = targetShape[targetDim];
-    ReassociationIndices currIndices;
-    // 1. Target dimension is dynamic. Source shape should contain at least
-    // one dynamic dimension.
-    if (currTargetShape == ShapedType::kDynamic) {
-      // FIXME: We stop the search with the first dynamic dimension, while in
-      // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
-      // indeterministic altogether when we have neighboring dynamic dimensions
-      // in the target shape. Most of these patterns will be safely rejected,
-      // however we might achieve more correct folds by taking affine
-      // expressions into account, if these can be passed on by the call sites.
-      bool foundDynamic = false;
-      while (sourceDim < numSourceDims) {
-        currIndices.push_back(sourceDim);
-        if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
-          foundDynamic = true;
-          break;
-        }
-      }
-      if (!foundDynamic)
-        return std::nullopt;
-
-      reassociationMap.push_back(currIndices);
-      continue;
-    }
-    // 2. Target dimension is static. The product of dimensions of the expanded
-    // shape should match the collapsed dimension shape.
+  // Source dimensions iteration logic for static target dimensions.
+  // FIXME: Instead of lambda-capturing this function's source shape index "in
+  // place", consider refactoring this into a separate function.
+  auto collectSourceIndicesForStaticTargetDim =
+      [&](int64_t targetShape,
+          bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
+    ReassociationIndices resultIndices;
     int64_t prodOfCollapsedDims = 1;
     bool reachedTargetDimSize = false;
-    while (sourceDim < numSourceDims) {
+    for (; sourceDim < numSourceDims; ++sourceDim) {
       // Source shape cannot be dynamic if the target dim is static.
       if (sourceShape[sourceDim] == ShapedType::kDynamic)
-        return std::nullopt;
+        return failure();
       prodOfCollapsedDims *= sourceShape[sourceDim];
-      if (prodOfCollapsedDims > currTargetShape)
-        break;
-      else if (prodOfCollapsedDims == currTargetShape) {
-        currIndices.push_back(sourceDim++);
+      resultIndices.push_back(sourceDim);
+      if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
+        return failure();
+      while (prodOfCollapsedDims > targetShape) {
+        assert(!resultIndices.empty());
+        auto frontOffsetIdx = resultIndices.begin();
+        prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
+        resultIndices.erase(frontOffsetIdx);
+      }
+      if (prodOfCollapsedDims == targetShape) {
         reachedTargetDimSize = true;
+        ++sourceDim;
         break;
-      } else // prodOfCollapsedDims < currTargetShape
-        currIndices.push_back(sourceDim++);
+      }
     }
     if (!reachedTargetDimSize)
+      return failure();
+    return resultIndices;
+  };
+  // Source dimensions iteration logic for dynamic target dimensions.
+  // FIXME: Instead of lambda-capturing this function's source shape index "in
+  // place", consider refactoring this into a separate function.
+  auto collectSourceIndicesForDynamicTargetDim =
+      [&](bool allowStaticNonOnes,
+          bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
+    ReassociationIndices resultIndices;
+    bool foundFirstDynamic = false;
+    while (sourceDim < numSourceDims) {
+      if (sourceShape[sourceDim] == ShapedType::kDynamic) {
+        if (foundFirstDynamic && !mapConsecutiveDynDims)
+          break;
+        foundFirstDynamic |= true;
+      } else {
+        if (foundFirstDynamic)
+          break;
+        else if (sourceShape[sourceDim] > 1 && !allowStaticNonOnes)
+          return failure();
+      }
+      resultIndices.push_back(sourceDim++);
+    }
+    if (!foundFirstDynamic)
+      return failure();
+    return resultIndices;
+  };
+  // Iterate over target shape.
+  bool wasLastDimDynamic = false;
+  for (; targetDim < numTargetDims; ++targetDim) {
+    int64_t currTargetShape = targetShape[targetDim];
+    if (currTargetShape != ShapedType::kDynamic) {
+      unsigned sourceDimAtStart = sourceDim;
+      auto indices = collectSourceIndicesForStaticTargetDim(
+          currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
+      if (failed(indices))
+        return std::nullopt;
+      if (wasLastDimDynamic) {
+        assert(!reassociationMap.empty());
+        auto &previousIndices = reassociationMap.back();
+        for (; sourceDimAtStart < indices->front(); ++sourceDimAtStart)
+          previousIndices.push_back(sourceDimAtStart);
+      }
+      reassociationMap.push_back(*indices);
+      wasLastDimDynamic = false;
+      continue;
+    }
+
+    bool isNextDimDynamic = targetDim + 1 < numTargetDims &&
+                            targetShape[targetDim + 1] == ShapedType::kDynamic;
+    auto indices = collectSourceIndicesForDynamicTargetDim(
+        /*allowStaticNonOnes=*/!wasLastDimDynamic,
+        /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic);
+    if (failed(indices))
       return std::nullopt;
-    reassociationMap.push_back(currIndices);
+    reassociationMap.push_back(*indices);
+    wasLastDimDynamic = true;
   }
   // Now that we've mapped all the target dimensions, process any remaining
-  // entries in the source shape explicitly. Either the last target dimension
-  // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
-  // applies when target shape is empty (can be the case for subshape
-  // reassociations).
+  // entries in the source shape explicitly.
   for (; sourceDim < numSourceDims; sourceDim++) {
-    if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
-        sourceShape[sourceDim] != ShapedType::kDynamic &&
-        sourceShape[sourceDim] != 1)
+    const bool isOne = sourceShape[sourceDim] == 1,
+               isDynamic = sourceShape[sourceDim] == ShapedType::kDynamic;
+    if (targetShape.empty()) {
+      if (!isOne && !isDynamic)
+        return std::nullopt;
+      continue;
+    }
+    if (wasLastDimDynamic && isDynamic)
+      return std::nullopt;
+    // If the last target dimension is static, only source dimensions of 1 are
+    // acceptable.
+    if (!wasLastDimDynamic && !isOne)
       return std::nullopt;
-    // The map is empty when the target type is a scalar.
-    if (!reassociationMap.empty())
-      reassociationMap.back().push_back(sourceDim);
+    assert(!reassociationMap.empty());
+    reassociationMap.back().push_back(sourceDim);
   }
   return reassociationMap;
 }

>From 1c85a6875fecf52e6714b81bbe9bd2da81178c9e Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:12:34 +0000
Subject: [PATCH 3/8] [fixup] Add/expand unit tests

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
Co-authored-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  15 +++
 mlir/unittests/Dialect/Utils/CMakeLists.txt   |   1 +
 .../Dialect/Utils/ReshapeOpsUtilsTest.cpp     | 125 ++++++++++++++++++
 3 files changed, 141 insertions(+)
 create mode 100644 mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp

diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 443f931745557..035ea850c9102 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1112,6 +1112,21 @@ func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %
 
 // -----
 
+func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
+      : tensor<?x?x?xf32> into tensor<?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2]
+      : tensor<?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic
+//       CHECK:   tensor.collapse_shape
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
+//       CHECK:   return %[[EXPAND]]
+
+// -----
+
 func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -> tensor<?x384xf32> {
   %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
   %c0 = arith.constant 0 : index
diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt
index 61b9cdcb3b8f3..e921c8bcfb4e5 100644
--- a/mlir/unittests/Dialect/Utils/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_unittest(MLIRDialectUtilsTests
   StructuredOpsUtilsTest.cpp
+  ReshapeOpsUtilsTest.cpp
   IndexingUtilsTest.cpp
 )
 mlir_target_link_libraries(MLIRDialectUtilsTests
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
new file mode 100644
index 0000000000000..bfcc70150e2ed
--- /dev/null
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -0,0 +1,125 @@
+//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "gtest/gtest.h"
+#include <optional>
+
+using namespace mlir;
+
+/// Helper to make constructing
+/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
+static std::optional<SmallVector<ReassociationIndices>>
+makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
+  return std::optional<SmallVector<ReassociationIndices>>(list);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTest) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
+            makeOptionalIndices({{0}, {1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
+            makeOptionalIndices({{0, 1}, {2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
+            std::nullopt);
+}
+
+TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1, 1}),
+            makeOptionalIndices({{0}, {1, 2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTest) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(
+      getReassociationIndicesForCollapse(
+          {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
+      makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, ShapedType::kDynamic},
+                {1, ShapedType::kDynamic}),
+            makeOptionalIndices({{0}, {1, 2}}));
+
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, ShapedType::kDynamic},
+                {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+                                               {ShapedType::kDynamic, 20}),
+            makeOptionalIndices({{0, 1}, {2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
+                                               {ShapedType::kDynamic, 20}),
+            makeOptionalIndices({{0, 1}, {2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}),
+            makeOptionalIndices({{0, 1}, {2, 3, 4}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
+                {ShapedType::kDynamic, 20, ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
+                                               {ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}, {2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 1, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0}, {1, 2}}));
+}
+
+TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
+                                               {ShapedType::kDynamic, 10}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}),
+            std::nullopt);
+  EXPECT_EQ(
+      getReassociationIndicesForCollapse(
+          {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
+          {ShapedType::kDynamic, ShapedType::kDynamic}),
+      std::nullopt);
+}

>From 0fe986e217e52cc519823a999e83155f1cd2d3ef Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:15:08 +0000
Subject: [PATCH 4/8] [fixup] variable renaming

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 39 +++++++++++-----------
 1 file changed, 20 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 1cd06a2757363..a6ee21d941e17 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -38,7 +38,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
   SmallVector<ReassociationIndices, 4> reassociationMap;
   reassociationMap.reserve(numTargetDims);
 
-  unsigned sourceDim = 0, targetDim = 0;
+  unsigned sourceDimIdx = 0, targetDimIdx = 0;
   // Source dimensions iteration logic for static target dimensions.
   // FIXME: Instead of lambda-capturing this function's source shape index "in
   // place", consider refactoring this into a separate function.
@@ -48,12 +48,12 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
     ReassociationIndices resultIndices;
     int64_t prodOfCollapsedDims = 1;
     bool reachedTargetDimSize = false;
-    for (; sourceDim < numSourceDims; ++sourceDim) {
+    for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
       // Source shape cannot be dynamic if the target dim is static.
-      if (sourceShape[sourceDim] == ShapedType::kDynamic)
+      if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
         return failure();
-      prodOfCollapsedDims *= sourceShape[sourceDim];
-      resultIndices.push_back(sourceDim);
+      prodOfCollapsedDims *= sourceShape[sourceDimIdx];
+      resultIndices.push_back(sourceDimIdx);
       if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
         return failure();
       while (prodOfCollapsedDims > targetShape) {
@@ -64,7 +64,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
       }
       if (prodOfCollapsedDims == targetShape) {
         reachedTargetDimSize = true;
-        ++sourceDim;
+        ++sourceDimIdx;
         break;
       }
     }
@@ -80,18 +80,18 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
           bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
     ReassociationIndices resultIndices;
     bool foundFirstDynamic = false;
-    while (sourceDim < numSourceDims) {
-      if (sourceShape[sourceDim] == ShapedType::kDynamic) {
+    while (sourceDimIdx < numSourceDims) {
+      if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) {
         if (foundFirstDynamic && !mapConsecutiveDynDims)
           break;
         foundFirstDynamic |= true;
       } else {
         if (foundFirstDynamic)
           break;
-        else if (sourceShape[sourceDim] > 1 && !allowStaticNonOnes)
+        else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes)
           return failure();
       }
-      resultIndices.push_back(sourceDim++);
+      resultIndices.push_back(sourceDimIdx++);
     }
     if (!foundFirstDynamic)
       return failure();
@@ -99,10 +99,10 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
   };
   // Iterate over target shape.
   bool wasLastDimDynamic = false;
-  for (; targetDim < numTargetDims; ++targetDim) {
-    int64_t currTargetShape = targetShape[targetDim];
+  for (; targetDimIdx < numTargetDims; ++targetDimIdx) {
+    int64_t currTargetShape = targetShape[targetDimIdx];
     if (currTargetShape != ShapedType::kDynamic) {
-      unsigned sourceDimAtStart = sourceDim;
+      unsigned sourceDimAtStart = sourceDimIdx;
       auto indices = collectSourceIndicesForStaticTargetDim(
           currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
       if (failed(indices))
@@ -118,8 +118,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
       continue;
     }
 
-    bool isNextDimDynamic = targetDim + 1 < numTargetDims &&
-                            targetShape[targetDim + 1] == ShapedType::kDynamic;
+    bool isNextDimDynamic =
+        targetDimIdx + 1 < numTargetDims &&
+        targetShape[targetDimIdx + 1] == ShapedType::kDynamic;
     auto indices = collectSourceIndicesForDynamicTargetDim(
         /*allowStaticNonOnes=*/!wasLastDimDynamic,
         /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic);
@@ -130,9 +131,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
   }
   // Now that we've mapped all the target dimensions, process any remaining
   // entries in the source shape explicitly.
-  for (; sourceDim < numSourceDims; sourceDim++) {
-    const bool isOne = sourceShape[sourceDim] == 1,
-               isDynamic = sourceShape[sourceDim] == ShapedType::kDynamic;
+  for (; sourceDimIdx < numSourceDims; sourceDimIdx++) {
+    const bool isOne = sourceShape[sourceDimIdx] == 1,
+               isDynamic = sourceShape[sourceDimIdx] == ShapedType::kDynamic;
     if (targetShape.empty()) {
       if (!isOne && !isDynamic)
         return std::nullopt;
@@ -145,7 +146,7 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
     if (!wasLastDimDynamic && !isOne)
       return std::nullopt;
     assert(!reassociationMap.empty());
-    reassociationMap.back().push_back(sourceDim);
+    reassociationMap.back().push_back(sourceDimIdx);
   }
   return reassociationMap;
 }

>From e3aa2394225bb23cde429cc06bea28df7257070a Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 9 May 2025 15:51:59 +0000
Subject: [PATCH 5/8] [fixup] Additional edge-case

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp           | 10 ++++++++--
 mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp |  9 +++++++++
 2 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index a6ee21d941e17..8c19f20f446da 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -139,8 +139,14 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
         return std::nullopt;
       continue;
     }
-    if (wasLastDimDynamic && isDynamic)
-      return std::nullopt;
+    // If the last 2 dimensions in the target were dynamic, the tail in the
+    // source shape cannot contain a dynamic value. E.g. ?x?->? is valid,
+    // however ?x?x10x?->?x? would be indeterminate.
+    if (wasLastDimDynamic && numTargetDims > 1 &&
+        targetShape[numTargetDims - 2] == ShapedType::kDynamic) {
+      if (isDynamic)
+        return std::nullopt;
+    }
     // If the last target dimension is static, only source dimensions of 1 are
     // acceptable.
     if (!wasLastDimDynamic && !isOne)
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index bfcc70150e2ed..2564866fac493 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -92,6 +92,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) {
   EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
                                                {ShapedType::kDynamic}),
             makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, ShapedType::kDynamic, 1},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0}, {1, 2}}));
   EXPECT_EQ(getReassociationIndicesForCollapse(
                 {1, ShapedType::kDynamic, ShapedType::kDynamic},
                 {ShapedType::kDynamic, ShapedType::kDynamic}),
@@ -122,4 +126,9 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
           {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
           {ShapedType::kDynamic, ShapedType::kDynamic}),
       std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
+                 ShapedType::kDynamic},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            std::nullopt);
 }

>From 16a932c8fa45f00e6474dd18bd8b7781a4b2fac8 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Tue, 20 May 2025 20:00:24 +0000
Subject: [PATCH 6/8] [WIP] Current tests pass

---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    | 430 +++++++++++++-----
 .../Dialect/Utils/ReshapeOpsUtilsTest.cpp     |  24 +
 2 files changed, 337 insertions(+), 117 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 8c19f20f446da..25dd434fc2122 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -10,6 +10,10 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 
 #include <numeric>
 #include <optional>
@@ -28,6 +32,257 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
   return std::nullopt;
 }
 
+namespace {
+/// A simple struct to represent ReassociationIndices as an inclusive interval.
+/// It's designed to be feasibly minimal, so the call sites should manage the
+/// validity of the range manually.
+struct ReassociationIndexRange {
+  /// FIXME: Signed type is used for consistency with ReassociationIndices.
+  /// We should consider refactoring all reassociation utilities to use unsigned
+  /// types.
+  int64_t leftIdx = 0, rightIdx = 0;
+
+  /// Util for manual checks of the range's validity
+  LogicalResult verify() const {
+    return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
+  }
+
+  /// Checks range's containment within another range. Treats the edges
+  /// non-exclusively.
+  bool isInRange(const ReassociationIndexRange &outerRange) const {
+    return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
+  }
+
+  unsigned size() const {
+    assert(succeeded(verify()));
+    return rightIdx - leftIdx + 1;
+  }
+  bool containsSingleIndex() const { return size() == 1; }
+
+  void expandRight() { ++rightIdx; }
+  void shrinkLeft() { ++leftIdx; }
+
+  /// Implements arithmetic XOR semantics to get non-overlapping indices between
+  /// ranges.
+  ReassociationIndices operator^(ReassociationIndexRange &rhs) const {
+    ReassociationIndices result;
+    result.reserve(size() + rhs.size() / 2); // Attempt to amortize
+    for (int64_t idx = this->leftIdx; idx <= this->rightIdx; ++idx) {
+      if (idx < rhs.leftIdx || idx > rhs.rightIdx)
+        result.push_back(idx);
+    }
+    for (int64_t rhsIndex = rhs.leftIdx; rhsIndex <= rhs.rightIdx; ++rhsIndex) {
+      if (rhsIndex < leftIdx || rhsIndex > rightIdx)
+        result.push_back(rhsIndex);
+    }
+    return result;
+  }
+
+  /// Converts the range into ReassociationIndices.
+  ReassociationIndices getFullIndices() const {
+    ReassociationIndices result;
+    for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
+      result.push_back(idx);
+    }
+    return result;
+  }
+};
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence that can be collapsed into a dynamic dimension (at least one must
+/// be present in the source).
+/// By default, lazily returns once the first dynamic dimension has been found.
+/// Setting `matchGreedily` as `true` will also mark all subsequent
+/// source dimensions for collapsing into the target.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
+                                    int64_t sourceStartIdx,
+                                    bool matchGreedily = false) {
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  if (!iterationRange.isInRange(sourceShapeAsRange))
+    return failure();
+  auto resultRange = iterationRange;
+
+  bool foundDynamic = false;
+  for (; iterationRange.isInRange(sourceShapeAsRange);
+       iterationRange.expandRight()) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (foundDynamic && !matchGreedily)
+      break;
+    if (sourceSize == ShapedType::kDynamic)
+      foundDynamic = true;
+    resultRange = iterationRange;
+  }
+  if (!foundDynamic)
+    return failure();
+  return resultRange;
+}
+
+/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
+/// sequence of static dimensions such that their product matches `targetSize`.
+/// By default, lazily returns once the product matches the target size. Setting
+/// `matchGreedily` as `true` will append all neighboring unit dimensions
+/// (dimensions of 1) to the match.
+FailureOr<ReassociationIndexRange>
+findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
+                              int64_t sourceStartIdx, int64_t targetSize,
+                              bool matchGreedily = false) {
+  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+  const unsigned numSourceDims = sourceShape.size();
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+  if (!iterationRange.isInRange(sourceShapeAsRange))
+    return failure();
+  auto resultRange = iterationRange;
+
+  int64_t prodOfCollapsedDims = 1;
+  bool reachedTargetDimSize = false;
+  while (iterationRange.isInRange(sourceShapeAsRange)) {
+    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
+    if (reachedTargetDimSize && !matchGreedily)
+      break;
+    if (sourceSize == ShapedType::kDynamic) {
+      if (reachedTargetDimSize)
+        break;
+      // Reassociation for a static dim cannot include a dynamic dim. Reset
+      // induction variables to essentially restart the loop from the next
+      // source dimension.
+      prodOfCollapsedDims = 1;
+      resultRange = {iterationRange.rightIdx + 1, iterationRange.rightIdx + 1};
+      iterationRange = resultRange;
+      continue;
+    }
+    prodOfCollapsedDims *= sourceSize;
+    if (prodOfCollapsedDims > targetSize && reachedTargetDimSize)
+      break;
+    // If the target size has been exceeded without matching, we need to shift
+    // the range start right. From the start of the range, roll back the
+    // multiplication until the target size exceeds the product again.
+    while (prodOfCollapsedDims > targetSize &&
+           !iterationRange.containsSingleIndex()) {
+      int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
+      prodOfCollapsedDims /= frontSourceSize;
+      iterationRange.shrinkLeft();
+    }
+    resultRange = iterationRange;
+    // We could've reached the target size with the current dimension,
+    // also as a result of the above shift to right.
+    if (prodOfCollapsedDims == targetSize)
+      reachedTargetDimSize = true;
+    // Increment the iteration range
+    iterationRange.expandRight();
+  }
+  if (!reachedTargetDimSize)
+    return failure();
+  return resultRange;
+}
+
+/// Attempts to find a valid collapsing reassociation of `sourceShape` into
+/// `targetShape` through a simple traversal. If successful, an array of source
+/// index ranges is returned, correspondingly to each dimension in the target
+/// shape. The resulting indices shall fully cover the `sourceShape` without
+/// overlaps.
+///
+/// The algorithm is essentially a lazy one, searching for non-greedy matches -
+/// it will only yield a greedy match for the last target dimension.
+/// FIXME: The algorithm can only backtrack when it needs to append an offset
+/// for a static target dimension to the preceding dynamic one (this retains the
+/// linear complexity). As feasible, consider adding further backtracking
+/// routines to enable more reassociations, e.g.:
+/// - ?x2x?x2 into ?x2
+FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+                                   ArrayRef<int64_t> targetShape) {
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  assert(numSourceDims > numTargetDims);
+  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
+
+  SmallVector<ReassociationIndexRange> reassocRanges;
+  reassocRanges.reserve(numTargetDims);
+  // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
+  // cases, e.g.:
+  // - ?x2x3x5 into ?x15
+  std::optional<int64_t> prevTargetSize = std::nullopt;
+  for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
+       targetDimIdx < numTargetDims; ++targetDimIdx) {
+    int64_t targetSize = targetShape[targetDimIdx];
+    std::optional<int64_t> nextTargetSize = std::nullopt;
+
+    // Simply check if there are any subsequent target dimensions left - if not,
+    // the match must be made greedily.
+    bool isLastTargetDim = targetDimIdx == numTargetDims - 1;
+    bool shouldMatchGreedily = isLastTargetDim;
+    FailureOr<ReassociationIndexRange> sourceRange;
+    if (targetSize == ShapedType::kDynamic) {
+      sourceRange = findReassociationRangeForDynamicDim(
+          sourceShape, sourceDimIdx, shouldMatchGreedily);
+    } else {
+      sourceRange = findReassociationRangeForSize(
+          sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
+    }
+
+    // Run sanity checks on the returned index range.
+    if (failed(sourceRange) || failed(sourceRange->verify()) ||
+        !sourceRange->isInRange(sourceShapeAsRange))
+      return failure();
+    if (sourceRange->leftIdx > sourceDimIdx) {
+      // If some source dimensions had to be skipped in order to find a match,
+      // they must be collapsed into the directly preceding dynamic dimension.
+      if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
+        return failure();
+      reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
+    }
+
+    // Store the gathered information as required for the next iteration.
+    prevTargetSize = targetSize;
+    sourceDimIdx = sourceRange->rightIdx + 1;
+    reassocRanges.emplace_back(std::move(*sourceRange));
+  }
+  // Fail if the source shape wasn't a full match for the target shape. We only
+  // need to check the last recorded index - any other gaps should have been
+  // mended by the main loop.
+  if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
+    return failure();
+  return reassocRanges;
+}
+
+/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
+/// the shapes right-to-left.
+FailureOr<SmallVector<ReassociationIndexRange>>
+findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
+                                   ArrayRef<int64_t> targetShape,
+                                   bool iterateRightToLeft) {
+  if (!iterateRightToLeft)
+    return findReassociationRangesForCollapse(sourceShape, targetShape);
+  // FIXME: It would be preferable to avoid the expensive copies. At the moment,
+  // this approach is chosen for readability of the main implementation.
+  auto sourceToReverse = sourceShape.vec(), targetToReverse = targetShape.vec();
+  std::reverse(sourceToReverse.begin(), sourceToReverse.end());
+  std::reverse(targetToReverse.begin(), targetToReverse.end());
+  auto invertedRanges =
+      findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
+  if (failed(invertedRanges))
+    return failure();
+  auto rangesToInvert = *invertedRanges;
+  unsigned numSourceDims = sourceShape.size();
+  // We have received the ranges for inverted shapes. Now we have to invert
+  // the ranges back to correspond with the original source shape.
+  for (auto &range : rangesToInvert) {
+    if (failed(range.verify()))
+      return failure();
+    int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
+    range.leftIdx = numSourceDims - 1 - invRightIdx;
+    range.rightIdx = numSourceDims - 1 - invLeftIdx;
+  }
+  // Also invert the ordering of the ranges to correspond with the original
+  // target shape.
+  std::reverse(rangesToInvert.begin(), rangesToInvert.end());
+  return rangesToInvert;
+}
+} // namespace
+
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                          ArrayRef<int64_t> targetShape) {
@@ -35,124 +290,65 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
            numTargetDims = targetShape.size();
   if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  SmallVector<ReassociationIndices, 4> reassociationMap;
-  reassociationMap.reserve(numTargetDims);
-
-  unsigned sourceDimIdx = 0, targetDimIdx = 0;
-  // Source dimensions iteration logic for static target dimensions.
-  // FIXME: Instead of lambda-capturing this function's source shape index "in
-  // place", consider refactoring this into a separate function.
-  auto collectSourceIndicesForStaticTargetDim =
-      [&](int64_t targetShape,
-          bool mayHaveOffset = false) -> FailureOr<ReassociationIndices> {
-    ReassociationIndices resultIndices;
-    int64_t prodOfCollapsedDims = 1;
-    bool reachedTargetDimSize = false;
-    for (; sourceDimIdx < numSourceDims; ++sourceDimIdx) {
-      // Source shape cannot be dynamic if the target dim is static.
-      if (sourceShape[sourceDimIdx] == ShapedType::kDynamic)
-        return failure();
-      prodOfCollapsedDims *= sourceShape[sourceDimIdx];
-      resultIndices.push_back(sourceDimIdx);
-      if (prodOfCollapsedDims > targetShape && !mayHaveOffset)
-        return failure();
-      while (prodOfCollapsedDims > targetShape) {
-        assert(!resultIndices.empty());
-        auto frontOffsetIdx = resultIndices.begin();
-        prodOfCollapsedDims /= sourceShape[*frontOffsetIdx];
-        resultIndices.erase(frontOffsetIdx);
-      }
-      if (prodOfCollapsedDims == targetShape) {
-        reachedTargetDimSize = true;
-        ++sourceDimIdx;
-        break;
-      }
-    }
-    if (!reachedTargetDimSize)
-      return failure();
-    return resultIndices;
-  };
-  // Source dimensions iteration logic for dynamic target dimensions.
-  // FIXME: Instead of lambda-capturing this function's source shape index "in
-  // place", consider refactoring this into a separate function.
-  auto collectSourceIndicesForDynamicTargetDim =
-      [&](bool allowStaticNonOnes,
-          bool mapConsecutiveDynDims) -> FailureOr<ReassociationIndices> {
-    ReassociationIndices resultIndices;
-    bool foundFirstDynamic = false;
-    while (sourceDimIdx < numSourceDims) {
-      if (sourceShape[sourceDimIdx] == ShapedType::kDynamic) {
-        if (foundFirstDynamic && !mapConsecutiveDynDims)
-          break;
-        foundFirstDynamic |= true;
-      } else {
-        if (foundFirstDynamic)
-          break;
-        else if (sourceShape[sourceDimIdx] > 1 && !allowStaticNonOnes)
-          return failure();
-      }
-      resultIndices.push_back(sourceDimIdx++);
-    }
-    if (!foundFirstDynamic)
-      return failure();
-    return resultIndices;
-  };
-  // Iterate over target shape.
-  bool wasLastDimDynamic = false;
-  for (; targetDimIdx < numTargetDims; ++targetDimIdx) {
-    int64_t currTargetShape = targetShape[targetDimIdx];
-    if (currTargetShape != ShapedType::kDynamic) {
-      unsigned sourceDimAtStart = sourceDimIdx;
-      auto indices = collectSourceIndicesForStaticTargetDim(
-          currTargetShape, /*mayHaveOffset=*/wasLastDimDynamic);
-      if (failed(indices))
+  // Early handling for scalar target types.
+  if (numTargetDims == 0) {
+    ReassociationIndices allSourceIndices(numSourceDims);
+    for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
+         ++sourceDimIdx) {
+      int64_t sourceSize = sourceShape[sourceDimIdx];
+      // All source dimensions must be unit or dynamic.
+      if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
         return std::nullopt;
-      if (wasLastDimDynamic) {
-        assert(!reassociationMap.empty());
-        auto &previousIndices = reassociationMap.back();
-        for (; sourceDimAtStart < indices->front(); ++sourceDimAtStart)
-          previousIndices.push_back(sourceDimAtStart);
-      }
-      reassociationMap.push_back(*indices);
-      wasLastDimDynamic = false;
-      continue;
+      allSourceIndices.emplace_back(sourceDimIdx);
     }
+    return SmallVector<ReassociationIndices>{allSourceIndices};
+  }
 
-    bool isNextDimDynamic =
-        targetDimIdx + 1 < numTargetDims &&
-        targetShape[targetDimIdx + 1] == ShapedType::kDynamic;
-    auto indices = collectSourceIndicesForDynamicTargetDim(
-        /*allowStaticNonOnes=*/!wasLastDimDynamic,
-        /*mapConsecutiveDynDims=*/!wasLastDimDynamic && !isNextDimDynamic);
-    if (failed(indices))
+  // Collect source ranges by iterating over the target shape left-to-right.
+  auto maybeForwardRanges =
+      findReassociationRangesForCollapse(sourceShape, targetShape);
+  if (failed(maybeForwardRanges))
+    return std::nullopt;
+  auto &ranges = *maybeForwardRanges;
+  // Now do the same in reverse. We need to get another valid reassociation
+  // through some other strategy, and then compare the results in order to
+  // disambiguate mixed subshapes, such as:
+  // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
+  // This leads us to lose some of the reassociation opportunities that can only
+  // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without
+  // backtracking, the algorithm will fail right-to-left. However, this is the
+  // best way to preserve correctness.
+  //
+  // NB: The reversed shapes must not be temporary as we're passing through an
+  // ArrayRef.
+  auto maybeReverseRanges = findReassociationRangesForCollapse(
+      sourceShape, targetShape, /*iterateRightToLeft=*/true);
+  if (failed(maybeReverseRanges))
+    return std::nullopt;
+  auto &reverseRanges = *maybeReverseRanges;
+
+  if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
+    return std::nullopt;
+  // Now we can check for ambiguity of each target dimension's reassociation. If
+  // successful, we put the full indices into our result map for the target
+  // shape.
+  SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
+  for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
+       ++targetDimIdx) {
+    auto &range = ranges[targetDimIdx];
+    auto &reverseRange = reverseRanges[targetDimIdx];
+    // Get non-overlapping indices between the ranges
+    ReassociationIndices nonMatchingIndices = range ^ reverseRange;
+    // The ranges should overlap, at the very least
+    if (nonMatchingIndices.size() == range.size() + reverseRange.size())
       return std::nullopt;
-    reassociationMap.push_back(*indices);
-    wasLastDimDynamic = true;
-  }
-  // Now that we've mapped all the target dimensions, process any remaining
-  // entries in the source shape explicitly.
-  for (; sourceDimIdx < numSourceDims; sourceDimIdx++) {
-    const bool isOne = sourceShape[sourceDimIdx] == 1,
-               isDynamic = sourceShape[sourceDimIdx] == ShapedType::kDynamic;
-    if (targetShape.empty()) {
-      if (!isOne && !isDynamic)
-        return std::nullopt;
-      continue;
-    }
-    // If the last 2 dimensions in the target were dynamic, the tail in the
-    // source shape cannot contain a dynamic value. E.g. ?x?->? is valid,
-    // however ?x?x10x?->?x? would be indeterminate.
-    if (wasLastDimDynamic && numTargetDims > 1 &&
-        targetShape[numTargetDims - 2] == ShapedType::kDynamic) {
-      if (isDynamic)
+    // Unit dimensions can be collapsed wherever - this is the only ambiguity
+    // that we allow.
+    for (int64_t sourceDimIdx : nonMatchingIndices) {
+      if (sourceShape[sourceDimIdx] != 1)
         return std::nullopt;
     }
-    // If the last target dimension is static, only source dimensions of 1 are
-    // acceptable.
-    if (!wasLastDimDynamic && !isOne)
-      return std::nullopt;
-    assert(!reassociationMap.empty());
-    reassociationMap.back().push_back(sourceDimIdx);
+    reassociationMap[targetDimIdx] = range.getFullIndices();
   }
   return reassociationMap;
 }
@@ -379,11 +575,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index 2564866fac493..a179d91129edb 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -57,6 +57,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) {
   EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
                                                {ShapedType::kDynamic}),
             makeOptionalIndices({{0, 1, 2}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1},
+                {ShapedType::kDynamic, ShapedType::kDynamic}),
+            makeOptionalIndices({{0, 1}, {2, 3, 4}}));
   EXPECT_EQ(
       getReassociationIndicesForCollapse(
           {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}),
@@ -76,6 +80,10 @@ TEST(ReassociationIndicesForCollapse, DynamicTest) {
   EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
                                                {ShapedType::kDynamic}),
             makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10},
+                {ShapedType::kDynamic, 10}),
+            makeOptionalIndices({{0, 1, 2, 3}, {4}}));
   EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
                                                {ShapedType::kDynamic, 20}),
             makeOptionalIndices({{0, 1}, {2}}));
@@ -131,4 +139,20 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
                  ShapedType::kDynamic},
                 {ShapedType::kDynamic, ShapedType::kDynamic}),
             std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 10, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 12, ShapedType::kDynamic}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
+                {ShapedType::kDynamic, 32, ShapedType::kDynamic}),
+            std::nullopt);
 }

>From dd36c47d7a6bb402497a3c4c1757f47928132e06 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Tue, 20 May 2025 22:26:12 +0000
Subject: [PATCH 7/8] [WIP] New tests

---
 .../Dialect/Utils/ReshapeOpsUtilsTest.cpp      | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index a179d91129edb..124c8ce86fc9c 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -155,4 +155,22 @@ TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
                 {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
                 {ShapedType::kDynamic, 32, ShapedType::kDynamic}),
             std::nullopt);
+
+  //===----------------------------------------------------------------------===//
+  // TODO: Reassociation for the following examples can be computed, but isn't
+  // supported by `getReassociationIndicesForCollapse`.
+  //===----------------------------------------------------------------------===//
+
+  // TODO: Fails because there's no backtracking when some source dimensions
+  // remain unmatched at either edge.
+  EXPECT_EQ(getReassociationIndicesForCollapse(
+                {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10},
+                {ShapedType::kDynamic, 10}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2},
+                                               {1, ShapedType::kDynamic, 2}),
+            std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1},
+                                               {2, ShapedType::kDynamic}),
+            std::nullopt);
 }

>From 07ed33d4363d64ed32f85fe0b296ca39cc916124 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Tue, 20 May 2025 22:45:24 +0000
Subject: [PATCH 8/8] [fixup] Add scalar target tests & fix em

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    |  3 ++-
 .../Dialect/Utils/ReshapeOpsUtilsTest.cpp     | 24 +++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 25dd434fc2122..209577db3272f 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -292,7 +292,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
     return std::nullopt;
   // Early handling for scalar target types.
   if (numTargetDims == 0) {
-    ReassociationIndices allSourceIndices(numSourceDims);
+    ReassociationIndices allSourceIndices;
+    allSourceIndices.reserve(numSourceDims);
     for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
          ++sourceDimIdx) {
       int64_t sourceSize = sourceShape[sourceDimIdx];
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index 124c8ce86fc9c..7abdf75c34cda 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "llvm/ADT/STLExtras.h"
 #include "gtest/gtest.h"
 #include <optional>
@@ -20,6 +21,29 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
   return std::optional<SmallVector<ReassociationIndices>>(list);
 }
 
+TEST(ReassociationIndicesForCollapse, ScalarTest) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
+            makeOptionalIndices({{0}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
+            makeOptionalIndices({{0, 1}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
+            makeOptionalIndices({{0}}));
+  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
+                                                ShapedType::kDynamic, 1,
+                                                ShapedType::kDynamic},
+                                               {}),
+            makeOptionalIndices({{0, 1, 2, 3, 4}}));
+}
+
+TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
+  EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
+  EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
+  EXPECT_EQ(
+      getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
+      std::nullopt);
+}
+
 TEST(ReassociationIndicesForCollapse, StaticTest) {
   EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
             makeOptionalIndices({{0, 1}}));



More information about the Mlir-commits mailing list