[Mlir-commits] [mlir] [mlir][tensor] Fix `getReassociationForCollapse` for tensor/scalar re… (PR #144118)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 13 10:20:50 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Artem Gindinson (AGindinson)
<details>
<summary>Changes</summary>
…shapes
Commit 6e5a142 changed the behavior of the function when computing reassociations between tensors (consisting of unit/dynamic dimensions) and scalars/0d vectors. The IR representation for such reshapes actually expects an empty reassociation, like so:
```
func.func @<!-- -->example(%arg0 : tensor<?x?x?xf32>) -> tensor<f32> {
%0 = tensor.collapse_shape %arg0 [] : tensor<?x?x?xf32> into tensor<f32>
}
```
Restore the original behavior - the routine should resort to reporting failures when compile time-known non-unit dimensions are part of the attempted reassociation.
---
Full diff: https://github.com/llvm/llvm-project/pull/144118.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+4-6)
- (modified) mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp (+4-4)
``````````diff
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 3b1fdb69e8ef1..aa566c0086a2f 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
// this utility).
if (numSourceDims <= numTargetDims)
return std::nullopt;
- // Early handling for scalar target types.
+ // Early handling for scalar target types. We should report an invalid
+ // reassociation for non-unit static dimensions - no chance to collapse these
+ // into a scalar.
if (numTargetDims == 0) {
- ReassociationIndices allSourceIndices;
- allSourceIndices.reserve(numSourceDims);
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
++sourceDimIdx) {
int64_t sourceSize = sourceShape[sourceDimIdx];
- // All source dimensions must be unit or dynamic.
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
return std::nullopt;
- allSourceIndices.push_back(sourceDimIdx);
}
- return SmallVector<ReassociationIndices>{allSourceIndices};
+ return SmallVector<ReassociationIndices>{};
}
// Collect source ranges by iterating over the target shape left-to-right.
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index db1a87a4de2d5..05f97e875e2dc 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
TEST(ReassociationIndicesForCollapse, ScalarTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
- makeOptionalIndices({{0}}));
+ makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
- makeOptionalIndices({{0, 1}}));
+ makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
- makeOptionalIndices({{0}}));
+ makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
ShapedType::kDynamic, 1,
ShapedType::kDynamic},
{}),
- makeOptionalIndices({{0, 1, 2, 3, 4}}));
+ makeOptionalIndices({}));
}
TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/144118
More information about the Mlir-commits
mailing list