[Mlir-commits] [mlir] 21debea - [mlir][Linalg] Generalize vector::transfer hoisting on tensors.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Feb 16 01:48:26 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-16T09:45:14Z
New Revision: 21debeae785dc4d7c9718fe5b46857a2c2ce6670

URL: https://github.com/llvm/llvm-project/commit/21debeae785dc4d7c9718fe5b46857a2c2ce6670
DIFF: https://github.com/llvm/llvm-project/commit/21debeae785dc4d7c9718fe5b46857a2c2ce6670.diff

LOG: [mlir][Linalg] Generalize vector::transfer hoisting on tensors.

This revision adds support for hoisting "subtensor + vector.transfer_read" / "subtensor_insert + vector.transfer_write pairs" across scf.for.
The unit of hoisting becomes a HoistableRead / HoistableWrite struct which contains a pair of "vector.transfer_read + optional subtensor" / "vector.transfer_write + optional subtensor_insert".
scf::ForOp canonicalization patterns are applied greedily on the successful application of the transformation to cleanup the IR more eagerly and potentially expose more transformation opportunities.

Differential revision: https://reviews.llvm.org/D96731

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/test/Dialect/Linalg/hoisting.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index fc79e68e4eb7..e18a0b7ea985 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -63,6 +63,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
     promoteSingleIterationLoops(cast<FuncOp>(op));
     hoistViewAllocOps(cast<FuncOp>(op));
     hoistRedundantVectorTransfers(cast<FuncOp>(op));
+    hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op));
     return success();
   };
   (void)linalg::applyStagedPatterns(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index c3bc73aea720..a0cb80fe3032 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -21,10 +21,13 @@
 #include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Dominance.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/LoopUtils.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Debug.h"
 
+using llvm::dbgs;
+
 #define DEBUG_TYPE "linalg-hoisting"
 
 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
@@ -32,8 +35,6 @@
 using namespace mlir;
 using namespace mlir::linalg;
 
-using llvm::dbgs;
-
 void mlir::linalg::hoistViewAllocOps(FuncOp func) {
   bool changed = true;
   while (changed) {
@@ -81,35 +82,145 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
   }
 }
 
-/// Look for a transfer_read, in the given tensor uses, accessing the same
-/// offset as the transfer_write.
-static vector::TransferReadOp
-findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) {
+namespace {
+/// Represents a unit of hoistable TransferWriteOp. This may comprise other
+/// instructions that need to be hoisted too.
+struct HoistableWrite {
+  vector::TransferWriteOp transferWriteOp;
+  SubTensorInsertOp subTensorInsertOp;
+};
+/// Represents a unit of hoistable TransferReadOp. This may comprise other
+/// instructions that need to be hoisted too.
+struct HoistableRead {
+  vector::TransferReadOp transferReadOp;
+  SubTensorOp subTensorOp;
+};
+} // namespace
+
+/// Return true if op1 and op2 are the same constant or the same SSA value.
+static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) {
+  auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
+    Attribute attr = ofr.dyn_cast<Attribute>();
+    // Note: isa+cast-like pattern allows writing the condition below as 1 line.
+    if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
+      attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
+    if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
+      return intAttr.getValue().getSExtValue();
+    return llvm::None;
+  };
+  auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
+  if (cst1 && cst2 && *cst1 == *cst2)
+    return true;
+  auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
+  return v1 && v2 && v1 == v2;
+}
+
+/// Return true is all offsets, sizes and strides are equal.
+static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) {
+  if (s.static_offsets().size() != si.static_offsets().size())
+    return false;
+  if (s.static_sizes().size() != si.static_sizes().size())
+    return false;
+  if (s.static_strides().size() != si.static_strides().size())
+    return false;
+  for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets()))
+    if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
+      return false;
+  for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes()))
+    if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
+      return false;
+  for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides()))
+    if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it)))
+      return false;
+  return true;
+}
+
+/// Look for a HoistableRead, in the given tensor uses, accessing the same
+/// offset as the HoistableWrite.
+static HoistableRead findMatchingTransferRead(HoistableWrite write,
+                                              Value srcTensor) {
+  assert(write.transferWriteOp &&
+         "expected hoistable write to have a .transfer_write");
+
+  LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: "
+                    << *write.transferWriteOp.getOperation() << "\n");
+  if (write.subTensorInsertOp)
+    LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: "
+                      << *write.subTensorInsertOp.getOperation() << "\n");
+
   for (Operation *user : srcTensor.getUsers()) {
-    auto read = dyn_cast<vector::TransferReadOp>(user);
-    if (read && read.indices() == write.indices() &&
-        read.getVectorType() == write.getVectorType()) {
-      return read;
+    LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user
+                      << "\n");
+
+    // If HoistableWrite involves a SubTensorInsertOp, we need to find a
+    // matching SubTensorOp.
+    SubTensorOp subTensorOp;
+    Operation *maybeTransferReadUser = user;
+    if (write.subTensorInsertOp) {
+      subTensorOp = dyn_cast<SubTensorOp>(user);
+      if (!subTensorOp || subTensorOp.getResult().getType() !=
+                              write.subTensorInsertOp.source().getType())
+        continue;
+
+      LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: "
+                        << *subTensorOp << " vs " << *write.subTensorInsertOp
+                        << "\n");
+      if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp))
+        continue;
+
+      LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n");
+      // If we got here, subTensorOp is hoistable iff it has exactly 2 uses:
+      //   1. the transfer_write we want to hoist.
+      //   2. a matching transfer_read.
+      // Anything else, we skip.
+      bool skip = false;
+      Operation *otherUser = nullptr;
+      for (Operation *u : subTensorOp->getUsers()) {
+        if (u == write.transferWriteOp)
+          continue;
+        if (otherUser) {
+          skip = true;
+          break;
+        }
+        otherUser = u;
+      }
+      if (skip || !otherUser)
+        continue;
+      maybeTransferReadUser = otherUser;
     }
+
+    LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser
+                      << "\n");
+    auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser);
+    if (read && read.indices() == write.transferWriteOp.indices() &&
+        read.getVectorType() == write.transferWriteOp.getVectorType())
+      return HoistableRead{read, subTensorOp};
   }
-  return nullptr;
+  return HoistableRead();
 }
 
-/// Check if the chunk of data inserted by the transfer_write in the given
-/// tensor are read by any other op than the read candidate.
-static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write,
-                                           vector::TransferReadOp candidateRead,
-                                           Value srcTensor) {
+/// Check if the chunk of data inserted by the HoistableWrite are read by any
+/// other op than the HoistableRead candidate.
+static bool tensorChunkAccessedByUnknownOp(HoistableWrite write,
+                                           HoistableRead candidateRead,
+                                           BlockArgument tensorArg) {
   // Make sure none of the other uses read the part of the tensor modified
   // by the transfer_write.
   llvm::SmallVector<Value::use_range, 1> uses;
-  uses.push_back(srcTensor.getUses());
+  uses.push_back(tensorArg.getUses());
   while (!uses.empty()) {
     for (OpOperand &use : uses.pop_back_val()) {
       Operation *user = use.getOwner();
       // Skip the candidate use, only inspect the "other" uses.
-      if (user == candidateRead.getOperation() || user == write.getOperation())
+      if (user == candidateRead.transferReadOp ||
+          user == candidateRead.subTensorOp || user == write.transferWriteOp ||
+          user == write.subTensorInsertOp)
         continue;
+      // Consider all transitive uses through a subtensor / subtensor_insert.
+      // TODO: atm we just bail because a stronger analysis is needed for these
+      // cases.
+      if (isa<SubTensorOp, SubTensorInsertOp>(user))
+        return true;
       // Consider all transitive uses through a vector.transfer_write.
       if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
         uses.push_back(writeUser->getResult(0).getUses());
@@ -128,8 +239,8 @@ static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write,
       // Follow the use yield as long as it doesn't escape the original
       // region.
       scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
-      if (yieldUser &&
-          write->getParentOp()->isAncestor(yieldUser->getParentOp())) {
+      if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor(
+                           yieldUser->getParentOp())) {
         Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
         uses.push_back(ret.getUses());
         continue;
@@ -137,7 +248,8 @@ static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write,
       auto read = dyn_cast<vector::TransferReadOp>(user);
       if (!read || !isDisjointTransferIndices(
                        cast<VectorTransferOpInterface>(read.getOperation()),
-                       cast<VectorTransferOpInterface>(write.getOperation()))) {
+                       cast<VectorTransferOpInterface>(
+                           write.transferWriteOp.getOperation()))) {
         return true;
       }
     }
@@ -145,6 +257,118 @@ static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write,
   return false;
 }
 
+/// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`.
+/// Return the null HoistableWrite() if it is not comprised of a
+/// vector.transfer_write + optional subtensor_insert or if any of the indexings
+/// is `forOp`-dependent.
+static HoistableWrite
+getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp,
+                                        OpOperand &yieldOperand) {
+  Value v = yieldOperand.get();
+  if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) {
+    // Indexing must not depend on `forOp`.
+    for (Value operand : write.indices())
+      if (!forOp.isDefinedOutsideOfLoop(operand))
+        return HoistableWrite();
+
+    return HoistableWrite{write, nullptr};
+  }
+
+  if (auto subTensorInsertOp = v.getDefiningOp<SubTensorInsertOp>()) {
+    // Inserted subTensor must come from vector.transfer_write.
+    auto write =
+        subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>();
+    if (!write)
+      return HoistableWrite();
+
+    // Tensor inserted into must be a BBArg at position matching yieldOperand's.
+    auto bbArg = subTensorInsertOp.dest().dyn_cast<BlockArgument>();
+    if (!bbArg || bbArg.getOwner()->getParentOp() != forOp ||
+        bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber())
+      return HoistableWrite();
+
+    // Indexing inserted into must not depend on `forOp`.
+    for (Value operand : subTensorInsertOp->getOperands().drop_front(
+             SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex()))
+      if (!forOp.isDefinedOutsideOfLoop(operand))
+        return HoistableWrite();
+
+    return HoistableWrite{write, subTensorInsertOp};
+  }
+
+  return HoistableWrite();
+}
+
+/// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair.
+static void hoistReadWrite(HoistableRead read, HoistableWrite write,
+                           BlockArgument tensorBBArg) {
+  scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp());
+  assert(read.transferReadOp && write.transferWriteOp &&
+         "expected transfer_read and transfer_write ops to be set");
+  assert(((read.subTensorOp && write.subTensorInsertOp) ||
+          (!read.subTensorOp && !write.subTensorInsertOp)) &&
+         "expected matching subtensor / subtensor_insert");
+  LLVM_DEBUG(DBGS() << "In forOp:\n"
+                    << *forOp.getOperation()
+                    << "\nHoist: " << *read.transferReadOp.getOperation()
+                    << "\nHoist: " << *write.transferWriteOp.getOperation()
+                    << "\nInvolving: " << tensorBBArg << "\n");
+
+  // If a read subtensor is present, hoist it.
+  if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp})))
+    llvm_unreachable("Unexpected failure moving subtensor out of loop");
+
+  // Hoist the transfer_read op.
+  if (failed(forOp.moveOutOfLoop({read.transferReadOp})))
+    llvm_unreachable("Unexpected failure moving transfer read out of loop");
+
+  // TODO: don't hardcode /*numIvs=*/1.
+  assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1);
+  unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1;
+
+  // Update the source tensor.
+  if (read.subTensorOp)
+    read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
+  else
+    read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]);
+
+  // Hoist write after.
+  if (write.subTensorInsertOp)
+    write.subTensorInsertOp->moveAfter(forOp);
+  write.transferWriteOp->moveAfter(forOp);
+
+  // Update the yield.
+  auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
+  if (write.subTensorInsertOp)
+    yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest());
+  else
+    yieldOp->setOperand(initArgNumber, write.transferWriteOp.source());
+
+  // Rewrite `loop` with additional new yields.
+  OpBuilder b(read.transferReadOp);
+  auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(),
+                                     write.transferWriteOp.vector());
+  // Transfer write has been hoisted, need to update the vector and tensor
+  // source. Replace the result of the loop to use the new tensor created
+  // outside the loop.
+  // Depending on whether a subtensor_insert is present or not, it carries the
+  // update on the tensor operands.
+  if (write.subTensorInsertOp) {
+    newForOp.getResult(initArgNumber)
+        .replaceAllUsesWith(write.subTensorInsertOp.getResult());
+    write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result());
+    write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source());
+  } else {
+    newForOp.getResult(initArgNumber)
+        .replaceAllUsesWith(write.transferWriteOp.getResult(0));
+    write.transferWriteOp.sourceMutable().assign(
+        newForOp.getResult(initArgNumber));
+  }
+
+  // Always update with the newly yield tensor and vector.
+  write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back());
+}
+
 // To hoist transfer op on tensor the logic can be significantly simplified
 // compared to the case on buffer. The transformation follows this logic:
 // 1. Look for transfer_write with a single use from ForOp yield
@@ -163,57 +387,48 @@ void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
     func.walk([&](scf::ForOp forOp) {
       Operation *yield = forOp.getBody()->getTerminator();
       for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
-        Value ret = yield->getOperand(it.index());
-        auto write = ret.getDefiningOp<vector::TransferWriteOp>();
-        if (!write || !write->hasOneUse())
+        OpOperand &ret = yield->getOpOperand(it.index());
+        HoistableWrite write =
+            getLoopInvariantTransferWriteOpDefining(forOp, ret);
+        if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse())
           continue;
-        LLVM_DEBUG(DBGS() << "Candidate write for hoisting: "
-                          << *write.getOperation() << "\n");
-        if (llvm::any_of(write.indices(), [&forOp](Value index) {
-              return !forOp.isDefinedOutsideOfLoop(index);
-            }))
+        LLVM_DEBUG(dbgs() << "\n";
+                   DBGS() << "Candidate write for hoisting: "
+                          << *write.transferWriteOp.getOperation() << "\n");
+        if (write.subTensorInsertOp)
+          LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: "
+                            << *write.subTensorInsertOp.getOperation() << "\n");
+        if (llvm::any_of(write.transferWriteOp.indices(),
+                         [&forOp](Value index) {
+                           return !forOp.isDefinedOutsideOfLoop(index);
+                         }))
           continue;
         // Find a read with the same type and indices.
-        vector::TransferReadOp matchingRead =
+        HoistableRead matchingRead =
             findMatchingTransferRead(write, it.value());
         // Make sure none of the other uses read the part of the tensor modified
         // by the transfer_write.
-        if (!matchingRead ||
+        if (!matchingRead.transferReadOp ||
             tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
           continue;
 
-        // Hoist read before.
-        if (failed(forOp.moveOutOfLoop({matchingRead})))
-          llvm_unreachable(
-              "Unexpected failure to move transfer read out of loop");
-        // Update the source tensor.
-        matchingRead.sourceMutable().assign(forOp.initArgs()[it.index()]);
-
-        // Hoist write after.
-        write->moveAfter(forOp);
-        yield->setOperand(it.index(), write.source());
-
-        // Rewrite `loop` with new yields by cloning and erase the original
-        // loop.
-        OpBuilder b(matchingRead);
-        auto newForOp =
-            cloneWithNewYields(b, forOp, matchingRead.vector(), write.vector());
-
-        // Transfer write has been hoisted, need to update the vector and tensor
-        // source. Replace the result of the loop to use the new tensor created
-        // outside the loop.
-        newForOp.getResult(it.index()).replaceAllUsesWith(write.getResult(0));
-        write.vectorMutable().assign(newForOp.getResults().back());
-        write.sourceMutable().assign(newForOp.getResult(it.index()));
-
+        LLVM_DEBUG(DBGS() << "Start hoisting\n");
+        hoistReadWrite(matchingRead, write, it.value());
         changed = true;
         forOp.erase();
-        // Need to interrupt and restart because erasing the loop messes up the
-        // walk.
+
+        // Need to interrupt and restart: erasing the loop messes up the walk.
         return WalkResult::interrupt();
       }
       return WalkResult::advance();
     });
+    // Apply canonicalization so the newForOp + yield folds immediately, thus
+    // cleaning up the IR and potentially enabling more hoisting.
+    if (changed) {
+      OwningRewritePatternList patterns;
+      scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
+      (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
+    }
   }
 }
 

diff  --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 504e85f4d4b1..540d1734dd63 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -1,5 +1,7 @@
-// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect | FileCheck %s
-// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect | FileCheck %s --check-prefix=VECTOR_TRANSFERS
+// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect -split-input-file | FileCheck %s --check-prefix=VECTOR_TRANSFERS
+
+// -----
 
 // CHECK-LABEL: func @hoist_allocs(
 //  CHECK-SAME:   %[[VAL:[a-zA-Z0-9]*]]: index,
@@ -82,6 +84,8 @@ func @hoist_allocs(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1
   return
 }
 
+// -----
+
 // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs(
 //  VECTOR_TRANSFERS-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  VECTOR_TRANSFERS-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
@@ -152,6 +156,8 @@ func @hoist_vector_transfer_pairs(
   return
 }
 
+// -----
+
 // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint(
 //  VECTOR_TRANSFERS-SAME:   %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
 //  VECTOR_TRANSFERS-SAME:   %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
@@ -231,6 +237,8 @@ func @hoist_vector_transfer_pairs_disjoint(
   return
 }
 
+// -----
+
 // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_tensor
 func @hoist_vector_transfer_pairs_tensor(
     %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
@@ -243,11 +251,10 @@ func @hoist_vector_transfer_pairs_tensor(
 
 // VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
 // VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) ->
-// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
 // VECTOR_TRANSFERS:   vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
 // VECTOR_TRANSFERS:   scf.for {{.*}} iter_args({{.*}}) ->
-// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
-// VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
 // VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
 // VECTOR_TRANSFERS:     "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
 // VECTOR_TRANSFERS:     vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
@@ -261,11 +268,11 @@ func @hoist_vector_transfer_pairs_tensor(
 // VECTOR_TRANSFERS:     vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
 // VECTOR_TRANSFERS:     "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
 // VECTOR_TRANSFERS:     scf.yield {{.*}} :
-// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
 // VECTOR_TRANSFERS:   }
 // VECTOR_TRANSFERS:   vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
 // VECTOR_TRANSFERS:   scf.yield {{.*}} :
-// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
 // VECTOR_TRANSFERS: }
 // VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
   %0:6 = scf.for %i = %lb to %ub step %step
@@ -280,7 +287,6 @@ func @hoist_vector_transfer_pairs_tensor(
        tensor<?x?xf32>, tensor<?x?xf32>)  {
       %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
       %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
-      %r2 = vector.transfer_read %arg8[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
       %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
       "some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
       %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
@@ -312,6 +318,8 @@ func @hoist_vector_transfer_pairs_tensor(
         tensor<?x?xf32>, tensor<?x?xf32>
 }
 
+// -----
+
 // VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
 //  VECTOR_TRANSFERS-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
 //  VECTOR_TRANSFERS-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
@@ -332,10 +340,10 @@ func @hoist_vector_transfer_pairs_disjoint_tensor(
 // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
 // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
 // VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
-// VECTOR_TRANSFERS: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) ->
-// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS: %[[R:.*]]:6 = scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
 // VECTOR_TRANSFERS:   scf.for {{.*}} iter_args({{.*}}) ->
-// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
 // VECTOR_TRANSFERS:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
 // VECTOR_TRANSFERS:     vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
 // VECTOR_TRANSFERS:     "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
@@ -349,15 +357,15 @@ func @hoist_vector_transfer_pairs_disjoint_tensor(
 // VECTOR_TRANSFERS:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
 // VECTOR_TRANSFERS:     vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
 // VECTOR_TRANSFERS:     scf.yield {{.*}} :
-// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
 // VECTOR_TRANSFERS:   }
 // VECTOR_TRANSFERS:   scf.yield {{.*}} :
-// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
 // VECTOR_TRANSFERS: }
-// VECTOR_TRANSFERS: %[[TENSOR4:.*]] = vector.transfer_write %{{.*}}, %[[R]]#3{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
-// VECTOR_TRANSFERS: %[[TENSOR5:.*]] = vector.transfer_write %{{.*}}, %[[R]]#2{{.*}} : vector<3xf32>, tensor<?x?xf32>
-// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: %[[TENSOR4:.*]] = vector.transfer_write %[[R]]#5, %[[TENSOR3]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:                   vector.transfer_write %[[R]]#4, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: %[[TENSOR5:.*]] = vector.transfer_write %[[R]]#3, %[[TENSOR2]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS:                   vector.transfer_write %[[R]]#2, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
   %0:4 = scf.for %i = %lb to %ub step %step
   iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
             %arg3 = %tensor3)
@@ -396,3 +404,111 @@ func @hoist_vector_transfer_pairs_disjoint_tensor(
   }
   return %0#0,  %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
 }
+
+// -----
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_tensor_and_subtensors
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR4:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+//  VECTOR_TRANSFERS-SAME:   %[[TENSOR5:[a-zA-Z0-9]*]]: tensor<?x?xf32>
+func @hoist_vector_transfer_pairs_tensor_and_subtensors(
+    %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
+    %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
+    %val: index, %lb : index, %ub : index, %step: index) ->
+    (
+      tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>//, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    ) {
+  %c0 = constant 0 : index
+  %cst = constant 0.0 : f32
+
+  //      VECTOR_TRANSFERS: scf.for %[[I:.*]] = {{.*}} iter_args(
+  // VECTOR_TRANSFERS-SAME:   %[[TENSOR0_ARG:[0-9a-zA-Z]+]] = %[[TENSOR0]],
+  // VECTOR_TRANSFERS-SAME:   %[[TENSOR1_ARG:[0-9a-zA-Z]+]] = %[[TENSOR1]],
+  // VECTOR_TRANSFERS-SAME:   %[[TENSOR2_ARG:[0-9a-zA-Z]+]] = %[[TENSOR2]]
+  // VECTOR_TRANSFERS-SAME: ) ->
+  // VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+  %0:3 = scf.for %i = %lb to %ub step %step
+  iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
+
+    // Hoisted
+    // VECTOR_TRANSFERS:   %[[ST0:.*]] = subtensor %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+    // VECTOR_TRANSFERS:   %[[V0:.*]] = vector.transfer_read %[[ST0]]{{.*}} : tensor<?x?xf32>, vector<1xf32>
+
+    //      VECTOR_TRANSFERS:   %[[R:.*]]:3 = scf.for %[[J:.*]] = {{.*}} iter_args(
+    // VECTOR_TRANSFERS-SAME:   %[[TENSOR1_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR1_ARG]]
+    // VECTOR_TRANSFERS-SAME:   %[[TENSOR2_ARG_L2:[0-9a-zA-Z]+]] = %[[TENSOR2_ARG]]
+    // VECTOR_TRANSFERS-SAME:   %[[V0_ARG_L2:[0-9a-zA-Z]+]] = %[[V0]]
+    // VECTOR_TRANSFERS-SAME: ) ->
+    // VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+    %1:3 = scf.for %j = %lb to %ub step %step
+    iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)  {
+      // Hoists.
+      %st0 = subtensor %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %r0 = vector.transfer_read %st0[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+
+      // VECTOR_TRANSFERS:     %[[ST1:.*]] = subtensor %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+      // VECTOR_TRANSFERS:     %[[V1:.*]] = vector.transfer_read %[[ST1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+      // Does not hoist (subtensor depends on %j)
+      %st1 = subtensor %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %r1 = vector.transfer_read %st1[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+
+      // VECTOR_TRANSFERS:     %[[ST2:.*]] = subtensor %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> to tensor<?x?xf32>
+      // VECTOR_TRANSFERS:     %[[V2:.*]] = vector.transfer_read %[[ST2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+      // Does not hoist, 2 subtensor %arg8.
+      %st2 = subtensor %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %r2 = vector.transfer_read %st2[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+
+      // VECTOR_TRANSFERS:     %[[U0:.*]] = "some_use"(%[[V0_ARG_L2]]) : (vector<1xf32>) -> vector<1xf32>
+      // VECTOR_TRANSFERS:     %[[U1:.*]] = "some_use"(%[[V1]]) : (vector<2xf32>) -> vector<2xf32>
+      // VECTOR_TRANSFERS:     %[[U2:.*]] = "some_use"(%[[V2]]) : (vector<3xf32>) -> vector<3xf32>
+      %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+      %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+      %u2 = "some_use"(%r2) : (vector<3xf32>) -> vector<3xf32>
+
+      // Hoists
+      %w0 = vector.transfer_write %u0, %st0[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+
+      // VECTOR_TRANSFERS-DAG:     %[[STI1:.*]] = vector.transfer_write %[[U1]], %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+      // Does not hoist (associated subtensor depends on %j).
+      %w1 = vector.transfer_write %u1, %st1[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+
+      // VECTOR_TRANSFERS-DAG:     %[[STI2:.*]] = vector.transfer_write %[[U2]], %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+      // Does not hoist, 2 subtensor / subtensor_insert for %arg8.
+      %w2 = vector.transfer_write %u2, %st2[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+
+      // Hoists.
+      %sti0 = subtensor_insert %w0 into %arg6[%i, %i][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+      // VECTOR_TRANSFERS-DAG:     subtensor_insert %[[STI1]] into %[[TENSOR1_ARG_L2]][%[[J]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
+      // Does not hoist (depends on %j).
+      %sti1 = subtensor_insert %w1 into %arg7[%j, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+      // VECTOR_TRANSFERS-DAG:     subtensor_insert %[[STI2]] into %[[TENSOR2_ARG_L2]][%[[I]],{{.*}}: tensor<?x?xf32> into tensor<?x?xf32>
+      // Does not hoist, 2 subtensor / subtensor_insert for %arg8.
+      %sti2 = subtensor_insert %w2 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+      %st22 = subtensor %sti2[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
+      %sti22 = subtensor_insert %st22 into %arg8[%i, %c0][%step, %step][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
+
+      // VECTOR_TRANSFERS:     scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+      // VECTOR_TRANSFERS:   }
+      scf.yield %sti0, %sti1, %sti22:
+        tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    }
+
+    // Hoisted
+    // VECTOR_TRANSFERS:   %[[STI0:.*]] = vector.transfer_write %[[R]]#2, %[[ST0]]{{.*}} : vector<1xf32>, tensor<?x?xf32>
+    // VECTOR_TRANSFERS:   subtensor_insert %[[STI0]] into %[[TENSOR0_ARG]][%[[I]], %[[I]]]{{.*}} : tensor<?x?xf32> into tensor<?x?xf32>
+
+    // VECTOR_TRANSFERS:   scf.yield {{.*}} : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+    scf.yield %1#0, %1#1, %1#2 :
+      tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+
+    // VECTOR_TRANSFERS: }
+  }
+  return %0#0, %0#1, %0#2 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}


        


More information about the Mlir-commits mailing list