[Mlir-commits] [mlir] [mlir][memref] Fold CopyOp if source and target are practically the same (PR #171801)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 11 03:02:09 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref
Author: Maya Amrami (amrami)
<details>
<summary>Changes</summary>
memref.copy(%a, %a) is already folded.
The folder is extended and cases such as:
memref.copy(collapse(%a), %a) will be folded as well because practically source and target equal.
---
Full diff: https://github.com/llvm/llvm-project/pull/171801.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+14-1)
- (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+16)
``````````diff
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..834adb8a4cc09 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -812,9 +812,22 @@ namespace {
struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
using OpRewritePattern<CopyOp>::OpRewritePattern;
+ // Given a value, traverse through CastOp, CollapseShapeOp and ExpandShapeOp
+ // to get the root value. The root and the original value have the same data.
+ // Thus operations like memref.subview are not considered here.
+ static Value getRoot(Value v) {
+ while (auto *definingOp = v.getDefiningOp()) {
+ if (!isa<CastOp, CollapseShapeOp, ExpandShapeOp>(definingOp))
+ return v;
+ v = definingOp->getOperand(0);
+ }
+ return v;
+ }
LogicalResult matchAndRewrite(CopyOp copyOp,
PatternRewriter &rewriter) const override {
- if (copyOp.getSource() != copyOp.getTarget())
+ Value sourceData = getRoot(copyOp.getSource());
+ Value targetData = getRoot(copyOp.getTarget());
+ if (sourceData != targetData)
return failure();
rewriter.eraseOp(copyOp);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..f73b1b4675bba 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -719,6 +719,22 @@ func.func @self_copy(%m1: memref<?xf32>) {
// -----
+func.func @practically_self_copy() {
+ %alloc = memref.alloc() : memref<1x8x384x384xui8>
+ %subview = memref.subview %alloc[0, 0, 0, 0] [1, 1, 384, 384] [1, 1, 1, 1] : memref<1x8x384x384xui8> to memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ %collapse_shape = memref.collapse_shape %subview [[0, 1], [2], [3]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ "my.op"(%collapse_shape) : (memref<1x384x384xui8, strided<[1179648, 384, 1]>>) -> ()
+ %expand_shape = memref.expand_shape %collapse_shape [[0, 1], [2], [3]] output_shape [1, 1, 384, 384] : memref<1x384x384xui8, strided<[1179648, 384, 1]>> into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ %cast = memref.cast %subview : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, affine_map<(d0, d1, d2, d3) -> (d0 * 1179648 + d1 * 147456 + d2 * 384 + d3)>>
+ memref.copy %expand_shape, %subview : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ return
+}
+
+// CHECK-LABEL: func @practically_self_copy
+// CHECK-NOT: memref.copy
+
+// -----
+
// CHECK-LABEL: func @empty_copy
// CHECK-NEXT: return
func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/171801
More information about the Mlir-commits
mailing list