[Mlir-commits] [mlir] efd0504 - [mlir] Add hoisting transformation for transfer ops on tensor

Thomas Raoux llvmlistbot at llvm.org
Wed Jan 6 14:24:21 PST 2021


Author: Thomas Raoux
Date: 2021-01-06T14:23:59-08:00
New Revision: efd05040e13e942a4fbb79eb798fb9833e319b51

URL: https://github.com/llvm/llvm-project/commit/efd05040e13e942a4fbb79eb798fb9833e319b51
DIFF: https://github.com/llvm/llvm-project/commit/efd05040e13e942a4fbb79eb798fb9833e319b51.diff

LOG: [mlir] Add hoisting transformation for transfer ops on tensor

Add same hoisting transformation existing for transfer ops on buffers for
transfer_ops on tensor. The logic is significantly different so this is done as
a separate transformation and it is expect that user would know which
transformation to use based on the flow.

Differential Revision: https://reviews.llvm.org/D94115

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
    mlir/include/mlir/Dialect/Vector/VectorUtils.h
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Dialect/Vector/VectorUtils.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir
    mlir/test/lib/Transforms/TestLinalgHoisting.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 32693555ff40..ed585d1f5cf5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -21,8 +21,9 @@ namespace linalg {
 // TODO: generalize on a per-need basis.
 void hoistViewAllocOps(FuncOp func);
 
-/// Hoist vector.transfer_read/vector.transfer_write pairs out of immediately
-/// enclosing scf::ForOp iteratively, if the following conditions are true:
+/// Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of
+/// immediately enclosing scf::ForOp iteratively, if the following conditions
+/// are true:
 ///   1. The two ops access the same memref with the same indices.
 ///   2. All operands are invariant under the enclosing scf::ForOp.
 ///   3. No uses of the memref either dominate the transfer_read or are
@@ -35,6 +36,10 @@ void hoistViewAllocOps(FuncOp func);
 // TODO: generalize on a per-need basis.
 void hoistRedundantVectorTransfers(FuncOp func);
 
+/// Same behavior as `hoistRedundantVectorTransfers` but works on tensors
+/// instead of buffers.
+void hoistRedundantVectorTransfersOnTensor(FuncOp func);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index a06bc8cf6562..666603250f0a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -165,6 +165,12 @@ AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
                            VectorTransferOpInterface transferB);
 
+/// Same behavior as `isDisjointTransferSet` but doesn't require the operations
+/// to have the same tensor/memref. This allows comparing operations accessing
+/// 
diff erent tensors.
+bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
+                               VectorTransferOpInterface transferB);
+
 namespace matcher {
 
 /// Matches vector.transfer_read, vector.transfer_write and ops that return a

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index a1797fde7da6..98d61fa6a8d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -81,12 +81,151 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
   }
 }
 
+/// Look for a transfer_read, in the given tensor uses, accessing the same
+/// offset as the transfer_write.
+static vector::TransferReadOp
+findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) {
+  for (Operation *user : srcTensor.getUsers()) {
+    auto read = dyn_cast<vector::TransferReadOp>(user);
+    if (read && read.indices() == write.indices() &&
+        read.getVectorType() == write.getVectorType()) {
+      return read;
+    }
+  }
+  return nullptr;
+}
+
+/// Check if the chunk of data inserted by the transfer_write in the given
+/// tensor are read by any other op than the read candidate.
+static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write,
+                                           vector::TransferReadOp candidateRead,
+                                           Value srcTensor) {
+  // Make sure none of the other uses read the part of the tensor modified
+  // by the transfer_write.
+  llvm::SmallVector<Value::use_range, 1> uses;
+  uses.push_back(srcTensor.getUses());
+  while (!uses.empty()) {
+    for (OpOperand &use : uses.pop_back_val()) {
+      Operation *user = use.getOwner();
+      // Skip the candidate use, only inspect the "other" uses.
+      if (user == candidateRead.getOperation() || user == write.getOperation())
+        continue;
+      // Consider all transitive uses through a vector.transfer_write.
+      if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
+        uses.push_back(writeUser->getResult(0).getUses());
+        continue;
+      }
+      // Consider all nested uses through an scf::ForOp. We may have
+      // pass-through tensor arguments left from previous level of
+      // hoisting.
+      if (auto forUser = dyn_cast<scf::ForOp>(user)) {
+        Value arg = forUser.getLoopBody().getArgument(
+            use.getOperandNumber() - forUser.getNumControlOperands() +
+            /*iv value*/ 1);
+        uses.push_back(arg.getUses());
+        continue;
+      }
+      // Follow the use yield as long as it doesn't escape the original
+      // region.
+      scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
+      if (yieldUser &&
+          write->getParentOp()->isAncestor(yieldUser->getParentOp())) {
+        Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
+        uses.push_back(ret.getUses());
+        continue;
+      }
+      auto read = dyn_cast<vector::TransferReadOp>(user);
+      if (!read || !isDisjointTransferIndices(
+                       cast<VectorTransferOpInterface>(read.getOperation()),
+                       cast<VectorTransferOpInterface>(write.getOperation()))) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+// To hoist transfer op on tensor the logic can be significantly simplified
+// compared to the case on buffer. The transformation follows this logic:
+// 1. Look for transfer_write with a single use from ForOp yield
+// 2. Check the uses of the matching block argument and look for a transfer_read
+// with the same indices.
+// 3. Check that all the other uses of the tensor argument are either disjoint
+// tensor_read or transfer_write. For transfer_write uses recurse to make sure
+// the new tensor has the same restrictions on its uses.
+// 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
+// After this transformation the scf.forOp may have unused arguments that can be
+// remove by the canonicalization pass.
+void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    func.walk([&](scf::ForOp forOp) {
+      Operation *yield = forOp.getBody()->getTerminator();
+      for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
+        Value ret = yield->getOperand(it.index());
+        auto write = ret.getDefiningOp<vector::TransferWriteOp>();
+        if (!write || !write->hasOneUse())
+          continue;
+        LLVM_DEBUG(DBGS() << "Candidate write for hoisting: "
+                          << *write.getOperation() << "\n");
+        if (llvm::any_of(write.indices(), [&forOp](Value index) {
+              return !forOp.isDefinedOutsideOfLoop(index);
+            }))
+          continue;
+        // Find a read with the same type and indices.
+        vector::TransferReadOp matchingRead =
+            findMatchingTransferRead(write, it.value());
+        // Make sure none of the other uses read the part of the tensor modified
+        // by the transfer_write.
+        if (!matchingRead ||
+            tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
+          continue;
+
+        // Hoist read before.
+        if (failed(forOp.moveOutOfLoop({matchingRead})))
+          llvm_unreachable(
+              "Unexpected failure to move transfer read out of loop");
+        // Update the source tensor.
+        matchingRead.sourceMutable().assign(forOp.initArgs()[it.index()]);
+
+        // Hoist write after.
+        write->moveAfter(forOp);
+        yield->setOperand(it.index(), write.source());
+
+        // Rewrite `loop` with new yields by cloning and erase the original
+        // loop.
+        OpBuilder b(matchingRead);
+        auto newForOp =
+            cloneWithNewYields(b, forOp, matchingRead.vector(), write.vector());
+
+        // Transfer write has been hoisted, need to update the vector and tensor
+        // source. Replace the result of the loop to use the new tensor created
+        // outside the loop.
+        newForOp.getResult(it.index()).replaceAllUsesWith(write.getResult(0));
+        write.vectorMutable().assign(newForOp.getResults().back());
+        write.sourceMutable().assign(newForOp.getResult(it.index()));
+
+        changed = true;
+        forOp.erase();
+        // Need to interrupt and restart because erasing the loop messes up the
+        // walk.
+        return WalkResult::interrupt();
+      }
+      return WalkResult::advance();
+    });
+  }
+}
+
 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
   bool changed = true;
   while (changed) {
     changed = false;
 
     func.walk([&](vector::TransferReadOp transferRead) {
+      if (!transferRead.getShapedType().isa<MemRefType>())
+        return WalkResult::advance();
+
       LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
                         << *transferRead.getOperation() << "\n");
       auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());

diff  --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index fc08d21b27a5..ef3ef3db1f81 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -312,10 +312,8 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
   return true;
 }
 
-bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
-                                 VectorTransferOpInterface transferB) {
-  if (transferA.source() != transferB.source())
-    return false;
+bool mlir::isDisjointTransferIndices(VectorTransferOpInterface transferA,
+                                     VectorTransferOpInterface transferB) {
   // For simplicity only look at transfer of same type.
   if (transferA.getVectorType() != transferB.getVectorType())
     return false;
@@ -345,3 +343,10 @@ bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
   }
   return false;
 }
+
+bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
+                                 VectorTransferOpInterface transferB) {
+  if (transferA.source() != transferB.source())
+    return false;
+  return isDisjointTransferIndices(transferA, transferB);
+}

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 4a6fca554a09..504e85f4d4b1 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -230,3 +230,169 @@ func @hoist_vector_transfer_pairs_disjoint(
   }
   return
 }
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_tensor
+func @hoist_vector_transfer_pairs_tensor(
+    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
+    %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index) ->
+    (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+     tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
+// VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
+// VECTOR_TRANSFERS:   vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS:   scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
+// VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS:     "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:     "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// VECTOR_TRANSFERS:     scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
+// VECTOR_TRANSFERS:   }
+// VECTOR_TRANSFERS:   vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:   scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
+  %0:6 = scf.for %i = %lb to %ub step %step
+  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+            %arg3 = %tensor3,  %arg4 = %tensor4, %arg5 = %tensor5)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+     tensor<?x?xf32>, tensor<?x?xf32>)  {
+    %1:6 = scf.for %j = %lb to %ub step %step
+    iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2,
+              %arg9 = %arg3,  %arg10 = %arg4, %arg11 = %arg5)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+       tensor<?x?xf32>, tensor<?x?xf32>)  {
+      %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+      %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r2 = vector.transfer_read %arg8[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+      %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
+      "some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
+      %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
+      %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32>
+      "some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> ()
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+      %u2 = "some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32>
+      %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
+      %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
+      %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
+      %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+      %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+      %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+      %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32>
+      %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32>
+      %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32>
+      "some_crippling_use"(%w3) : (tensor<?x?xf32>) -> ()
+      scf.yield %w0, %w1, %w2, %w3, %w4, %w5 :
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+        tensor<?x?xf32>, tensor<?x?xf32>
+      }
+      scf.yield %1#0,  %1#1, %1#2, %1#3, %1#4, %1#5 :
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+        tensor<?x?xf32>, tensor<?x?xf32>
+  }
+  return %0#0,  %0#1, %0#2, %0#3, %0#4,  %0#5 :
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+        tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+func @hoist_vector_transfer_pairs_disjoint_tensor(
+    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>,
+    %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index,
+    %random_index : index) ->
+    (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c3 = constant 3 : index
+  %cst = constant 0.0 : f32
+
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS:   scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:     scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS:   }
+// VECTOR_TRANSFERS:   scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: %[[TENSOR4:.*]] = vector.transfer_write %{{.*}}, %[[R]]#3{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: %[[TENSOR5:.*]] = vector.transfer_write %{{.*}}, %[[R]]#2{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
+  %0:4 = scf.for %i = %lb to %ub step %step
+  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+            %arg3 = %tensor3)
+  -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+    %1:4 = scf.for %j = %lb to %ub step %step
+    iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2,
+              %arg7 = %arg3)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+      %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+      %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32>
+      %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+      %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+      %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+      %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32>
+      %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+      %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+      %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
+      %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
+      %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
+      %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
+      %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
+      %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
+      %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+      %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32>
+      %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+      %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32>
+      %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32>
+      %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32>
+      %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+      %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32>
+      scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    }
+    scf.yield %1#0,  %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+  }
+  return %0#0,  %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}

diff  --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
index d78833e78f15..76d41f1fcdc4 100644
--- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
@@ -47,6 +47,7 @@ void TestLinalgHoisting::runOnFunction() {
   }
   if (testHoistRedundantTransfers) {
     hoistRedundantVectorTransfers(getFunction());
+    hoistRedundantVectorTransfersOnTensor(getFunction());
     return;
   }
 }


        


More information about the Mlir-commits mailing list