[Mlir-commits] [mlir] [mlir][tensor] Loosen restrictions on folding dynamic reshapes (PR #137963)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 30 06:27:50 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Gindinson (AGindinson)
<details>
<summary>Changes</summary>
The main idea behind the change is to allow expand-of-collapse folds for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in its `output_shape` argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage:
```
%c32 = arith.constant 32 : index
%div = arith.divsi %<some_index>, %c32 : index
%collapsed = tensor.collapse_shape %41#<!-- -->1 [[0], [1, 2], [3, 4]]
: tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32>
%affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div]
%expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine]
: tensor<9x?x?xf32> into tensor<9x?x32x?xf32>
```
On the above assumption, adjust the routine in
`getReassociationIndicesForCollapse()` to allow dynamic reshapes beyond just `?x..?x1x1x..x1` -> `?`.
Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; it doesn't seem possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously. That would be difficult to maintain in a single general utility. Other implementation ideas/larger refactoring could include:
- abandoning the util usage in the `ComposeExpandOfCollapseOp` pattern, employing similar logic to `ComposeCollapseOfExpandOp`;
- providing dialect-specific implementations for Linalg/Tensor.
Signed-off-by: Artem Gindinson <gindinson@<!-- -->roofline.ai>
---
Full diff: https://github.com/llvm/llvm-project/pull/137963.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+57-46)
- (modified) mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir (+2-2)
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+20-4)
``````````diff
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index ed40a080441bc..694783849198a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
std::optional<SmallVector<ReassociationIndices>>
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape) {
- if (sourceShape.size() <= targetShape.size())
+ unsigned numSourceDims = sourceShape.size(),
+ numTargetDims = targetShape.size();
+ if (numSourceDims <= numTargetDims)
return std::nullopt;
- unsigned sourceDim = 0;
- SmallVector<ReassociationIndices> reassociationMap;
- reassociationMap.reserve(targetShape.size());
-
- ReassociationIndices currIndices;
- int64_t prodOfCollapsedDims = 1;
- while (sourceDim < sourceShape.size()) {
- unsigned targetDim = reassociationMap.size();
- // If we have mapped all the target dimensions stop and handle the remaining
- // tail of size-1 dimensions explicitly.
- if (targetDim == targetShape.size())
- break;
+ SmallVector<ReassociationIndices, 4> reassociationMap;
+ reassociationMap.reserve(numTargetDims);
+ unsigned sourceDim = 0, targetDim = 0;
+ for (; targetDim < numTargetDims; ++targetDim) {
int64_t currTargetShape = targetShape[targetDim];
- while (sourceDim < (sourceShape.size() - 1) &&
- sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+ ReassociationIndices currIndices;
+ // 1. Target dimension is dynamic. Source shape should contain at least
+ // one dynamic dimension.
+ if (currTargetShape == ShapedType::kDynamic) {
+ // FIXME: We stop the search with the first dynamic dimension, while in
+ // fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
+ // indeterministic altogether when we have neighboring dynamic dimensions
+ // in the target shape. Most of these patterns will be safely rejected,
+ // however we might achieve more correct folds by taking affine
+ // expressions into account, if these can be passed on by the call sites.
+ bool foundDynamic = false;
+ while (sourceDim < numSourceDims) {
+ currIndices.push_back(sourceDim);
+ if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
+ foundDynamic = true;
+ break;
+ }
+ }
+ if (!foundDynamic)
+ return std::nullopt;
+
+ reassociationMap.push_back(currIndices);
+ continue;
+ }
+ // 2. Target dimension is static. The product of dimensions of the expanded
+ // shape should match the collapsed dimension shape.
+ int64_t prodOfCollapsedDims = 1;
+ bool reachedTargetDimSize = false;
+ while (sourceDim < numSourceDims) {
+ // Source shape cannot be dynamic if the target dim is static.
+ if (sourceShape[sourceDim] == ShapedType::kDynamic)
+ return std::nullopt;
prodOfCollapsedDims *= sourceShape[sourceDim];
- currIndices.push_back(sourceDim++);
+ if (prodOfCollapsedDims > currTargetShape)
+ break;
+ else if (prodOfCollapsedDims == currTargetShape) {
+ currIndices.push_back(sourceDim++);
+ reachedTargetDimSize = true;
+ break;
+ } else // prodOfCollapsedDims < currTargetShape
+ currIndices.push_back(sourceDim++);
}
-
- // If the current expanded dimension is dynamic, then the collapsed
- // dimensions should also be dynamic and product of all previous unprocessed
- // dimensions of the expanded shape should be 1.
- if (sourceShape[sourceDim] == ShapedType::kDynamic &&
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+ if (!reachedTargetDimSize)
return std::nullopt;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamic &&
- sourceShape[sourceDim] != ShapedType::kDynamic)
- return std::nullopt;
-
- // For static shapes, if the product of dimensions of the expanded shape
- // should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
- return std::nullopt;
-
- currIndices.push_back(sourceDim++);
- reassociationMap.emplace_back(ReassociationIndices{});
- std::swap(reassociationMap.back(), currIndices);
- prodOfCollapsedDims = 1;
+ reassociationMap.push_back(currIndices);
}
- // All the dimensions in the target must have been processed.
- if (reassociationMap.size() != targetShape.size())
- return std::nullopt;
- // Process any remaining entries in the source shape. They all need to be
- // 1 or dynamic.
- for (; sourceDim < sourceShape.size(); sourceDim++) {
- if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+ // Now that we've mapped all the target dimensions, process any remaining
+ // entries in the source shape explicitly. Either the last target dimension
+ // is dynamic, or all remaining source entries need to be 1 or dynamic. Same
+ // applies when target shape is empty (can be the case for subshape
+ // reassociations).
+ for (; sourceDim < numSourceDims; sourceDim++) {
+ if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
+ sourceShape[sourceDim] != ShapedType::kDynamic &&
sourceShape[sourceDim] != 1)
return std::nullopt;
// The map is empty when the target type is a scalar.
diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 51350e5bc8498..6979770154bab 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
// -----
// CHECK-LABEL: func.func @unpack_dynamic
-// CHECK-NOT: tensor.collapse
-// CHECK: linalg.unpack
+// CHECK: tensor.collapse
+// CHECK-NOT: linalg.unpack
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 85bf6fba52aa4..443f931745557 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1068,7 +1068,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
// -----
-func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
-> tensor<?x4x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1076,12 +1076,28 @@ func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: ind
: tensor<?x?xf32> into tensor<?x4x?xf32>
return %1 : tensor<?x4x?xf32>
}
-// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
// CHECK-NOT: tensor.{{.*}}_shape
// -----
-func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
+ -> tensor<?x4x?xf32> {
+ %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
+ : tensor<?x4x?x2xf32> into tensor<?x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+ : tensor<?x?xf32> into tensor<?x4x?xf32>
+ return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
+// CHECK-NOT: tensor.expand_shape
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
+// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
+// CHECK-NEXT: return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
-> tensor<?x?x?xf32> {
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1089,7 +1105,7 @@ func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1:
: tensor<?x?xf32> into tensor<?x?x?xf32>
return %1 : tensor<?x?x?xf32>
}
-// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
// CHECK: tensor.collapse_shape
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
// CHECK: return %[[EXPAND]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/137963
More information about the Mlir-commits
mailing list