[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Dec 1 23:29:05 PST 2025
https://github.com/kdmitry1 updated https://github.com/llvm/llvm-project/pull/170037
>From fac3e3db8ecc9ef2557bc85b5829b3f22a1bb2a8 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <dmitry.kaptsenel at mobileye.com>
Date: Sun, 30 Nov 2025 14:42:35 +0200
Subject: [PATCH 1/4] [mlir] Fold memref.cast static-to-dynamic to
memref.expand_shape
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 80 ++++++++++++++++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 84 ++++++++++++++++++++++
2 files changed, 163 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..49dc23b702875 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,89 @@ LogicalResult ExpandShapeOp::verify() {
return success();
}
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ auto cast = op.getSrc().getDefiningOp<CastOp>();
+ if (!cast)
+ return failure();
+
+ if (!CastOp::canFoldIntoConsumerOp(cast))
+ return failure();
+
+ auto originalOutputShape = op.getMixedOutputShape();
+ auto newOutputShape = originalOutputShape;
+ SmallVector<int64_t> newOutputShapeSizes;
+ SmallVector<Value> newOperands;
+
+ // Convert output shape dims from dynamic to static where possible.
+ for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+ auto dimVal = dimSize.dyn_cast<Value>();
+ if (!dimVal) {
+ newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+ continue;
+ }
+
+ auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
+ if (!constOp) {
+ newOperands.push_back(dimVal);
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
+ continue;
+ }
+
+ newOutputShape[dimIdx] = constOp.getValue();
+ newOutputShapeSizes.push_back(
+ getConstantIntValue(constOp.getValue()).value());
+ }
+
+ if (newOperands.size() == op->getNumOperands())
+ return rewriter.notifyMatchFailure(
+ op, "no static-to-dynamic conversions found");
+
+ auto castSource = cast.getSource();
+ auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+ auto reassociationIndices = op.getReassociationIndices();
+ for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+ int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+ auto newOutputShapeSizesSlice =
+ ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+ int64_t newOutputDynCount =
+ llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+ if (castSourceDynCount != newOutputDynCount)
+ return rewriter.notifyMatchFailure(
+ op, "folding cast will result in changing dynamicity in "
+ "reassociation group");
+ }
+
+ auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
+ castSourceType, newOutputShapeSizes, reassociationIndices);
+
+ if (failed(newResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "could not compute new expanded type after folding cast");
+
+ if (*newResultTypeOrFailure == op.getResultType()) {
+ rewriter.modifyOpInPlace(
+ op, [&]() { op.getSrcMutable().assign(castSource); });
+ } else {
+ Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+ *newResultTypeOrFailure, castSource,
+ reassociationIndices, newOutputShape);
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+ }
+ return success();
+ }
+};
+
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ExpandShapeOpMemRefCastFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..c2d0376fc9723 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,90 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// -----
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %c4 = arith.constant 4 : index
+ %dim_ext = arith.divui %dim0 , %c4: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+ : memref<?x?xf32> into memref<1x?x1x?xf32>
+ %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+ : memref<?x?xf32> into memref<?x?x?x?xf32>
+ %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+ : memref<8x4xf32> into memref<2x1x4x4xf32>
+ return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: memref.cast
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %dim_ext = arith.divui %dim0 , %arg1: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
>From c33223075ccfae2e8a57ed19ee0af702c14bd653 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Mon, 1 Dec 2025 13:19:28 +0200
Subject: [PATCH 2/4] Updated newOutputShape building loop according to
Matthias Springer
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 23 ++++++++---------------
1 file changed, 8 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 49dc23b702875..11bfc99320644 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2517,29 +2517,22 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
if (!CastOp::canFoldIntoConsumerOp(cast))
return failure();
- auto originalOutputShape = op.getMixedOutputShape();
- auto newOutputShape = originalOutputShape;
+ SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
+ SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
SmallVector<int64_t> newOutputShapeSizes;
SmallVector<Value> newOperands;
// Convert output shape dims from dynamic to static where possible.
for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
- auto dimVal = dimSize.dyn_cast<Value>();
- if (!dimVal) {
- newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+ auto sizeOpt = getConstantIntValue(dimSize);
+ if (sizeOpt.has_value()) {
+ newOutputShapeSizes.push_back(sizeOpt.value());
+ newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
continue;
}
- auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
- if (!constOp) {
- newOperands.push_back(dimVal);
- newOutputShapeSizes.push_back(ShapedType::kDynamic);
- continue;
- }
-
- newOutputShape[dimIdx] = constOp.getValue();
- newOutputShapeSizes.push_back(
- getConstantIntValue(constOp.getValue()).value());
+ newOperands.push_back(llvm::cast<Value>(dimSize));
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
}
if (newOperands.size() == op->getNumOperands())
>From a4de41b41a2c58f9a6110320d0f7b4255cd967fc Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Tue, 2 Dec 2025 08:57:20 +0200
Subject: [PATCH 3/4] Removed more autos. Made lit checks more explicit.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++++----
mlir/test/Dialect/MemRef/canonicalize.mlir | 25 ++++++++++++++++------
2 files changed, 24 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11bfc99320644..ba2cabe668f13 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2539,9 +2539,10 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
return rewriter.notifyMatchFailure(
op, "no static-to-dynamic conversions found");
- auto castSource = cast.getSource();
+ Value castSource = cast.getSource();
auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
- auto reassociationIndices = op.getReassociationIndices();
+ SmallVector<ReassociationIndices> reassociationIndices =
+ op.getReassociationIndices();
for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
auto newOutputShapeSizesSlice =
@@ -2554,8 +2555,9 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
"reassociation group");
}
- auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
- castSourceType, newOutputShapeSizes, reassociationIndices);
+ FailureOr<MemRefType> newResultTypeOrFailure =
+ ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
+ reassociationIndices);
if (failed(newResultTypeOrFailure))
return rewriter.notifyMatchFailure(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index c2d0376fc9723..641b9a0a8624c 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -553,6 +553,8 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
// CHECK-NOT: memref.cast
+// CHECK: memref.expand_shape {{.*}} output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK-NOT: memref.cast
// CHECK: return
func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
%0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
@@ -570,6 +572,8 @@ func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32
// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
// CHECK-NOT: memref.cast
+// CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK-NOT: memref.cast
// CHECK: return
func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
%0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
@@ -587,6 +591,8 @@ func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>)
// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
// CHECK-NOT: memref.cast
+// CHECK: memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK-NOT: memref.cast
// CHECK: return
func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
%0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
@@ -603,9 +609,10 @@ func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>
// -----
// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
-// CHECK: memref.cast
-// CHECK: memref.expand_shape
-// CHECK: return
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
// CHECK: }
func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
%0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
@@ -617,10 +624,14 @@ func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4
// -----
// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
-// CHECK: memref.cast
-// CHECK: memref.expand_shape
-// CHECK: memref.cast
-// CHECK: return
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> {
+// CHECK: %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref<?x4xf32>
+// CHECK: %[[DIVUI_0:.*]] = arith.divui %[[CONSTANT_0]], %[[ARG1]] : index
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
+// CHECK: %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+// CHECK: return %[[CAST_1]] : memref<2x1x4x4xf32>
// CHECK: }
func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
%0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
>From 34c9711280120555c516a482687ca61cdea9c160 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Tue, 2 Dec 2025 09:24:13 +0200
Subject: [PATCH 4/4] Allow fold when single dynamic dim is expanded to
multiple dynamic
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 7 +++----
mlir/test/Dialect/MemRef/canonicalize.mlir | 20 ++++++++++++++++++++
2 files changed, 23 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ba2cabe668f13..90b7a866ba6d1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2544,12 +2544,11 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
SmallVector<ReassociationIndices> reassociationIndices =
op.getReassociationIndices();
for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
- int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
auto newOutputShapeSizesSlice =
ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
- int64_t newOutputDynCount =
- llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
- if (castSourceDynCount != newOutputDynCount)
+ bool newOutputDynamic =
+ llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
+ if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
return rewriter.notifyMatchFailure(
op, "folding cast will result in changing dynamicity in "
"reassociation group");
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 641b9a0a8624c..854c8ba0597e1 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -608,6 +608,26 @@ func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>
// -----
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_multiple(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> {
+// CHECK-NOT: memref.cast
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
+// CHECK-NOT: memref.cast
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x?x?xf32>
+// CHECK: }
+func.func @fold_memref_expand_static_to_dynamic_multiple(%arg0 : memref<8x?xf32>, %arg1 : index, %arg2 : index) -> memref<8x1x?x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%dim0, 1, %arg1, %arg2]
+ : memref<?x?xf32> into memref<?x1x?x?xf32>
+ %2 = memref.cast %1 : memref<?x1x?x?xf32> to memref<8x1x?x?xf32>
+ return %2 : memref<8x1x?x?xf32>
+}
+
+// -----
+
// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
More information about the Mlir-commits
mailing list