[Mlir-commits] [mlir] 634e253 - [mlir] Add special case for 0-D tensor when fusing expand from collapse (#130838)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 11 15:55:58 PDT 2025


Author: Evan Liu
Date: 2025-03-11T15:55:55-07:00
New Revision: 634e25319e0e99affcb61cc9fba639c4d40cc420

URL: https://github.com/llvm/llvm-project/commit/634e25319e0e99affcb61cc9fba639c4d40cc420
DIFF: https://github.com/llvm/llvm-project/commit/634e25319e0e99affcb61cc9fba639c4d40cc420.diff

LOG: [mlir] Add special case for 0-D tensor when fusing expand from collapse (#130838)

One fusion pattern for collapse_shape -> expand_shape was added in
https://github.com/llvm/llvm-project/commit/a95ad2da36b6a996b05c79df6b385cd98bac286d,
however if the intermediate tensor between a collapse and expand is a
0-D tensor, then the `reassociation_map` for these two are special cases
and can't be generally fused in this function
`BubbleUpExpandThroughParallelCollapse`.

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
    mlir/test/Dialect/Tensor/bubble-reshapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index ae8e3528b02e0..acedf51d0e240 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -160,6 +160,12 @@ struct BubbleUpExpandThroughParallelCollapse
     auto expandReInds = expandOp.getReassociationIndices();
     auto collapseReInds = collapseOp.getReassociationIndices();
 
+    // Special case where the collapsed tensor to expand is a 0-D tensor,
+    // then the reassociation maps will be empty and not produce valid results.
+    if (expandReInds.size() == 0) {
+      return failure();
+    }
+
     // Reshapes are parallel to each other if none of the reassociation indices
     // have greater than 1 index for both reshapes.
     for (auto [expandReassociation, collapseReassociation] :

diff  --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd3..eeed794884942 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -45,3 +45,17 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %
 //      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
 //      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
 //      CHECK:   return %[[EXPAND]]
+
+// -----
+
+func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
+  %expand = tensor.expand_shape %collapse []
+              output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @no_bubble_0d_tensor_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
+//      CHECK:   return %[[EXPAND]]


        


More information about the Mlir-commits mailing list