[Mlir-commits] [mlir] 634e253 - [mlir] Add special case for 0-D tensor when fusing expand from collapse (#130838)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 11 15:55:58 PDT 2025
Author: Evan Liu
Date: 2025-03-11T15:55:55-07:00
New Revision: 634e25319e0e99affcb61cc9fba639c4d40cc420
URL: https://github.com/llvm/llvm-project/commit/634e25319e0e99affcb61cc9fba639c4d40cc420
DIFF: https://github.com/llvm/llvm-project/commit/634e25319e0e99affcb61cc9fba639c4d40cc420.diff
LOG: [mlir] Add special case for 0-D tensor when fusing expand from collapse (#130838)
One fusion pattern for collapse_shape -> expand_shape was added in
https://github.com/llvm/llvm-project/commit/a95ad2da36b6a996b05c79df6b385cd98bac286d,
however if the intermediate tensor between a collapse and expand is a
0-D tensor, then the `reassociation_map` for these two are special cases
and can't be generally fused in this function
`BubbleUpExpandThroughParallelCollapse`.
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
mlir/test/Dialect/Tensor/bubble-reshapes.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index ae8e3528b02e0..acedf51d0e240 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -160,6 +160,12 @@ struct BubbleUpExpandThroughParallelCollapse
auto expandReInds = expandOp.getReassociationIndices();
auto collapseReInds = collapseOp.getReassociationIndices();
+ // Special case where the collapsed tensor to expand is a 0-D tensor,
+ // then the reassociation maps will be empty and not produce valid results.
+ if (expandReInds.size() == 0) {
+ return failure();
+ }
+
// Reshapes are parallel to each other if none of the reassociation indices
// have greater than 1 index for both reshapes.
for (auto [expandReassociation, collapseReassociation] :
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index cf6b12852bcd3..eeed794884942 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -45,3 +45,17 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
// CHECK: return %[[EXPAND]]
+
+// -----
+
+func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+ %collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
+ %expand = tensor.expand_shape %collapse []
+ output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
+ return %expand : tensor<?x?x?x?xf32>
+}
+// CHECK: func @no_bubble_0d_tensor_reshapes
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
+// CHECK: return %[[EXPAND]]
More information about the Mlir-commits
mailing list