[Mlir-commits] [mlir] [mlir][tensor][memref] Enhance collapse(expand(src)) canonicalization pattern. (PR #145995)

Han-Chung Wang llvmlistbot at llvm.org
Thu Jun 26 19:06:59 PDT 2025


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/145995

>From 76eeb3633b6ac40e1e6bdeb0b4bd0efe2d59b214 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 26 Jun 2025 16:54:16 -0700
Subject: [PATCH 1/2] [mlir][tensor][memref] Enhance collapse(expand(src))
 canonicalization pattern.

The expand_shape op takes dynamic output value, and we need to take it
into account when we compose the op. Otherwise, it fails to create the
new expand_shape op.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../mlir/Dialect/Utils/ReshapeOpsUtils.h      | 37 ++++++++++++++++++-
 mlir/test/Dialect/MemRef/canonicalize.mlir    | 18 +++++++++
 mlir/test/Dialect/Tensor/canonicalize.mlir    | 18 +++++++++
 3 files changed, 72 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c2a50e514ca..7f946f739baf9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
 #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
       rewriter.replaceOpWithNewOp<CollapseOpTy>(
           collapseOp, resultType, expandOp.getSrc(), composedReassociation);
     } else if (srcRank < resultRank) {
+      // Compute the dynamic output shape for the new expand_shape op.
+      Location loc = collapseOp.getLoc();
+      SmallVector<OpFoldResult> origOutputShape =
+          expandOp.getMixedOutputShape();
+      SmallVector<OpFoldResult> newOutputShape;
+      for (auto indices : collapseOp.getReassociationIndices()) {
+        int64_t numStaticElems = 1;
+        SmallVector<Value> dynamicSizes;
+        for (auto idx : indices) {
+          OpFoldResult size = origOutputShape[idx];
+          if (auto maybeCst = getConstantIntValue(size)) {
+            numStaticElems *= maybeCst.value();
+            continue;
+          }
+          dynamicSizes.push_back(cast<Value>(size));
+        }
+        if (dynamicSizes.empty()) {
+          newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
+          continue;
+        }
+
+        // There is at least one dynamic size, so we can intialize `result` to
+        // the first dynamic size.
+        Value result = dynamicSizes[0];
+        for (auto v : llvm::drop_begin(dynamicSizes))
+          result = rewriter.create<arith::MulIOp>(loc, result, v);
+        if (numStaticElems != 1) {
+          result = rewriter.create<arith::MulIOp>(
+              loc, result,
+              rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
+        }
+        newOutputShape.push_back(result);
+      }
       rewriter.replaceOpWithNewOp<ExpandOpTy>(
-          collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+          collapseOp, resultType, expandOp.getSrc(), composedReassociation,
+          newOutputShape);
     } else {
       // Collapses/expansions that do not change the rank are not allowed. Use
       // a cast instead.
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7a267ae8a2c95..decc85a9af3c9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
 
 // -----
 
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
+  %expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
+  %collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
+  return %collapsed : memref<8x?x?xf16>
+}
+//       CHECK: func @compose_collapse_of_expand_partially_dynamic
+//  CHECK-SAME:   %[[SRC:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
+//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
+//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
+//  CHECK-SAME:     [0, 1, 2]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @do_not_compose_collapse_of_expand_non_identity_layout(
     %arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
     -> memref<?xf32, strided<[?], offset: 0>> {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..ed87bdafe80c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
 
 // -----
 
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
+  %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
+  %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
+  return %collapsed : tensor<8x?x?xf16>
+}
+//       CHECK: func @compose_collapse_of_expand_partially_dynamic
+//  CHECK-SAME:   %[[SRC:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
+//  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
+//   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
+//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
+//  CHECK-SAME:     [0, 1, 2]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//       CHECK:   return %[[RESULT]]
+
+// -----
+
 func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
     -> tensor<1x1x1x1xf32> {
   %0 = tensor.collapse_shape %arg0 []

>From abaaa94c124707bb57c55f41a4109a349b4c97cf Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 26 Jun 2025 19:06:44 -0700
Subject: [PATCH 2/2] address comments

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 11 ++++++-----
 mlir/test/Dialect/MemRef/canonicalize.mlir        |  4 ++--
 mlir/test/Dialect/Tensor/canonicalize.mlir        |  4 ++--
 3 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 7f946f739baf9..704e39e908841 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -311,12 +311,13 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
       SmallVector<OpFoldResult> origOutputShape =
           expandOp.getMixedOutputShape();
       SmallVector<OpFoldResult> newOutputShape;
-      for (auto indices : collapseOp.getReassociationIndices()) {
+      for (const ReassociationIndices &indices :
+           collapseOp.getReassociationIndices()) {
         int64_t numStaticElems = 1;
         SmallVector<Value> dynamicSizes;
-        for (auto idx : indices) {
+        for (int64_t idx : indices) {
           OpFoldResult size = origOutputShape[idx];
-          if (auto maybeCst = getConstantIntValue(size)) {
+          if (std::optional<int64_t> maybeCst = getConstantIntValue(size)) {
             numStaticElems *= maybeCst.value();
             continue;
           }
@@ -327,10 +328,10 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
           continue;
         }
 
-        // There is at least one dynamic size, so we can intialize `result` to
+        // There is at least one dynamic size, so we can initialize `result` to
         // the first dynamic size.
         Value result = dynamicSizes[0];
-        for (auto v : llvm::drop_begin(dynamicSizes))
+        for (Value v : llvm::drop_begin(dynamicSizes))
           result = rewriter.create<arith::MulIOp>(loc, result, v);
         if (numStaticElems != 1) {
           result = rewriter.create<arith::MulIOp>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index decc85a9af3c9..a91e54a126100 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -476,10 +476,10 @@ func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %a
 //  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
 //  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
 //   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
-//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
 //       CHECK:   %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
 //  CHECK-SAME:     [0, 1, 2]
-//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]]
 //       CHECK:   return %[[RESULT]]
 
 // -----
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ed87bdafe80c9..3f9236095138b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1253,10 +1253,10 @@ func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %a
 //  CHECK-SAME:   %[[ORIG_D2:.[a-zA-Z0-9]+]]
 //  CHECK-SAME:   %[[ORIG_D3:.[a-zA-Z0-9]+]]
 //   CHECK-DAG:   %[[C32:.+]] = arith.constant 32
-//       CHECK:   %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+//       CHECK:   %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
 //       CHECK:   %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
 //  CHECK-SAME:     [0, 1, 2]
-//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+//  CHECK-SAME:     output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]]
 //       CHECK:   return %[[RESULT]]
 
 // -----



More information about the Mlir-commits mailing list