[Mlir-commits] [mlir] c86f218 - [mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 25 08:50:39 PDT 2021
Author: MaheshRavishankar
Date: 2021-10-25T08:50:25-07:00
New Revision: c86f218fe4ca661a4348d20b66210324224870e8
URL: https://github.com/llvm/llvm-project/commit/c86f218fe4ca661a4348d20b66210324224870e8
DIFF: https://github.com/llvm/llvm-project/commit/c86f218fe4ca661a4348d20b66210324224870e8.diff
LOG: [mlir][Linalg] Allow comprehensive bufferization to use callbacks for alloc/dealloc.
Using callbacks for allocation/deallocation allows users to override
the default.
Also add an option to comprehensive bufferization pass to use `alloca`
instead of `alloc`s. Note that this option is just for testing. The
option to use `alloca` does not work well with the option to allow for
returning memrefs.
Differential Revision: https://reviews.llvm.org/D112166
Added:
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
Modified:
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d0744a891328..d75437ed6077 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -40,7 +40,10 @@ def LinalgComprehensiveModuleBufferize :
"Only runs inplaceability analysis (for testing purposes only)">,
Option<"allowReturnMemref", "allow-return-memref", "bool",
/*default=*/"false",
- "Allows the return of memrefs (for testing purposes only)">
+ "Allows the return of memrefs (for testing purposes only)">,
+ Option<"useAlloca", "use-alloca", "bool",
+ /*default=*/"false",
+ "Use stack allocations for memrefs (for testing purposes only)">
];
let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
}
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index 5226ab394cd7..da5505504999 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -175,14 +175,36 @@ LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
const DominanceInfo &domInfo);
+/// Default allocation function that is used by the comprehensive bufferization
+/// pass. The default currently creates a ranked memref using `memref.alloc`.
+Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
+ Value shapedValue);
+
+/// Default deallocation function that is used by the comprehensive
+/// bufferization pass. It expects to recieve back the value called from the
+/// `defaultAllocationFn`.
+void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
+
+/// Callback functions that are used by the comprehensive bufferization pass to
+/// allocate/deallocate memory. These default to use the
+/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
+/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
+/// by the `allocationFn`.
+struct AllocationCallbacks {
+ std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
+ allocationFn = defaultAllocationFn;
+ std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
+ defaultDeallocationFn;
+};
+
/// Bufferize one particular op.
/// `bufferizedFunctionTypes` (resp. `globalCreator`) are expected to be
/// non-null if `op` is a CallOpInterface (resp. GlobalCreator).
LogicalResult
bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
- DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr,
- GlobalCreator *globalCreator = nullptr);
+ AllocationCallbacks allocationFns,
+ DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes = nullptr);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4b1de5f2b471..840978bf5648 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -33,6 +33,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRAnalysis
MLIRArithmetic
MLIRComplex
+ MLIRInferTypeOpInterface
MLIRIR
MLIRMemRef
MLIRLinalgAnalysis
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 9e970f3e5e14..ac01bd80ccc8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -118,12 +118,12 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/BufferUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
-
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetVector.h"
@@ -983,7 +983,6 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const DominanceInfo &domInfo) const {
-
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
@@ -1415,66 +1414,27 @@ Operation *getFirstParentOfType(Value v) {
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
-static Value
-createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
- Value shapedValue,
- BufferizationAliasInfo &aliasInfo) {
+static Value createNewAllocDeallocPairForShapedValue(
+ OpBuilder &b, Location loc, Value shapedValue,
+ BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
- // TODO: non-zero address space.
- // TODO: layout information if relevant.
- // Cannot allocate an unranked memref so just always go for the contiguous
- // form.
- MemRefType allocMemRefType =
- getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
- memRefType = memRefType ? memRefType : allocMemRefType;
-
- if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
- b.setInsertionPointToStart(bbArg.getOwner());
- loc = bbArg.getOwner()->getParentOp()->getLoc();
- } else {
- b.setInsertionPoint(shapedValue.getDefiningOp());
- loc = shapedValue.getDefiningOp()->getLoc();
- }
-
- // Compute the dynamic part of the shape.
- SmallVector<Value> dynShape;
- for (auto dim : enumerate(memRefType.getShape()))
- if (dim.value() == ShapedType::kDynamicSize)
- dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
- // If the buffer is statically shaped, try to hoist it to the first enclosing
- // parallel region.
- // TODO: this concept of parallel region and threadlocal needs interfaces.
- // TODO: also hoist in the dynamic case. For now this relies on subsequent
- // calls to LICM and buffer hoisting which will most likely not succeed.
- // TODO: when packing, allocate a static bounding box which will enable more
- // hoisting.
- Value allocated;
- { // Guarded insertion point to potentially hoist the AllocOp.
- OpBuilder::InsertionGuard g(b);
- if (dynShape.empty()) {
- Operation *parent =
- getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
- AffineParallelOp>(shapedValue);
- if (parent)
- b.setInsertionPointToStart(&(parent->getRegion(0).front()));
- }
- allocated = b.create<memref::AllocOp>(
- loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
- aliasInfo.createAliasInfoEntry(allocated);
+ Optional<Value> allocated = allocationFns.allocationFn(b, loc, shapedValue);
+ // TODO: For now just assert the value is returned. Eventually need to
+ // error-propagate.
+ assert(allocated && "allocation failed");
+ Value casted = allocated.getValue();
+ MemRefType allocMemRefType = allocated->getType().cast<MemRefType>();
+ if (memRefType && memRefType != allocMemRefType) {
+ casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
+ aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
- Value casted = allocated;
- if (memRefType != allocMemRefType) {
- casted = b.create<memref::CastOp>(loc, memRefType, allocated);
- aliasInfo.insertNewBufferEquivalence(casted, allocated);
- }
- b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
- b.create<memref::DeallocOp>(loc, allocated);
+ allocationFns.deallocationFn(b, loc, allocated.getValue());
return casted;
}
@@ -1488,6 +1448,7 @@ createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
static Value getResultBuffer(OpBuilder &b, OpResult result,
const BlockAndValueMapping &bvm,
BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks allocationFns,
bool skipCopy = false) {
OpBuilder::InsertionGuard guard(b);
Operation *op = result.getOwner();
@@ -1515,8 +1476,8 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
"ops with multiple aliasing OpOperands cannot bufferize out-of-place");
Location loc = op->getLoc();
// Allocate the result buffer.
- Value resultBuffer =
- createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
+ Value resultBuffer = createNewAllocDeallocPairForShapedValue(
+ b, loc, operand, aliasInfo, allocationFns);
// Do not copy the result of an InitTensorOp.
if (isInitTensorOp(operand))
skipCopy = true;
@@ -1538,11 +1499,10 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
/// Helper function for LinalgOp bufferization.
/// When allocating a new buffer, analyze whether `op` wants to read form that
/// buffer. Only in that case, a copy of the result buffer may be needed.
-static LogicalResult
-allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
- SmallVectorImpl<Value> &resultBuffers,
- BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+static LogicalResult allocateBuffersForResults(
+ OpBuilder &b, Location loc, LinalgOp op,
+ SmallVectorImpl<Value> &resultBuffers, BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@@ -1553,7 +1513,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
OpResult opResult = getInplaceableOpResult(*opOperand);
assert(opResult && "could not find correspond OpResult");
bool skipCopy = !op.payloadUsesValueFromOperand(opOperand);
- Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo, skipCopy);
+ Value resultBuffer =
+ getResultBuffer(b, opResult, bvm, aliasInfo, allocationFns, skipCopy);
if (!resultBuffer)
return failure();
resultBuffers.push_back(resultBuffer);
@@ -1568,7 +1529,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFns) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -1591,7 +1553,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
SmallVector<Value> newOutputBuffers;
// Try to allocate new buffers depending on op's inplace semantics.
if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
- aliasInfo)))
+ aliasInfo, allocationFns)))
return failure();
// Clone the newly bufferized op.
@@ -1616,7 +1578,7 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
/// to allow FuncOp that are inplaceable to write inPlace.
static LogicalResult
bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
+ BufferizationAliasInfo &aliasInfo, AllocationCallbacks &allocationFns,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
@@ -1755,12 +1717,14 @@ bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
/// tensor::CastOp bufferizes to memref::CastOp.
static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(castOp);
- Value resultBuffer = getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo);
+ Value resultBuffer =
+ getResultBuffer(b, castOp->getResult(0), bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
Type sourceType = resultBuffer.getType();
@@ -1786,10 +1750,15 @@ static LogicalResult bufferize(OpBuilder &b, tensor::CastOp castOp,
static LogicalResult bufferize(OpBuilder &b, arith::ConstantOp constantOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo,
- GlobalCreator &globalCreator) {
+ BufferizationAliasInfo &aliasInfo) {
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
"not a constant ranked tensor");
+ auto moduleOp = constantOp->getParentOfType<ModuleOp>();
+ if (!moduleOp) {
+ return constantOp.emitError(
+ "cannot bufferize constants not within builtin.module op");
+ }
+ GlobalCreator globalCreator(moduleOp);
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -1824,7 +1793,8 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -1837,7 +1807,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
"unsupported unranked tensor");
// TODO: More general: Matching bbArg does not bufferize to a read.
- Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+ Value resultBuffer =
+ getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
@@ -1880,7 +1851,8 @@ static LogicalResult bufferize(OpBuilder &b, scf::IfOp ifOp,
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
@@ -1906,7 +1878,8 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
/// TODO: consider hoisting across function boundaries prior to bufferization.
static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// The InitTensorOp may have been eliminated.
if (initTensorOp->getUses().empty())
return success();
@@ -1916,7 +1889,8 @@ static LogicalResult bufferize(OpBuilder &b, InitTensorOp initTensorOp,
b.setInsertionPoint(initTensorOp);
Value alloc = createNewAllocDeallocPairForShapedValue(
- b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo);
+ b, initTensorOp->getLoc(), initTensorOp.result(), aliasInfo,
+ allocationFn);
map(bvm, initTensorOp.result(), alloc);
return success();
}
@@ -1949,7 +1923,8 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
/// Bufferization for TiledLoopOp..
static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -1989,7 +1964,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
const OpResult &opResult = tiledLoopOp->getResult(resultIndex);
OpOperand &yieldOperand = yieldOp->getOpOperand(resultIndex);
- Value resultBuffer = getResultBuffer(b, opResult, bvm, aliasInfo);
+ Value resultBuffer =
+ getResultBuffer(b, opResult, bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
@@ -2073,7 +2049,8 @@ static LogicalResult bufferize(OpBuilder &b, TiledLoopOp tiledLoopOp,
/// isolation.
static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -2093,7 +2070,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
auto inPlace = getInPlace(extractSliceOp->getResult(0));
if (inPlace != InPlaceSpec::True)
alloc = createNewAllocDeallocPairForShapedValue(
- b, loc, extractSliceOp.result(), aliasInfo);
+ b, loc, extractSliceOp.result(), aliasInfo, allocationFn);
// Set insertion point now that potential alloc/dealloc are introduced.
b.setInsertionPoint(extractSliceOp);
@@ -2125,7 +2102,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(insertSliceOp);
@@ -2140,8 +2118,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
// TODO: be very loud about it or even consider failing the pass.
// Alloc a copy for `insertSliceOp.dest()`, it will become the result
// buffer.
- Value dstMemref =
- getResultBuffer(b, insertSliceOp->getResult(0), bvm, aliasInfo);
+ Value dstMemref = getResultBuffer(b, insertSliceOp->getResult(0), bvm,
+ aliasInfo, allocationFn);
if (!dstMemref)
return failure();
auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
@@ -2184,7 +2162,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
BlockAndValueMapping &bvm,
- BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo,
+ AllocationCallbacks &allocationFn) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@@ -2205,7 +2184,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
auto writeOp = cast<vector::TransferWriteOp>(op.getOperation());
- Value resultBuffer = getResultBuffer(b, op->getResult(0), bvm, aliasInfo);
+ Value resultBuffer =
+ getResultBuffer(b, op->getResult(0), bvm, aliasInfo, allocationFn);
if (!resultBuffer)
return failure();
b.create<vector::TransferWriteOp>(
@@ -2436,18 +2416,107 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
// Bufferization entry-point for functions.
//===----------------------------------------------------------------------===//
+/// Compute the type of the `memref` to use for allocating the buffer for
+/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
+/// dynamic dimensions in the returned `memref` type. The function also sets the
+/// insertion point of the builder `b` to the position where the allocation is
+/// to be inserted.
+static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
+ Value shapedValue,
+ SmallVectorImpl<Value> &dynShape) {
+ MemRefType allocMemRefType =
+ getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
+ if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
+ b.setInsertionPointToStart(bbArg.getOwner());
+ loc = bbArg.getOwner()->getParentOp()->getLoc();
+ } else {
+ b.setInsertionPoint(shapedValue.getDefiningOp());
+ loc = shapedValue.getDefiningOp()->getLoc();
+ }
+
+ // Compute the dynamic part of the shape.
+ bool foundDynamicShapes = false;
+ if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
+ shapedValue.getDefiningOp())) {
+ ReifiedRankedShapedTypeDims resultDims;
+ if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
+ foundDynamicShapes = true;
+ OpResult resultValue = shapedValue.dyn_cast<OpResult>();
+ auto &shape = resultDims[resultValue.getResultNumber()];
+ for (auto dim : enumerate(allocMemRefType.getShape()))
+ if (dim.value() == ShapedType::kDynamicSize)
+ dynShape.push_back(shape[dim.index()]);
+ }
+ }
+ if (!foundDynamicShapes) {
+ for (auto dim : enumerate(allocMemRefType.getShape()))
+ if (dim.value() == ShapedType::kDynamicSize)
+ dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
+ }
+
+ // If the buffer is statically shaped, try to hoist it to the first enclosing
+ // parallel region.
+ // TODO: this concept of parallel region and threadlocal needs interfaces.
+ // TODO: also hoist in the dynamic case. For now this relies on subsequent
+ // calls to LICM and buffer hoisting which will most likely not succeed.
+ // TODO: when packing, allocate a static bounding box which will enable more
+ // hoisting.
+ if (dynShape.empty()) {
+ Operation *parent =
+ getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
+ AffineParallelOp>(shapedValue);
+ if (parent)
+ b.setInsertionPointToStart(&(parent->getRegion(0).front()));
+ }
+ return allocMemRefType;
+}
+
+Optional<Value> mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc,
+ Value shapedValue) {
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ SmallVector<Value> dynShape;
+ MemRefType allocMemRefType =
+ getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+ Value allocated = b.create<memref::AllocOp>(
+ loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+ return allocated;
+}
+
+static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
+ Value shapedValue) {
+ OpBuilder::InsertionGuard g(b);
+ SmallVector<Value> dynShape;
+ MemRefType allocMemRefType =
+ getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+ Value allocated = b.create<memref::AllocaOp>(
+ loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+ return allocated;
+}
+
+void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
+ Value allocatedBuffer) {
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator());
+ b.create<memref::DeallocOp>(loc, allocatedBuffer);
+}
+
LogicalResult mlir::linalg::bufferizeOp(
Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
- DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes,
- GlobalCreator *globalCreator) {
+ AllocationCallbacks allocationFns,
+ DenseMap<FuncOp, FunctionType> *bufferizedFunctionTypes) {
OpBuilder b(op->getContext());
return TypeSwitch<Operation *, LogicalResult>(op)
// Skip BufferCast and TensorLoad ops.
.Case<memref::BufferCastOp, memref::TensorLoadOp>(
[&](auto) { return success(); })
- .Case<tensor::CastOp, tensor::DimOp, ExtractSliceOp, scf::ForOp,
- InitTensorOp, InsertSliceOp, tensor::ExtractOp, LinalgOp, ReturnOp,
- TiledLoopOp, VectorTransferOpInterface, linalg::YieldOp,
+ .Case<ExtractSliceOp, InitTensorOp, InsertSliceOp, LinalgOp, scf::ForOp,
+ tensor::CastOp, TiledLoopOp, VectorTransferOpInterface>(
+ [&](auto op) {
+ LDBG("Begin bufferize:\n" << op << '\n');
+ return bufferize(b, op, bvm, aliasInfo, allocationFns);
+ })
+ .Case<tensor::DimOp, tensor::ExtractOp, ReturnOp, linalg::YieldOp,
scf::YieldOp>([&](auto op) {
LDBG("Begin bufferize:\n" << op << '\n');
return bufferize(b, op, bvm, aliasInfo);
@@ -2464,15 +2533,14 @@ LogicalResult mlir::linalg::bufferizeOp(
if (!bufferizedFunctionTypes)
llvm_unreachable(
"null bufferizedFunctionTypes when bufferizing CallOpInterface");
- return bufferize(b, op, bvm, aliasInfo, *bufferizedFunctionTypes);
+ return bufferize(b, op, bvm, aliasInfo, allocationFns,
+ *bufferizedFunctionTypes);
})
.Case([&](arith::ConstantOp op) {
if (!isaTensor(op.getResult().getType()))
return success();
LDBG("Begin bufferize:\n" << op << '\n');
- if (!globalCreator)
- llvm_unreachable("null globalCreator when bufferizing ConstantOp");
- return bufferize(b, op, bvm, aliasInfo, *globalCreator);
+ return bufferize(b, op, bvm, aliasInfo);
})
.Default([&](Operation *op) -> LogicalResult {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -2485,15 +2553,13 @@ LogicalResult mlir::linalg::bufferizeOp(
static LogicalResult bufferizeFuncOpInternals(
FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
- DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes,
- GlobalCreator &globalCreator) {
-
+ AllocationCallbacks &allocationFns,
+ DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
OpBuilder b(funcOp->getContext());
-
- // Start by bufferizing `funcOp` arguments.
- if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
+ /// Start by bufferizing `funcOp` arguments.
+ if (failed(bufferize(b, funcOp, bvm, aliasInfo, allocationFns)))
return failure();
// Cannot erase ops during the traversal. Do that afterwards.
@@ -2516,13 +2582,13 @@ static LogicalResult bufferizeFuncOpInternals(
}
for (Operation *op : llvm::reverse(preorderBufferize))
- if (failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
- &globalCreator)))
+ if (failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
+ &bufferizedFunctionTypes)))
return failure();
if (!bufferizedOps.contains(op) &&
- failed(bufferizeOp(op, bvm, aliasInfo, &bufferizedFunctionTypes,
- &globalCreator)))
+ failed(bufferizeOp(op, bvm, aliasInfo, allocationFns,
+ &bufferizedFunctionTypes)))
return failure();
// Register post-walk erasure, if necessary.
@@ -2793,12 +2859,19 @@ namespace {
struct LinalgComprehensiveModuleBufferize
: public LinalgComprehensiveModuleBufferizeBase<
LinalgComprehensiveModuleBufferize> {
+ LinalgComprehensiveModuleBufferize() {}
+
+ LinalgComprehensiveModuleBufferize(
+ const LinalgComprehensiveModuleBufferize &p) {}
void runOnOperation() override;
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<linalg::LinalgDialect, memref::MemRefDialect>();
}
+
+private:
+ std::unique_ptr<AllocationCallbacks> allocationFns;
};
} // end namespace
@@ -2983,6 +3056,22 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
}
void LinalgComprehensiveModuleBufferize::runOnOperation() {
+ if (!allocationFns) {
+ // The allocation functions to use needs to be set here. The flag for the
+ // pass and flag for the use of alloca map to LLVM command line
+ // options. These being static global objects have no set order in which
+ // they are defined. So ideally this should be in the constructor, but the
+ // constructor might be called before the flag is initialized using the
+ // command line option. So this is set up at the start of the pass.
+ if (useAlloca) {
+ AllocationCallbacks allocaAllocationFns = {
+ allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
+ allocationFns =
+ std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
+ } else {
+ allocationFns = std::make_unique<AllocationCallbacks>();
+ }
+ }
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);
@@ -2992,7 +3081,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return signalPassFailure();
- GlobalCreator globalCreator(moduleOp);
DominanceInfo domInfo(moduleOp);
BufferizationAliasInfo aliasInfo(moduleOp);
// Interestingly, all function args that are not visible outside of a module
@@ -3032,8 +3120,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
if (!testAnalysisOnly) {
BlockAndValueMapping tensorToBufferMap;
if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
- bufferizedFunctionTypes,
- globalCreator))) {
+ *allocationFns,
+ bufferizedFunctionTypes))) {
signalPassFailure();
return;
}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
new file mode 100644
index 000000000000..71d631c85e0d
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-alloca.mlir
@@ -0,0 +1,65 @@
+// RUN: mlir-opt %s -pass-pipeline="linalg-comprehensive-module-bufferize{allow-return-memref use-alloca}" -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$DYN_0D_MAP:.*]] = affine_map<()[s0] -> (s0)>
+// CHECK-DAG: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK: func @init_and_dot(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<64xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<f32, #[[$DYN_0D_MAP]]>
+func @init_and_dot(%a: tensor<64xf32>, %b: tensor<64xf32>, %c: tensor<f32>) -> tensor<f32> {
+ // CHECK-NEXT: %[[C0:.*]] = arith.constant 0{{.*}} : f32
+ %v0 = arith.constant 0.0 : f32
+
+ // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32, #[[$DYN_0D_MAP]]>
+ %d = linalg.fill(%v0, %c) : f32, tensor<f32> -> tensor<f32>
+
+ // CHECK-NEXT: linalg.dot ins(%[[A]], %[[B]] : memref<64xf32, #[[$DYN_1D_MAP]]>, memref<64xf32, #[[$DYN_1D_MAP]]>) outs(%[[C]] : memref<f32, #[[$DYN_0D_MAP]]>)
+ %e = linalg.dot ins(%a, %b : tensor<64xf32>,tensor<64xf32>)
+ outs(%d: tensor<f32>) -> tensor<f32>
+
+ // CHECK-NEXT: return
+ return %e : tensor<f32>
+}
+
+// CHECK: func @main()
+func @main() {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0{{.*}} : f32
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1{{.*}} : f32
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2{{.*}} : f32
+ %v0 = arith.constant 0.0 : f32
+ %v1 = arith.constant 1.0 : f32
+ %v2 = arith.constant 2.0 : f32
+
+ // CHECK-NEXT: %[[C:.*]] = memref.alloca() {alignment = 128 : i64} : memref<f32>
+ // CHECK-NEXT: %[[B:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
+ // CHECK-NEXT: %[[A:.*]] = memref.alloca() {alignment = 128 : i64} : memref<64xf32>
+ %A = linalg.init_tensor [64] : tensor<64xf32>
+ %B = linalg.init_tensor [64] : tensor<64xf32>
+ %C = linalg.init_tensor [] : tensor<f32>
+
+ // CHECK-NEXT: linalg.fill(%[[C1]], %[[A]]) : f32, memref<64xf32>
+ // CHECK-NEXT: linalg.fill(%[[C2]], %[[B]]) : f32, memref<64xf32>
+ // CHECK-NEXT: linalg.fill(%[[C0]], %[[C]]) : f32, memref<f32>
+ %AA = linalg.fill(%v1, %A) : f32, tensor<64xf32> -> tensor<64xf32>
+ %BB = linalg.fill(%v2, %B) : f32, tensor<64xf32> -> tensor<64xf32>
+ %CC = linalg.fill(%v0, %C) : f32, tensor<f32> -> tensor<f32>
+
+ // CHECK-NEXT: %[[cA:.*]] = memref.cast %[[A]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
+ // CHECK-NEXT: %[[cB:.*]] = memref.cast %[[B]] : memref<64xf32> to memref<64xf32, #[[$DYN_1D_MAP]]>
+ // CHECK-NEXT: %[[cC:.*]] = memref.cast %[[C]] : memref<f32> to memref<f32, #[[$DYN_0D_MAP]]>
+ // CHECK-NEXT: call @init_and_dot(%[[cA]], %[[cB]], %[[cC]])
+ %res = call @init_and_dot(%AA, %BB, %CC) :
+ (tensor<64xf32>, tensor<64xf32>, tensor<f32>) -> tensor<f32>
+
+ // CHECK-NEXT: %[[dC:.*]] = memref.cast %[[C]] : memref<f32> to memref<*xf32>
+ %res2 = tensor.cast %res: tensor<f32> to tensor<*xf32>
+
+ // CHECK-NEXT: call @print_memref_f32(%[[dC]]) : (memref<*xf32>) -> ()
+ call @print_memref_f32(%res2) : (tensor<*xf32>) -> ()
+
+ return
+}
+
+// CHECK: func private @print_memref_f32(memref<*xf32>)
+func private @print_memref_f32(tensor<*xf32>)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a0339919063f..39c258ade3fe 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6314,6 +6314,7 @@ cc_library(
":ComplexDialect",
":DialectUtils",
":IR",
+ ":InferTypeOpInterface",
":LinalgOps",
":LinalgPassIncGen",
":LinalgStructuredOpsIncGen",
More information about the Mlir-commits
mailing list