[Mlir-commits] [mlir] [mlir][memref] Simplify memref.copy canonicalization (PR #149506)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 18 05:22:20 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

FoldCopyOfCast has both a pattern implementation and a folder implementation. This PR removes the pattern implementation.

---
Full diff: https://github.com/llvm/llvm-project/pull/149506.diff


1 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+14-53) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index d1a9920aa66c5..4b868ed5b08fc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -715,51 +715,6 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// If the source/target of a CopyOp is a CastOp that does not modify the shape
-/// and element type, the cast can be skipped. Such CastOps only cast the layout
-/// of the type.
-struct FoldCopyOfCast : public OpRewritePattern<CopyOp> {
-  using OpRewritePattern<CopyOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(CopyOp copyOp,
-                                PatternRewriter &rewriter) const override {
-    bool modified = false;
-
-    // Check source.
-    if (auto castOp = copyOp.getSource().getDefiningOp<CastOp>()) {
-      auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
-      auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
-
-      if (fromType && toType) {
-        if (fromType.getShape() == toType.getShape() &&
-            fromType.getElementType() == toType.getElementType()) {
-          rewriter.modifyOpInPlace(copyOp, [&] {
-            copyOp.getSourceMutable().assign(castOp.getSource());
-          });
-          modified = true;
-        }
-      }
-    }
-
-    // Check target.
-    if (auto castOp = copyOp.getTarget().getDefiningOp<CastOp>()) {
-      auto fromType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
-      auto toType = llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
-
-      if (fromType && toType) {
-        if (fromType.getShape() == toType.getShape() &&
-            fromType.getElementType() == toType.getElementType()) {
-          rewriter.modifyOpInPlace(copyOp, [&] {
-            copyOp.getTargetMutable().assign(castOp.getSource());
-          });
-          modified = true;
-        }
-      }
-    }
-
-    return success(modified);
-  }
-};
 
 /// Fold memref.copy(%x, %x).
 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
@@ -797,22 +752,28 @@ struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
 
 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                          MLIRContext *context) {
-  results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
+  results.add<FoldEmptyCopy, FoldSelfCopy>(context);
 }
 
-LogicalResult CopyOp::fold(FoldAdaptor adaptor,
-                           SmallVectorImpl<OpFoldResult> &results) {
-  /// copy(memrefcast) -> copy
-  bool folded = false;
-  Operation *op = *this;
+/// If the source/target of a CopyOp is a CastOp that does not modify the shape
+/// and element type, the cast can be skipped. Such CastOps only cast the layout
+/// of the type.
+LogicalResult FoldCopyOfCast(CopyOp op) {
   for (OpOperand &operand : op->getOpOperands()) {
     auto castOp = operand.get().getDefiningOp<memref::CastOp>();
     if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
       operand.set(castOp.getOperand());
-      folded = true;
+      return success();
     }
   }
-  return success(folded);
+  return failure();
+}
+
+LogicalResult CopyOp::fold(FoldAdaptor adaptor,
+                           SmallVectorImpl<OpFoldResult> &results) {
+
+  /// copy(memrefcast) -> copy
+  return FoldCopyOfCast(*this);
 }
 
 //===----------------------------------------------------------------------===//

``````````

</details>


https://github.com/llvm/llvm-project/pull/149506


More information about the Mlir-commits mailing list