[Mlir-commits] [mlir] efd0504 - [mlir] Add hoisting transformation for transfer ops on tensor
Thomas Raoux
llvmlistbot at llvm.org
Wed Jan 6 14:24:21 PST 2021
Author: Thomas Raoux
Date: 2021-01-06T14:23:59-08:00
New Revision: efd05040e13e942a4fbb79eb798fb9833e319b51
URL: https://github.com/llvm/llvm-project/commit/efd05040e13e942a4fbb79eb798fb9833e319b51
DIFF: https://github.com/llvm/llvm-project/commit/efd05040e13e942a4fbb79eb798fb9833e319b51.diff
LOG: [mlir] Add hoisting transformation for transfer ops on tensor
Add same hoisting transformation existing for transfer ops on buffers for
transfer_ops on tensor. The logic is significantly different so this is done as
a separate transformation and it is expect that user would know which
transformation to use based on the flow.
Differential Revision: https://reviews.llvm.org/D94115
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
mlir/include/mlir/Dialect/Vector/VectorUtils.h
mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
mlir/lib/Dialect/Vector/VectorUtils.cpp
mlir/test/Dialect/Linalg/hoisting.mlir
mlir/test/lib/Transforms/TestLinalgHoisting.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
index 32693555ff40..ed585d1f5cf5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h
@@ -21,8 +21,9 @@ namespace linalg {
// TODO: generalize on a per-need basis.
void hoistViewAllocOps(FuncOp func);
-/// Hoist vector.transfer_read/vector.transfer_write pairs out of immediately
-/// enclosing scf::ForOp iteratively, if the following conditions are true:
+/// Hoist vector.transfer_read/vector.transfer_write on buffers pairs out of
+/// immediately enclosing scf::ForOp iteratively, if the following conditions
+/// are true:
/// 1. The two ops access the same memref with the same indices.
/// 2. All operands are invariant under the enclosing scf::ForOp.
/// 3. No uses of the memref either dominate the transfer_read or are
@@ -35,6 +36,10 @@ void hoistViewAllocOps(FuncOp func);
// TODO: generalize on a per-need basis.
void hoistRedundantVectorTransfers(FuncOp func);
+/// Same behavior as `hoistRedundantVectorTransfers` but works on tensors
+/// instead of buffers.
+void hoistRedundantVectorTransfersOnTensor(FuncOp func);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
index a06bc8cf6562..666603250f0a 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h
@@ -165,6 +165,12 @@ AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
bool isDisjointTransferSet(VectorTransferOpInterface transferA,
VectorTransferOpInterface transferB);
+/// Same behavior as `isDisjointTransferSet` but doesn't require the operations
+/// to have the same tensor/memref. This allows comparing operations accessing
+///
diff erent tensors.
+bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
+ VectorTransferOpInterface transferB);
+
namespace matcher {
/// Matches vector.transfer_read, vector.transfer_write and ops that return a
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index a1797fde7da6..98d61fa6a8d9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -81,12 +81,151 @@ void mlir::linalg::hoistViewAllocOps(FuncOp func) {
}
}
+/// Look for a transfer_read, in the given tensor uses, accessing the same
+/// offset as the transfer_write.
+static vector::TransferReadOp
+findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) {
+ for (Operation *user : srcTensor.getUsers()) {
+ auto read = dyn_cast<vector::TransferReadOp>(user);
+ if (read && read.indices() == write.indices() &&
+ read.getVectorType() == write.getVectorType()) {
+ return read;
+ }
+ }
+ return nullptr;
+}
+
+/// Check if the chunk of data inserted by the transfer_write in the given
+/// tensor are read by any other op than the read candidate.
+static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write,
+ vector::TransferReadOp candidateRead,
+ Value srcTensor) {
+ // Make sure none of the other uses read the part of the tensor modified
+ // by the transfer_write.
+ llvm::SmallVector<Value::use_range, 1> uses;
+ uses.push_back(srcTensor.getUses());
+ while (!uses.empty()) {
+ for (OpOperand &use : uses.pop_back_val()) {
+ Operation *user = use.getOwner();
+ // Skip the candidate use, only inspect the "other" uses.
+ if (user == candidateRead.getOperation() || user == write.getOperation())
+ continue;
+ // Consider all transitive uses through a vector.transfer_write.
+ if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) {
+ uses.push_back(writeUser->getResult(0).getUses());
+ continue;
+ }
+ // Consider all nested uses through an scf::ForOp. We may have
+ // pass-through tensor arguments left from previous level of
+ // hoisting.
+ if (auto forUser = dyn_cast<scf::ForOp>(user)) {
+ Value arg = forUser.getLoopBody().getArgument(
+ use.getOperandNumber() - forUser.getNumControlOperands() +
+ /*iv value*/ 1);
+ uses.push_back(arg.getUses());
+ continue;
+ }
+ // Follow the use yield as long as it doesn't escape the original
+ // region.
+ scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user);
+ if (yieldUser &&
+ write->getParentOp()->isAncestor(yieldUser->getParentOp())) {
+ Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber());
+ uses.push_back(ret.getUses());
+ continue;
+ }
+ auto read = dyn_cast<vector::TransferReadOp>(user);
+ if (!read || !isDisjointTransferIndices(
+ cast<VectorTransferOpInterface>(read.getOperation()),
+ cast<VectorTransferOpInterface>(write.getOperation()))) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+// To hoist transfer op on tensor the logic can be significantly simplified
+// compared to the case on buffer. The transformation follows this logic:
+// 1. Look for transfer_write with a single use from ForOp yield
+// 2. Check the uses of the matching block argument and look for a transfer_read
+// with the same indices.
+// 3. Check that all the other uses of the tensor argument are either disjoint
+// tensor_read or transfer_write. For transfer_write uses recurse to make sure
+// the new tensor has the same restrictions on its uses.
+// 4. Hoist the tensor_read/tensor_write and update the tensor SSA links.
+// After this transformation the scf.forOp may have unused arguments that can be
+// remove by the canonicalization pass.
+void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ func.walk([&](scf::ForOp forOp) {
+ Operation *yield = forOp.getBody()->getTerminator();
+ for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) {
+ Value ret = yield->getOperand(it.index());
+ auto write = ret.getDefiningOp<vector::TransferWriteOp>();
+ if (!write || !write->hasOneUse())
+ continue;
+ LLVM_DEBUG(DBGS() << "Candidate write for hoisting: "
+ << *write.getOperation() << "\n");
+ if (llvm::any_of(write.indices(), [&forOp](Value index) {
+ return !forOp.isDefinedOutsideOfLoop(index);
+ }))
+ continue;
+ // Find a read with the same type and indices.
+ vector::TransferReadOp matchingRead =
+ findMatchingTransferRead(write, it.value());
+ // Make sure none of the other uses read the part of the tensor modified
+ // by the transfer_write.
+ if (!matchingRead ||
+ tensorChunkAccessedByUnknownOp(write, matchingRead, it.value()))
+ continue;
+
+ // Hoist read before.
+ if (failed(forOp.moveOutOfLoop({matchingRead})))
+ llvm_unreachable(
+ "Unexpected failure to move transfer read out of loop");
+ // Update the source tensor.
+ matchingRead.sourceMutable().assign(forOp.initArgs()[it.index()]);
+
+ // Hoist write after.
+ write->moveAfter(forOp);
+ yield->setOperand(it.index(), write.source());
+
+ // Rewrite `loop` with new yields by cloning and erase the original
+ // loop.
+ OpBuilder b(matchingRead);
+ auto newForOp =
+ cloneWithNewYields(b, forOp, matchingRead.vector(), write.vector());
+
+ // Transfer write has been hoisted, need to update the vector and tensor
+ // source. Replace the result of the loop to use the new tensor created
+ // outside the loop.
+ newForOp.getResult(it.index()).replaceAllUsesWith(write.getResult(0));
+ write.vectorMutable().assign(newForOp.getResults().back());
+ write.sourceMutable().assign(newForOp.getResult(it.index()));
+
+ changed = true;
+ forOp.erase();
+ // Need to interrupt and restart because erasing the loop messes up the
+ // walk.
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ }
+}
+
void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
bool changed = true;
while (changed) {
changed = false;
func.walk([&](vector::TransferReadOp transferRead) {
+ if (!transferRead.getShapedType().isa<MemRefType>())
+ return WalkResult::advance();
+
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *transferRead.getOperation() << "\n");
auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp());
diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp
index fc08d21b27a5..ef3ef3db1f81 100644
--- a/mlir/lib/Dialect/Vector/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp
@@ -312,10 +312,8 @@ bool matcher::operatesOnSuperVectorsOf(Operation &op,
return true;
}
-bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
- VectorTransferOpInterface transferB) {
- if (transferA.source() != transferB.source())
- return false;
+bool mlir::isDisjointTransferIndices(VectorTransferOpInterface transferA,
+ VectorTransferOpInterface transferB) {
// For simplicity only look at transfer of same type.
if (transferA.getVectorType() != transferB.getVectorType())
return false;
@@ -345,3 +343,10 @@ bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
}
return false;
}
+
+bool mlir::isDisjointTransferSet(VectorTransferOpInterface transferA,
+ VectorTransferOpInterface transferB) {
+ if (transferA.source() != transferB.source())
+ return false;
+ return isDisjointTransferIndices(transferA, transferB);
+}
diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir
index 4a6fca554a09..504e85f4d4b1 100644
--- a/mlir/test/Dialect/Linalg/hoisting.mlir
+++ b/mlir/test/Dialect/Linalg/hoisting.mlir
@@ -230,3 +230,169 @@ func @hoist_vector_transfer_pairs_disjoint(
}
return
}
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_tensor
+func @hoist_vector_transfer_pairs_tensor(
+ %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>, %tensor2: tensor<?x?xf32>,
+ %tensor3: tensor<?x?xf32>, %tensor4: tensor<?x?xf32>, %tensor5: tensor<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index) ->
+ (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>) {
+ %c0 = constant 0 : index
+ %cst = constant 0.0 : f32
+
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<1xf32>
+// VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>) {
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>) {
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : tensor<?x?xf32>, vector<5xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (tensor<?x?xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<5xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: "some_crippling_use"(%{{.*}}) : (tensor<?x?xf32>) -> ()
+// VECTOR_TRANSFERS: scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<2xf32>, vector<1xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<1xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, tensor<?x?xf32>
+ %0:6 = scf.for %i = %lb to %ub step %step
+ iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+ %arg3 = %tensor3, %arg4 = %tensor4, %arg5 = %tensor5)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>) {
+ %1:6 = scf.for %j = %lb to %ub step %step
+ iter_args(%arg6 = %arg0, %arg7 = %arg1, %arg8 = %arg2,
+ %arg9 = %arg3, %arg10 = %arg4, %arg11 = %arg5)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>) {
+ %r0 = vector.transfer_read %arg7[%c0, %c0], %cst: tensor<?x?xf32>, vector<1xf32>
+ %r1 = vector.transfer_read %arg6[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r2 = vector.transfer_read %arg8[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+ %r3 = vector.transfer_read %arg9[%c0, %c0], %cst: tensor<?x?xf32>, vector<4xf32>
+ "some_crippling_use"(%arg10) : (tensor<?x?xf32>) -> ()
+ %r4 = vector.transfer_read %arg10[%c0, %c0], %cst: tensor<?x?xf32>, vector<5xf32>
+ %r5 = vector.transfer_read %arg11[%c0, %c0], %cst: tensor<?x?xf32>, vector<6xf32>
+ "some_crippling_use"(%arg11) : (tensor<?x?xf32>) -> ()
+ %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
+ %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
+ %u2 = "some_use"(%arg8) : (tensor<?x?xf32>) -> vector<3xf32>
+ %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
+ %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
+ %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
+ %w1 = vector.transfer_write %u0, %arg7[%c0, %c0] : vector<1xf32>, tensor<?x?xf32>
+ %w0 = vector.transfer_write %u1, %arg6[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+ %w2 = vector.transfer_write %u2, %arg8[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+ %w3 = vector.transfer_write %u3, %arg9[%c0, %c0] : vector<4xf32>, tensor<?x?xf32>
+ %w4 = vector.transfer_write %u4, %arg10[%c0, %c0] : vector<5xf32>, tensor<?x?xf32>
+ %w5 = vector.transfer_write %u5, %arg11[%c0, %c0] : vector<6xf32>, tensor<?x?xf32>
+ "some_crippling_use"(%w3) : (tensor<?x?xf32>) -> ()
+ scf.yield %w0, %w1, %w2, %w3, %w4, %w5 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %1#0, %1#1, %1#2, %1#3, %1#4, %1#5 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 :
+ tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>,
+ tensor<?x?xf32>, tensor<?x?xf32>
+}
+
+// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint_tensor(
+// VECTOR_TRANSFERS-SAME: %[[TENSOR0:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[TENSOR1:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[TENSOR2:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// VECTOR_TRANSFERS-SAME: %[[TENSOR3:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+func @hoist_vector_transfer_pairs_disjoint_tensor(
+ %tensor0: tensor<?x?xf32>, %tensor1: tensor<?x?xf32>,
+ %tensor2: tensor<?x?xf32>, %tensor3: tensor<?x?xf32>,
+ %val: index, %lb : index, %ub : index, %step: index,
+ %random_index : index) ->
+ (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c3 = constant 3 : index
+ %cst = constant 0.0 : f32
+
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR2]]{{.*}} : tensor<?x?xf32>, vector<3xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR3]]{{.*}} : tensor<?x?xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: %[[R:.*]]:8 = scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS: scf.for {{.*}} iter_args({{.*}}) ->
+// VECTOR_TRANSFERS-SAME: (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS: vector.transfer_read %[[TENSOR1]]{{.*}} : tensor<?x?xf32>, vector<2xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %{{.*}}{{.*}} : vector<2xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: scf.yield {{.*}} :
+// VECTOR_TRANSFERS-SAME: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>
+// VECTOR_TRANSFERS: }
+// VECTOR_TRANSFERS: %[[TENSOR4:.*]] = vector.transfer_write %{{.*}}, %[[R]]#3{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR4]]{{.*}} : vector<4xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: %[[TENSOR5:.*]] = vector.transfer_write %{{.*}}, %[[R]]#2{{.*}} : vector<3xf32>, tensor<?x?xf32>
+// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[TENSOR5]]{{.*}} : vector<3xf32>, tensor<?x?xf32>
+ %0:4 = scf.for %i = %lb to %ub step %step
+ iter_args(%arg0 = %tensor0, %arg1 = %tensor1, %arg2 = %tensor2,
+ %arg3 = %tensor3)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ %1:4 = scf.for %j = %lb to %ub step %step
+ iter_args(%arg4 = %arg0, %arg5 = %arg1, %arg6 = %arg2,
+ %arg7 = %arg3)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
+ %r00 = vector.transfer_read %arg5[%c0, %c0], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r01 = vector.transfer_read %arg5[%c0, %c1], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r20 = vector.transfer_read %arg6[%c0, %c0], %cst: tensor<?x?xf32>, vector<3xf32>
+ %r21 = vector.transfer_read %arg6[%c0, %c3], %cst: tensor<?x?xf32>, vector<3xf32>
+ %r30 = vector.transfer_read %arg7[%c0, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+ %r31 = vector.transfer_read %arg7[%c1, %random_index], %cst: tensor<?x?xf32>, vector<4xf32>
+ %r10 = vector.transfer_read %arg4[%i, %i], %cst: tensor<?x?xf32>, vector<2xf32>
+ %r11 = vector.transfer_read %arg4[%random_index, %random_index], %cst: tensor<?x?xf32>, vector<2xf32>
+ %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32>
+ %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32>
+ %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32>
+ %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32>
+ %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32>
+ %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32>
+ %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32>
+ %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32>
+ %w10 = vector.transfer_write %u00, %arg5[%c0, %c0] : vector<2xf32>, tensor<?x?xf32>
+ %w11 = vector.transfer_write %u01, %w10[%c0, %c1] : vector<2xf32>, tensor<?x?xf32>
+ %w20 = vector.transfer_write %u20, %arg6[%c0, %c0] : vector<3xf32>, tensor<?x?xf32>
+ %w21 = vector.transfer_write %u21, %w20[%c0, %c3] : vector<3xf32>, tensor<?x?xf32>
+ %w30 = vector.transfer_write %u30, %arg7[%c0, %random_index] : vector<4xf32>, tensor<?x?xf32>
+ %w31 = vector.transfer_write %u31, %w30[%c1, %random_index] : vector<4xf32>, tensor<?x?xf32>
+ %w00 = vector.transfer_write %u10, %arg4[%i, %i] : vector<2xf32>, tensor<?x?xf32>
+ %w01 = vector.transfer_write %u11, %w00[%random_index, %random_index] : vector<2xf32>, tensor<?x?xf32>
+ scf.yield %w01, %w11, %w21, %w31 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %1#0, %1#1, %1#2, %1#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ return %0#0, %0#1, %0#2, %0#3 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
+}
diff --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
index d78833e78f15..76d41f1fcdc4 100644
--- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
@@ -47,6 +47,7 @@ void TestLinalgHoisting::runOnFunction() {
}
if (testHoistRedundantTransfers) {
hoistRedundantVectorTransfers(getFunction());
+ hoistRedundantVectorTransfersOnTensor(getFunction());
return;
}
}
More information about the Mlir-commits
mailing list