[Mlir-commits] [mlir] [mlir][Tensor] Generalize the pattern to swap `tensor.collapse_shape` -> `tensor.expand_shape`. (PR #133819)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 31 16:09:52 PDT 2025


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/133819

The current patterns compared the reassocation indices for the two ops and failed if neither of them were of size 1. This patch relaxes this restriction by handling a new case where the reassociation indices might be of the same size.

Also generalizes to cases where when generating the swapped `tensor.expand_shape` -> `tensor.collapse_shape` if one of them is degenerate, those are not generated.

>From e768306e072523519a22f328b3284ddd99b6f84a Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mravisha at amd.com>
Date: Mon, 31 Mar 2025 15:50:39 -0500
Subject: [PATCH] [mlir][Tensor] Generalize the pattern to swap
 `tensor.collapse_shape` -> `tensor.expand_shape`.

The current patterns compared the reassocation indices for the two ops
and failed if neither of them were of size 1. This patch relaxes this
restriction by handling a new case where the reassociation indices
might be of the same size.

Also generalizes to cases where when generating the swapped
`tensor.expand_shape` -> `tensor.collapse_shape` if one of them is
degenerate, those are not generated.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../Tensor/Transforms/ReshapePatterns.cpp     | 113 ++++++++++++++----
 mlir/test/Dialect/Tensor/bubble-reshapes.mlir |  63 +++++++++-
 2 files changed, 150 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index acedf51d0e240..2542039267f01 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -166,10 +166,39 @@ struct BubbleUpExpandThroughParallelCollapse
       return failure();
     }
 
-    // Reshapes are parallel to each other if none of the reassociation indices
-    // have greater than 1 index for both reshapes.
+    // Reshapes are parallel to each other (by construction the number of
+    // reassociations specified in the collapse and expand are the same), if at
+    // any position
+    // 1. either the reassociation indices are of the same size, or
+    // 2. either the reassociation in the collapse or the expand is of size 1.
+    ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
+    ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
     for (auto [expandReassociation, collapseReassociation] :
          llvm::zip_equal(expandReInds, collapseReInds)) {
+      if (collapseReassociation.size() == expandReassociation.size()) {
+        // Even if the reassociations are the same, the collapse/expand should
+        // result in the same dimensions. i.e  4x8x2 into 64 should be expanded
+        // into 4x8x2 again. In presense of dynamic dimensions one can only
+        // verify "equality" when there is only one dynamic dimension present,
+        // and all other static dimensions are equal.
+        ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
+            collapseReassociation.front(), collapseReassociation.size());
+        int64_t numCollapsedDynamic =
+            llvm::count_if(collapsedStaticShapes,
+                           [](int64_t d) { return ShapedType::isDynamic(d); });
+        ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
+            expandReassociation.front(), expandReassociation.size());
+        int64_t numExpandedDynamic =
+            llvm::count_if(expandedStaticShapes,
+                           [](int64_t d) { return ShapedType::isDynamic(d); });
+        if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
+            collapsedStaticShapes != expandedStaticShapes) {
+          return failure();
+        }
+        continue;
+      }
+      // If the reassociations are not same, one or the other needs to be of
+      // size one.
       if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
         return failure();
     }
@@ -177,33 +206,61 @@ struct BubbleUpExpandThroughParallelCollapse
     // Compute new reassociation indices and expanded/collaped shapes.
     SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
     Location loc = expandOp->getLoc();
-    SmallVector<OpFoldResult> collapseSizes =
+    SmallVector<OpFoldResult> sourceSizes =
         tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
-    SmallVector<OpFoldResult> expandSizes(getMixedValues(
-        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
     SmallVector<OpFoldResult> newExpandSizes;
-    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
-    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+
+    int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
+            resultSizeIndex = 0;
+
+    for (size_t idx = 0, idx_end = collapseReInds.size(); idx < idx_end;
+         idx++) {
+      auto collapseReassociation = collapseReInds[idx];
+      auto expandReassociation = expandReInds[idx];
+
+      // Case 1. The reassociations are same in the collapse producer
+      // and expand consumer. In the swapped expand, each of the final
+      // dimensions are kept as is in the expand and the collapse. So,
+      // for every element in the `ReassocationIndices` vector add a new
+      // `ReassociationIndices` vector for the swapped expand and collapse
+      // (of size 1).
+      if (collapseReassociation.size() == expandReassociation.size()) {
+        for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+          newCollapseReInds.push_back({newCollapseIndex++});
+          newExpandReInds.push_back({newExpandIndex++});
+          newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
+          sourceSizeIndex++;
+        }
+        continue;
+      }
+
+      // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
+      // in the expand is of size == 1). In this case, the original dimensions
+      // are preserved on expansion and collapsed subsequently.
       if (collapseReassociation.size() != 1) {
         ReassociationIndices newCollapseReassociation;
         for (size_t i = 0; i < collapseReassociation.size(); ++i) {
-          newCollapseReassociation.push_back(index);
-          newExpandReInds.push_back({index++});
-          newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+          newCollapseReassociation.push_back(newCollapseIndex++);
+          newExpandReInds.push_back({newExpandIndex++});
+          newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
         }
+        resultSizeIndex++;
         newCollapseReInds.push_back(newCollapseReassociation);
-        expandIndex++;
         continue;
       }
+
+      // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
+      // in the collapse is of size == 1). In this case, the expansion happens
+      // first and the expanded dimensions are preserved on collapse.
       ReassociationIndices newExpandReassociation;
-      auto expandReassociation = expandReInds[idx];
       for (size_t i = 0; i < expandReassociation.size(); ++i) {
-        newExpandReassociation.push_back(index);
-        newCollapseReInds.push_back({index++});
-        newExpandSizes.push_back(expandSizes[expandIndex++]);
+        newExpandReassociation.push_back(newExpandIndex++);
+        newCollapseReInds.push_back({newCollapseIndex++});
+        newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
       }
       newExpandReInds.push_back(newExpandReassociation);
-      collapseIndex++;
+      sourceSizeIndex++;
     }
 
     // Swap reshape order.
@@ -211,11 +268,25 @@ struct BubbleUpExpandThroughParallelCollapse
     SmallVector<int64_t> staticSizes;
     dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
     auto expandResultType = expandOp.getResultType().clone(staticSizes);
-    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
-        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
-        newExpandSizes);
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        expandOp, newExpand.getResult(), newCollapseReInds);
+    Value newCollapseSrc = collapseOp.getSrc();
+    // If the number of reassociation indices in the new `expand_shape` op
+    // matches the number of dimensions of the result, then the expand_shape
+    // is a no-op.
+    if (newExpandReInds.size() != newExpandSizes.size()) {
+      newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
+          loc, expandResultType, newCollapseSrc, newExpandReInds,
+          newExpandSizes);
+    }
+
+    // If the number of reassociation indices in the new `collapse_shape` op
+    // matches the number of dimensions of the source, then the collapse_shape
+    // is a no-op.
+    Value replacement = newCollapseSrc;
+    if (newCollapseReInds.size() != newExpandSizes.size()) {
+      replacement = rewriter.create<tensor::CollapseShapeOp>(
+          loc, newCollapseSrc, newCollapseReInds);
+    }
+    rewriter.replaceOp(expandOp, replacement);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
index eeed794884942..1a277af96c6f3 100644
--- a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -48,14 +48,67 @@ func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %
 
 // -----
 
-func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
-  %collapse = tensor.collapse_shape %arg0 [] : tensor<?xf32> into tensor<f32>
+func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> {
+  %collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor<f32>
   %expand = tensor.expand_shape %collapse []
-              output_shape [%s0, %s1, %s2, %s3] : tensor<f32> into tensor<?x?x?x?xf32>
-  return %expand : tensor<?x?x?x?xf32>
+              output_shape [1, 1, 1] : tensor<f32> into tensor<1x1x1xf32>
+  return %expand : tensor<1x1x1xf32>
 }
 //      CHECK: func @no_bubble_0d_tensor_reshapes
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<?xf32>
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x1xf32>
 //      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}]
 //      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}]
 //      CHECK:   return %[[EXPAND]]
+
+// -----
+
+// Test the case where the reassocation indices in the collapse and expand
+// are of same size.
+func.func @bubble_expand_match_non_unit_size_reassocation(
+      %arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]]
+      : tensor<4x?x4x32x4x?xf16> into tensor<?x128x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
+      : tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
+  return %expanded : tensor<4x?x4x128x?x32xf16>
+}
+//      CHECK: func @bubble_expand_match_non_unit_size_reassocation
+// CHECK-SAME:     %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16>
+// CHECK-SAME:     %[[ARG1:[a-zA-z0-9]+]]: index
+// CHECK-SAME:     %[[ARG2:[a-zA-z0-9]+]]: index
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK-SAME:       {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}}
+// CHECK-SAME:       [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32]
+//      CHECK:   %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
+// CHECK-SAME:       {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}}
+//      CHECK:   return %[[COLLAPSED]]
+
+// -----
+
+// Test the case where the trailing collapse isnt needed.
+func.func @no_collapse_generated(
+      %arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]]
+      : tensor<4x?x4x128x?xf16> into tensor<?x128x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32]
+      : tensor<?x128x?xf16> into tensor<4x?x4x128x?x32xf16>
+  return %expanded : tensor<4x?x4x128x?x32xf16>
+}
+//      CHECK: func @no_collapse_generated
+//      CHECK:   %[[EXPANDED:.+]] = tensor.expand_shape 
+//      CHECK:   return %[[EXPANDED]]
+
+// -----
+
+// Test the case where the trailing collapse isnt needed.
+func.func @no_expand_generated(
+      %arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> {
+  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]]
+      : tensor<4x?x4x128x?x?x?xf16> into tensor<?x128x?x?xf16>
+  %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3]
+      : tensor<?x128x?x?xf16> into tensor<4x?x4x128x?x?xf16>
+  return %expanded : tensor<4x?x4x128x?x?xf16>
+}
+//      CHECK: func @no_expand_generated
+//      CHECK:   %[[EXPANDED:.+]] = tensor.collapse_shape
+//      CHECK:   return %[[EXPANDED]]



More information about the Mlir-commits mailing list