[Mlir-commits] [mlir] [mlir][vector] Add memref reshapes to transfer flow opt (PR #110521)

Quinn Dawkins llvmlistbot at llvm.org
Mon Sep 30 20:49:55 PDT 2024


https://github.com/qedawkins updated https://github.com/llvm/llvm-project/pull/110521

>From 8cd19f82ed27c82381c5bfc942367f461ac01065 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 30 Sep 2024 10:52:33 -0400
Subject: [PATCH 1/2] [mlir][vector] Add memref reshapes to transfer flow opt

`vector.transfer_*` folding and forwarding currently does not take into
account reshaping view-like memref ops (expand and collapse shape),
leading to potentially invalid store folding or value forwarding. This
patch adds tracking for those ops, however the TransferOptimization
patterns in general still don't properly account for other potential
aliasing ops or non-transfer read/write ops and needs a redesign.
---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   |  6 +-
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 10 ++-
 .../Transforms/VectorTransferOpTransforms.cpp | 12 +--
 .../Dialect/Vector/vector-transferop-opt.mlir | 86 +++++++++++++++++++
 4 files changed, 102 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 1c094c1c1328a9..9b5211a2b6138f 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -106,9 +106,9 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
   return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
 }
 
-/// Walk up the source chain until something an op other than a `memref.subview`
-/// or `memref.cast` is found.
-MemrefValue skipSubViewsAndCasts(MemrefValue source);
+/// Walk up the source chain until something an op other than a `memref.cast`,
+/// `memref.subview`, or `memref.expand/collapse_shape` is found.
+MemrefValue skipViewLikeOps(MemrefValue source);
 
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 68edd45448ee5f..ca75ac26fe96aa 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -193,12 +193,16 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
   return source;
 }
 
-MemrefValue skipSubViewsAndCasts(MemrefValue source) {
+MemrefValue skipViewLikeOps(MemrefValue source) {
   while (auto op = source.getDefiningOp()) {
     if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
       source = cast<MemrefValue>(subView.getSource());
-    } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
-      source = cast.getSource();
+    } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
+      source = castOp.getSource();
+    } else if (auto collapse = dyn_cast<memref::CollapseShapeOp>(op)) {
+      source = cast<MemrefValue>(collapse.getSrc());
+    } else if (auto expand = dyn_cast<memref::ExpandShapeOp>(op)) {
+      source = cast<MemrefValue>(expand.getSrc());
     } else {
       return source;
     }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 4c93d3841bf878..a0c83b7cea52a2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -105,8 +105,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
                     << "\n");
   llvm::SmallVector<Operation *, 8> blockingAccesses;
   Operation *firstOverwriteCandidate = nullptr;
-  Value source =
-      memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
+  Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -115,7 +114,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (isa<memref::SubViewOp, memref::CastOp>(user)) {
+    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::ExpandShapeOp,
+            memref::CastOp>(user)) {
       users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
@@ -192,8 +192,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
                     << "\n");
   SmallVector<Operation *, 8> blockingWrites;
   vector::TransferWriteOp lastwrite = nullptr;
-  Value source =
-      memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
+  Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -202,7 +201,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
+    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::ExpandShapeOp,
+            memref::CastOp>(user)) {
       users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 07e6647533c6fe..977e4caa499d49 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -257,6 +257,92 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
   return
 }
 
+// The same regression test for expand_shape.
+
+// CHECK-LABEL:  func.func @expand_shape
+//       CHECK:    scf.for {{.*}} {
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+
+func.func @expand_shape(%in_0: memref<20xi32>, %vec: vector<1x4x1xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c20 = arith.constant 20 : index
+
+  %alloca = memref.alloca() {alignment = 64 : i64} : memref<4xi32>
+  %expand_shape = memref.expand_shape %alloca [[0, 1, 2]] output_shape [1, 4, 1] : memref<4xi32> into memref<1x4x1xi32>
+  scf.for %arg0 = %c0 to %c20 step %c4 {
+    %subview = memref.subview %in_0[%arg0] [4] [1] : memref<20xi32> to memref<4xi32, strided<[1], offset: ?>>
+    %1 = vector.transfer_read %subview[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32, strided<[1], offset: ?>>, vector<4xi32>
+    // $alloca and $expand_shape alias
+    vector.transfer_write %1, %alloca[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
+    vector.transfer_write %vec, %expand_shape[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
+    %2 = vector.transfer_read %alloca[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32>, vector<4xi32>
+    vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32, strided<[1], offset: ?>>
+  }
+  return
+}
+
+// The same regression test, but the initial write is to the collapsed memref.
+
+// CHECK-LABEL:  func.func @collapse_shape_of_source
+//       CHECK:    scf.for {{.*}} {
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+
+func.func @collapse_shape_of_source(%in_0: memref<20xi32>, %vec: vector<1x4x1xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c20 = arith.constant 20 : index
+
+  %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x4x1xi32>
+  %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x4x1xi32> into memref<4xi32>
+  scf.for %arg0 = %c0 to %c20 step %c4 {
+    %subview = memref.subview %in_0[%arg0] [4] [1] : memref<20xi32> to memref<4xi32, strided<[1], offset: ?>>
+    %1 = vector.transfer_read %subview[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32, strided<[1], offset: ?>>, vector<4xi32>
+    vector.transfer_write %1, %collapse_shape[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
+    vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
+    %2 = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32>, vector<4xi32>
+    vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32, strided<[1], offset: ?>>
+  }
+  return
+}
+
+// CHECK-LABEL:  func.func @expand_shape_of_source
+//       CHECK:    scf.for {{.*}} {
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+
+func.func @expand_shape_of_source(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c20 = arith.constant 20 : index
+
+  %alloca = memref.alloca() {alignment = 64 : i64} : memref<4xi32>
+  %expand_shape = memref.expand_shape %alloca [[0, 1, 2]] output_shape [1, 4, 1] : memref<4xi32> into memref<1x4x1xi32>
+  scf.for %arg0 = %c0 to %c20 step %c4 {
+    %subview = memref.subview %in_0[0, %arg0, 0] [1, 4, 1] [1, 1, 1] : memref<1x20x1xi32> to memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
+    %1 = vector.transfer_read %subview[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>, vector<1x4x1xi32>
+    vector.transfer_write %1, %expand_shape[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
+    vector.transfer_write %vec, %alloca[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
+    %2 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32>, vector<1x4x1xi32>
+    vector.transfer_write %2, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
+  }
+  return
+}
+
 // CHECK-LABEL: func @forward_dead_store_dynamic_same_index
 //   CHECK-NOT:   vector.transfer_write
 //   CHECK-NOT:   vector.transfer_read

>From 691e2c70c7889271ef553d513eb01ccec3d2533e Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 30 Sep 2024 23:49:43 -0400
Subject: [PATCH 2/2] Switch to view-like op interface

---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h       |  4 ++--
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp     | 15 +++++----------
 .../Transforms/VectorTransferOpTransforms.cpp     |  6 ++----
 3 files changed, 9 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 9b5211a2b6138f..ca3326dbbef519 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -106,8 +106,8 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
   return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
 }
 
-/// Walk up the source chain until something an op other than a `memref.cast`,
-/// `memref.subview`, or `memref.expand/collapse_shape` is found.
+/// Walk up the source chain until we find an operation that is not a view of
+/// the source memref (i.e. implements ViewLikeOpInterface).
 MemrefValue skipViewLikeOps(MemrefValue source);
 
 } // namespace memref
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index ca75ac26fe96aa..7321b19068016c 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
 
 namespace mlir {
@@ -195,17 +196,11 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
 
 MemrefValue skipViewLikeOps(MemrefValue source) {
   while (auto op = source.getDefiningOp()) {
-    if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
-      source = cast<MemrefValue>(subView.getSource());
-    } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
-      source = castOp.getSource();
-    } else if (auto collapse = dyn_cast<memref::CollapseShapeOp>(op)) {
-      source = cast<MemrefValue>(collapse.getSrc());
-    } else if (auto expand = dyn_cast<memref::ExpandShapeOp>(op)) {
-      source = cast<MemrefValue>(expand.getSrc());
-    } else {
-      return source;
+    if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
+      source = cast<MemrefValue>(viewLike.getViewSource());
+      continue;
     }
+    return source;
   }
   return source;
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index a0c83b7cea52a2..e05c801121ffc4 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -114,8 +114,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::ExpandShapeOp,
-            memref::CastOp>(user)) {
+    if (isa<ViewLikeOpInterface>(user)) {
       users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
@@ -201,8 +200,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::ExpandShapeOp,
-            memref::CastOp>(user)) {
+    if (isa<ViewLikeOpInterface>(user)) {
       users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }



More information about the Mlir-commits mailing list