[Mlir-commits] [mlir] 73bea97 - [mlir][Linalg] Add support for CallOp bufferization (10/n)
Nicolas Vasilache
llvmlistbot at llvm.org
Thu Jul 1 04:13:23 PDT 2021
Author: Nicolas Vasilache
Date: 2021-07-01T10:33:12Z
New Revision: 73bea97a336ba2da276ef34fd21b2c5c676b0a97
URL: https://github.com/llvm/llvm-project/commit/73bea97a336ba2da276ef34fd21b2c5c676b0a97
DIFF: https://github.com/llvm/llvm-project/commit/73bea97a336ba2da276ef34fd21b2c5c676b0a97.diff
LOG: [mlir][Linalg] Add support for CallOp bufferization (10/n)
Cross function boundary bufferization support is added.
This is enabled by cross-function boundary alias analysis, for which the bufferization process is extended: it can now modify the BufferizationAliasInfo as new ops are introduced.
A number of simplifying assumptions are made:
1. by default we bufferize to the most dynamic strided memref type, further memref::CastOp canonicalizations are expected to clean up the IR.
2. in the current implementation, the stride information is always erased at function boundaries. A subsequent pass will be required to analyze the meet of all call ops to a function and decide whether more static buffer types can be used. This will potentially clone functions when it is deemed profitable to do so (e.g. when the stride-1 dimension may vary).
3. external function always bufferize to the most dynamic strided memref version. This may require special annotations for specifying that particular operands of top-level functions have contiguous buffer layout.
An alternative to point 3. would be to support tensor layout annotations, which is currently not supported in MLIR.
Differential revision: https://reviews.llvm.org/D104873
Added:
mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 14acc36fbf22..824092df292c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -114,7 +114,9 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/BufferUtils.h"
+#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/EquivalenceClasses.h"
@@ -136,6 +138,8 @@ using namespace tensor;
// Generic helpers.
//===----------------------------------------------------------------------===//
+static bool isaTensor(Type t) { return t.isa<TensorType>(); }
+
/// Return the FuncOp called by `callOp`.
static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
@@ -145,6 +149,20 @@ static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
+/// Return the unique ReturnOp that terminates `funcOp`.
+/// Return nullptr if there is no such unique ReturnOp.
+static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
+ ReturnOp returnOp;
+ for (Block &b : funcOp.body()) {
+ if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
+ if (returnOp)
+ return nullptr;
+ returnOp = candidateOp;
+ }
+ }
+ return returnOp;
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//
@@ -163,7 +181,7 @@ static void map(BlockAndValueMapping &bvm, Value key, Value value) {
}
/// Wrapper for better debugging.
-static Value lookup(BlockAndValueMapping &bvm, Value key) {
+static Value lookup(const BlockAndValueMapping &bvm, Value key) {
// TODO: if key comes from bbArg, forward.
assert(key.getType().isa<TensorType>());
Value v = bvm.lookupOrNull(key);
@@ -347,10 +365,8 @@ static bool hasKnownBufferizationAliasingBehavior(Operation *op) {
VectorTransferOpInterface,
scf::YieldOp>(op)
// clang-format on
- || (none_of(op->getResultTypes(),
- [](Type t) { return t.isa<TensorType>(); }) &&
- none_of(op->getOperandTypes(),
- [](Type t) { return t.isa<TensorType>(); }));
+ || (none_of(op->getResultTypes(), isaTensor) &&
+ none_of(op->getOperandTypes(), isaTensor));
}
/// Return the OpResult that may bufferize into the same buffer as `opOperand`
@@ -577,14 +593,22 @@ class BufferizationAliasInfo {
/// beginning the alias and equivalence sets only contain `v` itself.
void createAliasInfoEntry(Value v);
+ /// Insert an info entry for `newValue` and merge its alias set with that of
+ /// `alias`.
+ void insertNewBufferAlias(Value newValue, Value alias);
+
+ /// Insert an info entry for `newValue` and merge its alias set with that of
+ /// `alias`. Additionally, merge their equivalence classes.
+ void insertNewBufferEquivalence(Value newValue, Value alias);
+
/// Return true if the buffer to which `operand` would bufferize aliases a
/// buffer that is known to not be writeable. This implies that the matching
/// OpResult cannot be bufferized inplace.
bool aliasesNonWriteableBuffer(OpOperand &operand) const;
/// Return true if the buffer to which `operand` would bufferize is equivalent
- /// to some use that would bufferize to a write to a buffer.
- bool aliasesInPlaceWrite(ExtractSliceOp extractSliceOp) const;
+ /// to some buffer write.
+ bool aliasesInPlaceWrite(Value v) const;
/// Set the inPlace bufferization spec to true.
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
@@ -619,6 +643,9 @@ class BufferizationAliasInfo {
bool isSourceEquivalentToAMatchingExtractSliceOp(
InsertSliceOp insertSliceOp) const;
+ /// Apply `fun` to all the members of the equivalence class of `v`.
+ void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
+
/// Print to `os`.
void print(raw_ostream &os) const;
@@ -626,8 +653,9 @@ class BufferizationAliasInfo {
void dump() const { print(llvm::errs()); }
private:
- /// Check aliasInfo for `v` exists and return a reference to it.
+ /// Check that aliasInfo for `v` exists and return a reference to it.
DenseSet<Value> &getAliasInfoRef(Value v);
+
const DenseSet<Value> &getAliasInfoRef(Value v) const {
return const_cast<BufferizationAliasInfo *>(this)->getAliasInfoRef(v);
}
@@ -740,6 +768,23 @@ void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
equivalentInfo.insert(v);
}
+/// Insert an info entry for `newValue` and merge its alias set with that of
+/// `alias`.
+void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
+ assert(aliasInfo.find(alias) != aliasInfo.end() && "Missing alias entry");
+ createAliasInfoEntry(newValue);
+ mergeAliases(newValue, alias);
+ mergeAliasesToFixedPoint();
+}
+
+/// Insert an info entry for `newValue` and merge its alias set with that of
+/// `alias`. Additionally, merge their equivalence classes.
+void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
+ Value alias) {
+ insertNewBufferAlias(newValue, alias);
+ equivalentInfo.unionSets(newValue, alias);
+}
+
/// Return true if the buffer to which `operand` would bufferize aliases a
/// buffer that is known to not be writeable. This implies that the matching
/// OpResult cannot be bufferized inplace.
@@ -755,13 +800,13 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
LDBG("-----------bbArg is writeable -> skip: " << bbArg << '\n');
continue;
}
- LDBG("-----------notWriteable: " << v << '\n');
+ LDBG("-----------notWriteable\n");
return true;
}
if (Operation *op = v.getDefiningOp()) {
if (isa<ConstantOp>(op) || !hasKnownBufferizationAliasingBehavior(op)) {
- LDBG("-----------notWriteable: " << v << '\n');
+ LDBG("-----------notWriteable\n");
return true;
}
}
@@ -771,12 +816,11 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
}
/// Return true if the buffer to which `operand` would bufferize is equivalent
-/// to some use that would bufferize to a write to a buffer.
-bool BufferizationAliasInfo::aliasesInPlaceWrite(
- ExtractSliceOp extractSliceOp) const {
+/// to some buffer write.
+bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
LDBG("----Start aliasesInPlaceWrite\n");
- LDBG("-------for op: " << *extractSliceOp.getOperation() << '\n');
- for (Value v : getAliasInfoRef(extractSliceOp.result())) {
+ LDBG("-------for : " << value << '\n');
+ for (Value v : getAliasInfoRef(value)) {
for (auto &use : v.getUses()) {
if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) {
LDBG("-----------wants to bufferize to inPlace write: "
@@ -785,7 +829,7 @@ bool BufferizationAliasInfo::aliasesInPlaceWrite(
}
}
}
- LDBG("----------->extract_slice does not alias an inplace write");
+ LDBG("----------->does not alias an inplace write\n");
return false;
}
@@ -920,6 +964,16 @@ bool BufferizationAliasInfo::isSourceEquivalentToAMatchingExtractSliceOp(
return false;
}
+/// Apply `fun` to all the members of the equivalence class of `v`.
+void BufferizationAliasInfo::applyOnEquivalenceClass(
+ Value v, function_ref<void(Value)> fun) const {
+ for (auto it = equivalentInfo.findLeader(v),
+ eit = equivalentInfo.member_end();
+ it != eit; ++it) {
+ fun(v);
+ }
+}
+
void BufferizationAliasInfo::print(raw_ostream &os) const {
os << "\n/========================== AliasInfo "
"==========================\n";
@@ -1106,6 +1160,21 @@ bool BufferizationAliasInfo::isClobberedWriteBeforeRead(
return existsInterleavedValueClobber(aliasingRead, aliasingWrite, domInfo);
}
+//===----------------------------------------------------------------------===//
+// Forward declarations.
+//===----------------------------------------------------------------------===//
+
+/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such
+/// an op. Return null otherwise.
+static Operation *getEquivalentAlloc(Value value,
+ const BufferizationAliasInfo &aliasInfo);
+
+/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
+/// Return null if no such bbArg can be found.
+static BlockArgument
+getEquivalentEnclosingFuncBBArg(Value v,
+ const BufferizationAliasInfo &aliasInfo);
+
//===----------------------------------------------------------------------===//
// Bufferization-specific MemRefType support.
//===----------------------------------------------------------------------===//
@@ -1152,6 +1221,47 @@ static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
stridedLayout, addressSpace);
}
+/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
+/// tensor is replaced by the corresponding buffer type.
+/// In order for all the callers to agree, this *must* bufferize to the most
+/// dynamic buffer type supported.
+/// A later pass across all CallOps in the module can decide whether to simplify
+/// the types of to version according to some cost model.
+static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
+ TypeRange argumentTypes,
+ TypeRange resultTypes) {
+ auto rewrite = [](Type t) -> Type {
+ // TODO: non-zero address space.
+ // TODO: layout information if relevant.
+ if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
+ return getDynamicMemRefType(rankedTensorType);
+ if (auto tensorType = t.dyn_cast<TensorType>())
+ return getContiguousOrUnrankedMemRefType(tensorType);
+ return t;
+ };
+ auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
+ auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite));
+ return FunctionType::get(ctx, argTypes, retTypes);
+}
+
+/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
+/// it. Otherwise, construct a new entry based on `argumentTypes` and
+/// `resultTypes`.
+// TODO: improve the layering.
+static FunctionType getOrCreateBufferizedFunctionType(
+ FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
+ DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+ auto it = bufferizedFunctionTypes.find(funcOp);
+ if (it != bufferizedFunctionTypes.end())
+ return it->second;
+
+ auto it2 = bufferizedFunctionTypes.try_emplace(
+ funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
+ resultTypes));
+ LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n");
+ return it2.first->second;
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
@@ -1159,8 +1269,10 @@ static MemRefType getDynamicMemRefType(RankedTensorType tensorType,
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
-static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
- Value shapedValue) {
+static Value
+createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
+ Value shapedValue,
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -1189,9 +1301,12 @@ static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
Value allocated = b.create<memref::AllocOp>(loc, allocMemRefType, dynShape);
+ aliasInfo.createAliasInfoEntry(allocated);
Value casted = allocated;
- if (memRefType != allocMemRefType)
+ if (memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated);
+ aliasInfo.insertNewBufferEquivalence(casted, allocated);
+ }
b.setInsertionPoint(allocated.getParentBlock()->getTerminator());
b.create<memref::DeallocOp>(loc, allocated);
return casted;
@@ -1212,7 +1327,8 @@ static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc,
static LogicalResult
allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
SmallVectorImpl<Value> &resultBuffers,
- BlockAndValueMapping &bvm) {
+ BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -1236,7 +1352,8 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
// Otherwise, `op` is not inplaceable and we need to allocate its result.
Value dimTensor = bvm.lookupOrDefault(output);
- Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor);
+ Value alloc =
+ createNewAllocDeallocPairForShapedValue(b, loc, dimTensor, aliasInfo);
b.setInsertionPointAfter(alloc.getDefiningOp());
resultBuffers.push_back(alloc);
@@ -1258,7 +1375,7 @@ allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op,
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -1267,8 +1384,6 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
if (!op.hasTensorSemantics())
return failure();
- LDBG("bufferize: " << *op << '\n');
-
b.setInsertionPoint(op);
Location loc = op.getLoc();
SmallVector<Value> newInputBuffers;
@@ -1284,7 +1399,8 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
}
SmallVector<Value> newOutputBuffers;
// Try to allocate new buffers depending on op's inplace semantics.
- if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm)))
+ if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm,
+ aliasInfo)))
return failure();
// Clone the newly bufferized op.
@@ -1301,11 +1417,153 @@ static LogicalResult bufferize(OpBuilder &b, LinalgOp op,
return success();
}
+/// In a first approximation, all the function arguments of a FuncOp are marked
+/// inplaceable. For now, it is the responsibility of the `callOp` bufferization
+/// to allow FuncOp that are inplaceable to write inPlace.
+static LogicalResult
+bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm,
+ BufferizationAliasInfo &aliasInfo,
+ DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+ FuncOp funcOp = getCalledFunction(callOp);
+ assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
+ "expected Callop to a FuncOp");
+
+ // If nothing to do then we are done.
+ if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
+ !llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return success();
+
+ // Take a guard before anything else.
+ OpBuilder::InsertionGuard g(b);
+ b.setInsertionPoint(callOp);
+
+ // 1. Filter return types:
+ // - if the callee is bodiless / external, we cannot inspect it and we
+ // cannot assume anything. We can just assert that it does not return a
+ // tensor as this would have to bufferize to "return a memref", whose
+ // semantics is ill-defined.
+ // - if the callee has a body, we perform inter-procedural equivalence
+ // analysis. When successful, a result folds onto an operand. When
+ // unsuccessful, additional work is needed to either:
+ // * hoist a result into an inplaceable operand or
+ // * devise a better representation to truly return a buffer.
+ SmallVector<Type> resultTypes;
+ SmallVector<Value> hoistedArguments;
+ if (funcOp.body().empty()) {
+ if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return callOp->emitError()
+ << "cannot bufferize bodiless function that returns a tensor";
+ } else {
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (!returnOp)
+ return funcOp->emitError() << "cannot bufferize a FuncOp with tensors "
+ "and without a unique ReturnOp";
+
+ // For each FuncOp result, keep track of which inplace argument it reuses.
+ for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ Type returnType = returnOperand.get().getType();
+ if (!isaTensor(returnType)) {
+ resultTypes.push_back(returnType);
+ continue;
+ }
+
+ // If return operand is equivalent to some bbArg, no need to return it.
+ Value returnVal = returnOperand.get();
+ if (BlockArgument bbArg =
+ getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) {
+ Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
+ int64_t idx = bbArg.getArgNumber();
+ Value buffer = bvm.lookupOrNull(callOp->getOperand(idx));
+ if (!buffer)
+ return callOp->emitError() << "operand #" << idx << " not bufferized";
+ // Add CallOp operand/result equivalence: this is interprocedural info.
+ aliasInfo.insertNewBufferEquivalence(oldRes, buffer);
+ map(bvm, oldRes, buffer);
+ // Add a TensorLoadOp to kill all uses of the CallOp return.
+ // Replace all uses of the CallOp results so we can erase the CallOp.
+ // This TensorLoadOp must fold/DCE away or bufferization should be
+ // considered failed.
+ Value tensorLoad =
+ b.create<memref::TensorLoadOp>(callOp.getLoc(), buffer);
+ oldRes.replaceAllUsesWith(tensorLoad);
+ // Add new op equivalence info.
+ aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer);
+ map(bvm, tensorLoad, buffer);
+ continue;
+ }
+
+ // TODO: Need to hoist above function boundary and add to
+ // `hoistedArgumentTypes`.
+ if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo))
+ return allocOp->emitError()
+ << " needs hoist across function boundary\n";
+
+ // Other cases legitimately need to return a tensor, this is currently not
+ // supported. For instance, if hoisting across function boundary has
+ // failed, it may be due to e.g. data-dependent sizes. In such a case, we
+ // would we need a better type than memref.
+ resultTypes.push_back(returnType);
+
+ int64_t returnIdx = returnOperand.getOperandNumber();
+ return returnOp->emitError()
+ << " bufferize result #" << returnIdx << "\n";
+ }
+ }
+
+ // 2. Compute bufferized FunctionType.
+ SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
+ llvm::append_range(argumentTypes, ValueRange{hoistedArguments}.getTypes());
+ // Get the bufferized FunctionType for funcOp or construct it if not yet
+ // available.
+ FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
+ funcOp, argumentTypes, resultTypes, bufferizedFunctionTypes);
+
+ // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
+ SmallVector<Value> newOperands;
+ newOperands.reserve(callOp->getNumOperands());
+ for (OpOperand &opOperand : callOp->getOpOperands()) {
+ Value tensorOperand = opOperand.get();
+ // Non-tensor operands are just copied.
+ if (!tensorOperand.getType().isa<TensorType>()) {
+ newOperands.push_back(tensorOperand);
+ continue;
+ }
+
+ // Tensor operands are guaranteed to have been buferized.
+ int64_t idx = opOperand.getOperandNumber();
+ Value buffer = bvm.lookupOrNull(tensorOperand);
+ assert(buffer && " missing buffer for operand");
+
+ // Caller / callee type mistmatch is handled with a CastOp.
+ auto memRefType = bufferizedFuncType.getInput(idx);
+ // Since we don't yet have a clear layout story, buffer_cast may
+ // conservatively turn tensors into more dynamic memref than necessary.
+ // If the memref type of the callee fails, introduce an extra memref.cast
+ // that will either canonicalize away or fail compilation until we can do
+ // something better.
+ if (buffer.getType() != memRefType) {
+ Value castBuffer =
+ b.create<memref::CastOp>(callOp.getLoc(), memRefType, buffer);
+ // Add new op equivalence info.
+ aliasInfo.insertNewBufferEquivalence(castBuffer, buffer);
+ map(bvm, tensorOperand, castBuffer);
+ buffer = castBuffer;
+ }
+ newOperands.push_back(buffer);
+ }
+
+ // 4. Create the new CallOp.
+ Operation *newCallOp = b.create<CallOp>(callOp.getLoc(), funcOp.sym_name(),
+ resultTypes, newOperands);
+ newCallOp->setAttrs(callOp->getAttrs());
+ return success();
+}
+
/// DimOp tensor operand is modified inplace. This allows leaving dead
/// tensors behind that will get DCE'd.
static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
if (dimOp.source().getType().isa<RankedTensorType>()) {
Value v = lookup(bvm, dimOp.source());
if (!v)
@@ -1317,13 +1575,11 @@ static LogicalResult bufferize(OpBuilder &b, tensor::DimOp dimOp,
static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
Location loc = forOp.getLoc();
- LLVM_DEBUG(DBGS() << "bufferize: " << *forOp << "\n");
-
// If inPlace, just forward the buffer.
// Otherwise alloc and copy.
b.setInsertionPoint(forOp);
@@ -1337,11 +1593,12 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
Value operandBuffer = lookup(bvm, operand);
Value resultBuffer = operandBuffer;
if (getInPlace(opResult) != InPlaceSpec::True) {
- resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand);
+ resultBuffer =
+ createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo);
// If the tensor comes from `linalg::InitTensorOp`, the value is
// unitialized and we do not need to copy.
- // TODO: if the matching bbArg does not bufferize to a read is more
- // general.
+ // TODO: "matching bbArg does not bufferize to a read" is a more general
+ // check.
if (!operand.getDefiningOp<linalg::InitTensorOp>())
b.create<linalg::CopyOp>(forOp.getLoc(), operandBuffer, resultBuffer);
}
@@ -1356,7 +1613,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp,
/// FuncOp always creates TensorToMemRef ops.
static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&funcOp.body().front());
@@ -1370,9 +1627,10 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
Type memRefType = rankedTensorType
? getDynamicMemRefType(rankedTensorType)
: getContiguousOrUnrankedMemRefType(tensorType);
- Value tensorToMemref =
+ Value bufferCast =
b.create<memref::BufferCastOp>(funcOp.getLoc(), memRefType, bbArg);
- map(bvm, bbArg, tensorToMemref);
+ aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg);
+ map(bvm, bbArg, bufferCast);
}
return success();
}
@@ -1380,7 +1638,7 @@ static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp,
/// ReturnOp always creates memref::TensorLoadOp.
static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(returnOp);
@@ -1394,7 +1652,10 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
Value v = lookup(bvm, operand.get());
if (!v)
return failure();
- operand.set(b.create<memref::TensorLoadOp>(returnOp.getLoc(), v));
+ Value returnTensor = b.create<memref::TensorLoadOp>(returnOp.getLoc(), v);
+ operand.set(returnTensor);
+ aliasInfo.insertNewBufferEquivalence(returnTensor, v);
+ map(bvm, returnTensor, v);
}
return success();
}
@@ -1406,7 +1667,7 @@ static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp,
/// isolation.
static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
LDBG("bufferize: " << *extractSliceOp << '\n');
// Take a guard before anything else.
@@ -1426,8 +1687,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
Value alloc;
auto inPlace = getInPlace(extractSliceOp->getResult(0));
if (inPlace != InPlaceSpec::True) {
- alloc = createNewAllocDeallocPairForShapedValue(b, loc,
- extractSliceOp.result());
+ alloc = createNewAllocDeallocPairForShapedValue(
+ b, loc, extractSliceOp.result(), aliasInfo);
b.setInsertionPointAfter(alloc.getDefiningOp());
}
@@ -1441,6 +1702,8 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+ // Insert new alias.
+ aliasInfo.insertNewBufferAlias(subView, srcMemref);
/// If not inplaceable, copy.
if (alloc) {
@@ -1454,7 +1717,7 @@ static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp,
static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
LDBG("bufferize: " << *insertSliceOp << '\n');
// Take a guard before anything else.
@@ -1472,8 +1735,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
// cloning the whole tensor on every single iteration and is a symptom
// of a catastrophically bad scheduling decision.
// TODO: be very loud about it or even consider failing the pass.
- Value newDstMemref =
- createNewAllocDeallocPairForShapedValue(b, loc, insertSliceOp.result());
+ Value newDstMemref = createNewAllocDeallocPairForShapedValue(
+ b, loc, insertSliceOp.result(), aliasInfo);
b.setInsertionPointAfter(newDstMemref.getDefiningOp());
b.create<CopyOp>(insertSliceOp.getLoc(), dstMemref, newDstMemref);
dstMemref = newDstMemref;
@@ -1503,6 +1766,8 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
Value subView = b.create<memref::SubViewOp>(
loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+ // Insert new alias.
+ aliasInfo.insertNewBufferAlias(subView, dstMemref);
b.create<CopyOp>(insertSliceOp.getLoc(), srcMemref, subView);
}
@@ -1513,7 +1778,7 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp,
static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(op);
@@ -1522,8 +1787,6 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
if (op.getShapedType().isa<MemRefType>())
return failure();
- LDBG("bufferize: " << *op << '\n');
-
/// transfer_read from buffer always reads from the bufferized
/// op.source().
if (auto readOp = dyn_cast<vector::TransferReadOp>(op.getOperation())) {
@@ -1540,8 +1803,8 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
// If transfer_write is not inPlace, allocate a new buffer.
Value newInputBuffer;
if (inPlace != InPlaceSpec::True) {
- newInputBuffer =
- createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result());
+ newInputBuffer = createNewAllocDeallocPairForShapedValue(
+ b, loc, writeOp.result(), aliasInfo);
b.setInsertionPointAfter(newInputBuffer.getDefiningOp());
map(bvm, writeOp.result(), newInputBuffer);
} else {
@@ -1567,7 +1830,7 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op,
static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp,
BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+ BufferizationAliasInfo &aliasInfo) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(yieldOp);
@@ -1618,7 +1881,7 @@ bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
// If `extractSliceOp` were to be bufferized inplace, it cannot end up
// aliasing a write into a non-writeable buffer.
bool wouldCreateAliasingWriteToNonWriteableBuffer =
- aliasInfo.aliasesInPlaceWrite(extractSliceOp) &&
+ aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) &&
aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0));
if (wouldCreateAliasingWriteToNonWriteableBuffer)
@@ -1743,7 +2006,6 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
return extractSliceOps.push_back(extractSliceOp);
if (auto insertSliceOp = dyn_cast<InsertSliceOp>(op))
return insertSliceOps.push_back(insertSliceOp);
- auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
// No tensors => no buffers.
if (none_of(op->getOperandTypes(), isaTensor) &&
none_of(op->getResultTypes(), isaTensor))
@@ -1792,12 +2054,12 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
}
//===----------------------------------------------------------------------===//
-// Bufferization entry-point.
+// Bufferization entry-point for functions.
//===----------------------------------------------------------------------===//
-static LogicalResult
-bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm,
- const BufferizationAliasInfo &aliasInfo) {
+static LogicalResult bufferizeFuncOpInternals(
+ FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
+ DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
LLVM_DEBUG(llvm::dbgs() << "\n\n");
LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n');
OpBuilder b(funcOp->getContext());
@@ -1805,42 +2067,54 @@ bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm,
if (failed(bufferize(b, funcOp, bvm, aliasInfo)))
return failure();
// Walk in PreOrder to ensure ops with regions are handled before their body.
- WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
- LogicalResult status =
- TypeSwitch<Operation *, LogicalResult>(op)
- // Skip BufferCast and TensorLoad ops.
- // clang-format off
- .Case<memref::BufferCastOp,
- memref::TensorLoadOp>(
- [&](auto) { return success(); })
- .Case<scf::ForOp,
- tensor::DimOp,
- LinalgOp,
- ReturnOp,
- ExtractSliceOp,
- InsertSliceOp,
- VectorTransferOpInterface,
- scf::YieldOp>(
- [&](auto op) {
- LDBG("Begin buferize:\n" << op << '\n');
- return bufferize(b, op, bvm, aliasInfo);
- })
- // clang-format on
- .Default([&](Operation *op) {
- auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
- if (any_of(op->getOperandTypes(), isaTensor) ||
- any_of(op->getResultTypes(), isaTensor))
- return failure();
- return success();
- });
- if (failed(status)) {
- op->emitError("Failed bufferization");
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
+ // Since walk has to be PreOrder, we need to erase ops that require it
+ // separately: this is the case for CallOp
+ SmallVector<Operation *> toErase;
+ WalkResult result = funcOp.walk<WalkOrder::PreOrder>([&](Operation *op)
+ -> WalkResult {
+ // clang-format off
+ WalkResult result =
+ TypeSwitch<Operation *, LogicalResult>(op)
+ // Skip BufferCast and TensorLoad ops.
+ .Case<memref::BufferCastOp,
+ memref::TensorLoadOp>([&](auto) { return success(); })
+ .Case<tensor::DimOp,
+ scf::ForOp,
+ LinalgOp,
+ ReturnOp,
+ ExtractSliceOp,
+ InsertSliceOp,
+ VectorTransferOpInterface,
+ scf::YieldOp>([&](auto op) {
+ LDBG("Begin bufferize:\n" << op << '\n');
+ return bufferize(b, op, bvm, aliasInfo);
+ })
+ .Case([&](CallOpInterface op) {
+ LDBG("Begin bufferize:\n" << op << '\n');
+ return bufferize(b, op, bvm, aliasInfo, bufferizedFunctionTypes);
+ })
+ .Default([&](Operation *op) {
+ auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
+ if (any_of(op->getOperandTypes(), isaTensor) ||
+ any_of(op->getResultTypes(), isaTensor))
+ return failure();
+ return success();
+ });
+ // clang-format on
+
+ // Register post-walk erasure, if necessary.
+ if (isa<CallOpInterface>(op))
+ if (llvm::any_of(op->getOperandTypes(), isaTensor) ||
+ llvm::any_of(op->getResultTypes(), isaTensor))
+ toErase.push_back(op);
+
+ return result;
});
LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n');
+ for (Operation *op : toErase)
+ op->erase();
+
return failure(result.wasInterrupted());
}
@@ -1874,7 +2148,9 @@ void LinalgComprehensiveFuncBufferize::runOnFunction() {
// Bufferization phase.
BlockAndValueMapping bvm;
- if (failed(bufferizeFuncOpInternals(funcOp, bvm, aliasInfo)))
+ DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
+ if (failed(bufferizeFuncOpInternals(funcOp, bvm, aliasInfo,
+ bufferizedFunctionTypes)))
signalPassFailure();
// Post-pass cleanup of inplaceable attributes.
@@ -1889,6 +2165,168 @@ std::unique_ptr<Pass> mlir::createLinalgComprehensiveFuncBufferizePass() {
// Bufferization entry-point for modules.
//===----------------------------------------------------------------------===//
+/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such
+/// an op. Return null otherwise.
+static Operation *getEquivalentAlloc(Value value,
+ const BufferizationAliasInfo &aliasInfo) {
+ Operation *res;
+ aliasInfo.applyOnEquivalenceClass(value, [&](Value v) {
+ if (!res)
+ if (auto interface =
+ dyn_cast_or_null<MemoryEffectOpInterface>(v.getDefiningOp()))
+ if (auto effect =
+ interface.getEffectOnValue<MemoryEffects::Allocate>(value))
+ res = v.getDefiningOp();
+ });
+ return res;
+}
+
+/// Return the first argument of the enclosing FuncOp that is equivalent to `v`.
+/// Return null if no such bbArg can be found.
+static BlockArgument
+getEquivalentEnclosingFuncBBArg(Value v,
+ const BufferizationAliasInfo &aliasInfo) {
+ Operation *op = v.getParentBlock()->getParentOp();
+ FuncOp funcOp = dyn_cast<FuncOp>(op);
+ if (!funcOp)
+ funcOp = op->getParentOfType<FuncOp>();
+ assert(funcOp && "expected non-null FuncOp");
+ for (BlockArgument bbArg : funcOp.getArguments())
+ if (aliasInfo.areEquivalentBufferizedValues(v, bbArg))
+ return bbArg;
+ return nullptr;
+}
+
+/// Rewrite the `funcOp` arguments analysis return values and terminator into
+/// buffer form (using the canonical memref layout for now), according to the
+/// inPlace-bufferizable information of the function arguments.
+/// This relies on a buffer equivalence analysis of each return operand. When a
+/// result buffer is equivalent to:
+/// 1. a BlockArgument of `funcOp`, it can be dropped from the return values
+/// and becomes inplaceable at all callers. This assumes all CallOp perform
+/// the necessary work to clone operands so as to make them inplaceable.
+// Reliance on this logic will need to be relaxed in thefuture.
+/// 2. an op with an Alloc effect, this currently fails bufferization but is a
+/// candidate for hoisting and creating a new inplace operand at all caller
+/// sites.
+/// 3. if such a hoisting for 2. is not possible (e.g. data-dependent that
+/// prevents hoisting), this is currently unsupported and will require a
+/// refcounted buffer type.
+static LogicalResult bufferizeFuncOpBoundary(
+ FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+ DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
+ LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n");
+
+ // If nothing to do then we are done.
+ if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
+ !llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return success();
+
+ // Get the bufferized FunctionType for funcOp or construct it if not yet
+ // available.
+ // TODO: Atm we have 3 cases:
+ // 1. if a function is called from within the Module, it must have bufferized
+ // to inplaceable tensor results.
+ // 2. if it is bodiless, it must have bufferized and is not allowed to have
+ // result tensors.
+ // 3. if it is not called internally, it still must bufferize to inplaceable
+ // tensor results and we construct it now (e.g. top-level function called
+ // externally).
+ // -> Figure out a better layering.
+ TypeRange resultTypes;
+ FunctionType bufferizedFuncType =
+ getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(),
+ resultTypes, bufferizedFunctionTypes);
+
+ // Corner case: Bodiless FuncOp
+ // ============================
+ // The body of such functions is assumed opaque and we can't know the
+ // bufferization contract they want to enforce atm.
+ // As a consequence, only support functions that don't return any tensor atm.
+ if (funcOp.getBody().empty()) {
+ if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
+ return funcOp->emitError() << "cannot bufferize bodiless function that "
+ << "returns a tensor";
+ funcOp.setType(bufferizedFuncType);
+ LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp);
+ return success();
+ }
+
+ // Support only single return-terminated block in the function.
+ ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+ if (!returnOp)
+ return funcOp->emitError() << "cannot bufferize a FuncOp with tensors and "
+ "without a unique ReturnOp";
+
+ // 1. For each FuncOp result, keep track of which inplace argument it reuses.
+ SmallVector<Value> returnValues;
+ for (OpOperand &returnOperand : returnOp->getOpOperands()) {
+ // If return operand is equivalent to some bbArg, no need to return it.
+ Value returnVal = returnOperand.get();
+ if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))
+ continue;
+ // TODO: Need to hoist above function boundary. If this is not possible due
+ // to data-depedent sizes, we need a better type than memref.
+ if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo))
+ return allocOp->emitError() << " needs hoist across function boundary\n";
+ int64_t returnIdx = returnOperand.getOperandNumber();
+ return returnOp->emitError() << " bufferize result #" << returnIdx << "\n";
+ }
+
+ // 2. Rewrite the terminator without the inPlace bufferizable values.
+ OpBuilder(returnOp).create<ReturnOp>(returnOp.getLoc(), returnValues);
+ returnOp->erase();
+
+ // 3. Rewrite the bbArgs.
+ // Iterate on the original `numArgs` and replace them in order.
+ // This guarantees the argument order still matches after the rewrite.
+ Block &frontBlock = funcOp.body().front();
+ unsigned numArgs = frontBlock.getNumArguments();
+ for (unsigned idx = 0; idx < numArgs; ++idx) {
+ auto bbArg = frontBlock.getArgument(0);
+ auto tensorType = bbArg.getType().dyn_cast<TensorType>();
+ // Non-tensor types are just forwarded.
+ if (!tensorType) {
+ frontBlock.addArgument(bbArg.getType());
+ bbArg.replaceAllUsesWith(frontBlock.getArguments().back());
+ frontBlock.eraseArgument(0);
+ continue;
+ }
+
+ // Get the buffer type from the bufferized function type.
+ Type memrefType = bufferizedFuncType.getInput(idx);
+ Value memref = frontBlock.addArgument(memrefType);
+ OpBuilder b(funcOp->getContext());
+ b.setInsertionPointToStart(&frontBlock);
+ // Replace all uses of bbArg through a BufferCastOp by a memref::CastOp.
+ for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) {
+ if (auto bufferCastOp = dyn_cast<memref::BufferCastOp>(use.getOwner())) {
+ auto castOp = b.create<memref::CastOp>(
+ funcOp.getLoc(), bufferCastOp.memref().getType(), memref);
+ bufferCastOp.memref().replaceAllUsesWith(castOp);
+ aliasInfo.insertNewBufferEquivalence(castOp.dest(),
+ bufferCastOp.memref());
+ }
+ }
+ // Replace all remaining uses by a tensor_load.
+ if (!bbArg.use_empty()) {
+ auto tensorLoadOp =
+ b.create<memref::TensorLoadOp>(funcOp.getLoc(), memref);
+ aliasInfo.insertNewBufferEquivalence(tensorLoadOp, bbArg);
+ bbArg.replaceAllUsesWith(tensorLoadOp);
+ }
+ frontBlock.eraseArgument(0);
+ // TODO: add support to erase aliasInfo entries if deemed necessary.
+ }
+
+ // 4. Rewrite the FuncOp type to buffer form.
+ funcOp.setType(bufferizedFuncType);
+
+ LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp);
+
+ return success();
+}
+
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e. callees without callers first).
/// Store the map of FuncOp to all its callers in `callerMap`.
@@ -1905,10 +2343,12 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
DenseMap<FuncOp, unsigned> numberCallOpsContainedInFuncOp;
WalkResult res = moduleOp.walk([&](FuncOp funcOp) {
numberCallOpsContainedInFuncOp[funcOp] = 0;
- return funcOp.walk([&](CallOpInterface callOp) {
+ return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
+ // Only support CallOp for now.
+ if (!isa<CallOp>(callOp.getOperation()))
+ return callOp->emitError() << "expected a CallOp";
FuncOp calledFunction = getCalledFunction(callOp);
- if (!calledFunction)
- return WalkResult::interrupt();
+ assert(calledFunction && "could not retrieved called FuncOp");
auto it = callerMap.try_emplace(calledFunction, DenseSet<Operation *>{});
it.first->getSecond().insert(callOp);
if (calledBy[calledFunction].count(funcOp) == 0) {
@@ -1954,6 +2394,7 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
SmallVector<FuncOp> orderedFuncOps;
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
+ DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return signalPassFailure();
@@ -1985,12 +2426,30 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
return;
}
- // TODO: Bufferization phase.
+ // Bufferization phase.
+ if (!testAnalysisOnly) {
+ BlockAndValueMapping tensorToBufferMap;
+ if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo,
+ bufferizedFunctionTypes))) {
+ signalPassFailure();
+ return;
+ }
+ }
}
// Don't drop the attributes if we only want to report the analysis.
if (testAnalysisOnly)
return;
+ for (FuncOp funcOp : orderedFuncOps) {
+ // Note: It would be good to apply cleanups here but we cannot as aliasInfo
+ // would be invalidated.
+ if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo,
+ bufferizedFunctionTypes))) {
+ signalPassFailure();
+ return;
+ }
+ }
+
// Post-pass cleanup of inplaceable attributes.
moduleOp.walk(
[&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
@@ -1998,6 +2457,12 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
for (BlockArgument bbArg : op.getArguments())
removeInPlaceFuncArgument(bbArg);
});
+
+ OpPassManager cleanupPipeline(OpPassManager("module"));
+ cleanupPipeline.addPass(createCanonicalizerPass());
+ cleanupPipeline.addPass(createCSEPass());
+ cleanupPipeline.addPass(createLoopInvariantCodeMotionPass());
+ (void)runPipeline(cleanupPipeline, moduleOp);
}
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index 0e378a89ef58..d6a6d7c67f6c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -1,5 +1,36 @@
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file -verify-diagnostics
+func private @foo() -> tensor<?xf32>
+
+func @bar() -> tensor<?xf32> {
+ %foo = constant @foo : () -> (tensor<?xf32>)
+// expected-error @+1 {{expected a CallOp}}
+ %res = call_indirect %foo() : () -> (tensor<?xf32>)
+ return %res : tensor<?xf32>
+}
+
+// -----
+
+// expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}}
+func private @foo() -> tensor<?xf32>
+
+// -----
+
+// expected-error @+1 {{cannot bufferize a FuncOp with tensors and without a unique ReturnOp}}
+func @switch(%flag : i32, %caseOperand : i32, %t1 : tensor<f32>, %t2 : tensor<f32>)
+ -> (tensor<f32>)
+{
+ switch %flag : i32, [
+ default: ^bb1(%caseOperand : i32),
+ 42: ^bb2(%caseOperand : i32)
+ ]
+
+ ^bb1(%bb1arg : i32):
+ return %t1 : tensor<f32>
+ ^bb2(%bb2arg : i32):
+ return %t2 : tensor<f32>
+}
+
// -----
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
new file mode 100644
index 000000000000..7756587560ea
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s
+
+// CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+
+// CHECK: func private @some_external_func(memref<?xf32, #[[$DYN_1D_MAP]]>)
+func private @some_external_func(tensor<?xf32>)
+
+// CHECK: func @scf_for_with_tensor_insert_slice(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$DYN_1D_MAP]]>
+func @scf_for_with_tensor_insert_slice(
+ %A : tensor<?xf32>, %B : tensor<?xf32>, %C : tensor<4xf32>,
+ %lb : index, %ub : index, %step : index)
+ -> (tensor<?xf32>, tensor<?xf32>)
+{
+ // CHECK-NEXT: scf.for
+ %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
+ -> (tensor<?xf32>, tensor<?xf32>)
+ {
+ // CHECK-NEXT: %[[SVA:.*]] = memref.subview %[[A]]
+ // CHECK-NEXT: linalg.copy(%[[C]], %[[SVA]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+ %ttA = tensor.insert_slice %C into %tA[%i][4][1] : tensor<4xf32> into tensor<?xf32>
+
+ // CHECK-NEXT: %[[SVB:.*]] = memref.subview %[[B]]
+ // CHECK-NEXT: linalg.copy(%[[C]], %[[SVB]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]>
+ %ttB = tensor.insert_slice %C into %tB[%i][4][1] : tensor<4xf32> into tensor<?xf32>
+
+ // scf.yield is empty and is elided
+ // CHECK-NOT: scf.yield
+ scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
+ }
+
+ // Swaparoo requires bufferizing the whole function to figure out who's who.
+ return %r0#1, %r0#0: tensor<?xf32>, tensor<?xf32>
+}
+
+// CHECK: func @bar(
+// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref<?xf32, #[[$DYN_1D_MAP]]>
+// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$DYN_1D_MAP]]>
+func @bar(
+ %A : tensor<?xf32> {linalg.inplaceable = true},
+ %B : tensor<?xf32> {linalg.inplaceable = true},
+ %C : tensor<4xf32> {linalg.inplaceable = true},
+ %lb : index, %ub : index, %step : index)
+ -> (tensor<?xf32>, tensor<?xf32>)
+{
+// CHECK-NEXT: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]]
+ %r0:2 = call @scf_for_with_tensor_insert_slice(%A, %B, %C, %lb, %ub, %step) :
+ (tensor<?xf32>, tensor<?xf32>, tensor<4xf32>, index, index, index)
+ -> (tensor<?xf32>, tensor<?xf32>)
+
+ // %r0#0 is actually %B after inplaceable results are swapped in the callee.
+// CHECK-NEXT: call @some_external_func(%[[B]]) : (memref<?xf32, #[[$DYN_1D_MAP]]>) -> ()
+ call @some_external_func(%r0#0) : (tensor<?xf32>) -> ()
+
+// CHECK-NEXT: return
+ return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
+}
More information about the Mlir-commits
mailing list