[Mlir-commits] [mlir] bebf5d5 - [mlir][Linalg] Add support for vector.transfer ops to comprehensive bufferization (2/n).
Nicolas Vasilache
llvmlistbot at llvm.org
Thu May 13 15:36:11 PDT 2021
Author: Nicolas Vasilache
Date: 2021-05-13T22:26:28Z
New Revision: bebf5d56bff75cd5b74b58cbdcb965885a82916f
URL: https://github.com/llvm/llvm-project/commit/bebf5d56bff75cd5b74b58cbdcb965885a82916f
DIFF: https://github.com/llvm/llvm-project/commit/bebf5d56bff75cd5b74b58cbdcb965885a82916f.diff
LOG: [mlir][Linalg] Add support for vector.transfer ops to comprehensive bufferization (2/n).
Differential revision: https://reviews.llvm.org/D102395
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 2534eeeb7dcd..acac96ea88d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -82,8 +82,8 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Operation.h"
-#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/BufferUtils.h"
@@ -128,16 +128,25 @@ OpResult getMatchingOpResult(LinalgOp linalgOp, OpOperand &opOperand) {
return linalgOp->getResult(outputOperandIndex - numOutputBuffers);
}
+OpResult getMatchingOpResult(VectorTransferOpInterface op,
+ OpOperand &opOperand) {
+ if (opOperand.get() != op.source() ||
+ !op.source().getType().isa<TensorType>())
+ return OpResult();
+ return op->getResult(0);
+}
+
/// Determine which results may be reused inplace by the bufferization
/// patterns of `bufferizeFuncOpInternals`.
/// The inplace analysis uses this information along with interfering read
/// analysis to determine which op results reuse the same buffer as some
/// operand.
OpResult getMatchingOpResult(OpOperand &opOperand) {
- OpResult res =
- llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
- .Case([&](LinalgOp op) { return getMatchingOpResult(op, opOperand); })
- .Default([&](Operation *op) { return OpResult(); });
+ OpResult res = llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
+ .Case<LinalgOp, VectorTransferOpInterface>([&](auto op) {
+ return getMatchingOpResult(op, opOperand);
+ })
+ .Default([&](Operation *op) { return OpResult(); });
return res;
}
@@ -708,6 +717,54 @@ static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
return success();
}
+static LogicalResult convertTransferOp(OpBuilder &b,
+ VectorTransferOpInterface op,
+ BlockAndValueMapping &bvm) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(op);
+ Location loc = op.getLoc();
+
+ if (op.getShapedType().isa<MemRefType>())
+ return failure();
+
+ LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
+
+ /// transfer_read from buffer
+ if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
+ readOp.sourceMutable().assign(lookup(bvm, op.source()));
+ return success();
+ }
+
+ auto inPlace = getInPlace(op->getResult(0));
+ auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
+
+ // If transfer_write is not inPlace, allocate a new buffer.
+ Value newInputBuffer;
+ if (inPlace != InPlaceSpec::True) {
+ newInputBuffer =
+ createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result());
+ b.setInsertionPointAfter(newInputBuffer.getDefiningOp());
+ map(bvm, writeOp.result(), newInputBuffer);
+ } else {
+ // InPlace write will result in memref.tensor_load(x) which must
+ // canonicalize away with one of it uses.
+ newInputBuffer = lookup(bvm, writeOp.source());
+ }
+
+ // Create a new transfer_write on buffer that doesn't have a return value.
+ // Leave the previous transfer_write to dead code as it still has uses at
+ // this point.
+ b.create<vector::TransferWriteOp>(
+ loc, writeOp.vector(), newInputBuffer, writeOp.indices(),
+ writeOp.permutation_map(),
+ writeOp.in_bounds() ? *writeOp.in_bounds() : ArrayAttr());
+
+ map(bvm, op->getResult(0), newInputBuffer);
+
+ return success();
+}
+
static void inPlaceAnalysisFuncOpInternals(FuncOp funcOp,
const DominanceInfo &domInfo) {
assert(funcOp && funcOp->getNumRegions() > 0 && !funcOp.body().empty() &&
@@ -733,6 +790,9 @@ static LogicalResult bufferizeFuncOpInternals(
.Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); })
.Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); })
.Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); })
+ .Case([&](VectorTransferOpInterface op) {
+ return convertTransferOp(b, op, bvm);
+ })
.Default([&](Operation *op) {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
index 69c4e3fe5919..a5636329fd3a 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -81,3 +81,39 @@ func @not_inplace(%A : tensor<?x?xf32> {linalg.inplaceable = true}) -> tensor<?x
-> tensor<?x?xf32>
return %r: tensor<?x?xf32>
}
+// -----
+
+// CHECK-LABEL: func @vec_inplace
+func @vec_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
+ -> tensor<?xf32>
+{
+ %c0 = constant 0 : index
+ // CHECK-NOT: alloc
+ %r = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
+ return %r: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @vec_not_inplace
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?xf32> {linalg.inplaceable = true}
+func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>)
+{
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+
+ // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]>
+
+ /// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
+ // CHECK: %[[ALLOC:.*]] = memref.alloc
+ // CHECK-NEXT: vector.transfer_write {{.*}}, %[[ALLOC]]
+ %r0 = vector.transfer_write %vec, %A[%c0] : vector<4xf32>, tensor<?xf32>
+
+ /// The second vector.transfer has no interfering reads and can reuse the buffer.
+ // CHECK-NOT: alloc
+ // CHECK-NEXT: vector.transfer_write {{.*}}, %[[BUFFER_CAST]]
+ %r1 = vector.transfer_write %vec, %A[%c1] : vector<4xf32>, tensor<?xf32>
+ return %r0, %r1: tensor<?xf32>, tensor<?xf32>
+}
+
More information about the Mlir-commits
mailing list