[Mlir-commits] [mlir] [mlir][vector] Add memref reshapes to transfer flow opt (PR #110521)

Quinn Dawkins llvmlistbot at llvm.org
Mon Sep 30 08:23:51 PDT 2024


https://github.com/qedawkins created https://github.com/llvm/llvm-project/pull/110521

`vector.transfer_*` folding and forwarding currently does not take into account reshaping view-like memref ops (expand and collapse shape), leading to potentially invalid store folding or value forwarding. This patch adds tracking for those ops, however the TransferOptimization patterns in general still don't properly account for other potential aliasing ops or non-transfer read/write ops and needs a redesign.

>From 8cd19f82ed27c82381c5bfc942367f461ac01065 Mon Sep 17 00:00:00 2001
From: Quinn Dawkins <quinn at nod-labs.com>
Date: Mon, 30 Sep 2024 10:52:33 -0400
Subject: [PATCH] [mlir][vector] Add memref reshapes to transfer flow opt

`vector.transfer_*` folding and forwarding currently does not take into
account reshaping view-like memref ops (expand and collapse shape),
leading to potentially invalid store folding or value forwarding. This
patch adds tracking for those ops, however the TransferOptimization
patterns in general still don't properly account for other potential
aliasing ops or non-transfer read/write ops and needs a redesign.
---
 .../mlir/Dialect/MemRef/Utils/MemRefUtils.h   |  6 +-
 mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 10 ++-
 .../Transforms/VectorTransferOpTransforms.cpp | 12 +--
 .../Dialect/Vector/vector-transferop-opt.mlir | 86 +++++++++++++++++++
 4 files changed, 102 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 1c094c1c1328a9..9b5211a2b6138f 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -106,9 +106,9 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
   return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
 }
 
-/// Walk up the source chain until something an op other than a `memref.subview`
-/// or `memref.cast` is found.
-MemrefValue skipSubViewsAndCasts(MemrefValue source);
+/// Walk up the source chain until something an op other than a `memref.cast`,
+/// `memref.subview`, or `memref.expand/collapse_shape` is found.
+MemrefValue skipViewLikeOps(MemrefValue source);
 
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 68edd45448ee5f..ca75ac26fe96aa 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -193,12 +193,16 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
   return source;
 }
 
-MemrefValue skipSubViewsAndCasts(MemrefValue source) {
+MemrefValue skipViewLikeOps(MemrefValue source) {
   while (auto op = source.getDefiningOp()) {
     if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
       source = cast<MemrefValue>(subView.getSource());
-    } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
-      source = cast.getSource();
+    } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
+      source = castOp.getSource();
+    } else if (auto collapse = dyn_cast<memref::CollapseShapeOp>(op)) {
+      source = cast<MemrefValue>(collapse.getSrc());
+    } else if (auto expand = dyn_cast<memref::ExpandShapeOp>(op)) {
+      source = cast<MemrefValue>(expand.getSrc());
     } else {
       return source;
     }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 4c93d3841bf878..a0c83b7cea52a2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -105,8 +105,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
                     << "\n");
   llvm::SmallVector<Operation *, 8> blockingAccesses;
   Operation *firstOverwriteCandidate = nullptr;
-  Value source =
-      memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
+  Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -115,7 +114,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (isa<memref::SubViewOp, memref::CastOp>(user)) {
+    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::ExpandShapeOp,
+            memref::CastOp>(user)) {
       users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
@@ -192,8 +192,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
                     << "\n");
   SmallVector<Operation *, 8> blockingWrites;
   vector::TransferWriteOp lastwrite = nullptr;
-  Value source =
-      memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
+  Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -202,7 +201,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
+    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::ExpandShapeOp,
+            memref::CastOp>(user)) {
       users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 07e6647533c6fe..977e4caa499d49 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -257,6 +257,92 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
   return
 }
 
+// The same regression test for expand_shape.
+
+// CHECK-LABEL:  func.func @expand_shape
+//       CHECK:    scf.for {{.*}} {
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+
+func.func @expand_shape(%in_0: memref<20xi32>, %vec: vector<1x4x1xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c20 = arith.constant 20 : index
+
+  %alloca = memref.alloca() {alignment = 64 : i64} : memref<4xi32>
+  %expand_shape = memref.expand_shape %alloca [[0, 1, 2]] output_shape [1, 4, 1] : memref<4xi32> into memref<1x4x1xi32>
+  scf.for %arg0 = %c0 to %c20 step %c4 {
+    %subview = memref.subview %in_0[%arg0] [4] [1] : memref<20xi32> to memref<4xi32, strided<[1], offset: ?>>
+    %1 = vector.transfer_read %subview[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32, strided<[1], offset: ?>>, vector<4xi32>
+    // $alloca and $expand_shape alias
+    vector.transfer_write %1, %alloca[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
+    vector.transfer_write %vec, %expand_shape[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
+    %2 = vector.transfer_read %alloca[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32>, vector<4xi32>
+    vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32, strided<[1], offset: ?>>
+  }
+  return
+}
+
+// The same regression test, but the initial write is to the collapsed memref.
+
+// CHECK-LABEL:  func.func @collapse_shape_of_source
+//       CHECK:    scf.for {{.*}} {
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+
+func.func @collapse_shape_of_source(%in_0: memref<20xi32>, %vec: vector<1x4x1xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c20 = arith.constant 20 : index
+
+  %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x4x1xi32>
+  %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x4x1xi32> into memref<4xi32>
+  scf.for %arg0 = %c0 to %c20 step %c4 {
+    %subview = memref.subview %in_0[%arg0] [4] [1] : memref<20xi32> to memref<4xi32, strided<[1], offset: ?>>
+    %1 = vector.transfer_read %subview[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32, strided<[1], offset: ?>>, vector<4xi32>
+    vector.transfer_write %1, %collapse_shape[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
+    vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
+    %2 = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32>, vector<4xi32>
+    vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32, strided<[1], offset: ?>>
+  }
+  return
+}
+
+// CHECK-LABEL:  func.func @expand_shape_of_source
+//       CHECK:    scf.for {{.*}} {
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_write
+//       CHECK:      vector.transfer_read
+//       CHECK:      vector.transfer_write
+
+func.func @expand_shape_of_source(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
+  %c0_i32 = arith.constant 0 : i32
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c20 = arith.constant 20 : index
+
+  %alloca = memref.alloca() {alignment = 64 : i64} : memref<4xi32>
+  %expand_shape = memref.expand_shape %alloca [[0, 1, 2]] output_shape [1, 4, 1] : memref<4xi32> into memref<1x4x1xi32>
+  scf.for %arg0 = %c0 to %c20 step %c4 {
+    %subview = memref.subview %in_0[0, %arg0, 0] [1, 4, 1] [1, 1, 1] : memref<1x20x1xi32> to memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
+    %1 = vector.transfer_read %subview[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>, vector<1x4x1xi32>
+    vector.transfer_write %1, %expand_shape[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
+    vector.transfer_write %vec, %alloca[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
+    %2 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32>, vector<1x4x1xi32>
+    vector.transfer_write %2, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
+  }
+  return
+}
+
 // CHECK-LABEL: func @forward_dead_store_dynamic_same_index
 //   CHECK-NOT:   vector.transfer_write
 //   CHECK-NOT:   vector.transfer_read



More information about the Mlir-commits mailing list