[Mlir-commits] [mlir] 90d2f8c - [mlir][vector] Teach `TransferOptimization` to look through trivial aliases (#87805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 16 02:53:18 PDT 2024


Author: Benjamin Maxwell
Date: 2024-05-16T10:53:14+01:00
New Revision: 90d2f8c630e1ddddd034e4a0e575929c08dd26bf

URL: https://github.com/llvm/llvm-project/commit/90d2f8c630e1ddddd034e4a0e575929c08dd26bf
DIFF: https://github.com/llvm/llvm-project/commit/90d2f8c630e1ddddd034e4a0e575929c08dd26bf.diff

LOG: [mlir][vector] Teach `TransferOptimization` to look through trivial aliases (#87805)

This allows `TransferOptimization` to eliminate and forward stores that
are to trivial aliases (rather than just to identical memref values).

A trivial aliases is (currently) defined as:

  1. A `memref.cast`
  2. A `memref.subview` with a zero offset and unit strides
  3. A chain of 1 and 2

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
    mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
    mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 46003ed846869..1c094c1c1328a 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -22,6 +22,9 @@ namespace mlir {
 
 class MemRefType;
 
+/// A value with a memref type.
+using MemrefValue = TypedValue<BaseMemRefType>;
+
 namespace memref {
 
 /// Returns true, if the memref type has static shapes and represents a
@@ -93,6 +96,20 @@ computeStridesIRBlock(Location loc, OpBuilder &builder,
   return computeSuffixProductIRBlock(loc, builder, sizes);
 }
 
+/// Walk up the source chain until an operation that changes/defines the view of
+/// memory is found (i.e. skip operations that alias the entire view).
+MemrefValue skipFullyAliasingOperations(MemrefValue source);
+
+/// Checks if two (memref) values are the same or are statically known to alias
+/// the same region of memory.
+inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
+  return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
+}
+
+/// Walk up the source chain until something an op other than a `memref.subview`
+/// or `memref.cast` is found.
+MemrefValue skipSubViewsAndCasts(MemrefValue source);
+
 } // namespace memref
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index c93e5a9dcd39f..05d5ca2ce12f4 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -178,5 +178,35 @@ computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
   return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
 }
 
+MemrefValue skipFullyAliasingOperations(MemrefValue source) {
+  while (auto op = source.getDefiningOp()) {
+    if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
+        subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
+      // A `memref.subview` with an all zero offset, and all unit strides, still
+      // points to the same memory.
+      source = cast<MemrefValue>(subViewOp.getSource());
+    } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
+      // A `memref.cast` still points to the same memory.
+      source = castOp.getSource();
+    } else {
+      return source;
+    }
+  }
+  return source;
+}
+
+MemrefValue skipSubViewsAndCasts(MemrefValue source) {
+  while (auto op = source.getDefiningOp()) {
+    if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
+      source = cast<MemrefValue>(subView.getSource());
+    } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
+      source = cast.getSource();
+    } else {
+      return source;
+    }
+  }
+  return source;
+}
+
 } // namespace memref
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 0ffef6aabccc1..997b56a1ce142 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -104,10 +105,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
                     << "\n");
   llvm::SmallVector<Operation *, 8> blockingAccesses;
   Operation *firstOverwriteCandidate = nullptr;
-  Value source = write.getSource();
-  // Skip subview ops.
-  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
-    source = subView.getSource();
+  Value source =
+      memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -116,8 +115,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
-      users.append(subView->getUsers().begin(), subView->getUsers().end());
+    if (isa<memref::SubViewOp, memref::CastOp>(user)) {
+      users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
     if (isMemoryEffectFree(user))
@@ -126,7 +125,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
       continue;
     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
       // Check candidate that can override the store.
-      if (write.getSource() == nextWrite.getSource() &&
+      if (memref::isSameViewOrTrivialAlias(
+              cast<MemrefValue>(nextWrite.getSource()),
+              cast<MemrefValue>(write.getSource())) &&
           checkSameValueWAW(nextWrite, write) &&
           postDominators.postDominates(nextWrite, write)) {
         if (firstOverwriteCandidate == nullptr ||
@@ -191,10 +192,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
                     << "\n");
   SmallVector<Operation *, 8> blockingWrites;
   vector::TransferWriteOp lastwrite = nullptr;
-  Value source = read.getSource();
-  // Skip subview ops.
-  while (auto subView = source.getDefiningOp<memref::SubViewOp>())
-    source = subView.getSource();
+  Value source =
+      memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
   llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
                                            source.getUsers().end());
   llvm::SmallDenseSet<Operation *, 32> processed;
@@ -203,12 +202,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
     // If the user has already been processed skip.
     if (!processed.insert(user).second)
       continue;
-    if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
-      users.append(subView->getUsers().begin(), subView->getUsers().end());
-      continue;
-    }
-    if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
-      users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
+    if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
+      users.append(user->getUsers().begin(), user->getUsers().end());
       continue;
     }
     if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
@@ -221,7 +216,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
               cast<VectorTransferOpInterface>(read.getOperation()),
               /*testDynamicValueUsingBounds=*/true))
         continue;
-      if (write.getSource() == read.getSource() &&
+      if (memref::isSameViewOrTrivialAlias(
+              cast<MemrefValue>(read.getSource()),
+              cast<MemrefValue>(write.getSource())) &&
           dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
           lastwrite = write;

diff  --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 0719c0dd17427..3ddfacf40cf64 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -485,3 +485,33 @@ func.func @forward_dead_constant_splat_store_with_masking_negative_3(%buffer : m
   vector.transfer_write %x, %buffer[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
   return
 }
+
+// Here each read/write is to a 
diff erent subview, but they all point to exact
+// same bit of memory (just through casts and subviews with unit strides and
+// zero offsets).
+// CHECK-LABEL: func @forward_and_eliminate_stores_through_trivial_aliases
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func.func @forward_and_eliminate_stores_through_trivial_aliases(
+  %buffer : memref<?x?xf32>, %vec: vector<[8]x[8]xf32>, %size: index, %a_size: index, %another_size: index
+) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c32 = arith.constant 32 : index
+  %cst = arith.constant 0.0 : f32
+  vector.transfer_write %vec, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  %direct_subview = memref.subview %buffer[0, 0] [%a_size, %a_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %cast = memref.cast %direct_subview : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32>
+  %subview_of_cast = memref.subview %cast[0, 0] [%another_size, %another_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
+  %21 = vector.transfer_read %direct_subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<[8]x[8]xf32>
+  %23 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %21) -> (vector<[8]x[8]xf32>) {
+    %24 = arith.addf %arg3, %arg3 : vector<[8]x[8]xf32>
+    scf.yield %24 : vector<[8]x[8]xf32>
+  }
+  vector.transfer_write %23, %subview_of_cast[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>
+  return
+}


        


More information about the Mlir-commits mailing list