[llvm-branch-commits] [mlir] 080943f - [mlir][vector] Support transfer op on tensor optimizations

Thomas Raoux via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jan 6 15:14:19 PST 2021


Author: Thomas Raoux
Date: 2021-01-06T15:09:03-08:00
New Revision: 080943f7525f277579a000cf30364cc96fba6773

URL: https://github.com/llvm/llvm-project/commit/080943f7525f277579a000cf30364cc96fba6773
DIFF: https://github.com/llvm/llvm-project/commit/080943f7525f277579a000cf30364cc96fba6773.diff

LOG: [mlir][vector] Support transfer op on tensor optimizations

Support store to load forwarding and dead store transformations for transfer op
on tensor.

Differential Revision: https://reviews.llvm.org/D94148

Added: 
    

Modified: 
    mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
index ea1189d53b31..161d02cd3435 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -34,13 +34,33 @@ static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
   return op;
 }
 
+/// Return true if the transfer_write fully writes the data accessed by the
+/// transfer_read.
+static bool transferEncompasses(vector::TransferWriteOp defWrite,
+                                vector::TransferReadOp read) {
+  return !defWrite.hasMaskedDim() && defWrite.indices() == read.indices() &&
+         defWrite.getVectorType() == read.getVectorType() &&
+         defWrite.permutation_map() == read.permutation_map();
+}
+
+/// Return true if the write op fully over-write the priorWrite transfer_write
+/// op.
+static bool transferEncompasses(vector::TransferWriteOp write,
+                                vector::TransferWriteOp priorWrite) {
+  return priorWrite.indices() == write.indices() &&
+         priorWrite.getVectorType() == write.getVectorType() &&
+         priorWrite.permutation_map() == write.permutation_map();
+}
+
 namespace {
 
 class TransferOptimization {
 public:
   TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
   void deadStoreOp(vector::TransferWriteOp);
+  void deadStoreOpTensor(vector::TransferWriteOp);
   void storeToLoadForwarding(vector::TransferReadOp);
+  void storeToLoadForwardingTensor(vector::TransferReadOp);
   void removeDeadOp() {
     for (Operation *op : opToErase)
       op->erase();
@@ -99,9 +119,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
       continue;
     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
       // Check candidate that can override the store.
-      if (write.indices() == nextWrite.indices() &&
-          write.getVectorType() == nextWrite.getVectorType() &&
-          write.permutation_map() == write.permutation_map() &&
+      if (transferEncompasses(nextWrite, write) &&
           postDominators.postDominates(nextWrite, write)) {
         if (firstOverwriteCandidate == nullptr ||
             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
@@ -173,10 +191,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
               cast<VectorTransferOpInterface>(write.getOperation()),
               cast<VectorTransferOpInterface>(read.getOperation())))
         continue;
-      if (dominators.dominates(write, read) && !write.hasMaskedDim() &&
-          write.indices() == read.indices() &&
-          write.getVectorType() == read.getVectorType() &&
-          write.permutation_map() == read.permutation_map()) {
+      if (dominators.dominates(write, read) &&
+          transferEncompasses(write, read)) {
         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
           lastwrite = write;
         else
@@ -214,15 +230,62 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
+/// Walk up the SSA links, if any write gets fully overwritten we can skip it.
+/// If it has no more uses it becomes dead.
+void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) {
+  auto defWrite = write.source().getDefiningOp<vector::TransferWriteOp>();
+  while (defWrite) {
+    if (transferEncompasses(write, defWrite)) {
+      write.sourceMutable().assign(defWrite.source());
+      if (defWrite->use_empty())
+        opToErase.push_back(defWrite.getOperation());
+      return;
+    }
+    if (!isDisjointTransferIndices(
+            cast<VectorTransferOpInterface>(defWrite.getOperation()),
+            cast<VectorTransferOpInterface>(write.getOperation())))
+      break;
+    defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+  }
+}
+
+/// Walk up the SSA links, if any write fully match the written vector we can
+/// replace the read by the vector. The read becomes dead and can be removed.
+void TransferOptimization::storeToLoadForwardingTensor(
+    vector::TransferReadOp read) {
+  auto defWrite = read.source().getDefiningOp<vector::TransferWriteOp>();
+  while (defWrite) {
+    if (transferEncompasses(defWrite, read)) {
+      read.replaceAllUsesWith(defWrite.vector());
+      opToErase.push_back(read.getOperation());
+      return;
+    }
+    if (!isDisjointTransferIndices(
+            cast<VectorTransferOpInterface>(defWrite.getOperation()),
+            cast<VectorTransferOpInterface>(read.getOperation())))
+      break;
+    defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+  }
+}
+
 } // namespace
 
 void mlir::vector::transferOpflowOpt(FuncOp func) {
   TransferOptimization opt(func);
   // Run store to load forwarding first since it can expose more dead store
   // opportunity.
-  func.walk(
-      [&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); });
+  func.walk([&](vector::TransferReadOp read) {
+    if (read.getShapedType().isa<MemRefType>())
+      opt.storeToLoadForwarding(read);
+    else
+      opt.storeToLoadForwardingTensor(read);
+  });
   opt.removeDeadOp();
-  func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); });
+  func.walk([&](vector::TransferWriteOp write) {
+    if (write.getShapedType().isa<MemRefType>())
+      opt.deadStoreOp(write);
+    else
+      opt.deadStoreOpTensor(write);
+  });
   opt.removeDeadOp();
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 0ed061cab4d8..30464a135e29 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -13,16 +13,16 @@ func @forward_dead_store(%arg0: i1, %arg1 : memref<4x4xf32>,
   %c4 = constant 4 : index
   %c0 = constant 0 : index
   %cf0 = constant 0.0 : f32
-  vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : 
+  vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
     vector<1x4xf32>, memref<4x4xf32>
-  %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} : 
+  %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
     memref<4x4xf32>, vector<1x4xf32>
-  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) 
+  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
     -> (vector<1x4xf32>) {
     %1 = addf %acc, %acc : vector<1x4xf32>
     scf.yield %1 : vector<1x4xf32>
   }
-  vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : 
+  vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
     vector<1x4xf32>, memref<4x4xf32>
   return
 }
@@ -103,7 +103,7 @@ func @forward_nested_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
 //       CHECK:   vector.transfer_read
 //       CHECK:   return
 func @dead_store_region(%arg0: i1, %arg1 : memref<4x4xf32>,
-  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) 
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index)
   -> (vector<1x4xf32>) {
   %c0 = constant 0 : index
   %c1 = constant 1 : index
@@ -184,3 +184,56 @@ func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
   return
 }
 
+// CHECK-LABEL: func @forward_dead_store_tensor
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   %[[VTW:.*]] = vector.transfer_write
+//       CHECK:   return %[[VTW]] : tensor<4x4xf32>
+func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
+    tensor<4x4xf32>, vector<1x4xf32>
+  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
+    -> (vector<1x4xf32>) {
+    %1 = addf %acc, %acc : vector<1x4xf32>
+    scf.yield %1 : vector<1x4xf32>
+  }
+  %w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  return %w1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func @forward_dead_store_negative_tensor
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   %[[VTW:.*]] = vector.transfer_write
+//       CHECK:   return %[[VTW]] : tensor<4x4xf32>
+func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg1[%c1, %i] {masked = [false, false]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
+    tensor<4x4xf32>, vector<1x4xf32>
+  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
+    -> (vector<1x4xf32>) {
+    %1 = addf %acc, %acc : vector<1x4xf32>
+    scf.yield %1 : vector<1x4xf32>
+  }
+  %w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  return %w1 : tensor<4x4xf32>
+}


        


More information about the llvm-branch-commits mailing list