[Mlir-commits] [mlir] 3fc0fbe - [mlir][vector] Move transferOp on tensor opt to folder/canonicalization

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 16 08:13:39 PDT 2021


Author: thomasraoux
Date: 2021-04-16T08:13:10-07:00
New Revision: 3fc0fbefc84382b5e63b4e497ee3744d678cfb91

URL: https://github.com/llvm/llvm-project/commit/3fc0fbefc84382b5e63b4e497ee3744d678cfb91
DIFF: https://github.com/llvm/llvm-project/commit/3fc0fbefc84382b5e63b4e497ee3744d678cfb91.diff

LOG: [mlir][vector] Move transferOp on tensor opt to folder/canonicalization

Move the existing optimization for transfer op on tensor to folder and
canonicalization. This handles the write after write case and read after write
and also add write after read case.

Differential Revision: https://reviews.llvm.org/D100597

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/VectorOps.td
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/lib/Dialect/Vector/VectorOps.cpp
    mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir
    mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
    mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td
index b0ff956fe49de..70dbdf8d05582 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td
@@ -1421,6 +1421,7 @@ def Vector_TransferWriteOp :
   ];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_LoadOp : Vector_Op<"load"> {

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index 56f8f6211cccc..3816d8d660c06 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -28,6 +28,11 @@ class Value;
 class VectorType;
 class VectorTransferOpInterface;
 
+namespace vector {
+class TransferWriteOp;
+class TransferReadOp;
+} // namespace vector
+
 /// Return the number of elements of basis, `0` if empty.
 int64_t computeMaxLinearIndex(ArrayRef<int64_t> basis);
 
@@ -177,6 +182,16 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
                                VectorTransferOpInterface transferB);
 
+/// Return true if the transfer_write fully writes the data accessed by the
+/// transfer_read.
+bool checkSameValueRAW(vector::TransferWriteOp defWrite,
+                       vector::TransferReadOp read);
+
+/// Return true if the write op fully over-write the priorWrite transfer_write
+/// op.
+bool checkSameValueWAW(vector::TransferWriteOp write,
+                       vector::TransferWriteOp priorWrite);
+
 namespace matcher {
 
 /// Matches vector.transfer_read, vector.transfer_write and ops that return a

diff  --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 0ad89109b3ada..934ea611c4324 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -2512,7 +2512,35 @@ static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
   return success();
 }
 
+///  ```
+///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+///  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
+///    : tensor<4x4xf32>, vector<1x4xf32>
+///  ```
+///  -> Folds into
+///  ```
+///  %v0
+///  ```
+static Value foldRAW(TransferReadOp readOp) {
+  if (!readOp.getShapedType().isa<RankedTensorType>())
+    return {};
+  auto defWrite = readOp.source().getDefiningOp<vector::TransferWriteOp>();
+  while (defWrite) {
+    if (checkSameValueRAW(defWrite, readOp))
+      return defWrite.vector();
+    if (!isDisjointTransferIndices(
+            cast<VectorTransferOpInterface>(defWrite.getOperation()),
+            cast<VectorTransferOpInterface>(readOp.getOperation())))
+      break;
+    defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+  }
+  return {};
+}
+
 OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
+  if (Value vec = foldRAW(*this))
+    return vec;
   /// transfer_read(memrefcast) -> transfer_read
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return getResult();
@@ -2724,10 +2752,47 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
   return success();
 }
 
+static bool checkSameValueWAR(vector::TransferReadOp read,
+                              vector::TransferWriteOp write) {
+  return read.source() == write.source() && read.indices() == write.indices() &&
+         read.permutation_map() == write.permutation_map() &&
+         read.getVectorType() == write.getVectorType() && !read.mask() &&
+         !write.mask();
+}
+/// Fold transfer_write write after read:
+/// ```
+///    %t0 = ...
+///    %v = vector.transfer_read %t0[%c0...] :
+///      tensor<static_sizesxf32>, vector<static_sizesxf32>
+///    %t1 = vector.transfer_write %v, %t0[%c0...] :
+///      vector<static_sizesxf32>, tensor<static_sizesxf32>
+/// ```
+///
+/// into:
+///
+/// ```
+///    %t0
+/// ```
+static LogicalResult foldWAR(TransferWriteOp write,
+                             SmallVectorImpl<OpFoldResult> &results) {
+  if (!write.source().getType().isa<RankedTensorType>())
+    return failure();
+  auto read = write.vector().getDefiningOp<vector::TransferReadOp>();
+  if (!read)
+    return failure();
+
+  if (!checkSameValueWAR(read, write))
+    return failure();
+  results.push_back(read.source());
+  return success();
+}
+
 LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
                                     SmallVectorImpl<OpFoldResult> &results) {
   if (succeeded(foldReadInitWrite(*this, operands, results)))
     return success();
+  if (succeeded(foldWAR(*this, results)))
+    return success();
   if (succeeded(foldTransferInBoundsAttribute(*this)))
     return success();
   return foldMemRefCast(*this);
@@ -2745,6 +2810,67 @@ void TransferWriteOp::getEffects(
                          SideEffects::DefaultResource::get());
 }
 
+namespace {
+/// Remove dead transfer write from the SSA chain so that it an be eliminated by
+/// DCE
+/// ```
+///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+///  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+/// ```
+///
+/// into:
+///
+/// ```
+///  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+///  %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+///  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
+///    : vector<1x4xf32>, tensor<4x4xf32>
+/// ```
+///
+/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
+/// any other uses.
+class foldWAW final : public OpRewritePattern<TransferWriteOp> {
+public:
+  using OpRewritePattern<TransferWriteOp>::OpRewritePattern;
+  LogicalResult matchAndRewrite(TransferWriteOp writeOp,
+                                PatternRewriter &rewriter) const override {
+    if (!writeOp.getShapedType().isa<RankedTensorType>())
+      return failure();
+    vector::TransferWriteOp writeToModify = writeOp;
+
+    auto defWrite = writeOp.source().getDefiningOp<vector::TransferWriteOp>();
+    while (defWrite) {
+      if (checkSameValueWAW(writeOp, defWrite)) {
+        writeToModify.sourceMutable().assign(defWrite.source());
+        return success();
+      }
+      if (!isDisjointTransferIndices(
+              cast<VectorTransferOpInterface>(defWrite.getOperation()),
+              cast<VectorTransferOpInterface>(writeOp.getOperation())))
+        break;
+      // If the previous write op doesn't have any other use we an safely look
+      // at the previous store to see if it can be removed.
+      if (!defWrite->hasOneUse())
+        break;
+      writeToModify = defWrite;
+      defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+    }
+    return failure();
+  }
+};
+} // namespace
+
+void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                                  MLIRContext *context) {
+  results.add<foldWAW>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
index f0b7a389eb494..ae6f3949c3998 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -34,34 +34,13 @@ static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
   return op;
 }
 
-/// Return true if the transfer_write fully writes the data accessed by the
-/// transfer_read.
-static bool transferEncompasses(vector::TransferWriteOp defWrite,
-                                vector::TransferReadOp read) {
-  return !defWrite.hasOutOfBoundsDim() &&
-         defWrite.indices() == read.indices() &&
-         defWrite.getVectorType() == read.getVectorType() &&
-         defWrite.permutation_map() == read.permutation_map();
-}
-
-/// Return true if the write op fully over-write the priorWrite transfer_write
-/// op.
-static bool transferEncompasses(vector::TransferWriteOp write,
-                                vector::TransferWriteOp priorWrite) {
-  return priorWrite.indices() == write.indices() &&
-         priorWrite.getVectorType() == write.getVectorType() &&
-         priorWrite.permutation_map() == write.permutation_map();
-}
-
 namespace {
 
 class TransferOptimization {
 public:
   TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
   void deadStoreOp(vector::TransferWriteOp);
-  void deadStoreOpTensor(vector::TransferWriteOp);
   void storeToLoadForwarding(vector::TransferReadOp);
-  void storeToLoadForwardingTensor(vector::TransferReadOp);
   void removeDeadOp() {
     for (Operation *op : opToErase)
       op->erase();
@@ -120,7 +99,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
       continue;
     if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
       // Check candidate that can override the store.
-      if (transferEncompasses(nextWrite, write) &&
+      if (checkSameValueWAW(nextWrite, write) &&
           postDominators.postDominates(nextWrite, write)) {
         if (firstOverwriteCandidate == nullptr ||
             postDominators.postDominates(firstOverwriteCandidate, nextWrite))
@@ -192,8 +171,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
               cast<VectorTransferOpInterface>(write.getOperation()),
               cast<VectorTransferOpInterface>(read.getOperation())))
         continue;
-      if (dominators.dominates(write, read) &&
-          transferEncompasses(write, read)) {
+      if (dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
         if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
           lastwrite = write;
         else
@@ -231,44 +209,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
-/// Walk up the SSA links, if any write gets fully overwritten we can skip it.
-/// If it has no more uses it becomes dead.
-void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) {
-  auto defWrite = write.source().getDefiningOp<vector::TransferWriteOp>();
-  while (defWrite) {
-    if (transferEncompasses(write, defWrite)) {
-      write.sourceMutable().assign(defWrite.source());
-      if (defWrite->use_empty())
-        opToErase.push_back(defWrite.getOperation());
-      return;
-    }
-    if (!isDisjointTransferIndices(
-            cast<VectorTransferOpInterface>(defWrite.getOperation()),
-            cast<VectorTransferOpInterface>(write.getOperation())))
-      break;
-    defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
-  }
-}
-
-/// Walk up the SSA links, if any write fully match the written vector we can
-/// replace the read by the vector. The read becomes dead and can be removed.
-void TransferOptimization::storeToLoadForwardingTensor(
-    vector::TransferReadOp read) {
-  auto defWrite = read.source().getDefiningOp<vector::TransferWriteOp>();
-  while (defWrite) {
-    if (transferEncompasses(defWrite, read)) {
-      read.replaceAllUsesWith(defWrite.vector());
-      opToErase.push_back(read.getOperation());
-      return;
-    }
-    if (!isDisjointTransferIndices(
-            cast<VectorTransferOpInterface>(defWrite.getOperation()),
-            cast<VectorTransferOpInterface>(read.getOperation())))
-      break;
-    defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
-  }
-}
-
 } // namespace
 
 void mlir::vector::transferOpflowOpt(FuncOp func) {
@@ -278,15 +218,11 @@ void mlir::vector::transferOpflowOpt(FuncOp func) {
   func.walk([&](vector::TransferReadOp read) {
     if (read.getShapedType().isa<MemRefType>())
       opt.storeToLoadForwarding(read);
-    else
-      opt.storeToLoadForwardingTensor(read);
   });
   opt.removeDeadOp();
   func.walk([&](vector::TransferWriteOp write) {
     if (write.getShapedType().isa<MemRefType>())
       opt.deadStoreOp(write);
-    else
-      opt.deadStoreOpTensor(write);
   });
   opt.removeDeadOp();
 }

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index 8e6fb32bde5e4..27e7a772f01cf 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -354,3 +354,19 @@ bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
     return false;
   return isDisjointTransferIndices(transferA, transferB);
 }
+
+bool mlir::checkSameValueRAW(vector::TransferWriteOp defWrite,
+                             vector::TransferReadOp read) {
+  return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() &&
+         defWrite.indices() == read.indices() &&
+         defWrite.getVectorType() == read.getVectorType() &&
+         defWrite.permutation_map() == read.permutation_map();
+}
+
+bool mlir::checkSameValueWAW(vector::TransferWriteOp write,
+                             vector::TransferWriteOp priorWrite) {
+  return priorWrite.indices() == write.indices() &&
+         priorWrite.mask() == write.mask() &&
+         priorWrite.getVectorType() == write.getVectorType() &&
+         priorWrite.permutation_map() == write.permutation_map();
+}

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 9da9e340ddffc..6d25e40ace3e4 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -799,3 +799,136 @@ func @transfer_folding_1(%t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>)
   // CHECK-NEXT: return %[[T0]], %[[T0]], %[[T0]]
   return %r0, %r1, %r2: tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @store_after_load_tensor
+//  CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>)
+//   CHECK-NOT:   vector.transfer_read
+//   CHECK-NOT:   vector.transfer_write
+//       CHECK:   return %[[ARG]] : tensor<4x4xf32>
+func @store_after_load_tensor(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c1, %c0], %cf0 :
+    tensor<4x4xf32>, vector<1x4xf32>
+  %w0 = vector.transfer_write %0, %arg0[%c1, %c0] :
+    vector<1x4xf32>, tensor<4x4xf32>
+  return %w0 : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_after_load_tensor_negative
+//       CHECK:   vector.transfer_read
+//       CHECK:   vector.transfer_write
+//       CHECK:   return
+func @store_after_load_tensor_negative(%arg0 : tensor<4x4xf32>) -> tensor<4x4xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %0 = vector.transfer_read %arg0[%c1, %c0], %cf0 :
+    tensor<4x4xf32>, vector<1x4xf32>
+  %w0 = vector.transfer_write %0, %arg0[%c0, %c0] :
+    vector<1x4xf32>, tensor<4x4xf32>
+  return %w0 : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_tensor
+//  CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<1x4xf32>, %[[V1:.*]]: vector<1x4xf32>)
+//   CHECK-NOT:   vector.transfer_write
+//   CHECK-NOT:   vector.transfer_read
+//       CHECK:   return %[[V0]] : vector<1x4xf32>
+func @store_to_load_tensor(%arg0 : tensor<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>) -> vector<1x4xf32> {
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %w1 = vector.transfer_write %v1, %w0[%c2, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<4x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @store_to_load_negative_tensor
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_write
+//       CHECK:   %[[V:.*]] = vector.transfer_read
+//       CHECK:   return %[[V]] : vector<1x4xf32>
+func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> {
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %w1 = vector.transfer_write %v0, %w0[%i, %i] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} :
+    tensor<4x4xf32>, vector<1x4xf32>
+  return %0 : vector<1x4xf32>
+}
+
+// -----
+
+
+// CHECK-LABEL: func @dead_store_tensor
+//   CHECK-DAG:      %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:      %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG:      %[[C2:.*]] = constant 2 : index
+//   CHECK-NOT:   vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]
+//       CHECK:   vector.transfer_write {{.*}}, {{.*}}[%[[C2]], %[[C0]]
+//       CHECK:   %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]
+//       CHECK:   return %[[VTW]] : tensor<4x4xf32>
+func @dead_store_tensor(%arg0 : tensor<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  return %w2 : tensor<4x4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @dead_store_tensor_negative
+//   CHECK-DAG:      %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:      %[[C1:.*]] = constant 1 : index
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_write
+//       CHECK:   vector.transfer_read
+//       CHECK:   %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]]
+//       CHECK:   return %[[VTW]] : tensor<4x4xf32>
+func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>,
+  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
+  %c1 = constant 1 : index
+  %c2 = constant 2 : index
+  %c0 = constant 0 : index
+  %cf0 = constant 0.0 : f32
+  %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  %0 = vector.transfer_read %w1[%i, %i], %cf0 {in_bounds = [true, true]} :
+    tensor<4x4xf32>, vector<1x4xf32>
+  %x = addf %0, %0 : vector<1x4xf32>
+  %w2 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
+    vector<1x4xf32>, tensor<4x4xf32>
+  return %w2 : tensor<4x4xf32>
+}

diff  --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
index 15b68275decf9..d63809c3063b8 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir
@@ -112,11 +112,11 @@ func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>,
 //  CHECK-NEXT:   %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32>
 //  CHECK-NEXT:   return %[[VTW3]] : tensor<4x4xf32>
 
-func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) ->
+func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>, %arg1 : tensor<4x4xf32>) ->
   tensor<4x4xf32> {
   %c0 = constant 0 : index
   %cf0 = constant 0.0 : f32
   %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32>
-  %r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
+  %r = vector.transfer_write %0, %arg1[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32>
   return %r: tensor<4x4xf32>
 }

diff  --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 97d618932f9b8..559f11f6d9cf8 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -184,56 +184,3 @@ func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
   return
 }
 
-// CHECK-LABEL: func @forward_dead_store_tensor
-//   CHECK-NOT:   vector.transfer_write
-//   CHECK-NOT:   vector.transfer_read
-//       CHECK:   scf.for
-//       CHECK:   }
-//       CHECK:   %[[VTW:.*]] = vector.transfer_write
-//       CHECK:   return %[[VTW]] : tensor<4x4xf32>
-func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
-  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
-  %c1 = constant 1 : index
-  %c4 = constant 4 : index
-  %c0 = constant 0 : index
-  %cf0 = constant 0.0 : f32
-  %w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {in_bounds = [true, true]} :
-    vector<1x4xf32>, tensor<4x4xf32>
-  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} :
-    tensor<4x4xf32>, vector<1x4xf32>
-  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-    -> (vector<1x4xf32>) {
-    %1 = addf %acc, %acc : vector<1x4xf32>
-    scf.yield %1 : vector<1x4xf32>
-  }
-  %w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
-    vector<1x4xf32>, tensor<4x4xf32>
-  return %w1 : tensor<4x4xf32>
-}
-
-// CHECK-LABEL: func @forward_dead_store_negative_tensor
-//       CHECK:   vector.transfer_write
-//       CHECK:   vector.transfer_read
-//       CHECK:   scf.for
-//       CHECK:   }
-//       CHECK:   %[[VTW:.*]] = vector.transfer_write
-//       CHECK:   return %[[VTW]] : tensor<4x4xf32>
-func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
-  %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
-  %c1 = constant 1 : index
-  %c4 = constant 4 : index
-  %c0 = constant 0 : index
-  %cf0 = constant 0.0 : f32
-  %w0 = vector.transfer_write %v0, %arg1[%c1, %i] {in_bounds = [true, true]} :
-    vector<1x4xf32>, tensor<4x4xf32>
-  %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} :
-    tensor<4x4xf32>, vector<1x4xf32>
-  %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-    -> (vector<1x4xf32>) {
-    %1 = addf %acc, %acc : vector<1x4xf32>
-    scf.yield %1 : vector<1x4xf32>
-  }
-  %w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} :
-    vector<1x4xf32>, tensor<4x4xf32>
-  return %w1 : tensor<4x4xf32>
-}


        


More information about the Mlir-commits mailing list