[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 1 23:29:05 PST 2025


https://github.com/kdmitry1 updated https://github.com/llvm/llvm-project/pull/170037

>From fac3e3db8ecc9ef2557bc85b5829b3f22a1bb2a8 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <dmitry.kaptsenel at mobileye.com>
Date: Sun, 30 Nov 2025 14:42:35 +0200
Subject: [PATCH 1/4] [mlir] Fold memref.cast static-to-dynamic to
 memref.expand_shape

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 80 ++++++++++++++++++++-
 mlir/test/Dialect/MemRef/canonicalize.mlir | 84 ++++++++++++++++++++++
 2 files changed, 163 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..49dc23b702875 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,89 @@ LogicalResult ExpandShapeOp::verify() {
   return success();
 }
 
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto cast = op.getSrc().getDefiningOp<CastOp>();
+    if (!cast)
+      return failure();
+
+    if (!CastOp::canFoldIntoConsumerOp(cast))
+      return failure();
+
+    auto originalOutputShape = op.getMixedOutputShape();
+    auto newOutputShape = originalOutputShape;
+    SmallVector<int64_t> newOutputShapeSizes;
+    SmallVector<Value> newOperands;
+
+    // Convert output shape dims from dynamic to static where possible.
+    for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+      auto dimVal = dimSize.dyn_cast<Value>();
+      if (!dimVal) {
+        newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+        continue;
+      }
+
+      auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
+      if (!constOp) {
+        newOperands.push_back(dimVal);
+        newOutputShapeSizes.push_back(ShapedType::kDynamic);
+        continue;
+      }
+
+      newOutputShape[dimIdx] = constOp.getValue();
+      newOutputShapeSizes.push_back(
+          getConstantIntValue(constOp.getValue()).value());
+    }
+
+    if (newOperands.size() == op->getNumOperands())
+      return rewriter.notifyMatchFailure(
+          op, "no static-to-dynamic conversions found");
+
+    auto castSource = cast.getSource();
+    auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+    auto reassociationIndices = op.getReassociationIndices();
+    for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+      int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+      auto newOutputShapeSizesSlice =
+          ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+      int64_t newOutputDynCount =
+          llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+      if (castSourceDynCount != newOutputDynCount)
+        return rewriter.notifyMatchFailure(
+            op, "folding cast will result in changing dynamicity in "
+                "reassociation group");
+    }
+
+    auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
+        castSourceType, newOutputShapeSizes, reassociationIndices);
+
+    if (failed(newResultTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "could not compute new expanded type after folding cast");
+
+    if (*newResultTypeOrFailure == op.getResultType()) {
+      rewriter.modifyOpInPlace(
+          op, [&]() { op.getSrcMutable().assign(castSource); });
+    } else {
+      Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+                                          *newResultTypeOrFailure, castSource,
+                                          reassociationIndices, newOutputShape);
+      rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+    }
+    return success();
+  }
+};
+
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      ExpandShapeOpMemRefCastFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..c2d0376fc9723 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,90 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
 
 // -----
 
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %c4 = arith.constant 4 : index
+  %dim_ext = arith.divui %dim0 , %c4: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+      : memref<?x?xf32> into memref<1x?x1x?xf32>
+  %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+      : memref<?x?xf32> into memref<?x?x?x?xf32>
+  %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK:           memref.cast
+// CHECK:           memref.expand_shape
+// CHECK:           return
+// CHECK:         }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+      : memref<8x4xf32> into memref<2x1x4x4xf32>
+  return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK:           memref.cast
+// CHECK:           memref.expand_shape
+// CHECK:           memref.cast
+// CHECK:           return
+// CHECK:         }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %dim_ext = arith.divui %dim0 , %arg1: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
 // CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]

>From c33223075ccfae2e8a57ed19ee0af702c14bd653 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Mon, 1 Dec 2025 13:19:28 +0200
Subject: [PATCH 2/4] Updated newOutputShape building loop according to
 Matthias Springer

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 23 ++++++++---------------
 1 file changed, 8 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 49dc23b702875..11bfc99320644 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2517,29 +2517,22 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
     if (!CastOp::canFoldIntoConsumerOp(cast))
       return failure();
 
-    auto originalOutputShape = op.getMixedOutputShape();
-    auto newOutputShape = originalOutputShape;
+    SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
+    SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
     SmallVector<int64_t> newOutputShapeSizes;
     SmallVector<Value> newOperands;
 
     // Convert output shape dims from dynamic to static where possible.
     for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
-      auto dimVal = dimSize.dyn_cast<Value>();
-      if (!dimVal) {
-        newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+      auto sizeOpt = getConstantIntValue(dimSize);
+      if (sizeOpt.has_value()) {
+        newOutputShapeSizes.push_back(sizeOpt.value());
+        newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
         continue;
       }
 
-      auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
-      if (!constOp) {
-        newOperands.push_back(dimVal);
-        newOutputShapeSizes.push_back(ShapedType::kDynamic);
-        continue;
-      }
-
-      newOutputShape[dimIdx] = constOp.getValue();
-      newOutputShapeSizes.push_back(
-          getConstantIntValue(constOp.getValue()).value());
+      newOperands.push_back(llvm::cast<Value>(dimSize));
+      newOutputShapeSizes.push_back(ShapedType::kDynamic);
     }
 
     if (newOperands.size() == op->getNumOperands())

>From a4de41b41a2c58f9a6110320d0f7b4255cd967fc Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Tue, 2 Dec 2025 08:57:20 +0200
Subject: [PATCH 3/4] Removed more autos. Made lit checks more explicit.

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 10 +++++----
 mlir/test/Dialect/MemRef/canonicalize.mlir | 25 ++++++++++++++++------
 2 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 11bfc99320644..ba2cabe668f13 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2539,9 +2539,10 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
       return rewriter.notifyMatchFailure(
           op, "no static-to-dynamic conversions found");
 
-    auto castSource = cast.getSource();
+    Value castSource = cast.getSource();
     auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
-    auto reassociationIndices = op.getReassociationIndices();
+    SmallVector<ReassociationIndices> reassociationIndices =
+        op.getReassociationIndices();
     for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
       int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
       auto newOutputShapeSizesSlice =
@@ -2554,8 +2555,9 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
                 "reassociation group");
     }
 
-    auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
-        castSourceType, newOutputShapeSizes, reassociationIndices);
+    FailureOr<MemRefType> newResultTypeOrFailure =
+        ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
+                                           reassociationIndices);
 
     if (failed(newResultTypeOrFailure))
       return rewriter.notifyMatchFailure(
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index c2d0376fc9723..641b9a0a8624c 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -553,6 +553,8 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
 
 // CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
 // CHECK-NOT:     memref.cast
+// CHECK:         memref.expand_shape {{.*}} output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK-NOT:     memref.cast
 // CHECK:         return
 func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
   %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
@@ -570,6 +572,8 @@ func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32
 
 // CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial(
 // CHECK-NOT:     memref.cast
+// CHECK:         memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK-NOT:     memref.cast
 // CHECK:         return
 func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
   %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
@@ -587,6 +591,8 @@ func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>)
 
 // CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial1(
 // CHECK-NOT:     memref.cast
+// CHECK:         memref.expand_shape {{.*}} {{\[\[}}0, 1], [2, 3]] output_shape [1, 8, 1, %{{.*}}] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK-NOT:     memref.cast
 // CHECK:         return
 func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
   %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
@@ -603,9 +609,10 @@ func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>
 // -----
 
 // CHECK-LABEL:   func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
-// CHECK:           memref.cast
-// CHECK:           memref.expand_shape
-// CHECK:           return
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK:           %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
 // CHECK:         }
 func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
   %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
@@ -617,10 +624,14 @@ func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4
 // -----
 
 // CHECK-LABEL:   func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
-// CHECK:           memref.cast
-// CHECK:           memref.expand_shape
-// CHECK:           memref.cast
-// CHECK:           return
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x4xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> {
+// CHECK:           %[[CONSTANT_0:.*]] = arith.constant 8 : index
+// CHECK:           %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref<?x4xf32>
+// CHECK:           %[[DIVUI_0:.*]] = arith.divui %[[CONSTANT_0]], %[[ARG1]] : index
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]] output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
+// CHECK:           %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+// CHECK:           return %[[CAST_1]] : memref<2x1x4x4xf32>
 // CHECK:         }
 func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
   %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>

>From 34c9711280120555c516a482687ca61cdea9c160 Mon Sep 17 00:00:00 2001
From: Dmitry Kaptsenel <21355508+kdmitry1 at users.noreply.github.com>
Date: Tue, 2 Dec 2025 09:24:13 +0200
Subject: [PATCH 4/4] Allow fold when single dynamic dim is expanded to
 multiple dynamic

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   |  7 +++----
 mlir/test/Dialect/MemRef/canonicalize.mlir | 20 ++++++++++++++++++++
 2 files changed, 23 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index ba2cabe668f13..90b7a866ba6d1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2544,12 +2544,11 @@ struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
     SmallVector<ReassociationIndices> reassociationIndices =
         op.getReassociationIndices();
     for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
-      int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
       auto newOutputShapeSizesSlice =
           ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
-      int64_t newOutputDynCount =
-          llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
-      if (castSourceDynCount != newOutputDynCount)
+      bool newOutputDynamic =
+          llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
+      if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
         return rewriter.notifyMatchFailure(
             op, "folding cast will result in changing dynamicity in "
                 "reassociation group");
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 641b9a0a8624c..854c8ba0597e1 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -608,6 +608,26 @@ func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>
 
 // -----
 
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_multiple(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> {
+// CHECK-NOT:     memref.cast
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
+// CHECK-NOT:     memref.cast
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<8x1x?x?xf32>
+// CHECK:         }
+func.func @fold_memref_expand_static_to_dynamic_multiple(%arg0 : memref<8x?xf32>, %arg1 : index, %arg2 : index) -> memref<8x1x?x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%dim0, 1, %arg1, %arg2]
+      : memref<?x?xf32> into memref<?x1x?x?xf32>
+  %2 = memref.cast %1 : memref<?x1x?x?xf32> to memref<8x1x?x?xf32>
+  return %2 : memref<8x1x?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
 // CHECK-SAME:      %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
 // CHECK:           %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>



More information about the Mlir-commits mailing list