[llvm-branch-commits] [mlir] 080943f - [mlir][vector] Support transfer op on tensor optimizations
Thomas Raoux via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jan 6 15:14:19 PST 2021
Author: Thomas Raoux
Date: 2021-01-06T15:09:03-08:00
New Revision: 080943f7525f277579a000cf30364cc96fba6773
URL: https://github.com/llvm/llvm-project/commit/080943f7525f277579a000cf30364cc96fba6773
DIFF: https://github.com/llvm/llvm-project/commit/080943f7525f277579a000cf30364cc96fba6773.diff
LOG: [mlir][vector] Support transfer op on tensor optimizations
Support store to load forwarding and dead store transformations for transfer op
on tensor.
Differential Revision: https://reviews.llvm.org/D94148
Added:
Modified:
mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
mlir/test/Dialect/Vector/vector-transferop-opt.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
index ea1189d53b31..161d02cd3435 100644
--- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp
@@ -34,13 +34,33 @@ static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
return op;
}
+/// Return true if the transfer_write fully writes the data accessed by the
+/// transfer_read.
+static bool transferEncompasses(vector::TransferWriteOp defWrite,
+ vector::TransferReadOp read) {
+ return !defWrite.hasMaskedDim() && defWrite.indices() == read.indices() &&
+ defWrite.getVectorType() == read.getVectorType() &&
+ defWrite.permutation_map() == read.permutation_map();
+}
+
+/// Return true if the write op fully over-write the priorWrite transfer_write
+/// op.
+static bool transferEncompasses(vector::TransferWriteOp write,
+ vector::TransferWriteOp priorWrite) {
+ return priorWrite.indices() == write.indices() &&
+ priorWrite.getVectorType() == write.getVectorType() &&
+ priorWrite.permutation_map() == write.permutation_map();
+}
+
namespace {
class TransferOptimization {
public:
TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {}
void deadStoreOp(vector::TransferWriteOp);
+ void deadStoreOpTensor(vector::TransferWriteOp);
void storeToLoadForwarding(vector::TransferReadOp);
+ void storeToLoadForwardingTensor(vector::TransferReadOp);
void removeDeadOp() {
for (Operation *op : opToErase)
op->erase();
@@ -99,9 +119,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store.
- if (write.indices() == nextWrite.indices() &&
- write.getVectorType() == nextWrite.getVectorType() &&
- write.permutation_map() == write.permutation_map() &&
+ if (transferEncompasses(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
postDominators.postDominates(firstOverwriteCandidate, nextWrite))
@@ -173,10 +191,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation())))
continue;
- if (dominators.dominates(write, read) && !write.hasMaskedDim() &&
- write.indices() == read.indices() &&
- write.getVectorType() == read.getVectorType() &&
- write.permutation_map() == read.permutation_map()) {
+ if (dominators.dominates(write, read) &&
+ transferEncompasses(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
else
@@ -214,15 +230,62 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
opToErase.push_back(read.getOperation());
}
+/// Walk up the SSA links, if any write gets fully overwritten we can skip it.
+/// If it has no more uses it becomes dead.
+void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) {
+ auto defWrite = write.source().getDefiningOp<vector::TransferWriteOp>();
+ while (defWrite) {
+ if (transferEncompasses(write, defWrite)) {
+ write.sourceMutable().assign(defWrite.source());
+ if (defWrite->use_empty())
+ opToErase.push_back(defWrite.getOperation());
+ return;
+ }
+ if (!isDisjointTransferIndices(
+ cast<VectorTransferOpInterface>(defWrite.getOperation()),
+ cast<VectorTransferOpInterface>(write.getOperation())))
+ break;
+ defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+ }
+}
+
+/// Walk up the SSA links, if any write fully match the written vector we can
+/// replace the read by the vector. The read becomes dead and can be removed.
+void TransferOptimization::storeToLoadForwardingTensor(
+ vector::TransferReadOp read) {
+ auto defWrite = read.source().getDefiningOp<vector::TransferWriteOp>();
+ while (defWrite) {
+ if (transferEncompasses(defWrite, read)) {
+ read.replaceAllUsesWith(defWrite.vector());
+ opToErase.push_back(read.getOperation());
+ return;
+ }
+ if (!isDisjointTransferIndices(
+ cast<VectorTransferOpInterface>(defWrite.getOperation()),
+ cast<VectorTransferOpInterface>(read.getOperation())))
+ break;
+ defWrite = defWrite.source().getDefiningOp<vector::TransferWriteOp>();
+ }
+}
+
} // namespace
void mlir::vector::transferOpflowOpt(FuncOp func) {
TransferOptimization opt(func);
// Run store to load forwarding first since it can expose more dead store
// opportunity.
- func.walk(
- [&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); });
+ func.walk([&](vector::TransferReadOp read) {
+ if (read.getShapedType().isa<MemRefType>())
+ opt.storeToLoadForwarding(read);
+ else
+ opt.storeToLoadForwardingTensor(read);
+ });
opt.removeDeadOp();
- func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); });
+ func.walk([&](vector::TransferWriteOp write) {
+ if (write.getShapedType().isa<MemRefType>())
+ opt.deadStoreOp(write);
+ else
+ opt.deadStoreOpTensor(write);
+ });
opt.removeDeadOp();
}
diff --git a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
index 0ed061cab4d8..30464a135e29 100644
--- a/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
+++ b/mlir/test/Dialect/Vector/vector-transferop-opt.mlir
@@ -13,16 +13,16 @@ func @forward_dead_store(%arg0: i1, %arg1 : memref<4x4xf32>,
%c4 = constant 4 : index
%c0 = constant 0 : index
%cf0 = constant 0.0 : f32
- vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+ vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
vector<1x4xf32>, memref<4x4xf32>
- %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
+ %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} :
memref<4x4xf32>, vector<1x4xf32>
- %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
+ %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
-> (vector<1x4xf32>) {
%1 = addf %acc, %acc : vector<1x4xf32>
scf.yield %1 : vector<1x4xf32>
}
- vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
+ vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} :
vector<1x4xf32>, memref<4x4xf32>
return
}
@@ -103,7 +103,7 @@ func @forward_nested_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
// CHECK: vector.transfer_read
// CHECK: return
func @dead_store_region(%arg0: i1, %arg1 : memref<4x4xf32>,
- %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index)
+ %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index)
-> (vector<1x4xf32>) {
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -184,3 +184,56 @@ func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>,
return
}
+// CHECK-LABEL: func @forward_dead_store_tensor
+// CHECK-NOT: vector.transfer_write
+// CHECK-NOT: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: %[[VTW:.*]] = vector.transfer_write
+// CHECK: return %[[VTW]] : tensor<4x4xf32>
+func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
+ %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} :
+ vector<1x4xf32>, tensor<4x4xf32>
+ %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
+ tensor<4x4xf32>, vector<1x4xf32>
+ %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
+ -> (vector<1x4xf32>) {
+ %1 = addf %acc, %acc : vector<1x4xf32>
+ scf.yield %1 : vector<1x4xf32>
+ }
+ %w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
+ vector<1x4xf32>, tensor<4x4xf32>
+ return %w1 : tensor<4x4xf32>
+}
+
+// CHECK-LABEL: func @forward_dead_store_negative_tensor
+// CHECK: vector.transfer_write
+// CHECK: vector.transfer_read
+// CHECK: scf.for
+// CHECK: }
+// CHECK: %[[VTW:.*]] = vector.transfer_write
+// CHECK: return %[[VTW]] : tensor<4x4xf32>
+func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>,
+ %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> {
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %c0 = constant 0 : index
+ %cf0 = constant 0.0 : f32
+ %w0 = vector.transfer_write %v0, %arg1[%c1, %i] {masked = [false, false]} :
+ vector<1x4xf32>, tensor<4x4xf32>
+ %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {masked = [false, false]} :
+ tensor<4x4xf32>, vector<1x4xf32>
+ %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0)
+ -> (vector<1x4xf32>) {
+ %1 = addf %acc, %acc : vector<1x4xf32>
+ scf.yield %1 : vector<1x4xf32>
+ }
+ %w1 = vector.transfer_write %x, %w0[%c1, %c0] {masked = [false, false]} :
+ vector<1x4xf32>, tensor<4x4xf32>
+ return %w1 : tensor<4x4xf32>
+}
More information about the llvm-branch-commits
mailing list