[Mlir-commits] [mlir] [mlir][memref] Fold CopyOp if source and target are practically the same (PR #171801)
Maya Amrami
llvmlistbot at llvm.org
Thu Dec 11 03:01:39 PST 2025
https://github.com/amrami created https://github.com/llvm/llvm-project/pull/171801
memref.copy(%a, %a) is already folded.
The folder is extended and cases such as:
memref.copy(collapse(%a), %a) will be folded as well because practically source and target equal.
>From 4718103bbd8fb0d3d6966d53e0b62f126ac85d7d Mon Sep 17 00:00:00 2001
From: Maya Amrami <maya.amrami at mobileye.com>
Date: Tue, 9 Dec 2025 15:38:47 +0200
Subject: [PATCH] [mlir][memref] Fold CopyOp if source and target are
practically the same
memref.copy(%a, %a) is already folded.
The folder is extended and cases such as:
memref.copy(collapse(%a), %a) will be folded as well
because practically source and target equal.
Change-Id: I4f7eac64b1a5f015c8a7e3c7372b91d5cf058e46
---
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 15 ++++++++++++++-
mlir/test/Dialect/MemRef/canonicalize.mlir | 16 ++++++++++++++++
2 files changed, 30 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 1035d7cb46e6e..834adb8a4cc09 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -812,9 +812,22 @@ namespace {
struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
using OpRewritePattern<CopyOp>::OpRewritePattern;
+ // Given a value, traverse through CastOp, CollapseShapeOp and ExpandShapeOp
+ // to get the root value. The root and the original value have the same data.
+ // Thus operations like memref.subview are not considered here.
+ static Value getRoot(Value v) {
+ while (auto *definingOp = v.getDefiningOp()) {
+ if (!isa<CastOp, CollapseShapeOp, ExpandShapeOp>(definingOp))
+ return v;
+ v = definingOp->getOperand(0);
+ }
+ return v;
+ }
LogicalResult matchAndRewrite(CopyOp copyOp,
PatternRewriter &rewriter) const override {
- if (copyOp.getSource() != copyOp.getTarget())
+ Value sourceData = getRoot(copyOp.getSource());
+ Value targetData = getRoot(copyOp.getTarget());
+ if (sourceData != targetData)
return failure();
rewriter.eraseOp(copyOp);
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index e02717a2f5689..f73b1b4675bba 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -719,6 +719,22 @@ func.func @self_copy(%m1: memref<?xf32>) {
// -----
+func.func @practically_self_copy() {
+ %alloc = memref.alloc() : memref<1x8x384x384xui8>
+ %subview = memref.subview %alloc[0, 0, 0, 0] [1, 1, 384, 384] [1, 1, 1, 1] : memref<1x8x384x384xui8> to memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ %collapse_shape = memref.collapse_shape %subview [[0, 1], [2], [3]] : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> into memref<1x384x384xui8, strided<[1179648, 384, 1]>>
+ "my.op"(%collapse_shape) : (memref<1x384x384xui8, strided<[1179648, 384, 1]>>) -> ()
+ %expand_shape = memref.expand_shape %collapse_shape [[0, 1], [2], [3]] output_shape [1, 1, 384, 384] : memref<1x384x384xui8, strided<[1179648, 384, 1]>> into memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>>
+ %cast = memref.cast %subview : memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>> to memref<1x1x384x384xui8, affine_map<(d0, d1, d2, d3) -> (d0 * 1179648 + d1 * 147456 + d2 * 384 + d3)>>
+ memref.copy %expand_shape, %subview : memref<1x1x384x384xui8, strided<[1179648, 1179648, 384, 1]>> to memref<1x1x384x384xui8, strided<[1179648, 147456, 384, 1]>>
+ return
+}
+
+// CHECK-LABEL: func @practically_self_copy
+// CHECK-NOT: memref.copy
+
+// -----
+
// CHECK-LABEL: func @empty_copy
// CHECK-NEXT: return
func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
More information about the Mlir-commits
mailing list