[Mlir-commits] [mlir] [mlir][memref] Fold CopyOp if source and target are practically the same (PR #171801)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 11 03:02:09 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-memref

Author: Maya Amrami (amrami)

<details>
<summary>Changes</summary>

memref.copy(%a, %a) is already folded.
The folder is extended and cases such as:
memref.copy(collapse(%a), %a) will be folded as well because practically source and target equal.

---
Full diff: https://github.com/llvm/llvm-project/pull/171801.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+14-1) 
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+16) 


``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..834adb8a4cc09 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -812,9 +812,22 @@ namespace {
 struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
   using OpRewritePattern<CopyOp>::OpRewritePattern;
 
+  // Given a value, traverse through CastOp, CollapseShapeOp and ExpandShapeOp
+  // to get the root value. The root and the original value have the same data.
+  // Thus operations like memref.subview are not considered here.
+  static Value getRoot(Value v) {
+    while (auto *definingOp = v.getDefiningOp()) {
+      if (!isa<CastOp, CollapseShapeOp, ExpandShapeOp>(definingOp))
+        return v;
+      v = definingOp->getOperand(0);
+    }
+    return v;
+  }
   LogicalResult matchAndRewrite(CopyOp copyOp,
                                 PatternRewriter &rewriter) const override {
-    if (copyOp.getSource() != copyOp.getTarget())
+    Value sourceData = getRoot(copyOp.getSource());
+    Value targetData = getRoot(copyOp.getTarget());
+    if (sourceData != targetData)
       return failure();
 
     rewriter.eraseOp(copyOp);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..f73b1b4675bba 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -719,6 +719,22 @@ func.func @self_copy(%m1: memref<?xf32>) {
 
 // -----
 
+func.func @practically_self_copy() {
+  %alloc = memref.alloc() : memref<1x8x384x384xui8>
+  %subview = memref.subview %alloc[0, 0, 0, 0] [1, 1, 384, 384] [1, 1, 1, 1] : memref<1x8x384x384xui8> to memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+  %collapse_shape = memref.collapse_shape %subview [[0, 1], [2], [3]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+  "my.op"(%collapse_shape) : (memref<1x384x384xui8, strided<[1179648, 384, 1]>>) -> ()
+  %expand_shape = memref.expand_shape %collapse_shape [[0, 1], [2], [3]] output_shape [1, 1, 384, 384] : memref<1x384x384xui8, strided<[1179648, 384, 1]>> into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+  %cast = memref.cast %subview : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, affine_map<(d0, d1, d2, d3) -> (d0 * 1179648 + d1 * 147456 + d2 * 384 + d3)>>
+  memref.copy %expand_shape, %subview : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+  return
+}
+
+// CHECK-LABEL: func @practically_self_copy
+//   CHECK-NOT:   memref.copy
+
+// -----
+
 // CHECK-LABEL: func @empty_copy
 //  CHECK-NEXT:   return
 func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/171801


More information about the Mlir-commits mailing list