[Mlir-commits] [mlir] [mlir][tensor] Fix `getReassociationForCollapse` for tensor/scalar re… (PR #144118)
Artem Gindinson
llvmlistbot at llvm.org
Fri Jun 13 09:53:31 PDT 2025
https://github.com/AGindinson created https://github.com/llvm/llvm-project/pull/144118
…shapes
Commit 6e5a142 changed the behavior of the function when computing reassociations between tensors (consisting of unit/dynamic dimensions) and scalars/0d vectors. The IR representation for such reshapes actually expects an empty reassociation, like so:
```
func.func @example(%arg0 : tensor<?x?x?xf32>) -> tensor<f32> {
%0 = tensor.collapse_shape %arg0 [] : tensor<?x?x?xf32> into tensor<f32>
}
```
Restore the original behavior - the routine should resort to reporting failures when compile time-known non-unit dimensions are part of the attempted reassociation.
>From 1f421b4c3b4cd67859a275b7a1770e3691f38515 Mon Sep 17 00:00:00 2001
From: Artem Gindinson <gindinson at roofline.ai>
Date: Fri, 13 Jun 2025 16:44:19 +0000
Subject: [PATCH] [mlir][tensor] Fix `getReassociationForCollapse` for
tensor/scalar reshapes
Commit 6e5a142 changed the behavior of the function when computing
reassociations between tensors (consisting of unit/dynamic dimensions) and
scalars/0d vectors. The IR representation for such reshapes actually expects
an empty reassociation, like so:
```
func.func @example(%arg0 : tensor<?x?x?xf32>) -> tensor<f32> {
%0 = tensor.collapse_shape %arg0 [] : tensor<?x?x?xf32> into tensor<f32>
}
```
Restore the original behavior - the routine should resort to reporting
failures when compile time-known non-unit dimensions are part of the
attempted reassociation.
Signed-off-by: Artem Gindinson <gindinson at roofline.ai>
---
mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 10 ++++------
mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp | 8 ++++----
2 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 3b1fdb69e8ef1..aa566c0086a2f 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
// this utility).
if (numSourceDims <= numTargetDims)
return std::nullopt;
- // Early handling for scalar target types.
+ // Early handling for scalar target types. We should report an invalid
+ // reassociation for non-unit static dimensions - no chance to collapse these
+ // into a scalar.
if (numTargetDims == 0) {
- ReassociationIndices allSourceIndices;
- allSourceIndices.reserve(numSourceDims);
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
++sourceDimIdx) {
int64_t sourceSize = sourceShape[sourceDimIdx];
- // All source dimensions must be unit or dynamic.
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
return std::nullopt;
- allSourceIndices.push_back(sourceDimIdx);
}
- return SmallVector<ReassociationIndices>{allSourceIndices};
+ return SmallVector<ReassociationIndices>{};
}
// Collect source ranges by iterating over the target shape left-to-right.
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index db1a87a4de2d5..05f97e875e2dc 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
TEST(ReassociationIndicesForCollapse, ScalarTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
- makeOptionalIndices({{0}}));
+ makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
- makeOptionalIndices({{0, 1}}));
+ makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
- makeOptionalIndices({{0}}));
+ makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
ShapedType::kDynamic, 1,
ShapedType::kDynamic},
{}),
- makeOptionalIndices({{0, 1, 2, 3, 4}}));
+ makeOptionalIndices({}));
}
TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
More information about the Mlir-commits
mailing list