[Mlir-commits] [mlir] 6f2ba47 - [mlir] Fix ComposeExpandOfCollapseOp for dynamic case (#142663)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 11 14:34:05 PDT 2025
Author: Ian Wood
Date: 2025-06-11T14:34:02-07:00
New Revision: 6f2ba4712f17d7c82228a5b705570571e13a3832
URL: https://github.com/llvm/llvm-project/commit/6f2ba4712f17d7c82228a5b705570571e13a3832
DIFF: https://github.com/llvm/llvm-project/commit/6f2ba4712f17d7c82228a5b705570571e13a3832.diff
LOG: [mlir] Fix ComposeExpandOfCollapseOp for dynamic case (#142663)
Changes `findCollapsingReassociation` to return nullopt in all cases
where source shape has `>=2` dynamic dims. `expand(collapse)` can
reshape to in any valid output shape but a collapse can only collapse
contiguous dimensions. When there are `>=2` dynamic dimensions it is
impossible to determine if it can be simplified to a collapse or if it
is preforming a more advanced reassociation.
This problem was uncovered by
https://github.com/llvm/llvm-project/pull/137963
---------
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
Added:
Modified:
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
mlir/test/Dialect/Tensor/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index af575e10acc8e..61c2a50e514ca 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -387,11 +387,14 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
auto resultSubShape =
resultShape.slice(resultIndices.front(), resultIndices.size());
+ if (llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2 &&
+ llvm::count_if(resultSubShape, ShapedType::isDynamic) >= 2)
+ return std::nullopt;
+
if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape != resultSubShape ||
- llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
+ if (srcSubShape != resultSubShape)
return std::nullopt;
- }
+
for (auto index : llvm::seq<int64_t>(0, srcSubShape.size())) {
composedReassociation.emplace_back(1, srcIndices.front() + index);
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 65c5b3e8602eb..67b03b0a3485b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1272,6 +1272,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
// -----
+func.func @no_compose_collapse_of_expand_dynamic(%arg0 : tensor<?x8x128x?xf16>, %arg1: index) -> tensor<?x128x?xf16> {
+ %collapse = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<?x8x128x?xf16> into tensor<?xf16>
+ %expanded_19 = tensor.expand_shape %collapse [[0, 1, 2]] output_shape [%arg1, 8, %arg1] : tensor<?xf16> into tensor<?x128x?xf16>
+ return %expanded_19 : tensor<?x128x?xf16>
+}
+// CHECK-LABEL: func @no_compose_collapse_of_expand_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor
+// CHECK-SAME: %[[ARG1:.+]]: index
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]]
+// CHECK: return %[[EXPAND]]
+
+// -----
+
// CHECK-LABEL: func @zero_rank_reshape_multi
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list