[Mlir-commits] [mlir] a53dbe2 - [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (#170037)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 31 08:08:04 PST 2025


Author: kdmitry1
Date: 2025-12-31T16:08:00Z
New Revision: a53dbe29d74dcab17078f9699e4a8c9f02ece0c3

URL: https://github.com/llvm/llvm-project/commit/a53dbe29d74dcab17078f9699e4a8c9f02ece0c3
DIFF: https://github.com/llvm/llvm-project/commit/a53dbe29d74dcab17078f9699e4a8c9f02ece0c3.diff

LOG: [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (#170037)

memref.expand_shape didn't have memref.cast op folder. Added
canonicalization pattern to allow folding of memref.cast from static to
dynamic.

Example:

```mlir
  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
  %c0 = arith.constant 0 : index
  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
  %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%dim0, 1, 4]  : memref<?x4xf32> into memref<?x1x4xf32>
```

is converted to:

```mlir
  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
  %cast = memref.cast %expand_shape : memref<8x1x4xf32> to memref<?x1x4xf32>

```

Added: 
    

Modified: 
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/test/Dialect/MemRef/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index eb321bbc15ded..7bc6ae5f21e8b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,77 @@ LogicalResult ExpandShapeOp::verify() {
   return success();
 }
 
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto cast = op.getSrc().getDefiningOp<CastOp>();
+    if (!cast)
+      return failure();
+
+    if (!CastOp::canFoldIntoConsumerOp(cast))
+      return failure();
+
+    SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
+    SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
+    SmallVector<int64_t> newOutputShapeSizes;
+
+    // Convert output shape dims from dynamic to static where possible.
+    for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+      std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
+      if (!sizeOpt.has_value()) {
+        newOutputShapeSizes.push_back(ShapedType::kDynamic);
+        continue;
+      }
+
+      newOutputShapeSizes.push_back(sizeOpt.value());
+      newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
+    }
+
+    Value castSource = cast.getSource();
+    auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+    SmallVector<ReassociationIndices> reassociationIndices =
+        op.getReassociationIndices();
+    for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+      auto newOutputShapeSizesSlice =
+          ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+      bool newOutputDynamic =
+          llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
+      if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
+        return rewriter.notifyMatchFailure(
+            op, "folding cast will result in changing dynamicity in "
+                "reassociation group");
+    }
+
+    FailureOr<MemRefType> newResultTypeOrFailure =
+        ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
+                                           reassociationIndices);
+
+    if (failed(newResultTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "could not compute new expanded type after folding cast");
+
+    if (*newResultTypeOrFailure == op.getResultType()) {
+      rewriter.modifyOpInPlace(
+          op, [&]() { op.getSrcMutable().assign(castSource); });
+    } else {
+      Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+                                          *newResultTypeOrFailure, castSource,
+                                          reassociationIndices, newOutputShape);
+      rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+    }
+    return success();
+  }
+};
+
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      ExpandShapeOpMemRefCastFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>

diff  --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7b4dea6a24396..132acfd9b1d48 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,144 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
 
 // -----
 
+// CHECK-LABEL:   func.func @fold_memref_expand_with_static_to_dynamic_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME:                  output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
+// CHECK:         }
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0: memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+  %c2 = arith.constant 2 : index
+  %cast = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %expand_shape = memref.expand_shape %cast [[0, 1, 2], [3]] output_shape [%c2, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %cast_0 = memref.cast %expand_shape : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %cast_0 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME:              output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
+// CHECK:         }
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+      : memref<?x?xf32> into memref<1x?x1x?xf32>
+  %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial_with_arith_const_as_dim(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME:               output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
+// CHECK:         }
+func.func @fold_memref_expand_static_to_dynamic_partial_with_arith_const_as_dim(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+      : memref<?x?xf32> into memref<?x?x?x?xf32>
+  %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_multiple(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x?xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> {
+// CHECK-NOT:     memref.cast
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME:        output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
+// CHECK-NOT:     memref.cast
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<8x1x?x?xf32>
+// CHECK:         }
+func.func @fold_memref_expand_static_to_dynamic_multiple(%arg0 : memref<8x?xf32>, %arg1 : index, %arg2 : index) -> memref<8x1x?x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%dim0, 1, %arg1, %arg2]
+      : memref<?x?xf32> into memref<?x1x?x?xf32>
+  %2 = memref.cast %1 : memref<?x1x?x?xf32> to memref<8x1x?x?xf32>
+  return %2 : memref<8x1x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK:           %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME:          output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
+// CHECK:         }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+      : memref<8x4xf32> into memref<2x1x4x4xf32>
+  return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x4xf32>,
+// CHECK-SAME:      %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> {
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref<?x4xf32>
+// CHECK:           %[[DIVUI_0:.*]] = arith.divui %[[C8]], %[[ARG1]] : index
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME:          output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
+// CHECK:           %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+// CHECK:           return %[[CAST_1]] : memref<2x1x4x4xf32>
+// CHECK:         }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %dim_ext = arith.divui %dim0 , %arg1: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_layout(
+// CHECK-SAME:      %[[ARG0:.*]]: memref<8x4xf32>) -> memref<8x1x4xf32> {
+// CHECK:           %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]]
+// CHECK-SAME:          output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
+// CHECK:           return %[[EXPAND_SHAPE_0]] : memref<8x1x4xf32>
+// CHECK:         }
+func.func @fold_memref_expand_static_to_dynamic_layout(%arg0 : memref<8x4xf32>) -> memref<8x1x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<8x4xf32, strided<[?, ?], offset: ?>>
+  %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [8, 1, 4]
+      : memref<8x4xf32, strided<[?, ?], offset: ?>> into memref<8x1x4xf32, strided<[?,?,?], offset: ?>>
+  %2 = memref.cast %1 : memref<8x1x4xf32, strided<[?,?,?], offset: ?>> to memref<8x1x4xf32>
+  return %2 : memref<8x1x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
 // CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]


        


More information about the Mlir-commits mailing list