[Mlir-commits] [mlir] [mlir][memref] canonicalization for erasing copying subview to identical subview (PR #125852)

Frank Schlimbach llvmlistbot at llvm.org
Fri Feb 28 03:07:03 PST 2025


https://github.com/fschlimb updated https://github.com/llvm/llvm-project/pull/125852

>From 588d44c1b210cb117caaa17824f7354d81136757 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 5 Feb 2025 14:06:45 +0100
Subject: [PATCH 1/4] canonicalization for erasing copying subview to identical
 subview

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp   | 28 ++++++++++++++++++--
 mlir/test/Dialect/MemRef/canonicalize.mlir | 30 ++++++++++++++++++++++
 2 files changed, 56 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index e0930abc1887d..f94c588f99361 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -824,8 +824,32 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
 
   LogicalResult matchAndRewrite(CopyOp copyOp,
                                 PatternRewriter &rewriter) const override {
-    if (copyOp.getSource() != copyOp.getTarget())
-      return failure();
+    if (copyOp.getSource() != copyOp.getTarget()) {
+      // If the source and target are SubViews and they are identical, we can fold.
+      auto source = copyOp.getSource().getDefiningOp<SubViewOp>();
+      auto target = copyOp.getTarget().getDefiningOp<SubViewOp>();
+      if (!source || !target ||
+          source.getSource() != target.getSource() ||
+          llvm::any_of(llvm::zip(source.getOffsets(), target.getOffsets()),
+                    [](std::tuple<Value, Value> offsetPair) {
+                      return std::get<0>(offsetPair) != std::get<1>(offsetPair);
+                    }) ||
+          llvm::any_of(llvm::zip(source.getStaticOffsets(), target.getStaticOffsets()),
+                    [](std::tuple<int64_t, int64_t> offsetPair) {
+                      return std::get<0>(offsetPair) != std::get<1>(offsetPair);
+                    }) ||
+          // sizes must be the same anyway
+          llvm::any_of(llvm::zip(source.getStrides(), target.getStrides()),
+                    [](std::tuple<Value, Value> stridePair) {
+                      return std::get<0>(stridePair) != std::get<1>(stridePair);
+                    }) ||
+          llvm::any_of(llvm::zip(source.getStaticStrides(), target.getStaticStrides()),
+                    [](std::tuple<int64_t, int64_t> stridePair) {
+                      return std::get<0>(stridePair) != std::get<1>(stridePair);
+                    })) {
+          return failure();
+      }
+    }
 
     rewriter.eraseOp(copyOp);
     return success();
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index 02110bc2892d0..56a7014047aa1 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -704,6 +704,36 @@ func.func @self_copy(%m1: memref<?xf32>) {
 
 // -----
 
+func.func @self_copy_subview(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %s: index) {
+  %c3 = arith.constant 3: index
+  %0 = memref.subview %arg0[3] [4] [2] : memref<?xf32> to memref<4xf32, strided<[2], offset: 3>>
+  %1 = memref.subview %arg0[%c3] [4] [2] : memref<?xf32> to memref<4xf32, strided<[2], offset: ?>>
+  %2 = memref.subview %arg0[%c3] [4] [%s] : memref<?xf32> to memref<4xf32, strided<[?], offset: ?>>
+  %3 = memref.subview %arg0[3] [4] [%s] : memref<?xf32> to memref<4xf32, strided<[?], offset: 3>>
+  %4 = memref.subview %arg1[3] [4] [%s] : memref<?xf32> to memref<4xf32, strided<[?], offset: 3>>
+  // erase (source and destination subviews render the same)
+  memref.copy %0, %1 : memref<4xf32, strided<[2], offset: 3>> to memref<4xf32, strided<[2], offset: ?>>
+  // keep (strides differ)
+  memref.copy %2, %1 : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32, strided<[2], offset: ?>>
+  // erase (source and destination subviews render the same)
+  memref.copy %2, %3 : memref<4xf32, strided<[?], offset: ?>> to memref<4xf32, strided<[?], offset: 3>>
+  // keep (source and destination differ)
+  memref.copy %3, %4 : memref<4xf32, strided<[?], offset: 3>> to memref<4xf32, strided<[?], offset: 3>>
+  return
+}
+
+// CHECK-LABEL: func.func @self_copy_subview(
+// CHECK-SAME: [[varg0:%.*]]: memref<?xf32>, [[varg1:%.*]]: memref<?xf32>, [[varg2:%.*]]: index) {
+  // CHECK: [[vsubview:%.*]] = memref.subview [[varg0]][3] [4] [2]
+  // CHECK: [[vsubview_0:%.*]] = memref.subview [[varg0]][3] [4] [[[varg2]]]
+  // CHECK: [[vsubview_1:%.*]] = memref.subview [[varg0]][3] [4] [[[varg2]]]
+  // CHECK: [[vsubview_2:%.*]] = memref.subview [[varg1]][3] [4] [[[varg2]]]
+  // CHECK-NEXT: memref.copy [[vsubview_0]], [[vsubview]]
+  // CHECK-NEXT: memref.copy [[vsubview_1]], [[vsubview_2]]
+  // CHECK-NEXT: return
+
+// -----
+
 // CHECK-LABEL: func @empty_copy
 //  CHECK-NEXT:   return
 func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {

>From c78b154676c2748a785e2dd35710305f3f6c9524 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Wed, 5 Feb 2025 14:26:07 +0100
Subject: [PATCH 2/4] clang-format

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 40 +++++++++++++-----------
 1 file changed, 22 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f94c588f99361..12755749f29f1 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -825,29 +825,33 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
   LogicalResult matchAndRewrite(CopyOp copyOp,
                                 PatternRewriter &rewriter) const override {
     if (copyOp.getSource() != copyOp.getTarget()) {
-      // If the source and target are SubViews and they are identical, we can fold.
+      // If the source and target are SubViews and they are identical, we can
+      // fold.
       auto source = copyOp.getSource().getDefiningOp<SubViewOp>();
       auto target = copyOp.getTarget().getDefiningOp<SubViewOp>();
-      if (!source || !target ||
-          source.getSource() != target.getSource() ||
+      if (!source || !target || source.getSource() != target.getSource() ||
           llvm::any_of(llvm::zip(source.getOffsets(), target.getOffsets()),
-                    [](std::tuple<Value, Value> offsetPair) {
-                      return std::get<0>(offsetPair) != std::get<1>(offsetPair);
-                    }) ||
-          llvm::any_of(llvm::zip(source.getStaticOffsets(), target.getStaticOffsets()),
-                    [](std::tuple<int64_t, int64_t> offsetPair) {
-                      return std::get<0>(offsetPair) != std::get<1>(offsetPair);
-                    }) ||
+                       [](std::tuple<Value, Value> offsetPair) {
+                         return std::get<0>(offsetPair) !=
+                                std::get<1>(offsetPair);
+                       }) ||
+          llvm::any_of(
+              llvm::zip(source.getStaticOffsets(), target.getStaticOffsets()),
+              [](std::tuple<int64_t, int64_t> offsetPair) {
+                return std::get<0>(offsetPair) != std::get<1>(offsetPair);
+              }) ||
           // sizes must be the same anyway
           llvm::any_of(llvm::zip(source.getStrides(), target.getStrides()),
-                    [](std::tuple<Value, Value> stridePair) {
-                      return std::get<0>(stridePair) != std::get<1>(stridePair);
-                    }) ||
-          llvm::any_of(llvm::zip(source.getStaticStrides(), target.getStaticStrides()),
-                    [](std::tuple<int64_t, int64_t> stridePair) {
-                      return std::get<0>(stridePair) != std::get<1>(stridePair);
-                    })) {
-          return failure();
+                       [](std::tuple<Value, Value> stridePair) {
+                         return std::get<0>(stridePair) !=
+                                std::get<1>(stridePair);
+                       }) ||
+          llvm::any_of(
+              llvm::zip(source.getStaticStrides(), target.getStaticStrides()),
+              [](std::tuple<int64_t, int64_t> stridePair) {
+                return std::get<0>(stridePair) != std::get<1>(stridePair);
+              })) {
+        return failure();
       }
     }
 

>From 011b2b4ca53bf046387c1b3dba374572048bea72 Mon Sep 17 00:00:00 2001
From: "Schlimbach, Frank" <frank.schlimbach at intel.com>
Date: Fri, 7 Feb 2025 09:43:29 +0100
Subject: [PATCH 3/4] more readable by smiplification

---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 35 ++++++++----------------
 1 file changed, 11 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 12755749f29f1..18d39195443e6 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -825,32 +825,19 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
   LogicalResult matchAndRewrite(CopyOp copyOp,
                                 PatternRewriter &rewriter) const override {
     if (copyOp.getSource() != copyOp.getTarget()) {
-      // If the source and target are SubViews and they are identical, we can
-      // fold.
+      // We can still fold if source and target are similar SubViews.
       auto source = copyOp.getSource().getDefiningOp<SubViewOp>();
       auto target = copyOp.getTarget().getDefiningOp<SubViewOp>();
-      if (!source || !target || source.getSource() != target.getSource() ||
-          llvm::any_of(llvm::zip(source.getOffsets(), target.getOffsets()),
-                       [](std::tuple<Value, Value> offsetPair) {
-                         return std::get<0>(offsetPair) !=
-                                std::get<1>(offsetPair);
-                       }) ||
-          llvm::any_of(
-              llvm::zip(source.getStaticOffsets(), target.getStaticOffsets()),
-              [](std::tuple<int64_t, int64_t> offsetPair) {
-                return std::get<0>(offsetPair) != std::get<1>(offsetPair);
-              }) ||
-          // sizes must be the same anyway
-          llvm::any_of(llvm::zip(source.getStrides(), target.getStrides()),
-                       [](std::tuple<Value, Value> stridePair) {
-                         return std::get<0>(stridePair) !=
-                                std::get<1>(stridePair);
-                       }) ||
-          llvm::any_of(
-              llvm::zip(source.getStaticStrides(), target.getStaticStrides()),
-              [](std::tuple<int64_t, int64_t> stridePair) {
-                return std::get<0>(stridePair) != std::get<1>(stridePair);
-              })) {
+      if (!source || !target) {
+        return failure();
+      }
+      if (source.getSource() != target.getSource() ||
+          source.getOffsets() != target.getOffsets() ||
+          source.getStaticOffsets() != target.getStaticOffsets() ||
+          source.getStrides() != target.getStrides() ||
+          source.getStaticStrides() != target.getStaticStrides()) {
+        // By copy semantics, sizes of source and target must be the same
+        // -> no need to check sizes.
         return failure();
       }
     }

>From 438e6e69fd3e2dd773e00eca004ede3f017f66b8 Mon Sep 17 00:00:00 2001
From: Frank Schlimbach <frank.schlimbach at intel.com>
Date: Fri, 28 Feb 2025 12:06:54 +0100
Subject: [PATCH 4/4] Update mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 25cc201fce771..8ca03c8589fb0 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -828,9 +828,8 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
       // We can still fold if source and target are similar SubViews.
       auto source = copyOp.getSource().getDefiningOp<SubViewOp>();
       auto target = copyOp.getTarget().getDefiningOp<SubViewOp>();
-      if (!source || !target) {
+      if (!source || !target)
         return failure();
-      }
       if (source.getSource() != target.getSource() ||
           source.getOffsets() != target.getOffsets() ||
           source.getStaticOffsets() != target.getStaticOffsets() ||



More information about the Mlir-commits mailing list