[Mlir-commits] [mlir] 369c51a - [mlir][vector] Add transfer_op LoadToStore forwarding and deadStore optimizations

Thomas Raoux llvmlistbot at llvm.org
Fri Nov 20 12:02:02 PST 2020


Author: Thomas Raoux
Date: 2020-11-20T11:59:01-08:00
New Revision: 369c51a74b5327464e27e0749ca7ac59ac1349ce

URL: https://github.com/llvm/llvm-project/commit/369c51a74b5327464e27e0749ca7ac59ac1349ce
DIFF: https://github.com/llvm/llvm-project/commit/369c51a74b5327464e27e0749ca7ac59ac1349ce.diff

LOG: [mlir][vector] Add transfer_op LoadToStore forwarding and deadStore optimizations

Add transformation to be able to forward transfer_write into transfer_read
operation and to be able to remove dead transfer_write when a transfer_write is
overwritten before being read.

Differential Revision: https://reviews.llvm.org/D91321

Added: 
    mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
    mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorTransforms.h
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Dialect/Vector/CMakeLists.txt
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/lib/Transforms/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
index ae619c1ee41d..5c2edecdbc7e 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h
@@ -268,6 +268,10 @@ struct PointwiseExtractPattern : public OpRewritePattern<ExtractMapOp> {
   FilterConstraintType filter;
 };
 
+/// Implements transfer op write to read forwarding and dead transfer write
+/// optimizations.
+void transferOpflowOpt(FuncOp func);
+
 } // namespace vector
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 448004db32fa..f70fba819b66 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -25,6 +25,7 @@ class OpBuilder;
 class Operation;
 class Value;
 class VectorType;
+class VectorTransferOpInterface;
 
 /// Return the number of elements of basis, `0` if empty.
 int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
@@ -159,6 +160,11 @@ makePermutationMap(Operation *op, ArrayRef<Value> indices,
 AffineMap getTransferMinorIdentityMap(MemRefType memRefType,
                                       VectorType vectorType);
 
+/// Return true if we can prove that the transfer operations access disjoint
+/// memory.
+bool isDisjointTransferSet(VectorTransferOpInterface transferA,
+                           VectorTransferOpInterface transferB);
+
 namespace matcher {
 
 /// Matches vector.transfer_read, vector.transfer_write and ops that return a

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 5097812423cb..d292f4d9782e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SCF/Utils.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Transforms/LoopUtils.h"
@@ -80,42 +81,6 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
   }
 }
 
-/// Return true if we can prove that the transfer operations access disjoint
-/// memory.
-static bool isDisjoint(VectorTransferOpInterface transferA,
-                       VectorTransferOpInterface transferB) {
-  if (transferA.memref() != transferB.memref())
-    return false;
-  // For simplicity only look at transfer of same type.
-  if (transferA.getVectorType() != transferB.getVectorType())
-    return false;
-  unsigned rankOffset = transferA.getLeadingMemRefRank();
-  for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
-    auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
-    auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();
-    // If any of the indices are dynamic we cannot prove anything.
-    if (!indexA || !indexB)
-      continue;
-
-    if (i < rankOffset) {
-      // For dimension used as index if we can prove that index are 
diff erent we
-      // know we are accessing disjoint slices.
-      if (indexA.getValue().cast<IntegerAttr>().getInt() !=
-          indexB.getValue().cast<IntegerAttr>().getInt())
-        return true;
-    } else {
-      // For this dimension, we slice a part of the memref we need to make sure
-      // the intervals accessed don't overlap.
-      int64_t distance =
-          std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
-                   indexB.getValue().cast<IntegerAttr>().getInt());
-      if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
-        return true;
-    }
-  }
-  return false;
-}
-
 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
   bool changed = true;
   while (changed) {
@@ -185,14 +150,14 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
           continue;
         if (auto transferWriteUse =
                 dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
-          if (!isDisjoint(
+          if (!isDisjointTransferSet(
                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
                   cast<VectorTransferOpInterface>(
                       transferWriteUse.getOperation())))
             return WalkResult::advance();
         } else if (auto transferReadUse =
                        dyn_cast<vector::TransferReadOp>(use.getOwner())) {
-          if (!isDisjoint(
+          if (!isDisjointTransferSet(
                   cast<VectorTransferOpInterface>(transferWrite.getOperation()),
                   cast<VectorTransferOpInterface>(
                       transferReadUse.getOperation())))

diff  --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt
index 7c8c58e3fbfb..5c345fec7204 100644
--- a/mlir/lib/Dialect/Vector/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRVector
   VectorOps.cpp
+  VectorTransferOpTransforms.cpp
   VectorTransforms.cpp
   VectorUtils.cpp
   EDSC/Builders.cpp

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
new file mode 100644
index 000000000000..fd3317ded246
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -0,0 +1,228 @@
+//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements functions concerned with optimizing transfer_read and
+// transfer_write ops.
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/Dialect/Vector/VectorTransforms.h"
+#include "mlir/Dialect/Vector/VectorUtils.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Function.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "vector-transfer-opt"
+
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+
+using namespace mlir;
+
+/// Return the ancestor op in the region or nullptr if the region is not
+/// an ancestor of the op.
+static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
+  for (; op != nullptr && op->getParentRegion() != region;
+       op = op->getParentOp())
+    ;
+  return op;
+}
+
+namespace {
+
+class TransferOptimization {
+public:
+  TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
+  void deadStoreOp(vector::TransferWriteOp);
+  void storeToLoadForwarding(vector::TransferReadOp);
+  void removeDeadOp() {
+    for (Operation *op : opToErase)
+      op->erase();
+    opToErase.clear();
+  }
+
+private:
+  bool isReachable(Operation *start, Operation *dest);
+  DominanceInfo dominators;
+  PostDominanceInfo postDominators;
+  std::vector<Operation *> opToErase;
+};
+
+/// Return true if there is a path from start operation to dest operation,
+/// otherwise return false. The operations have to be in the same region.
+bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
+  assert(start->getParentRegion() == dest->getParentRegion() &&
+         "This function only works for ops i the same region");
+  // Simple case where the start op dominate the destination.
+  if (dominators.dominates(start, dest))
+    return true;
+  Block *startBlock = start->getBlock();
+  Block *destBlock = dest->getBlock();
+  SmallVector<Block *, 32> worklist(startBlock->succ_begin(),
+                                    startBlock->succ_end());
+  SmallPtrSet<Block *, 32> visited;
+  while (!worklist.empty()) {
+    Block *bb = worklist.pop_back_val();
+    if (!visited.insert(bb).second)
+      continue;
+    if (dominators.dominates(bb, destBlock))
+      return true;
+    worklist.append(bb->succ_begin(), bb->succ_end());
+  }
+  return false;
+}
+
+/// For transfer_write to overwrite fully another transfer_write must:
+/// 1. Access the same memref with the same indices and vector type.
+/// 2. Post-dominate the other transfer_write operation.
+/// If several candidates are available, one must be post-dominated by all the
+/// others since they are all post-dominating the same transfer_write. We only
+/// consider the transfer_write post-dominated by all the other candidates as
+/// this will be the first transfer_write executed after the potentially dead
+/// transfer_write.
+/// If we found such an overwriting transfer_write we know that the original
+/// transfer_write is dead if all reads that can be reached from the potentially
+/// dead transfer_write are dominated by the overwriting transfer_write.
+void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
+  LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
+                    << "\n");
+  llvm::SmallVector<Operation *, 8> reads;
+  Operation *firstOverwriteCandidate = nullptr;
+  for (auto *user : write.memref().getUsers()) {
+    if (user == write.getOperation())
+      continue;
+    if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
+      // Check candidate that can override the store.
+      if (write.indices() == nextWrite.indices() &&
+          write.getVectorType() == nextWrite.getVectorType() &&
+          write.permutation_map() == write.permutation_map() &&
+          postDominators.postDominates(nextWrite, write)) {
+        if (firstOverwriteCandidate == nullptr ||
+            postDominators.postDominates(firstOverwriteCandidate, nextWrite))
+          firstOverwriteCandidate = nextWrite;
+        else
+          assert(
+              postDominators.postDominates(nextWrite, firstOverwriteCandidate));
+      }
+    } else {
+      if (auto read = dyn_cast<vector::TransferReadOp>(user)) {
+        // Don't need to consider disjoint reads.
+        if (isDisjointTransferSet(
+                cast<VectorTransferOpInterface>(write.getOperation()),
+                cast<VectorTransferOpInterface>(read.getOperation())))
+          continue;
+      }
+      reads.push_back(user);
+    }
+  }
+  if (firstOverwriteCandidate == nullptr)
+    return;
+  Region *topRegion = firstOverwriteCandidate->getParentRegion();
+  Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
+  assert(writeAncestor &&
+         "write op should be recursively part of the top region");
+
+  for (Operation *read : reads) {
+    Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
+    // TODO: if the read and write have the same ancestor we could recurse in
+    // the region to know if the read is reachable with more precision.
+    if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
+      continue;
+    if (!dominators.dominates(firstOverwriteCandidate, read)) {
+      LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read
+                        << "\n");
+      return;
+    }
+  }
+  LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
+                    << " overwritten by: " << *firstOverwriteCandidate << "\n");
+  opToErase.push_back(write.getOperation());
+}
+
+/// A transfer_write candidate to storeToLoad forwarding must:
+/// 1. Access the same memref with the same indices and vector type as the
+/// transfer_read.
+/// 2. Dominate the transfer_read operation.
+/// If several candidates are available, one must be dominated by all the others
+/// since they are all dominating the same transfer_read. We only consider the
+/// transfer_write dominated by all the other candidates as this will be the
+/// last transfer_write executed before the transfer_read.
+/// If we found such a candidate we can do the forwarding if all the other
+/// potentially aliasing ops that may reach the transfer_read are post-dominated
+/// by the transfer_write.
+void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
+  if (read.hasMaskedDim())
+    return;
+  LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
+                    << "\n");
+  SmallVector<Operation *, 8> blockingWrites;
+  vector::TransferWriteOp lastwrite = nullptr;
+  for (Operation *user : read.memref().getUsers()) {
+    if (isa<vector::TransferReadOp>(user))
+      continue;
+    if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
+      // If there is a write, but we can prove that it is disjoint we can ignore
+      // the write.
+      if (isDisjointTransferSet(
+              cast<VectorTransferOpInterface>(write.getOperation()),
+              cast<VectorTransferOpInterface>(read.getOperation())))
+        continue;
+      if (dominators.dominates(write, read) && !write.hasMaskedDim() &&
+          write.indices() == read.indices() &&
+          write.getVectorType() == read.getVectorType() &&
+          write.permutation_map() == read.permutation_map()) {
+        if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
+          lastwrite = write;
+        else
+          assert(dominators.dominates(write, lastwrite));
+        continue;
+      }
+    }
+    blockingWrites.push_back(user);
+  }
+
+  if (lastwrite == nullptr)
+    return;
+
+  Region *topRegion = lastwrite.getParentRegion();
+  Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
+  assert(readAncestor &&
+         "read op should be recursively part of the top region");
+
+  for (Operation *write : blockingWrites) {
+    Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
+    // TODO: if the store and read have the same ancestor we could recurse in
+    // the region to know if the read is reachable with more precision.
+    if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
+      continue;
+    if (!postDominators.postDominates(lastwrite, write)) {
+      LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
+                        << *write << "\n");
+      return;
+    }
+  }
+
+  LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
+                    << " to: " << *read.getOperation() << "\n");
+  read.replaceAllUsesWith(lastwrite.vector());
+  opToErase.push_back(read.getOperation());
+}
+
+} // namespace
+
+void mlir::vector::transferOpflowOpt(FuncOp func) {
+  TransferOptimization opt(func);
+  // Run store to load forwarding first since it can expose more dead store
+  // opportunity.
+  func.walk(
+      [&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); });
+  opt.removeDeadOp();
+  func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); });
+  opt.removeDeadOp();
+}

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 75ebb2f7d959..3ab1f500f5d1 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -312,3 +312,36 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
   return true;
 }
 
+bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
+                                 VectorTransferOpInterface transferB) {
+  if (transferA.memref() != transferB.memref())
+    return false;
+  // For simplicity only look at transfer of same type.
+  if (transferA.getVectorType() != transferB.getVectorType())
+    return false;
+  unsigned rankOffset = transferA.getLeadingMemRefRank();
+  for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
+    auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>();
+    auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>();
+    // If any of the indices are dynamic we cannot prove anything.
+    if (!indexA || !indexB)
+      continue;
+
+    if (i < rankOffset) {
+      // For leading dimensions, if we can prove that index are 
diff erent we
+      // know we are accessing disjoint slices.
+      if (indexA.getValue().cast<IntegerAttr>().getInt() !=
+          indexB.getValue().cast<IntegerAttr>().getInt())
+        return true;
+    } else {
+      // For this dimension, we slice a part of the memref we need to make sure
+      // the intervals accessed don't overlap.
+      int64_t distance =
+          std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
+                   indexB.getValue().cast<IntegerAttr>().getInt());
+      if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
+        return true;
+    }
+  }
+  return false;
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
new file mode 100644
index 000000000000..0ed061cab4d8
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -0,0 +1,186 @@
+// RUN: mlir-opt %s -test-vector-transferop-opt | FileCheck %s
+
+// CHECK-LABEL: func @forward_dead_store
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   scf.for
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func @forward_dead_store(%arg0: i1, %arg1 : memref<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) {
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : 
+    vector<1x4xf32>, memref<4x4xf32>
+  %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} : 
+    memref<4x4xf32>, vector<1x4xf32>
+  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) 
+    -> (vector<1x4xf32>) {
+    %1 = addf %acc, %acc : vector<1x4xf32>
+    scf.yield %1 : vector<1x4xf32>
+  }
+  vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : 
+    vector<1x4xf32>, memref<4x4xf32>
+  return
+}
+
+// CHECK-LABEL: func @forward_nested
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_write
+//       CHECK:   scf.if
+//   CHECK-NOT:     vector.transfer_read
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func @forward_nested(%arg0: i1, %arg1 : memref<4x4xf32>, %v0 : vector<1x4xf32>,
+  %v1 : vector<1x4xf32>, %i : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cf0 = constant 0.0 : f32
+  vector.transfer_write %v1, %arg1[%i, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  %x = scf.if %arg0 -> (vector<1x4xf32>) {
+    %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
+      memref<4x4xf32>, vector<1x4xf32>
+    scf.yield %0 : vector<1x4xf32>
+  } else {
+    scf.yield %v1 : vector<1x4xf32>
+  }
+  vector.transfer_write %x, %arg1[%c0, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  return
+}
+
+// Negative test, the transfer_write in the scf.if region block the store to
+// load forwarding because we don't recursively look into the region to realize
+// that the transfer_write cannot reach the transfer_read.
+// CHECK-LABEL: func @forward_nested_negative
+//       CHECK:   vector.transfer_write
+//       CHECK:   scf.if
+//       CHECK:     vector.transfer_read
+//       CHECK:   } else {
+//       CHECK:     vector.transfer_write
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func @forward_nested_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cf0 = constant 0.0 : f32
+  vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  %x = scf.if %arg0 -> (vector<1x4xf32>) {
+    %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
+      memref<4x4xf32>, vector<1x4xf32>
+    scf.yield %0 : vector<1x4xf32>
+  } else {
+    vector.transfer_write %v1, %arg1[%i, %c0] {masked = [false, false]} :
+      vector<1x4xf32>, memref<4x4xf32>
+    scf.yield %v1 : vector<1x4xf32>
+  }
+  vector.transfer_write %x, %arg1[%c0, %i] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  return
+}
+
+// CHECK-LABEL: func @dead_store_region
+//       CHECK:   vector.transfer_write
+//       CHECK:   scf.if
+//       CHECK:   } else {
+//       CHECK:     vector.transfer_read
+//       CHECK:   }
+//       CHECK:   scf.if
+//   CHECK-NOT:     vector.transfer_write
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   return
+func @dead_store_region(%arg0: i1, %arg1 : memref<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) 
+  -> (vector<1x4xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cf0 = constant 0.0 : f32
+  vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  %x = scf.if %arg0 -> (vector<1x4xf32>) {
+    scf.yield %v1 : vector<1x4xf32>
+  } else {
+    %0 = vector.transfer_read %arg1[%i, %c0], %cf0 {masked = [false, false]} :
+      memref<4x4xf32>, vector<1x4xf32>
+    scf.yield %0 : vector<1x4xf32>
+  }
+  scf.if %arg0 {
+    vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+      vector<1x4xf32>, memref<4x4xf32>
+  }
+  vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  %1 = vector.transfer_read %arg1[%i, %c0], %cf0 {masked = [false, false]} :
+    memref<4x4xf32>, vector<1x4xf32>
+  return %1 : vector<1x4xf32>
+}
+
+// CHECK-LABEL: func @dead_store_negative
+//       CHECK:   scf.if
+//       CHECK:     vector.transfer_write
+//       CHECK:     vector.transfer_read
+//       CHECK:   } else {
+//       CHECK:   }
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func @dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
+  %v0 :vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cf0 = constant 0.0 : f32
+  %x = scf.if %arg0 -> (vector<1x4xf32>) {
+    vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+      vector<1x4xf32>, memref<4x4xf32>
+    %0 = vector.transfer_read %arg1[%i, %c0], %cf0 {masked = [false, false]} :
+      memref<4x4xf32>, vector<1x4xf32>
+    scf.yield %0 : vector<1x4xf32>
+  } else {
+    scf.yield %v1 : vector<1x4xf32>
+  }
+  vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
+    vector<1x4xf32>, memref<4x4xf32>
+  return
+}
+
+// CHECK-LABEL: func @dead_store_nested_region
+//       CHECK:   scf.if
+//       CHECK:     vector.transfer_read
+//       CHECK:     scf.if
+//   CHECK-NOT:       vector.transfer_write
+//       CHECK:     }
+//       CHECK:     vector.transfer_write
+//       CHECK:   }
+//       CHECK:   return
+func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %cf0 = constant 0.0 : f32
+  scf.if %arg0 {
+    %0 = vector.transfer_read %arg2[%i, %c0], %cf0 {masked = [false, false]} :
+      memref<4x4xf32>, vector<1x4xf32>
+    scf.if %arg1 {
+      vector.transfer_write %v1, %arg2[%c1, %c0] {masked = [false, false]} :
+        vector<1x4xf32>, memref<4x4xf32>
+    }
+    vector.transfer_write %v0, %arg2[%c1, %c0] {masked = [false, false]} :
+      vector<1x4xf32>, memref<4x4xf32>
+  }
+  return
+}
+

diff  --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index 484e78f2b596..602bf8148cd8 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -312,6 +312,11 @@ struct TestVectorTransferFullPartialSplitPatterns
   }
 };
 
+struct TestVectorTransferOpt
+    : public PassWrapper<TestVectorTransferOpt, FunctionPass> {
+  void runOnFunction() override { transferOpflowOpt(getFunction()); }
+};
+
 } // end anonymous namespace
 
 namespace mlir {
@@ -348,6 +353,9 @@ void registerTestVectorConversions() {
   PassRegistration<TestVectorToLoopPatterns> vectorToForLoop(
       "test-vector-to-forloop",
       "Test conversion patterns to break up a vector op into a for loop");
+  PassRegistration<TestVectorTransferOpt> transferOpOpt(
+      "test-vector-transferop-opt",
+      "Test optimization transformations for transfer ops");
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list