[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 30 06:52:11 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (kdmitry1)

<details>
<summary>Changes</summary>

memref.expand_shape didn't have memref.cast op folder. Added canonicalization pattern to allow folding of memref.cast from static to dynamic.

Example:

```mlir
  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
  %c0 = arith.constant 0 : index
  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
  %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%dim0, 1, 4]  : memref<?x4xf32> into memref<?x1x4xf32>
```

is converted to:

```mlir
  %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
  %cast = memref.cast %expand_shape : memref<8x1x4xf32> to memref<?x1x4xf32>

```



---
Full diff: https://github.com/llvm/llvm-project/pull/170037.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+79-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+84) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..49dc23b702875 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,89 @@ LogicalResult ExpandShapeOp::verify() {
   return success();
 }
 
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+  using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExpandShapeOp op,
+                                PatternRewriter &rewriter) const override {
+    auto cast = op.getSrc().getDefiningOp<CastOp>();
+    if (!cast)
+      return failure();
+
+    if (!CastOp::canFoldIntoConsumerOp(cast))
+      return failure();
+
+    auto originalOutputShape = op.getMixedOutputShape();
+    auto newOutputShape = originalOutputShape;
+    SmallVector<int64_t> newOutputShapeSizes;
+    SmallVector<Value> newOperands;
+
+    // Convert output shape dims from dynamic to static where possible.
+    for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+      auto dimVal = dimSize.dyn_cast<Value>();
+      if (!dimVal) {
+        newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+        continue;
+      }
+
+      auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
+      if (!constOp) {
+        newOperands.push_back(dimVal);
+        newOutputShapeSizes.push_back(ShapedType::kDynamic);
+        continue;
+      }
+
+      newOutputShape[dimIdx] = constOp.getValue();
+      newOutputShapeSizes.push_back(
+          getConstantIntValue(constOp.getValue()).value());
+    }
+
+    if (newOperands.size() == op->getNumOperands())
+      return rewriter.notifyMatchFailure(
+          op, "no static-to-dynamic conversions found");
+
+    auto castSource = cast.getSource();
+    auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+    auto reassociationIndices = op.getReassociationIndices();
+    for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+      int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+      auto newOutputShapeSizesSlice =
+          ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+      int64_t newOutputDynCount =
+          llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+      if (castSourceDynCount != newOutputDynCount)
+        return rewriter.notifyMatchFailure(
+            op, "folding cast will result in changing dynamicity in "
+                "reassociation group");
+    }
+
+    auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
+        castSourceType, newOutputShapeSizes, reassociationIndices);
+
+    if (failed(newResultTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "could not compute new expanded type after folding cast");
+
+    if (*newResultTypeOrFailure == op.getResultType()) {
+      rewriter.modifyOpInPlace(
+          op, [&]() { op.getSrcMutable().assign(castSource); });
+    } else {
+      Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+                                          *newResultTypeOrFailure, castSource,
+                                          reassociationIndices, newOutputShape);
+      rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+    }
+    return success();
+  }
+};
+
 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
   results.add<
       ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
-      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+      ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+      ExpandShapeOpMemRefCastFolder>(context);
 }
 
 FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..c2d0376fc9723 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,90 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
 
 // -----
 
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %c4 = arith.constant 4 : index
+  %dim_ext = arith.divui %dim0 , %c4: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+      : memref<?x?xf32> into memref<1x?x1x?xf32>
+  %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-NOT:     memref.cast
+// CHECK:         return
+func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+  %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+  %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+  %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+      : memref<?x?xf32> into memref<?x?x?x?xf32>
+  %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+  return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK:           memref.cast
+// CHECK:           memref.expand_shape
+// CHECK:           return
+// CHECK:         }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+      : memref<8x4xf32> into memref<2x1x4x4xf32>
+  return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK:           memref.cast
+// CHECK:           memref.expand_shape
+// CHECK:           memref.cast
+// CHECK:           return
+// CHECK:         }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+  %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+  %c0 = arith.constant 0 : index
+  %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+  %dim_ext = arith.divui %dim0 , %arg1: index
+  %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+      : memref<?x4xf32> into memref<?x1x4x4xf32>
+  %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+  return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL:   func @collapse_after_memref_cast_type_change(
 // CHECK-SAME:      %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
 // CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/170037


More information about the Mlir-commits mailing list