[Mlir-commits] [mlir] a53dbe2 - [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (#170037)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 31 08:08:04 PST 2025
Author: kdmitry1
Date: 2025-12-31T16:08:00Z
New Revision: a53dbe29d74dcab17078f9699e4a8c9f02ece0c3
URL: https://github.com/llvm/llvm-project/commit/a53dbe29d74dcab17078f9699e4a8c9f02ece0c3
DIFF: https://github.com/llvm/llvm-project/commit/a53dbe29d74dcab17078f9699e4a8c9f02ece0c3.diff
LOG: [mlir] Fold memref.cast static-to-dynamic to memref.expand_shape (#170037)
memref.expand_shape didn't have memref.cast op folder. Added
canonicalization pattern to allow folding of memref.cast from static to
dynamic.
Example:
```mlir
%0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
%c0 = arith.constant 0 : index
%dim0 = memref.dim %0, %c0 : memref<?x4xf32>
%1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [%dim0, 1, 4] : memref<?x4xf32> into memref<?x1x4xf32>
```
is converted to:
```mlir
%expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
%cast = memref.cast %expand_shape : memref<8x1x4xf32> to memref<?x1x4xf32>
```
Added:
Modified:
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index eb321bbc15ded..7bc6ae5f21e8b 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -2504,11 +2504,77 @@ LogicalResult ExpandShapeOp::verify() {
return success();
}
+struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
+public:
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp op,
+ PatternRewriter &rewriter) const override {
+ auto cast = op.getSrc().getDefiningOp<CastOp>();
+ if (!cast)
+ return failure();
+
+ if (!CastOp::canFoldIntoConsumerOp(cast))
+ return failure();
+
+ SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
+ SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
+ SmallVector<int64_t> newOutputShapeSizes;
+
+ // Convert output shape dims from dynamic to static where possible.
+ for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
+ std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
+ if (!sizeOpt.has_value()) {
+ newOutputShapeSizes.push_back(ShapedType::kDynamic);
+ continue;
+ }
+
+ newOutputShapeSizes.push_back(sizeOpt.value());
+ newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
+ }
+
+ Value castSource = cast.getSource();
+ auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
+ SmallVector<ReassociationIndices> reassociationIndices =
+ op.getReassociationIndices();
+ for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
+ auto newOutputShapeSizesSlice =
+ ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
+ bool newOutputDynamic =
+ llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
+ if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
+ return rewriter.notifyMatchFailure(
+ op, "folding cast will result in changing dynamicity in "
+ "reassociation group");
+ }
+
+ FailureOr<MemRefType> newResultTypeOrFailure =
+ ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
+ reassociationIndices);
+
+ if (failed(newResultTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "could not compute new expanded type after folding cast");
+
+ if (*newResultTypeOrFailure == op.getResultType()) {
+ rewriter.modifyOpInPlace(
+ op, [&]() { op.getSrcMutable().assign(castSource); });
+ } else {
+ Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
+ *newResultTypeOrFailure, castSource,
+ reassociationIndices, newOutputShape);
+ rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
+ }
+ return success();
+ }
+};
+
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
- ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
+ ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
+ ExpandShapeOpMemRefCastFolder>(context);
}
FailureOr<std::optional<SmallVector<Value>>>
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 7b4dea6a24396..132acfd9b1d48 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -551,6 +551,144 @@ func.func @fold_memref_expand_cast(%arg0 : memref<?x?xf32>) -> memref<2x4x4xf32>
// -----
+// CHECK-LABEL: func.func @fold_memref_expand_with_static_to_dynamic_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME: output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
+// CHECK: }
+func.func @fold_memref_expand_with_static_to_dynamic_cast(%arg0: memref<8x4xf32>) -> memref<2x1x4x4xf32> {
+ %c2 = arith.constant 2 : index
+ %cast = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %expand_shape = memref.expand_shape %cast [[0, 1, 2], [3]] output_shape [%c2, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %cast_0 = memref.cast %expand_shape : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %cast_0 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME: output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
+// CHECK: }
+func.func @fold_memref_expand_static_to_dynamic_partial(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [1, %dim0, 1, %dim1]
+ : memref<?x?xf32> into memref<1x?x1x?xf32>
+ %2 = memref.cast %1 : memref<1x?x1x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_partial_with_arith_const_as_dim(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<8x?xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME: output_shape [1, 8, 1, %[[DIM1]]] : memref<8x?xf32> into memref<1x8x1x?xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<1x8x1x?xf32>
+// CHECK: }
+func.func @fold_memref_expand_static_to_dynamic_partial_with_arith_const_as_dim(%arg0 : memref<8x?xf32>) -> memref<1x8x1x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %dim1 = memref.dim %0, %c1 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%c1, %dim0, %c1, %dim1]
+ : memref<?x?xf32> into memref<?x?x?x?xf32>
+ %2 = memref.cast %1 : memref<?x?x?x?xf32> to memref<1x8x1x?xf32>
+ return %2 : memref<1x8x1x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_multiple(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x?xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> memref<8x1x?x?xf32> {
+// CHECK-NOT: memref.cast
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]]
+// CHECK-SAME: output_shape [8, 1, %[[ARG1]], %[[ARG2]]] : memref<8x?xf32> into memref<8x1x?x?xf32>
+// CHECK-NOT: memref.cast
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x?x?xf32>
+// CHECK: }
+func.func @fold_memref_expand_static_to_dynamic_multiple(%arg0 : memref<8x?xf32>, %arg1 : index, %arg2 : index) -> memref<8x1x?x?xf32> {
+ %0 = memref.cast %arg0 : memref<8x?xf32> to memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x?xf32>
+ %1 = memref.expand_shape %0 [[0, 1], [2, 3]] output_shape [%dim0, 1, %arg1, %arg2]
+ : memref<?x?xf32> into memref<?x1x?x?xf32>
+ %2 = memref.cast %1 : memref<?x1x?x?xf32> to memref<8x1x?x?xf32>
+ return %2 : memref<8x1x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_with_dynamic_to_static_cast(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<?x4xf32> to memref<8x4xf32>
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME: output_shape [2, 1, 4, 4] : memref<8x4xf32> into memref<2x1x4x4xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<2x1x4x4xf32>
+// CHECK: }
+func.func @not_fold_memref_expand_with_dynamic_to_static_cast(%arg0 : memref<?x4xf32>) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<?x4xf32> to memref<8x4xf32>
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [2, 1, 4, 4]
+ : memref<8x4xf32> into memref<2x1x4x4xf32>
+ return %1 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>,
+// CHECK-SAME: %[[ARG1:.*]]: index) -> memref<2x1x4x4xf32> {
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]] : memref<8x4xf32> to memref<?x4xf32>
+// CHECK: %[[DIVUI_0:.*]] = arith.divui %[[C8]], %[[ARG1]] : index
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[CAST_0]] {{\[\[}}0, 1, 2], [3]]
+// CHECK-SAME: output_shape {{\[}}%[[DIVUI_0]], 1, 4, 4] : memref<?x4xf32> into memref<?x1x4x4xf32>
+// CHECK: %[[CAST_1:.*]] = memref.cast %[[EXPAND_SHAPE_0]] : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+// CHECK: return %[[CAST_1]] : memref<2x1x4x4xf32>
+// CHECK: }
+func.func @not_fold_memref_expand_static_to_dynamic_cast_if_really_dynamic(%arg0 : memref<8x4xf32>, %arg1 : index) -> memref<2x1x4x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<?x4xf32>
+ %c0 = arith.constant 0 : index
+ %dim0 = memref.dim %0, %c0 : memref<?x4xf32>
+ %dim_ext = arith.divui %dim0 , %arg1: index
+ %1 = memref.expand_shape %0 [[0, 1, 2], [3]] output_shape [%dim_ext, 1, 4, 4]
+ : memref<?x4xf32> into memref<?x1x4x4xf32>
+ %2 = memref.cast %1 : memref<?x1x4x4xf32> to memref<2x1x4x4xf32>
+ return %2 : memref<2x1x4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fold_memref_expand_static_to_dynamic_layout(
+// CHECK-SAME: %[[ARG0:.*]]: memref<8x4xf32>) -> memref<8x1x4xf32> {
+// CHECK: %[[EXPAND_SHAPE_0:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2]]
+// CHECK-SAME: output_shape [8, 1, 4] : memref<8x4xf32> into memref<8x1x4xf32>
+// CHECK: return %[[EXPAND_SHAPE_0]] : memref<8x1x4xf32>
+// CHECK: }
+func.func @fold_memref_expand_static_to_dynamic_layout(%arg0 : memref<8x4xf32>) -> memref<8x1x4xf32> {
+ %0 = memref.cast %arg0 : memref<8x4xf32> to memref<8x4xf32, strided<[?, ?], offset: ?>>
+ %1 = memref.expand_shape %0 [[0, 1], [2]] output_shape [8, 1, 4]
+ : memref<8x4xf32, strided<[?, ?], offset: ?>> into memref<8x1x4xf32, strided<[?,?,?], offset: ?>>
+ %2 = memref.cast %1 : memref<8x1x4xf32, strided<[?,?,?], offset: ?>> to memref<8x1x4xf32>
+ return %2 : memref<8x1x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @collapse_after_memref_cast_type_change(
// CHECK-SAME: %[[INPUT:.*]]: memref<?x512x1x1xf32>) -> memref<?x?xf32> {
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[INPUT]]
More information about the Mlir-commits
mailing list