[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Dec 21 06:54:20 PST 2025
https://github.com/kdmitry1 updated https://github.com/llvm/llvm-project/pull/170037
>From fac3e3db8ecc9ef2557bc85b5829b3f22a1bb2a8 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <dmitry.kaptsenel at mobileye.com>
Date: Sun, 30 Nov 2025 14:42:35 +0200
Subject: [PATCH 1/8] [mlir] Fold memref.cast static-to-dynamic to
memref.expand_shape
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 80 ++++++++++++++++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 84 ++++++++++++++++++++++
2 files changed, 163 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..49dc23b702875 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,89 @@ LogicalResult ExpandShapeOp::verify() {
return success();
}
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ auto cast = op.getSrc().getDefiningOp<CastOp>();
+ if (!cast)
+ return failure();
+
+ if (!CastOp::canFoldIntoConsumerOp(cast))
+ return failure();
+
+ auto originalOutputShape = op.getMixedOutputShape();
+ auto newOutputShape = originalOutputShape;
+ SmallVector<int64_t> newOutputShapeSizes;
+ SmallVector<Value> newOperands;
+
+ // Convert output shape dims from dynamic to static where possible.
+ for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+ auto dimVal = dimSize.dyn_cast<Value>();
+ if (!dimVal) {
+ newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+ continue;
+ }
+
+ auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
+ if (!constOp) {
+ newOperands.push_back(dimVal);
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
+ continue;
+ }
+
+ newOutputShape[dimIdx] = constOp.getValue();
+ newOutputShapeSizes.push_back(
+ getConstantIntValue(constOp.getValue()).value());
+ }
+
+ if (newOperands.size() == op->getNumOperands())
+ return rewriter.notifyMatchFailure(
+ op, "no static-to-dynamic conversions found");
+
+ auto castSource = cast.getSource();
+ auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+ auto reassociationIndices = op.getReassociationIndices();
+ for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+ int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+ auto newOutputShapeSizesSlice =
+ ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+ int64_t newOutputDynCount =
+ llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+ if (castSourceDynCount != newOutputDynCount)
+ return rewriter.notifyMatchFailure(
+ op, "folding cast will result in changing dynamicity in "
+ "reassociation group");
+ }
+
+ auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
+ castSourceType, newOutputShapeSizes, reassociationIndices);
+
+ if (failed(newResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "could not compute new expanded type after folding cast");
+
+ if (*newResultTypeOrFailure == op.getResultType()) {
+ rewriter.modifyOpInPlace(
+ op, [&]() { op.getSrcMutable().assign(castSource); });
+ } else {
+ Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+ *newResultTypeOrFailure, castSource,
+ reassociationIndices, newOutputShape);
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+ }
+ return success();
+ }
+};
+
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ExpandShapeOpMemRefCastFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..c2d0376fc9723 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,90 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// -----
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %c4 = arith.constant 4 : index
+ %dim_ext = arith.divui %dim0 , %c4: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+ : memref<?x?xf32> into memref<1x?x1x?xf32>
+ %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+ : memref<?x?xf32> into memref<?x?x?x?xf32>
+ %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+ : memref<8x4xf32> into memref<2x1x4x4xf32>
+ return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: memref.cast
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %dim_ext = arith.divui %dim0 , %arg1: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
>From 0127b7363771bab53eaac0f1c23988f221441b01 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Mon, 1 Dec 2025 13:19:28 +0200
Subject: [PATCH 2/8] Updated newOutputShape building loop according to
Matthias Springer
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 23 ++++++++---------------
1 file changed, 8 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 49dc23b702875..11bfc99320644 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2517,29 +2517,22 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
if (!CastOp::canFoldIntoConsumerOp(cast))
return failure();
- auto originalOutputShape = op.getMixedOutputShape();
- auto newOutputShape = originalOutputShape;
+ SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
+ SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
SmallVector<int64_t> newOutputShapeSizes;
SmallVector<Value> newOperands;
// Convert output shape dims from dynamic to static where possible.
for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
- auto dimVal = dimSize.dyn_cast<Value>();
- if (!dimVal) {
- newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+ auto sizeOpt = getConstantIntValue(dimSize);
+ if (sizeOpt.has_value()) {
+ newOutputShapeSizes.push_back(sizeOpt.value());
+ newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
continue;
}
- auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
- if (!constOp) {
- newOperands.push_back(dimVal);
- newOutputShapeSizes.push_back(ShapedType::kDynamic);
- continue;
- }
-
- newOutputShape[dimIdx] = constOp.getValue();
- newOutputShapeSizes.push_back(
- getConstantIntValue(constOp.getValue()).value());
+ newOperands.push_back(llvm::cast<Value>(dimSize));
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
}
if (newOperands.size() == op->getNumOperands())
>From 35ffdbbdc1b2fb3f375271e597a8e5bb9527c833 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Tue, 2 Dec 2025 08:57:20 +0200
Subject: [PATCH 3/8] Removed more autos. Made lit checks more explicit.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++++----
mlir/test/Dialect/MemRef/canonicalize.mlir | 25 ++++++++++++++++------
2 files changed, 24 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11bfc99320644..ba2cabe668f13 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2539,9 +2539,10 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
return rewriter.notifyMatchFailure(
op, "no static-to-dynamic conversions found");
- auto castSource = cast.getSource();
+ Value castSource = cast.getSource();
auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
- auto reassociationIndices = op.getReassociationIndices();
+ SmallVector<ReassociationIndices> reassociationIndices =
+ op.getReassociationIndices();
for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
auto newOutputShapeSizesSlice =
@@ -2554,8 +2555,9 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
"reassociation group");
}
- auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
- castSourceType, newOutputShapeSizes, reassociationIndices);
+ FailureOr<MemRefType> newResultTypeOrFailure =
+ ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
+ reassociationIndices);
if (failed(newResultTypeOrFailure))
return rewriter.notifyMatchFailure(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index c2d0376fc9723..641b9a0a8624c 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -553,6 +553,8 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
// CHECK-NOT: memref.cast
+// CHECK: memref.expand_shape {{.*}} output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK-NOT: memref.cast
// CHECK: return
func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
%0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
@@ -570,6 +572,8 @@ func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32
// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
// CHECK-NOT: memref.cast
+// CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK-NOT: memref.cast
// CHECK: return
func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
%0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
@@ -587,6 +591,8 @@ func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>)
// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
// CHECK-NOT: memref.cast
+// CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK-NOT: memref.cast
// CHECK: return
func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
%0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
@@ -603,9 +609,10 @@ func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>
// -----
// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
-// CHECK: memref.cast
-// CHECK: memref.expand_shape
-// CHECK: return
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
// CHECK: }
func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
%0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
@@ -617,10 +624,14 @@ func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4
// -----
// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
-// CHECK: memref.cast
-// CHECK: memref.expand_shape
-// CHECK: memref.cast
-// CHECK: return
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref<?x4xf32>
+// CHECK: %[[DIVUI_0:.*]] = arith.divui %[[CONSTANT_0]], %[[ARG1]] : index
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
+// CHECK: %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+// CHECK: return %[[CAST_1]] : memref<2x1x4x4xf32>
// CHECK: }
func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
%0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
>From 4de35571b1615e771d94cc8ba7b5becfa608a58a Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Tue, 2 Dec 2025 09:24:13 +0200
Subject: [PATCH 4/8] Allow fold when single dynamic dim is expanded to
multiple dynamic
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 7 +++----
mlir/test/Dialect/MemRef/canonicalize.mlir | 20 ++++++++++++++++++++
2 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ba2cabe668f13..90b7a866ba6d1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2544,12 +2544,11 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
SmallVector<ReassociationIndices> reassociationIndices =
op.getReassociationIndices();
for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
- int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
auto newOutputShapeSizesSlice =
ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
- int64_t newOutputDynCount =
- llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
- if (castSourceDynCount != newOutputDynCount)
+ bool newOutputDynamic =
+ llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
+ if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
return rewriter.notifyMatchFailure(
op, "folding cast will result in changing dynamicity in "
"reassociation group");
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 641b9a0a8624c..854c8ba0597e1 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -608,6 +608,26 @@ func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>
// -----
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_multiple(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> {
+// CHECK-NOT: memref.cast
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
+// CHECK-NOT: memref.cast
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x?x?xf32>
+// CHECK: }
+func.func @fold_memref_expand_static_to_dynamic_multiple(%arg0 : memref<8x?xf32>, %arg1 : index, %arg2 : index) -> memref<8x1x?x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%dim0, 1, %arg1, %arg2]
+ : memref<?x?xf32> into memref<?x1x?x?xf32>
+ %2 = memref.cast %1 : memref<?x1x?x?xf32> to memref<8x1x?x?xf32>
+ return %2 : memref<8x1x?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
>From 41f3d98a1eb47750c8379e4e8cda3a30ff5fe982 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Wed, 3 Dec 2025 08:52:08 +0200
Subject: [PATCH 5/8] Make lit checks fully explicit instead of using CHECK-NOT
and rename constants
---
mlir/test/Dialect/MemRef/canonicalize.mlir | 34 ++++++++++++----------
1 file changed, 19 insertions(+), 15 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 854c8ba0597e1..f241a78d022e4 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,11 +551,11 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// -----
-// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
-// CHECK-NOT: memref.cast
-// CHECK: memref.expand_shape {{.*}} output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
-// CHECK-NOT: memref.cast
-// CHECK: return
+// CHECK-LABEL: func.func @fold_memref_expand_with_static_to_dynamic_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
+// CHECK: }
func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
%0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
%c0 = arith.constant 0 : index
@@ -571,10 +571,12 @@ func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32
// -----
// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
-// CHECK-NOT: memref.cast
-// CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
-// CHECK-NOT: memref.cast
-// CHECK: return
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
+// CHECK: }
func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
%0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
%c0 = arith.constant 0 : index
@@ -590,10 +592,12 @@ func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>)
// -----
// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
-// CHECK-NOT: memref.cast
-// CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
-// CHECK-NOT: memref.cast
-// CHECK: return
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
+// CHECK: }
func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
%0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
%c0 = arith.constant 0 : index
@@ -646,9 +650,9 @@ func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4
// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>,
// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> {
-// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref<?x4xf32>
-// CHECK: %[[DIVUI_0:.*]] = arith.divui %[[CONSTANT_0]], %[[ARG1]] : index
+// CHECK: %[[DIVUI_0:.*]] = arith.divui %[[C8]], %[[ARG1]] : index
// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
// CHECK: %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
// CHECK: return %[[CAST_1]] : memref<2x1x4x4xf32>
>From 9f93c62c21e32c0f8f4e16a22a7816b96d3d9091 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Wed, 3 Dec 2025 09:05:32 +0200
Subject: [PATCH 6/8] Another "auto" fix missed earlier
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 90b7a866ba6d1..34cccdbfc57c0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2524,7 +2524,7 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
// Convert output shape dims from dynamic to static where possible.
for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
- auto sizeOpt = getConstantIntValue(dimSize);
+ std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
if (sizeOpt.has_value()) {
newOutputShapeSizes.push_back(sizeOpt.value());
newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
>From c1d1f06817a11c1f71d68d12e86664d71a1fa87c Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Wed, 3 Dec 2025 09:32:05 +0200
Subject: [PATCH 7/8] Even more simplify and add support for layout-only
casting
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 14 ++++----------
mlir/test/Dialect/MemRef/canonicalize.mlir | 15 +++++++++++++++
2 files changed, 19 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 34cccdbfc57c0..6fda6d8a8de52 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2520,25 +2520,19 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
SmallVector<int64_t> newOutputShapeSizes;
- SmallVector<Value> newOperands;
// Convert output shape dims from dynamic to static where possible.
for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
- if (sizeOpt.has_value()) {
- newOutputShapeSizes.push_back(sizeOpt.value());
- newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
+ if (!sizeOpt.has_value()) {
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
continue;
}
- newOperands.push_back(llvm::cast<Value>(dimSize));
- newOutputShapeSizes.push_back(ShapedType::kDynamic);
+ newOutputShapeSizes.push_back(sizeOpt.value());
+ newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
}
- if (newOperands.size() == op->getNumOperands())
- return rewriter.notifyMatchFailure(
- op, "no static-to-dynamic conversions found");
-
Value castSource = cast.getSource();
auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
SmallVector<ReassociationIndices> reassociationIndices =
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index f241a78d022e4..5d1c2a0ef28f6 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -670,6 +670,21 @@ func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0
// -----
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_layout(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<8x1x4xf32> {
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x4xf32>
+// CHECK: }
+func.func @fold_memref_expand_static_to_dynamic_layout(%arg0 : memref<8x4xf32>) -> memref<8x1x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<8x4xf32, strided<[?, ?], offset: ?>>
+ %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [8, 1, 4]
+ : memref<8x4xf32, strided<[?, ?], offset: ?>> into memref<8x1x4xf32, strided<[?,?,?], offset: ?>>
+ %2 = memref.cast %1 : memref<8x1x4xf32, strided<[?,?,?], offset: ?>> to memref<8x1x4xf32>
+ return %2 : memref<8x1x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
>From fc3ca6ab31b781f8457c2adfce5c526b378a795c Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Thu, 18 Dec 2025 08:40:00 +0200
Subject: [PATCH 8/8] Updated lit tests according to @adam-smnk comments
---
mlir/test/Dialect/MemRef/canonicalize.mlir | 40 ++++++++++++----------
1 file changed, 22 insertions(+), 18 deletions(-)
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 5d1c2a0ef28f6..47e0389d41abf 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -553,19 +553,17 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// CHECK-LABEL: func.func @fold_memref_expand_with_static_to_dynamic_cast(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<2x1x4x4xf32> {
-// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME: output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
// CHECK: }
-func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
- %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
- %c0 = arith.constant 0 : index
- %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
- %c4 = arith.constant 4 : index
- %dim_ext = arith.divui %dim0 , %c4: index
- %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0: memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+ %c2 = arith.constant 2 : index
+ %cast = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %expand_shape = memref.expand_shape %cast [[0, 1, 2], [3]] output_shape [%c2, 1, 4, 4]
: memref<?x4xf32> into memref<?x1x4x4xf32>
- %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
- return %2 : memref<2x1x4x4xf32>
+ %cast_0 = memref.cast %expand_shape : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %cast_0 : memref<2x1x4x4xf32>
}
// -----
@@ -574,7 +572,8 @@ func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32
// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
-// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME: output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
// CHECK: }
func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
@@ -591,14 +590,15 @@ func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>)
// -----
-// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial_with_arith_const_as_dim(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
-// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME: output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
// CHECK: }
-func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+func.func @fold_memref_expand_static_to_dynamic_partial_with_arith_const_as_dim(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
%0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -616,7 +616,8 @@ func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>
// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>,
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> {
// CHECK-NOT: memref.cast
-// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME: output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
// CHECK-NOT: memref.cast
// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x?x?xf32>
// CHECK: }
@@ -635,7 +636,8 @@ func.func @fold_memref_expand_static_to_dynamic_multiple(%arg0 : memref<8x?xf32>
// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
-// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME: output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
// CHECK: }
func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
@@ -653,7 +655,8 @@ func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref<?x4xf32>
// CHECK: %[[DIVUI_0:.*]] = arith.divui %[[C8]], %[[ARG1]] : index
-// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME: output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
// CHECK: %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
// CHECK: return %[[CAST_1]] : memref<2x1x4x4xf32>
// CHECK: }
@@ -672,7 +675,8 @@ func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0
// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_layout(
// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<8x1x4xf32> {
-// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]]
+// CHECK-SAME: output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x4xf32>
// CHECK: }
func.func @fold_memref_expand_static_to_dynamic_layout(%arg0 : memref<8x4xf32>) -> memref<8x1x4xf32> {
More information about the Mlir-commits
mailing list