[Mlir-commits] [mlir] [mlir][tensor] Fix computation of collapse in dynamic cases (PR #127560)
Ian Wood
llvmlistbot at llvm.org
Mon Feb 17 19:29:25 PST 2025
https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/127560
None
>From d6f03858c97158e7aa7599eec8180cb1011bb428 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Sun, 16 Feb 2025 14:55:20 -0800
Subject: [PATCH] Compute reassociation in dynamic cases.
Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 27 +++++++++++-----------
mlir/test/Dialect/Tensor/canonicalize.mlir | 14 +++++++++++
2 files changed, 27 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..c053ed488982a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -49,27 +49,26 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
int64_t currTargetShape = targetShape[targetDim];
while (sourceDim < (sourceShape.size() - 1) &&
sourceShape[sourceDim] != ShapedType::kDynamic &&
- prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+ (currTargetShape == ShapedType::kDynamic ||
+ prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape)) {
prodOfCollapsedDims *= sourceShape[sourceDim];
currIndices.push_back(sourceDim++);
}
+ if (sourceDim >= sourceShape.size())
+ return std::nullopt;
+
// If the current expanded dimension is dynamic, then the collapsed
// dimensions should also be dynamic and product of all previous unprocessed
// dimensions of the expanded shape should be 1.
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
- (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
- return std::nullopt;
-
- // If the collapsed dim is dynamic, the current expanded dim should also
- // be dynamic.
- if (currTargetShape == ShapedType::kDynamic &&
- sourceShape[sourceDim] != ShapedType::kDynamic)
+ currTargetShape != ShapedType::kDynamic)
return std::nullopt;
// For static shapes, if the product of dimensions of the expanded shape
// should match the collapsed dimension shape.
- if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
+ if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+ prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
return std::nullopt;
currIndices.push_back(sourceDim++);
@@ -315,11 +314,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
// have proven that these are not sliced. In this case we just take
// the full extent of each dimension in the reassociation list.
if (linearizedDimensions[it.index()]) {
- llvm::append_range(
- offsetsSizesAndStrides,
- llvm::map_range(it.value(), [&](int64_t idx) -> Range {
- return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
- }));
+ llvm::append_range(offsetsSizesAndStrides,
+ llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+ return {zeroAttr, collapseShapeInputShape[idx],
+ oneAttr};
+ }));
continue;
}
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90cc0ca658ffb..bbbef2ebc9d2b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1191,6 +1191,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
// -----
+func.func @compose_expand_of_collapse_dynamic_collapse(%arg0 : tensor<4x13x10x64x?xf16>, %arg1 : index) -> tensor<4x13x10x?xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x13x10x64x?xf16> into tensor<52x10x?xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 13, 10, %arg1] : tensor<52x10x?xf16> into tensor<4x13x10x?xf16>
+ return %expanded : tensor<4x13x10x?xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic_collapse
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x13x10x64x?xf16>
+// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: [0], [1], [2], [3, 4]
+// CHECK: return %[[RESULT]]
+
+// -----
+
// CHECK-LABEL: func @zero_rank_reshape_multi
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
More information about the Mlir-commits
mailing list