[Mlir-commits] [mlir] 542a8cf - [mlir][linalg][bufferize] Fix insertion point of result buffers
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 15 02:33:22 PST 2021
Author: Matthias Springer
Date: 2021-11-15T19:27:33+09:00
New Revision: 542a8cfba7fb5cc62aec4442b7d2f13d72da37fb
URL: https://github.com/llvm/llvm-project/commit/542a8cfba7fb5cc62aec4442b7d2f13d72da37fb
DIFF: https://github.com/llvm/llvm-project/commit/542a8cfba7fb5cc62aec4442b7d2f13d72da37fb.diff
LOG: [mlir][linalg][bufferize] Fix insertion point of result buffers
Differential Revision: https://reviews.llvm.org/D113723
Added:
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index bada67c7c1b79..0778f1dbd3b6b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -350,9 +350,16 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
aliasingOperands.size() == 1 &&
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
+ // Move insertion point right after `operandBuffer`. That is where the
+ // allocation should be inserted (in the absence of allocation hoisting).
+ if (auto bbArg = operandBuffer.dyn_cast<BlockArgument>()) {
+ b.setInsertionPointToStart(bbArg.getOwner());
+ } else {
+ b.setInsertionPointAfter(operandBuffer.getDefiningOp());
+ }
// Allocate the result buffer.
Value resultBuffer =
- state.allocationFns.createAllocDeallocFn(b, loc, operand, state);
+ state.allocationFns.createAllocDeallocFn(b, loc, operandBuffer, state);
bool skipCopy = false;
// Do not copy if the last preceding write of `operand` is an op that does
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -372,7 +379,7 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
!bufferizesToMemoryRead(*opOperand))
skipCopy = true;
if (!skipCopy) {
- // Set insertion point now that potential alloc/dealloc are introduced.
+ // The copy happens right before the op that is bufferized.
b.setInsertionPoint(op);
state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index c447e0d33fcea..f0024214909b3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -756,53 +756,68 @@ static FunctionType getOrCreateBufferizedFunctionType(
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
-/// Helper function that creates a memref::DimOp or tensor::DimOp depending on
-/// the type of `source`.
-static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
- int64_t dim) {
- if (source.getType().isa<UnrankedMemRefType, MemRefType>())
- return b.createOrFold<memref::DimOp>(loc, source, dim);
- if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
- return b.createOrFold<tensor::DimOp>(loc, source, dim);
- llvm_unreachable("Expected MemRefType or TensorType");
+/// Move the insertion point of the given builder to the beginning of a
+/// surrounding block as much as possible, while not crossing any allocation
+/// hoisting barriers.
+static void moveInsertionPointToAllocationHoistingBarrier(OpBuilder &b) {
+ Operation *op = b.getInsertionBlock()->getParentOp();
+ while (op) {
+ if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+ if (bufferizableOp.isAllocationHoistingBarrier())
+ break;
+ op = op->getParentOp();
+ }
+
+ // FuncOp is an allocation hoisting barrier, so the above loop should never
+ // run out of parents.
+ assert(
+ (op && cast<BufferizableOpInterface>(op).isAllocationHoistingBarrier()) &&
+ "expected traversal to end at allocation hoisting barrier");
+
+ // TODO: Handle cases where allocation hoisting barrier has more than one
+ // region or block.
+ assert(op->getNumRegions() == 1 &&
+ "allocation hoisting barriers with >1 regions not supported");
+ assert(op->getRegion(0).getBlocks().size() == 1 &&
+ "allocation hoisting barriers with >1 blocks not supported");
+ b.setInsertionPointToStart(&(op->getRegion(0).front()));
}
/// Compute the type of the `memref` to use for allocating the buffer for
/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
-/// dynamic dimensions in the returned `memref` type. The function also sets the
-/// insertion point of the builder `b` to the position where the allocation is
-/// to be inserted.
+/// dynamic dimensions in the returned `memref` type. The function may also set
+/// the insertion point to an earlier location, where the allocation should
+/// happen ("allocation hoisting").
static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
Value shapedValue,
SmallVectorImpl<Value> &dynShape) {
MemRefType allocMemRefType =
getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
- if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
- b.setInsertionPointToStart(bbArg.getOwner());
- loc = bbArg.getOwner()->getParentOp()->getLoc();
- } else {
- b.setInsertionPoint(shapedValue.getDefiningOp());
- loc = shapedValue.getDefiningOp()->getLoc();
- }
// Compute the dynamic part of the shape.
- bool foundDynamicShapes = false;
+ bool reifiedShapes = false;
if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
shapedValue.getDefiningOp())) {
ReifiedRankedShapedTypeDims resultDims;
if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
- foundDynamicShapes = true;
+ reifiedShapes = true;
OpResult resultValue = shapedValue.dyn_cast<OpResult>();
auto &shape = resultDims[resultValue.getResultNumber()];
for (auto dim : enumerate(allocMemRefType.getShape()))
- if (dim.value() == ShapedType::kDynamicSize)
+ if (ShapedType::isDynamic(dim.value()))
dynShape.push_back(shape[dim.index()]);
}
}
- if (!foundDynamicShapes) {
+
+ if (!reifiedShapes) {
for (auto dim : enumerate(allocMemRefType.getShape()))
- if (dim.value() == ShapedType::kDynamicSize)
- dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
+ if (ShapedType::isDynamic(dim.value())) {
+ assert((shapedValue.getType().isa<UnrankedMemRefType>() ||
+ shapedValue.getType().isa<MemRefType>()) &&
+ "expected MemRef type");
+ dynShape.push_back(
+ b.create<memref::DimOp>(loc, shapedValue, dim.index()));
+ }
}
// If the buffer is statically shaped, try to hoist it to the first enclosing
@@ -811,28 +826,9 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
// calls to LICM and buffer hoisting which will most likely not succeed.
// TODO: when packing, allocate a static bounding box which will enable more
// hoisting.
- if (dynShape.empty()) {
- Operation *parent;
- if (auto bbArg = shapedValue.dyn_cast<BlockArgument>())
- parent = bbArg.getOwner()->getParentOp();
- else
- parent = shapedValue.getDefiningOp()->getParentOp();
- while (parent) {
- if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(parent))
- if (bufferizableOp.isAllocationHoistingBarrier())
- break;
- parent = parent->getParentOp();
- }
+ if (dynShape.empty())
+ moveInsertionPointToAllocationHoistingBarrier(b);
- // FuncOp is an allocation hoisting barrier, so the above loop should never
- // run out of parents.
- assert(
- (parent &&
- cast<BufferizableOpInterface>(parent).isAllocationHoistingBarrier()) &&
- "expected traversal to end at allocation hoisting barrier");
-
- b.setInsertionPointToStart(&(parent->getRegion(0).front()));
- }
return allocMemRefType;
}
@@ -2247,6 +2243,7 @@ struct ExtractSliceOpInterface
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(extractSliceOp);
LDBG("bufferize: " << *extractSliceOp << '\n');
@@ -2263,9 +2260,6 @@ struct ExtractSliceOpInterface
alloc = createNewAllocDeallocPairForShapedValue(
b, loc, extractSliceOp.result(), state);
- // Set insertion point now that potential alloc/dealloc are introduced.
- b.setInsertionPoint(extractSliceOp);
-
// Bufferize to subview.
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
index dcd409da32d06..583c58ed12fd9 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -168,9 +168,9 @@ func @insert_slice_fun(%A0 : tensor<?xf32>,
-> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>)
{
// Hoisted allocs.
- // CHECK: %[[REALLOC_A1:.*]] = memref.alloc
// CHECK: %[[REALLOC_A0_2:.*]] = memref.alloc
// CHECK: %[[REALLOC_A0:.*]] = memref.alloc
+ // CHECK: %[[REALLOC_A1:.*]] = memref.alloc
// Alloc and copy the whole result tensor. Copy the tensor.extract_slice.
// CHECK: linalg.copy(%[[A0]], %[[REALLOC_A0]]
More information about the Mlir-commits
mailing list