[Mlir-commits] [mlir] [mlir][tensor] Fix computation of collapse in dynamic cases (PR #127560)

Ian Wood llvmlistbot at llvm.org
Mon Feb 17 19:29:25 PST 2025


https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/127560

None

>From d6f03858c97158e7aa7599eec8180cb1011bb428 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Sun, 16 Feb 2025 14:55:20 -0800
Subject: [PATCH] Compute reassociation in dynamic cases.

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 27 +++++++++++-----------
 mlir/test/Dialect/Tensor/canonicalize.mlir | 14 +++++++++++
 2 files changed, 27 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..c053ed488982a 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -49,27 +49,26 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
     int64_t currTargetShape = targetShape[targetDim];
     while (sourceDim < (sourceShape.size() - 1) &&
            sourceShape[sourceDim] != ShapedType::kDynamic &&
-           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+           (currTargetShape == ShapedType::kDynamic ||
+            prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape)) {
       prodOfCollapsedDims *= sourceShape[sourceDim];
       currIndices.push_back(sourceDim++);
     }
 
+    if (sourceDim >= sourceShape.size())
+      return std::nullopt;
+
     // If the current expanded dimension is dynamic, then the collapsed
     // dimensions should also be dynamic and product of all previous unprocessed
     // dimensions of the expanded shape should be 1.
     if (sourceShape[sourceDim] == ShapedType::kDynamic &&
-        (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
-      return std::nullopt;
-
-    // If the collapsed dim is dynamic, the current expanded dim should also
-    // be dynamic.
-    if (currTargetShape == ShapedType::kDynamic &&
-        sourceShape[sourceDim] != ShapedType::kDynamic)
+        currTargetShape != ShapedType::kDynamic)
       return std::nullopt;
 
     // For static shapes, if the product of dimensions of the expanded shape
     // should match the collapsed dimension shape.
-    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
+    if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+        prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
       return std::nullopt;
 
     currIndices.push_back(sourceDim++);
@@ -315,11 +314,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90cc0ca658ffb..bbbef2ebc9d2b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1191,6 +1191,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
 
 // -----
 
+func.func @compose_expand_of_collapse_dynamic_collapse(%arg0 : tensor<4x13x10x64x?xf16>, %arg1 : index) -> tensor<4x13x10x?xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x13x10x64x?xf16> into tensor<52x10x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 13,  10, %arg1] : tensor<52x10x?xf16> into tensor<4x13x10x?xf16>
+  return %expanded : tensor<4x13x10x?xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic_collapse
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x13x10x64x?xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0



More information about the Mlir-commits mailing list