[Mlir-commits] [mlir] [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (PR #170037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 30 06:52:11 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (kdmitry1)
<details>
<summary>Changes</summary>
memref.expand_shape didn't have memref.cast op folder. Added canonicalization pattern to allow folding of memref.cast from static to dynamic.
Example:
```mlir
%0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
%c0 = arith.constant 0 : index
%dim0 = memref.dim %0, %c0 : memref<?x4xf32>
%1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%dim0, 1, 4] : memref<?x4xf32> into memref<?x1x4xf32>
```
is converted to:
```mlir
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
%cast = memref.cast %expand_shape : memref<8x1x4xf32> to memref<?x1x4xf32>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/170037.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+79-1)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+84)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..49dc23b702875 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,89 @@ LogicalResult ExpandShapeOp::verify() {
return success();
}
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ auto cast = op.getSrc().getDefiningOp<CastOp>();
+ if (!cast)
+ return failure();
+
+ if (!CastOp::canFoldIntoConsumerOp(cast))
+ return failure();
+
+ auto originalOutputShape = op.getMixedOutputShape();
+ auto newOutputShape = originalOutputShape;
+ SmallVector<int64_t> newOutputShapeSizes;
+ SmallVector<Value> newOperands;
+
+ // Convert output shape dims from dynamic to static where possible.
+ for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+ auto dimVal = dimSize.dyn_cast<Value>();
+ if (!dimVal) {
+ newOutputShapeSizes.push_back(getConstantIntValue(dimSize).value());
+ continue;
+ }
+
+ auto constOp = dimVal.getDefiningOp<arith::ConstantIndexOp>();
+ if (!constOp) {
+ newOperands.push_back(dimVal);
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
+ continue;
+ }
+
+ newOutputShape[dimIdx] = constOp.getValue();
+ newOutputShapeSizes.push_back(
+ getConstantIntValue(constOp.getValue()).value());
+ }
+
+ if (newOperands.size() == op->getNumOperands())
+ return rewriter.notifyMatchFailure(
+ op, "no static-to-dynamic conversions found");
+
+ auto castSource = cast.getSource();
+ auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+ auto reassociationIndices = op.getReassociationIndices();
+ for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+ int64_t castSourceDynCount = castSourceType.isDynamicDim(idx) ? 1 : 0;
+ auto newOutputShapeSizesSlice =
+ ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+ int64_t newOutputDynCount =
+ llvm::count_if(newOutputShapeSizesSlice, ShapedType::isDynamic);
+ if (castSourceDynCount != newOutputDynCount)
+ return rewriter.notifyMatchFailure(
+ op, "folding cast will result in changing dynamicity in "
+ "reassociation group");
+ }
+
+ auto newResultTypeOrFailure = ExpandShapeOp::computeExpandedType(
+ castSourceType, newOutputShapeSizes, reassociationIndices);
+
+ if (failed(newResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "could not compute new expanded type after folding cast");
+
+ if (*newResultTypeOrFailure == op.getResultType()) {
+ rewriter.modifyOpInPlace(
+ op, [&]() { op.getSrcMutable().assign(castSource); });
+ } else {
+ Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+ *newResultTypeOrFailure, castSource,
+ reassociationIndices, newOutputShape);
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+ }
+ return success();
+ }
+};
+
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ExpandShapeOpMemRefCastFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..c2d0376fc9723 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,90 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// -----
+// CHECK-LABEL: @fold_memref_expand_with_static_to_dynamic_cast
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0 : memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %c4 = arith.constant 4 : index
+ %dim_ext = arith.divui %dim0 , %c4: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+ : memref<?x?xf32> into memref<1x?x1x?xf32>
+ %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial1(
+// CHECK-NOT: memref.cast
+// CHECK: return
+func.func @fold_memref_expand_static_to_dynamic_partial1(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+ : memref<?x?xf32> into memref<?x?x?x?xf32>
+ %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+ : memref<8x4xf32> into memref<2x1x4x4xf32>
+ return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK: memref.cast
+// CHECK: memref.expand_shape
+// CHECK: memref.cast
+// CHECK: return
+// CHECK: }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %dim_ext = arith.divui %dim0 , %arg1: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/170037
More information about the Mlir-commits
mailing list