[Mlir-commits] [mlir] 6f90955 - [mlir][Linalg] Add support for subtensor_insert comprehensive bufferization (3/n)
Nicolas Vasilache
llvmlistbot at llvm.org
Fri May 14 14:54:35 PDT 2021
Author: Nicolas Vasilache
Date: 2021-05-14T21:51:00Z
New Revision: 6f90955f6949397e20d359e4b914c2d46b35f863
URL: https://github.com/llvm/llvm-project/commit/6f90955f6949397e20d359e4b914c2d46b35f863
DIFF: https://github.com/llvm/llvm-project/commit/6f90955f6949397e20d359e4b914c2d46b35f863.diff
LOG: [mlir][Linalg] Add support for subtensor_insert comprehensive bufferization (3/n)
Differential revision: https://reviews.llvm.org/D102417
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 758cde2880312..0f271c2f27983 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -128,6 +128,8 @@ OpResult getMatchingOpResult(LinalgOp linalgOp, OpOperand &opOperand) {
return linalgOp->getResult(outputOperandIndex - numOutputBuffers);
}
+/// Return the OpResult that matches an operand.
+/// Return null if no such result exists.
OpResult getMatchingOpResult(VectorTransferOpInterface op,
OpOperand &opOperand) {
if (opOperand.get() != op.source() ||
@@ -136,17 +138,25 @@ OpResult getMatchingOpResult(VectorTransferOpInterface op,
return op->getResult(0);
}
+/// Return the OpResult that matches an operand.
+/// Return null if no such result exists.
+OpResult getMatchingOpResult(SubTensorInsertOp op, OpOperand &opOperand) {
+ if (opOperand.get() != op.dest())
+ return OpResult();
+ return op->getResult(0);
+}
+
/// Determine which results may be reused inplace by the bufferization
/// patterns of `bufferizeFuncOpInternals`.
/// The inplace analysis uses this information along with interfering read
/// analysis to determine which op results reuse the same buffer as some
/// operand.
OpResult getMatchingOpResult(OpOperand &opOperand) {
- OpResult res = llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
- .Case<LinalgOp, VectorTransferOpInterface>([&](auto op) {
- return getMatchingOpResult(op, opOperand);
- })
- .Default([&](Operation *op) { return OpResult(); });
+ OpResult res =
+ llvm::TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
+ .Case<LinalgOp, SubTensorInsertOp, VectorTransferOpInterface>(
+ [&](auto op) { return getMatchingOpResult(op, opOperand); })
+ .Default([&](Operation *op) { return OpResult(); });
return res;
}
@@ -644,8 +654,8 @@ static void finalizeBufferAllocation(OpBuilder &b, LinalgOp op,
/// Generic conversion for any LinalgOp.
/// Operate on mixed tensor + buffer Linalg ops for progressive bufferization.
-static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op,
- BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
+ BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -668,16 +678,16 @@ static LogicalResult convertAnyLinalgOp(OpBuilder &b, LinalgOp op,
/// DimOp tensor operand is modified inplace. This allows leaving dead tensors
/// behind that will get DCE'd.
-static LogicalResult convertDimOp(OpBuilder &b, memref::DimOp dimOp,
- BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp,
+ BlockAndValueMapping &bvm) {
if (dimOp.memrefOrTensor().getType().isa<RankedTensorType>())
dimOp.memrefOrTensorMutable().assign(lookup(bvm, dimOp.memrefOrTensor()));
return success();
}
/// FuncOp always creates TensorToMemRef ops.
-static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp,
- BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
+ BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
@@ -699,8 +709,8 @@ static LogicalResult convertFuncOp(OpBuilder &b, FuncOp funcOp,
}
/// ReturnOp always creates memref::TensorLoadOp.
-static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
- BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
+ BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(returnOp);
@@ -717,9 +727,69 @@ static LogicalResult convertReturnOp(OpBuilder &b, ReturnOp returnOp,
return success();
}
-static LogicalResult convertTransferOp(OpBuilder &b,
- VectorTransferOpInterface op,
- BlockAndValueMapping &bvm) {
+static LogicalResult bufferize(OpBuilder &b,
+ SubTensorInsertOp subTensorInsertOp,
+ BlockAndValueMapping &bvm) {
+ LLVM_DEBUG(DBGS() << "bufferize: " << *subTensorInsertOp << "\n");
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(subTensorInsertOp);
+ Location loc = subTensorInsertOp.getLoc();
+
+ Value dstMemref = lookup(bvm, subTensorInsertOp.dest());
+ auto inPlace = getInPlace(subTensorInsertOp->getResult(0));
+ if (inPlace != InPlaceSpec::True) {
+ // Since subtensor_insert arise from tiling and introducing loops, this case
+ // is generally a deal breaker. When used with loops, this ends up cloning
+ // the whole tensor on every single iteration and is a symtpom of a
+ // catastrophically bad scheduling decision.
+ // TODO: be very loud about it or even consider failing the pass.
+ Value newDstMemref = createNewAllocDeallocPairForShapedValue(
+ b, loc, subTensorInsertOp.result());
+ b.setInsertionPointAfter(newDstMemref.getDefiningOp());
+ b.create<CopyOp>(subTensorInsertOp.getLoc(), dstMemref, newDstMemref);
+ dstMemref = newDstMemref;
+ }
+ auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+
+ Value srcMemref = lookup(bvm, subTensorInsertOp.source());
+ auto subviewMemRefType =
+ memref::SubViewOp::inferRankReducedResultType(
+ subTensorInsertOp.getSourceType().getRank(), dstMemrefType,
+ subTensorInsertOp.getMixedOffsets(),
+ subTensorInsertOp.getMixedSizes(),
+ subTensorInsertOp.getMixedStrides())
+ .cast<MemRefType>();
+
+ // A copy of the source buffer is needed if either:
+ // - The producer of `source` is not inplace. This is the case where a
+ // subtensor is computed out of place into the inplace full tensor.
+ // - The result is not inplace. This is the case where the whole tensor is
+ // cloned and the clone needs to be updated.
+ Value source = subTensorInsertOp.source();
+ InPlaceSpec inPlaceProducer = InPlaceSpec::None;
+ if (auto opResult = source.dyn_cast<OpResult>())
+ inPlaceProducer = getInPlace(opResult);
+ else
+ inPlaceProducer = getInPlace(source.cast<BlockArgument>());
+ if (inPlaceProducer != InPlaceSpec::True) {
+ LLVM_DEBUG(DBGS() << "subtensor_insert needs extra source copy: " << source
+ << " -> copy\n");
+ // Take a subview of the dst.
+ Value subView = b.create<memref::SubViewOp>(
+ loc, subviewMemRefType, dstMemref, subTensorInsertOp.getMixedOffsets(),
+ subTensorInsertOp.getMixedSizes(), subTensorInsertOp.getMixedStrides());
+ b.create<CopyOp>(subTensorInsertOp.getLoc(), srcMemref, subView);
+ }
+
+ map(bvm, subTensorInsertOp.result(), dstMemref);
+
+ return success();
+}
+
+static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
+ BlockAndValueMapping &bvm) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@@ -730,7 +800,7 @@ static LogicalResult convertTransferOp(OpBuilder &b,
LLVM_DEBUG(DBGS() << "convert: " << *op << "\n");
- /// transfer_read from buffer
+ /// transfer_read from buffer always reads from the bufferized op.source().
if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
readOp.sourceMutable().assign(lookup(bvm, op.source()));
return success();
@@ -778,8 +848,8 @@ static LogicalResult bufferizeFuncOpInternals(
FuncOp funcOp, BlockAndValueMapping &bvm,
const DenseMap<FuncOp, SmallVector<int64_t>> &tiedResultsMap) {
OpBuilder b(funcOp->getContext());
- /// Start by converting `funcOp` arguments.
- if (failed(convertFuncOp(b, funcOp, bvm)))
+ /// Start by bufferizing `funcOp` arguments.
+ if (failed(bufferize(b, funcOp, bvm)))
return failure();
WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
LogicalResult status =
@@ -787,12 +857,9 @@ static LogicalResult bufferizeFuncOpInternals(
// Skip BufferCast and TensorLoad ops.
.Case<memref::BufferCastOp, memref::TensorLoadOp>(
[&](auto) { return success(); })
- .Case([&](memref::DimOp op) { return convertDimOp(b, op, bvm); })
- .Case([&](LinalgOp op) { return convertAnyLinalgOp(b, op, bvm); })
- .Case([&](ReturnOp op) { return convertReturnOp(b, op, bvm); })
- .Case([&](VectorTransferOpInterface op) {
- return convertTransferOp(b, op, bvm);
- })
+ .Case<memref::DimOp, LinalgOp, ReturnOp, SubTensorInsertOp,
+ VectorTransferOpInterface>(
+ [&](auto op) { return bufferize(b, op, bvm); })
.Default([&](Operation *op) {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
diff --git a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
index a5636329fd3a0..19599b99866bf 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-func-bufferize.mlir
@@ -96,14 +96,13 @@ func @vec_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<
// -----
// CHECK-LABEL: func @vec_not_inplace
-// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: tensor<?xf32> {linalg.inplaceable = true}
func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vector<4xf32>)
-> (tensor<?xf32>, tensor<?xf32>)
{
%c0 = constant 0 : index
%c1 = constant 1 : index
- // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast %[[A]] : memref<?xf32, #[[$map_2d_dyn]]>
+ // CHECK: %[[BUFFER_CAST:.*]] = memref.buffer_cast {{.*}} : memref<?xf32, #[[$map_2d_dyn]]>
/// Cross-op multiple uses of %A, the first vector.transfer which has interfering reads must alloc.
// CHECK: %[[ALLOC:.*]] = memref.alloc
@@ -117,3 +116,105 @@ func @vec_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %vec : vec
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
}
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun
+func @subtensor_insert_fun(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+ -> tensor<?xf32>
+{
+ // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+ // CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+ // CHECK-NOT: alloc
+ // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]]
+ // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]])
+ %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r0: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun
+func @subtensor_insert_fun(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+ -> tensor<?xf32>
+{
+ %f0 = constant 0.0 : f32
+
+ // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+ // CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+ // CHECK-NOT: alloc
+ // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]]
+ // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]])
+ %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+ /// Overwrite BUFFER_CAST_A inplace.
+ // CHECK: linalg.fill(%[[BUFFER_CAST_A]]
+ %r1 = linalg.fill(%r0, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+ return %r1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun
+func @subtensor_insert_fun(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+ -> tensor<?xf32>
+{
+ %f0 = constant 0.0 : f32
+
+ // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+ // CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+ // CHECK: linalg.fill(%[[BUFFER_CAST_A]]
+ %r0 = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+
+ // CHECK-NOT: alloc
+ // CHECK: %[[SV:.*]] = memref.subview %[[BUFFER_CAST_A]]
+ /// Overwrite BUFFER_CAST_A inplace by copying into the subview.
+ // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]])
+ %r1 = subtensor_insert %t into %r0[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+ return %r1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun_not_inplace
+func @subtensor_insert_fun_not_inplace(%A : tensor<?xf32>, %t : tensor<4xf32>)
+ -> tensor<?xf32>
+{
+ // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+ // CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref<?xf32>
+ // CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC]]) : memref<?xf32{{.*}}, memref<?xf32>
+ // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
+ // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32>
+ // CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
+ %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+ return %r0: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @subtensor_insert_fun_not_inplace
+func @subtensor_insert_fun_not_inplace(%A : tensor<?xf32> {linalg.inplaceable = true}, %t : tensor<4xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>)
+{
+ %f0 = constant 0.0 : f32
+
+ // CHECK: %[[BUFFER_CAST_A:.*]] = memref.buffer_cast {{.*}} : memref<?xf32
+ // CHECK: %[[BUFFER_CAST_B:.*]] = memref.buffer_cast {{.*}} : memref<4xf32
+
+ // CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref<?xf32>
+ // CHECK: linalg.copy(%[[BUFFER_CAST_A]], %[[ALLOC]]) : memref<?xf32{{.*}}, memref<?xf32>
+ // CHECK: %[[SV:.*]] = memref.subview %[[ALLOC]][0] [4] [1] : memref<?xf32> to memref<4xf32>
+ // CHECK: linalg.copy(%[[BUFFER_CAST_B]], %[[SV]]) : memref<4xf32, #map>, memref<4xf32>
+ %r0 = subtensor_insert %t into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
+
+ // TODO: WAW optimization where result is overwritten without being read.
+ // CHECK: linalg.fill(%[[BUFFER_CAST_A]]
+ // CHECK: memref.dealloc %[[ALLOC]] : memref<?xf32>
+ %r1 = linalg.fill(%A, %f0) : tensor<?xf32>, f32 -> tensor<?xf32>
+ return %r0, %r1: tensor<?xf32>, tensor<?xf32>
+}
More information about the Mlir-commits
mailing list