[Mlir-commits] [mlir] [mlir][tensor] Fix `getReassociationForCollapse` for tensor/scalar re… (PR #144118)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 13 10:20:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Artem Gindinson (AGindinson)

<details>
<summary>Changes</summary>

…shapes

Commit 6e5a142 changed the behavior of the function when computing reassociations between tensors (consisting of unit/dynamic dimensions) and scalars/0d vectors. The IR representation for such reshapes actually expects an empty reassociation, like so:
```
func.func @<!-- -->example(%arg0 : tensor<?x?x?xf32>) -> tensor<f32> {
  %0 = tensor.collapse_shape %arg0 [] : tensor<?x?x?xf32> into tensor<f32>
}
```

Restore the original behavior - the routine should resort to reporting failures when compile time-known non-unit dimensions are part of the attempted reassociation.

---
Full diff: https://github.com/llvm/llvm-project/pull/144118.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp (+4-6) 
- (modified) mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp (+4-4) 


``````````diff
diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 3b1fdb69e8ef1..aa566c0086a2f 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
   // this utility).
   if (numSourceDims <= numTargetDims)
     return std::nullopt;
-  // Early handling for scalar target types.
+  // Early handling for scalar target types. We should report an invalid
+  // reassociation for non-unit static dimensions - no chance to collapse these
+  // into a scalar.
   if (numTargetDims == 0) {
-    ReassociationIndices allSourceIndices;
-    allSourceIndices.reserve(numSourceDims);
     for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
          ++sourceDimIdx) {
       int64_t sourceSize = sourceShape[sourceDimIdx];
-      // All source dimensions must be unit or dynamic.
       if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
         return std::nullopt;
-      allSourceIndices.push_back(sourceDimIdx);
     }
-    return SmallVector<ReassociationIndices>{allSourceIndices};
+    return SmallVector<ReassociationIndices>{};
   }
 
   // Collect source ranges by iterating over the target shape left-to-right.
diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
index db1a87a4de2d5..05f97e875e2dc 100644
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
@@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
 
 TEST(ReassociationIndicesForCollapse, ScalarTest) {
   EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
-            makeOptionalIndices({{0}}));
+            makeOptionalIndices({}));
   EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
-            makeOptionalIndices({{0, 1}}));
+            makeOptionalIndices({}));
   EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
-            makeOptionalIndices({{0}}));
+            makeOptionalIndices({}));
   EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
                                                 ShapedType::kDynamic, 1,
                                                 ShapedType::kDynamic},
                                                {}),
-            makeOptionalIndices({{0, 1, 2, 3, 4}}));
+            makeOptionalIndices({}));
 }
 
 TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/144118


More information about the Mlir-commits mailing list