[Mlir-commits] [mlir] [mlir][tensor] Fold dynamic expand of collapse (PR #127689)

Ian Wood llvmlistbot at llvm.org
Tue Feb 18 11:28:26 PST 2025


https://github.com/IanWood1 created https://github.com/llvm/llvm-project/pull/127689

If the shape is collapsed into a single dim, the `ReassociationIndicies` for the shape are trivial to compute. Importantly, there is no need to worry about dynamic dimensions that usually restrict computing `ReassociationIndicies`.

>From 5accf5d08e932295427dff8f7db43094f57b6c6c Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 18 Feb 2025 22:47:05 -0800
Subject: [PATCH] Handle collapse into single element

Signed-off-by: Ian Wood <ianwood2024 at u.northwestern.edu>
---
 mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 13 ++++++++-----
 mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++++++
 2 files changed, 22 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 0336423c57b1d..169f28cece4dc 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -33,6 +33,9 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
                                          ArrayRef<int64_t> targetShape) {
   if (sourceShape.size() <= targetShape.size())
     return std::nullopt;
+  if (targetShape.size() == 1)
+    return SmallVector<ReassociationIndices>{
+        llvm::to_vector(llvm::seq<int64_t>(0, sourceShape.size()))};
   unsigned sourceDim = 0;
   SmallVector<ReassociationIndices> reassociationMap;
   reassociationMap.reserve(targetShape.size());
@@ -315,11 +318,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(
-          offsetsSizesAndStrides,
-          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
-            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
-          }));
+      llvm::append_range(offsetsSizesAndStrides,
+                         llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+                           return {zeroAttr, collapseShapeInputShape[idx],
+                                   oneAttr};
+                         }));
       continue;
     }
 
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 90cc0ca658ffb..bbbef2ebc9d2b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1191,6 +1191,20 @@ func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %
 
 // -----
 
+func.func @compose_expand_of_collapse_dynamic_collapse(%arg0 : tensor<4x13x10x64x?xf16>, %arg1 : index) -> tensor<4x13x10x?xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x13x10x64x?xf16> into tensor<52x10x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 13,  10, %arg1] : tensor<52x10x?xf16> into tensor<4x13x10x?xf16>
+  return %expanded : tensor<4x13x10x?xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic_collapse
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<4x13x10x64x?xf16>
+//      CHECK:   %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME:     [0], [1], [2], [3, 4]
+//      CHECK:   return %[[RESULT]]
+
+// -----
+
 // CHECK-LABEL: func @zero_rank_reshape_multi
 func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
   // CHECK: return %arg0



More information about the Mlir-commits mailing list