[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)

Artem Gindinson llvmlistbot at llvm.org
Wed Apr 30 06:26:53 PDT 2025


https://github.com/AGindinson created https://github.com/llvm/llvm-project/pull/137963

The main idea behind the change is to allow expand-of-collapse folds for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in its `output_shape` argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage:
```
%c32 = arith.constant 32 : index
%div = arith.divsi %<some_index>, %c32 : index
%collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]]
	         : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32>
%affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div]
%expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine]
		: tensor<9x?x?xf32> into tensor<9x?x32x?xf32>
```

On the above assumption, adjust the routine in
`getReassociationIndicesForCollapse()` to allow dynamic reshapes beyond just `?x..?x1x1x..x1` -> `?`.

Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include:
- abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern, employing similar logic to `ComposeCollapseOfExpandOp`;
- providing dialect-specific implementations for Linalg/Tensor.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>

>From 2ae44808cd573c31331d143d687afcdab6986a72 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Mon, 28 Apr 2025 17:24:11 +0000
Subject: [PATCH] [mlir][tensor] Loosen restrictions on folding dynamic
 reshapes

The main idea behind the change is to allow expand-of-collapse folds
for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the
expand op must have a coherent index/affine expression specified in its
`output_shape` argument (see example below), and if it doesn't, the IR
has already been invalidated at an earlier stage:
```
%c32 = arith.constant 32 : index
%div = arith.divsi %<some_index>, %c32 : index
%collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]]
	         : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32>
%affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div]
%expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine]
		: tensor<9x?x?xf32> into tensor<9x?x32x?xf32>
```

On the above assumption, adjust the routine in
`getReassociationIndicesForCollapse()` to allow dynamic reshapes
beyond just `?x..?x1x1x..x1` -> `?`.

Moreover, the reassociation util was refactored to clearly distinguish
between dynamic and static subshapes. A few known caveats were noted as
a comment; it doesn't seem possible to fold all qualifying dynamic shape
patterns in a deterministic way without looking into affine expressions
simultaneously. That would be difficult to maintain in a single general
utility. Other implementation ideas/larger refactoring could include:
- abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern,
  employing similar logic to `ComposeCollapseOfExpandOp`;
- providing dialect-specific implementations for Linalg/Tensor.

Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp    | 103 ++++++++++--------
 .../Dialect/Linalg/simplify-pack-unpack.mlir  |   4 +-
 mlir/test/Dialect/Tensor/canonicalize.mlir    |  24 +++-
 3 files changed, 79 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index ed40a080441bc..694783849198a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
 std::optional<SmallVector<ReassociationIndices>>
 mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                          ArrayRef<int64_t> targetShape) {
-  if (sourceShape.size() <= targetShape.size())
+  unsigned numSourceDims = sourceShape.size(),
+           numTargetDims = targetShape.size();
+  if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  unsigned sourceDim = 0;
-  SmallVector<ReassociationIndices> reassociationMap;
-  reassociationMap.reserve(targetShape.size());
-
-  ReassociationIndices currIndices;
-  int64_t prodOfCollapsedDims = 1;
-  while (sourceDim < sourceShape.size()) {
-    unsigned targetDim = reassociationMap.size();
-    // If we have mapped all the target dimensions stop and handle the remaining
-    // tail of size-1 dimensions explicitly.
-    if (targetDim == targetShape.size())
-      break;
+  SmallVector<ReassociationIndices, 4> reassociationMap;
+  reassociationMap.reserve(numTargetDims);
 
+  unsigned sourceDim = 0, targetDim = 0;
+  for (; targetDim < numTargetDims; ++targetDim) {
     int64_t currTargetShape = targetShape[targetDim];
-    while (sourceDim < (sourceShape.size() - 1) &&
-           sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+    ReassociationIndices currIndices;
+    // 1. Target dimension is dynamic. Source shape should contain at least
+    // one dynamic dimension.
+    if (currTargetShape == ShapedType::kDynamic) {
+      // FIXME: We stop the search with the first dynamic dimension, while in
+      // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
+      // indeterministic altogether when we have neighboring dynamic dimensions
+      // in the target shape. Most of these patterns will be safely rejected,
+      // however we might achieve more correct folds by taking affine
+      // expressions into account, if these can be passed on by the call sites.
+      bool foundDynamic = false;
+      while (sourceDim < numSourceDims) {
+        currIndices.push_back(sourceDim);
+        if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
+          foundDynamic = true;
+          break;
+        }
+      }
+      if (!foundDynamic)
+        return std::nullopt;
+
+      reassociationMap.push_back(currIndices);
+      continue;
+    }
+    // 2. Target dimension is static. The product of dimensions of the expanded
+    // shape should match the collapsed dimension shape.
+    int64_t prodOfCollapsedDims = 1;
+    bool reachedTargetDimSize = false;
+    while (sourceDim < numSourceDims) {
+      // Source shape cannot be dynamic if the target dim is static.
+      if (sourceShape[sourceDim] == ShapedType::kDynamic)
+        return std::nullopt;
       prodOfCollapsedDims *= sourceShape[sourceDim];
-      currIndices.push_back(sourceDim++);
+      if (prodOfCollapsedDims > currTargetShape)
+        break;
+      else if (prodOfCollapsedDims == currTargetShape) {
+        currIndices.push_back(sourceDim++);
+        reachedTargetDimSize = true;
+        break;
+      } else // prodOfCollapsedDims < currTargetShape
+        currIndices.push_back(sourceDim++);
     }
-
-    // If the current expanded dimension is dynamic, then the collapsed
-    // dimensions should also be dynamic and product of all previous unprocessed
-    // dimensions of the expanded shape should be 1.
-    if (sourceShape[sourceDim] == ShapedType::kDynamic &&
-        (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+    if (!reachedTargetDimSize)
       return std::nullopt;
-
-    // If the collapsed dim is dynamic, the current expanded dim should also
-    // be dynamic.
-    if (currTargetShape == ShapedType::kDynamic &&
-        sourceShape[sourceDim] != ShapedType::kDynamic)
-      return std::nullopt;
-
-    // For static shapes, if the product of dimensions of the expanded shape
-    // should match the collapsed dimension shape.
-    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
-      return std::nullopt;
-
-    currIndices.push_back(sourceDim++);
-    reassociationMap.emplace_back(ReassociationIndices{});
-    std::swap(reassociationMap.back(), currIndices);
-    prodOfCollapsedDims = 1;
+    reassociationMap.push_back(currIndices);
   }
-  // All the dimensions in the target must have been processed.
-  if (reassociationMap.size() != targetShape.size())
-    return std::nullopt;
-  // Process any remaining entries in the source shape. They all need to be
-  // 1 or dynamic.
-  for (; sourceDim < sourceShape.size(); sourceDim++) {
-    if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+  // Now that we've mapped all the target dimensions, process any remaining
+  // entries in the source shape explicitly. Either the last target dimension
+  // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
+  // applies when target shape is empty (can be the case for subshape
+  // reassociations).
+  for (; sourceDim < numSourceDims; sourceDim++) {
+    if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
+        sourceShape[sourceDim] != ShapedType::kDynamic &&
         sourceShape[sourceDim] != 1)
       return std::nullopt;
     // The map is empty when the target type is a scalar.
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
 // -----
 
 // CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT:     tensor.collapse
-// CHECK:         linalg.unpack
+// CHECK:     tensor.collapse
+// CHECK-NOT:         linalg.unpack
 func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
   %c32 = arith.constant 32 : index
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..443f931745557 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1068,7 +1068,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
 
 // -----
 
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
     -> tensor<?x4x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1076,12 +1076,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   return %1 : tensor<?x4x?xf32>
 }
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
 //   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+      : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+//   CHECK-NOT:   tensor.expand_shape
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+//  CHECK-SAME:     : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+//  CHECK-NEXT:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
     -> tensor<?x?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1089,7 +1105,7 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
       : tensor<?x?xf32> into tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
 //       CHECK:   tensor.collapse_shape
 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
 //       CHECK:   return %[[EXPAND]]



More information about the Mlir-commits mailing list