[Mlir-commits] [mlir] [mlir][memref] Simplify memref.copy canonicalization (PR #149506)
lonely eagle
llvmlistbot at llvm.org
Fri Jul 18 14:50:57 PDT 2025
https://github.com/linuxlonelyeagle updated https://github.com/llvm/llvm-project/pull/149506
>From 63ccf0a4731fa80a45eca10ff71f74d0311eb65b Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 18 Jul 2025 09:41:32 +0000
Subject: [PATCH 1/2] simplify memref.copy canonicalization.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 67 +++++-------------------
1 file changed, 14 insertions(+), 53 deletions(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d1a9920aa66c5..4b868ed5b08fc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -715,51 +715,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
namespace {
-/// If the source/target of a CopyOp is a CastOp that does not modify the shape
-/// and element type, the cast can be skipped. Such CastOps only cast the layout
-/// of the type.
-struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
- using OpRewritePattern<CopyOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CopyOp copyOp,
- PatternRewriter &rewriter) const override {
- bool modified = false;
-
- // Check source.
- if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
-
- if (fromType && toType) {
- if (fromType.getShape() == toType.getShape() &&
- fromType.getElementType() == toType.getElementType()) {
- rewriter.modifyOpInPlace(copyOp, [&] {
- copyOp.getSourceMutable().assign(castOp.getSource());
- });
- modified = true;
- }
- }
- }
-
- // Check target.
- if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
- auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
- auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
-
- if (fromType && toType) {
- if (fromType.getShape() == toType.getShape() &&
- fromType.getElementType() == toType.getElementType()) {
- rewriter.modifyOpInPlace(copyOp, [&] {
- copyOp.getTargetMutable().assign(castOp.getSource());
- });
- modified = true;
- }
- }
- }
-
- return success(modified);
- }
-};
/// Fold memref.copy(%x, %x).
struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
@@ -797,22 +752,28 @@ struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
+ results.add<FoldEmptyCopy, FoldSelfCopy>(context);
}
-LogicalResult CopyOp::fold(FoldAdaptor adaptor,
- SmallVectorImpl<OpFoldResult> &results) {
- /// copy(memrefcast) -> copy
- bool folded = false;
- Operation *op = *this;
+/// If the source/target of a CopyOp is a CastOp that does not modify the shape
+/// and element type, the cast can be skipped. Such CastOps only cast the layout
+/// of the type.
+LogicalResult FoldCopyOfCast(CopyOp op) {
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
- folded = true;
+ return success();
}
}
- return success(folded);
+ return failure();
+}
+
+LogicalResult CopyOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+
+ /// copy(memrefcast) -> copy
+ return FoldCopyOfCast(*this);
}
//===----------------------------------------------------------------------===//
>From 10abc3f78ee82ca0f707fefde8f5c8811cdf38bb Mon Sep 17 00:00:00 2001
From: linuxlonelyeagle <2020382038 at qq.com>
Date: Fri, 18 Jul 2025 21:50:43 +0000
Subject: [PATCH 2/2] add static on FoldCopyOfCast.
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 4b868ed5b08fc..51c813682ce25 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -758,7 +758,7 @@ void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// If the source/target of a CopyOp is a CastOp that does not modify the shape
/// and element type, the cast can be skipped. Such CastOps only cast the layout
/// of the type.
-LogicalResult FoldCopyOfCast(CopyOp op) {
+static LogicalResult FoldCopyOfCast(CopyOp op) {
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
More information about the Mlir-commits
mailing list