[Mlir-commits] [mlir] [mlir][tensor][memref] Enhance collapse(expand(src)) canonicalization pattern. (PR #145995)
Han-Chung Wang
llvmlistbot at llvm.org
Thu Jun 26 19:06:59 PDT 2025
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/145995
>From 76eeb3633b6ac40e1e6bdeb0b4bd0efe2d59b214 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 26 Jun 2025 16:54:16 -0700
Subject: [PATCH 1/2] [mlir][tensor][memref] Enhance collapse(expand(src))
canonicalization pattern.
The expand_shape op takes dynamic output value, and we need to take it
into account when we compose the op. Otherwise, it fails to create the
new expand_shape op.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../mlir/Dialect/Utils/ReshapeOpsUtils.h | 37 ++++++++++++++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 18 +++++++++
mlir/test/Dialect/Tensor/canonicalize.mlir | 18 +++++++++
3 files changed, 72 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 61c2a50e514ca..7f946f739baf9 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@@ -305,8 +306,42 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
} else if (srcRank < resultRank) {
+ // Compute the dynamic output shape for the new expand_shape op.
+ Location loc = collapseOp.getLoc();
+ SmallVector<OpFoldResult> origOutputShape =
+ expandOp.getMixedOutputShape();
+ SmallVector<OpFoldResult> newOutputShape;
+ for (auto indices : collapseOp.getReassociationIndices()) {
+ int64_t numStaticElems = 1;
+ SmallVector<Value> dynamicSizes;
+ for (auto idx : indices) {
+ OpFoldResult size = origOutputShape[idx];
+ if (auto maybeCst = getConstantIntValue(size)) {
+ numStaticElems *= maybeCst.value();
+ continue;
+ }
+ dynamicSizes.push_back(cast<Value>(size));
+ }
+ if (dynamicSizes.empty()) {
+ newOutputShape.push_back(rewriter.getIndexAttr(numStaticElems));
+ continue;
+ }
+
+ // There is at least one dynamic size, so we can intialize `result` to
+ // the first dynamic size.
+ Value result = dynamicSizes[0];
+ for (auto v : llvm::drop_begin(dynamicSizes))
+ result = rewriter.create<arith::MulIOp>(loc, result, v);
+ if (numStaticElems != 1) {
+ result = rewriter.create<arith::MulIOp>(
+ loc, result,
+ rewriter.create<arith::ConstantIndexOp>(loc, numStaticElems));
+ }
+ newOutputShape.push_back(result);
+ }
rewriter.replaceOpWithNewOp<ExpandOpTy>(
- collapseOp, resultType, expandOp.getSrc(), composedReassociation);
+ collapseOp, resultType, expandOp.getSrc(), composedReassociation,
+ newOutputShape);
} else {
// Collapses/expansions that do not change the rank are not allowed. Use
// a cast instead.
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7a267ae8a2c95..decc85a9af3c9 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -466,6 +466,24 @@ func.func @compose_collapse_of_collapse(%arg0 : memref<?x?x?x?x?xf32>)
// -----
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %arg1: index, %arg2: index) -> memref<8x?x?xf16> {
+ %expanded = memref.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : memref<?xf16> into memref<4x2x?x?x32xf16>
+ %collapsed = memref.collapse_shape %expanded [[0, 1], [2], [3, 4]] : memref<4x2x?x?x32xf16> into memref<8x?x?xf16>
+ return %collapsed : memref<8x?x?xf16>
+}
+// CHECK: func @compose_collapse_of_expand_partially_dynamic
+// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
+// CHECK-SAME: [0, 1, 2]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @do_not_compose_collapse_of_expand_non_identity_layout(
%arg0: memref<?x?xf32, strided<[?, 1], offset: 0>>, %sz0: index, %sz1: index)
-> memref<?xf32, strided<[?], offset: 0>> {
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..ed87bdafe80c9 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1243,6 +1243,24 @@ func.func @compose_collapse_of_expand_1D(%arg0 : tensor<2048xf32>)
// -----
+func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %arg1: index, %arg2: index) -> tensor<8x?x?xf16> {
+ %expanded = tensor.expand_shape %arg0 [[0, 1, 2, 3, 4]] output_shape [4, 2, %arg1, %arg2, 32] : tensor<?xf16> into tensor<4x2x?x?x32xf16>
+ %collapsed = tensor.collapse_shape %expanded [[0, 1], [2], [3, 4]] : tensor<4x2x?x?x32xf16> into tensor<8x?x?xf16>
+ return %collapsed : tensor<8x?x?xf16>
+}
+// CHECK: func @compose_collapse_of_expand_partially_dynamic
+// CHECK-SAME: %[[SRC:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C32:.+]] = arith.constant 32
+// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
+// CHECK-SAME: [0, 1, 2]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
func.func @compose_expand_of_collapse_0_rank_to_expand(%arg0 : tensor<1x1x1xf32>)
-> tensor<1x1x1x1xf32> {
%0 = tensor.collapse_shape %arg0 []
>From abaaa94c124707bb57c55f41a4109a349b4c97cf Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Thu, 26 Jun 2025 19:06:44 -0700
Subject: [PATCH 2/2] address comments
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 11 ++++++-----
mlir/test/Dialect/MemRef/canonicalize.mlir | 4 ++--
mlir/test/Dialect/Tensor/canonicalize.mlir | 4 ++--
3 files changed, 10 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 7f946f739baf9..704e39e908841 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -311,12 +311,13 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
SmallVector<OpFoldResult> origOutputShape =
expandOp.getMixedOutputShape();
SmallVector<OpFoldResult> newOutputShape;
- for (auto indices : collapseOp.getReassociationIndices()) {
+ for (const ReassociationIndices &indices :
+ collapseOp.getReassociationIndices()) {
int64_t numStaticElems = 1;
SmallVector<Value> dynamicSizes;
- for (auto idx : indices) {
+ for (int64_t idx : indices) {
OpFoldResult size = origOutputShape[idx];
- if (auto maybeCst = getConstantIntValue(size)) {
+ if (std::optional<int64_t> maybeCst = getConstantIntValue(size)) {
numStaticElems *= maybeCst.value();
continue;
}
@@ -327,10 +328,10 @@ struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
continue;
}
- // There is at least one dynamic size, so we can intialize `result` to
+ // There is at least one dynamic size, so we can initialize `result` to
// the first dynamic size.
Value result = dynamicSizes[0];
- for (auto v : llvm::drop_begin(dynamicSizes))
+ for (Value v : llvm::drop_begin(dynamicSizes))
result = rewriter.create<arith::MulIOp>(loc, result, v);
if (numStaticElems != 1) {
result = rewriter.create<arith::MulIOp>(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index decc85a9af3c9..a91e54a126100 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -476,10 +476,10 @@ func.func @compose_collapse_of_expand_partially_dynamic(%arg0: memref<?xf16>, %a
// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C32:.+]] = arith.constant 32
-// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
// CHECK: %[[RESULT:.+]] = memref.expand_shape %[[SRC]]
// CHECK-SAME: [0, 1, 2]
-// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]]
// CHECK: return %[[RESULT]]
// -----
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index ed87bdafe80c9..3f9236095138b 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1253,10 +1253,10 @@ func.func @compose_collapse_of_expand_partially_dynamic(%arg0: tensor<?xf16>, %a
// CHECK-SAME: %[[ORIG_D2:.[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ORIG_D3:.[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C32:.+]] = arith.constant 32
-// CHECK: %[[NEW_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
+// CHECK: %[[COLLAPSED_D2:.+]] = arith.muli %[[ORIG_D3]], %[[C32]]
// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SRC]]
// CHECK-SAME: [0, 1, 2]
-// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[NEW_D2]]]
+// CHECK-SAME: output_shape [8, %[[ORIG_D2]], %[[COLLAPSED_D2]]]
// CHECK: return %[[RESULT]]
// -----
More information about the Mlir-commits
mailing list