[Mlir-commits] [mlir] 75d6529 - [mlir][linalg][bufferize][NFC] Clean up comments and minor code refactorings
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 6 13:26:17 PST 2022
Author: Matthias Springer
Date: 2022-01-07T06:23:01+09:00
New Revision: 75d65293ca83a7ff24e4c6634e46e63e8ae8c24c
URL: https://github.com/llvm/llvm-project/commit/75d65293ca83a7ff24e4c6634e46e63e8ae8c24c
DIFF: https://github.com/llvm/llvm-project/commit/75d65293ca83a7ff24e4c6634e46e63e8ae8c24c.diff
LOG: [mlir][linalg][bufferize][NFC] Clean up comments and minor code refactorings
Differential Revision: https://reviews.llvm.org/D116451
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 0bd42e34f047..921353a23ea7 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -64,14 +64,14 @@ struct AllocationCallbacks {
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
-/// executed after the analysis, but before bufferization. They can be used
+/// executed after the analysis, but before bufferization. They can be used to
/// implement custom dialect-specific optimizations.
struct PostAnalysisStep {
virtual ~PostAnalysisStep() {}
/// Run the post analysis step. This function may modify the IR, but must keep
- /// `aliasInfo` (inside `state`) consistent. Newly created operations and
- /// operations that should be re-analyzed must be stored in `newOps`.
+ /// `aliasInfo` consistent. Newly created operations and operations that
+ /// should be re-analyzed must be added to `newOps`.
virtual LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) = 0;
@@ -102,7 +102,8 @@ struct BufferizationOptions {
}
/// Allow-list the given dialects in the dialect filter. Only ops from
- /// allow-listed dialects will be bufferized.
+ /// allow-listed dialects will be bufferized. If no dialect is added, ops from
+ /// any dialect will be bufferized.
template <typename... DialectTs>
void addToDialectFilter() {
// The following expands a call to addToDialectFilterImpl for each dialect
@@ -288,17 +289,7 @@ struct DialectBufferizationState {
};
/// BufferizationState provides a variety of helper functions for dealing with
-/// tensor values and memref buffers. In particular,
-/// `BufferizableOpInterface::bufferize` implementation should utilize the
-/// following helper functions.
-///
-/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
-/// that allocate and/or deallocate memref buffers.
-/// * `lookupBuffer` returns the memref buffer of a given tensor value.
-/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
-/// Based on inplace bufferization decisions of the analysis, it may either
-/// directly return a mapped buffer or allocate a new brand new buffer.
-/// * `replaceOp` replaces an op with new values.
+/// tensor values and memref buffers.
class BufferizationState {
public:
BufferizationState(Operation *op, const BufferizationOptions &options);
@@ -396,7 +387,8 @@ class BufferizationState {
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
- Value getResultBuffer(RewriterBase &rewriter, OpResult result) const;
+ FailureOr<Value> getResultBuffer(RewriterBase &rewriter,
+ OpResult result) const;
/// Return dialect-specific bufferization state.
template <typename StateT>
@@ -455,12 +447,9 @@ MemRefType getContiguousMemRefType(ShapedType shapedType,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
-/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
-/// with the same shape as `shapedType` and specified `layout` and
-/// `addressSpace` or an UnrankedMemRefType otherwise.
-Type getContiguousOrUnrankedMemRefType(Type type,
- MemRefLayoutAttrInterface layout = {},
- Attribute memorySpace = {});
+/// Return an UnrankedMemRefType with the given element type and memory space.
+UnrankedMemRefType getUnrankedMemRefType(Type elementType,
+ Attribute memorySpace = {});
/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
@@ -493,7 +482,7 @@ struct AllocationHoistingBarrierOnly
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
- return false;
+ return true;
}
SmallVector<OpOperand *>
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 1da5903f1048..e56371617b97 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -23,6 +23,12 @@ class BufferizationAliasInfo;
namespace linalg_ext {
struct InitTensorEliminationStep : public PostAnalysisStep {
+ /// A function that matches anchor OpOperands for InitTensorOp elimination.
+ using AnchorMatchFn = std::function<bool(OpOperand &)>;
+
+ /// A function that rewrites matched anchors.
+ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
+
/// Try to eliminate InitTensorOps inside `op`.
///
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
@@ -33,12 +39,11 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
/// InitTensorOp.
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
/// This analysis can be skipped with `skipAnalysis`.
- LogicalResult eliminateInitTensors(
- Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- std::function<bool(OpOperand &)> anchorMatchFunc,
- std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
- SmallVector<Operation *> &newOps);
+ LogicalResult eliminateInitTensors(Operation *op, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
+ AnchorMatchFn anchorMatchFunc,
+ RewriteFn rewriteFunc,
+ SmallVector<Operation *> &newOps);
};
/// Try to eliminate InitTensorOps inside `op` that are anchored on an
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
index 6dcc7c5fca92..bed69f19582f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp
@@ -13,6 +13,8 @@
void mlir::linalg::comprehensive_bufferize::affine_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
+ // AffineParallelOp bufferization not implemented yet. However, never hoist
+ // memref allocations across AffineParallelOp boundaries.
registry.addOpInterface<AffineParallelOp,
AllocationHoistingBarrierOnly<AffineParallelOp>>();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index 40d54445c5e8..3c0926e3fae6 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -20,23 +20,30 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {
+/// Bufferization of arith.constant. Replace with memref.get_global.
struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
- assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
- "not a constant ranked tensor");
+
+ // Only ranked tensors are supported.
+ if (!constantOp.getType().isa<RankedTensorType>())
+ return failure();
+
+ // Only constants inside a module are supported.
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
- return constantOp.emitError(
- "cannot bufferize constants not within builtin.module op");
+ return failure();
+ // Create global memory segment and replace tensor with memref pointing to
+ // that memory segment.
GlobalCreator globalCreator(moduleOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
rewriter, op, globalMemref.type(), globalMemref.getName());
+
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 84dba99b0840..118e25a23148 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -74,6 +74,21 @@ mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
BufferizationOptions::BufferizationOptions()
: allocationFns(defaultAllocationCallbacks()) {}
+BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
+ BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
+ if (isOpAllowed(op))
+ return dyn_cast<BufferizableOpInterface>(op);
+ return nullptr;
+}
+
+BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
+ BufferizationOptions::dynCastBufferizableOp(Value value) const {
+ if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
+ if (isOpAllowed(bufferizableOp.getOperation()))
+ return bufferizableOp;
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// BufferizationAliasInfo
//===----------------------------------------------------------------------===//
@@ -180,21 +195,6 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
}
}
-BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
- BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
- if (isOpAllowed(op))
- return dyn_cast<BufferizableOpInterface>(op);
- return nullptr;
-}
-
-BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
- BufferizationOptions::dynCastBufferizableOp(Value value) const {
- if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
- if (isOpAllowed(bufferizableOp.getOperation()))
- return bufferizableOp;
- return nullptr;
-}
-
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *>
@@ -358,8 +358,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::
- getResultBuffer(RewriterBase &rewriter, OpResult result) const {
+FailureOr<Value>
+mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
+ RewriterBase &rewriter, OpResult result) const {
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
@@ -375,10 +376,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
if (aliasingOperands.size() > 1 &&
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
return lookupBuffer(rewriter, o->get()) == operandBuffer;
- })) {
- op->emitError("result buffer is ambiguous");
- return Value();
- }
+ }))
+ return FailureOr<Value>(op->emitError("result buffer is ambiguous"));
// If bufferizing out-of-place, allocate a new buffer.
if (!aliasInfo.isInPlace(result)) {
@@ -610,10 +609,13 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, tensor);
- Type memrefType =
- tensor.getType().isa<RankedTensorType>()
- ? getDynamicMemRefType(tensor.getType().cast<RankedTensorType>())
- : getContiguousOrUnrankedMemRefType(tensor.getType());
+ Type memrefType;
+ if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
+ memrefType = getDynamicMemRefType(rankedTensorType);
+ } else {
+ memrefType = getUnrankedMemRefType(
+ tensor.getType().cast<TensorType>().getElementType());
+ }
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
@@ -630,13 +632,9 @@ MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
layout, memorySpace);
}
-Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
- Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
- if (type.isa<RankedTensorType, MemRefType>())
- return getContiguousMemRefType(type.cast<ShapedType>(), layout,
- memorySpace);
- assert(!layout && "expected empty layout with UnrankedMemRefType");
- return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
+UnrankedMemRefType mlir::linalg::comprehensive_bufferize::getUnrankedMemRefType(
+ Type elementType, Attribute memorySpace) {
+ return UnrankedMemRefType::get(elementType, memorySpace);
}
MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 5051d43bb584..aaa304b2c91f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -25,6 +25,9 @@ namespace bufferization_ext {
// TODO: These ops should implement BufferizableOpInterface directly when moved
// to the Bufferization dialect.
+/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded
+/// to x. Other to_memref ops are ignored during bufferization.
+///
/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
/// location of the incoming tensor once it will be bufferized. In the anlysis,
/// the incoming tensor is assumed to bufferize to a memory read and to an
@@ -41,7 +44,7 @@ struct ToMemrefOpInterface
bufferization::ToMemrefOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
- // It is unknown whether the resulting MemRef will be read or not.
+ // It is unknown whether the resulting memref will be read or not.
return true;
}
@@ -58,10 +61,13 @@ struct ToMemrefOpInterface
if (auto toTensorOp =
toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
Value buffer = toTensorOp.memref();
+
+ // Insert cast in case to_memref(to_tensor(x))'s type is
diff erent from
+ // x's type.
if (toTensorOp.memref().getType() != toMemrefOp.getType())
buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
toMemrefOp.getType());
- rewriter.replaceOp(toMemrefOp, buffer);
+ replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
return success();
}
@@ -69,16 +75,19 @@ struct ToMemrefOpInterface
}
};
-/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
-/// not lower any further, and they should have disappeared by the time the
-/// input is fully bufferized.
+/// Bufferization of bufferization.to_tensor. Such ops cannot be bufferized.
+/// However, other ops that are using to_tensor's result will eventually be
+/// bufferized. At that point, they will start using to_tensor's memref operand.
+/// Once all users of to_tensor are bufferized, the op will not have any users
+/// anymore and DCE away.
///
-/// The analysis has no information about the memref that is loaded from by the
-/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
-/// potentially alias with any other bufferized tensor. Since ToTensorOp and
-/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
-/// directly in the analysis. However, declaring ToTensorOp results as not
-/// writable also enforces a buffer copy and has the same effect.
+/// ToTensorOp conceptually loads a tensor from a memory location. The analysis
+/// has no information about the memref that is loaded from by ToTensorOp. We
+/// have to assume that the loaded tensor may after bufferization potentially
+/// alias with any other bufferized tensor. Since ToTensorOp and ToMemrefOp have
+/// no aliasing OpOperand/OpResult pairs, this cannot be encoded directly in the
+/// analysis. However, declaring ToTensorOp results as not writable enforces a
+/// buffer copy and has the same effect.
struct ToTensorOpInterface
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
bufferization::ToTensorOp> {
@@ -89,7 +98,7 @@ struct ToTensorOpInterface
bool isWritable(Operation *op, Value value,
const BufferizationState &state) const {
- // It is unknown whether the MemRef operand is writable or not.
+ // It is unknown whether the memref operand is writable or not.
return false;
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 88a8f861c543..60ca3623fb95 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -6,98 +6,37 @@
//
//===----------------------------------------------------------------------===//
//
-// Perform inplace bufferization within function boundaries.
-// This is a specialized pass that supports inplace analysis for a fixed subset
-// of ops that have well-defined inplace semantics.
-// This pass caters to high-performance codegen where buffer reuse is deemed
-// critical: the pass should fail if the bufferized form of the function needs
-// to return any buffer.
-// Generic control-flow and branching are unsupported.
-// Composability with extensible set of ops is not a first-class concern.
-//
-// Bufferization occurs by:
-// a. performing an inPlace analysis `inPlaceAnalysis` which marks each
-// operation within the op with the `kInPlaceResultsAttrName` attribute.
-// b. traversing each operation in the op and rewriting it in
-// buffer form and keeping a BlockAndValueMapping mapping of the
-// rewrites. New allocations are introduced during this step.
-// TODO: Allocation + depending op hoisting to outermost enclosing
-// sequential scope.
-// c. at the end of this bufferization, 3 cases may occur:
-// i. inplaceable function arguments may be reused in place after the
-// function itself has been bufferized. This is encoded by IR resembling:
-//
-// ```
-// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-// -> tensor<?xf32> {
-// %0 = bufferization.to_memref %A : memref<?xf32, #map>
-// // ... uses of %0
-// %res = bufferization.to_tensor %0 : memref<?xf32, #map>
-// return %res : tensor<?xf32>
-// }
-// ```
+// Comprehensive Bufferize bufferizes function bodies. Function boundaries
+// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
+// ModuleBufferization.cpp is an extension of Comprehensive Bufferize for simple
+// call graphs.
//
-// this is the cue for the bufferization of the function foo (and calls
-// to it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`.
-// To fully achieve bufferization, an additional analysis is needed to
-// determine whether function argument/operand pairs bufferize to a
-// single inplace buffer argument (i.e. functions may return tensors in
-// arbitrary order that may not match argument numbers).
+// Comprehensive Bufferize consists of two phases.
//
-// ii. results that don't map to an inplaceable function argument are
-// generally allocated. Since memref semantics wrt ownership of the
-// underlying memory region are not well-defined, comprehensive
-// bufferization chooses to perform allocations in a scoped fashion:
-// returning memrefs is always considered illegal.
-// Such scenarios are encoded by IR resembling:
+// 1. Analyze ops to decide which OpResults can bufferize inplace, i.e., without
+// inserting buffer copies. The analysis queries op bufferization semantics
+// via `BufferizableOpInterface`.
+// 2. Bufferize ops by calling `BufferizableOpInterface::bufferize`. This
+// function does not generate buffer copies for OpResults that were decided
+// to bufferize inplace during the analysis phase.
//
-// ```
-// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
-// -> tensor<?xf32> {
-// %0 = bufferization.to_memref %A : memref<?xf32, #map>
-// %1 = memref.dim %0, %c0 : memref<?xf32, #map>
-// %2 = memref.alloc(%1) : memref<?xf32>
-// %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map>
-// // ... uses of %3
-// memref.dealloc %2 : memref<?xf32, #map>
-// %res = bufferization.to_tensor %3 : memref<?xf32, #map>
-// return %res : tensor<?xf32>
-// }
-// ```
+// Inplace bufferization decisions are passed from the analysis to the
+// bufferization phase via `BufferizationState` and `BufferizationAliasInfo`.
+// They can be printed for debugging purposes with `testAnalysisOnly`.
//
-// this is the cue for the bufferization of the function foo (and calls
-// to it) that it must bufferize to `func @foo(%A: memref<?xf32,
-// some_layout>,
-// %B: memref<?xf32, some_layout>)` (i.e. make a cloned
-// allocation of the result tensor)
-// To fully achieve bufferization, the alloc/dealloc pair must be lifted
-// out of the function at each call site.
+// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
+// treated conservatively. E.g., the analysis has to assume that their
+// OpOperands bufferize to memory writes. While such ops can be analyzed, they
+// are not bufferized and remain in the IR. to_tensor and to_memref ops are
+// inserted at the bufferization boundary.
//
-// iii. as an optimization over ii., it may be possible to reuse an argument
-// and only want to return a slice.
-// This may forego allocation by letting *all* callers decide whether to
-// pass a new *aliasing* memref function argument (i.e. a subview).
-// Without loss of generality, callers may agree to allocate a new buffer
-// to avoid this aliasing. Such scenarios are encoded by IR resembling:
+// Note: If `allowUnknownOps` is set to false, bufferization fails when an
+// unknown op (that does not implement `BufferizableOpInterface`) is found. No
+// to_tensor/to_memref ops are inserted.
//
-// ```
-// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
-// func @foo(%arg0: tensor<?xf32> {linalg.inplaceable = true})
-// -> tensor<4xf32> {
-// %0 = bufferization.to_memref %arg0 : memref<?xf32, #map>
-// %1 = memref.subview %0[0] [4] [1] : memref<?xf32, #map> to
-// memref<4xf32, #map>
-// // ... inplace computes into %1
-// %3 = bufferization.to_tensor %1 : memref<4xf32, #map>
-// return %3 : tensor<4xf32>
-// }
-// ```
-//
-// Note: In the future, it may be worthwhile to design special bufferization
-// ops to encode the desired semantics at function boundaries for i., ii. and
-// iii.
+// This pass caters to high-performance codegen where buffer reuse is deemed
+// critical: the pass should fail if the bufferized form of the function needs
+// to return any buffer, unless `allowReturnMemref` is enabled.
//
// Lastly, note that layout map chosen to bufferize is the most dynamic
// canonical strided layout of the proper rank. This ensures compatibility with
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 5c5462527e9b..c4f42afb9828 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -38,6 +38,7 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
if (!op.hasTensorSemantics())
return op->emitError() << "op does not have tensor semantics";
+ // New input operands for the cloned op.
SmallVector<Value> newInputBuffers;
newInputBuffers.reserve(op.getNumInputs());
for (OpOperand *opOperand : op.getInputOperands()) {
@@ -48,22 +49,23 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(state.lookupBuffer(rewriter, opOperand->get()));
}
+ // New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpOperand *opOperand : op.getOutputOperands()) {
OpResult opResult = op.getTiedOpResult(opOperand);
assert(opResult && "could not find correspond OpResult");
- Value resultBuffer = state.getResultBuffer(rewriter, opResult);
- if (!resultBuffer)
- return failure();
- newOutputBuffers.push_back(resultBuffer);
+ FailureOr<Value> resultBuffer = state.getResultBuffer(rewriter, opResult);
+ newOutputBuffers.push_back(*resultBuffer);
}
- // Clone the newly bufferized op.
+ // Merge input/output operands.
SmallVector<Value> newOperands = newInputBuffers;
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
// Set insertion point now that potential alloc/dealloc are introduced.
rewriter.setInsertionPoint(op);
+ // Clone the op, but use the new operands. Since the new op does not have any
+ // tensor results, it does not return anything.
op.clone(rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands);
// Replace the results of the old op with the new output buffers.
@@ -135,18 +137,23 @@ static DenseMap<OpOperand *, OpResult> computeAliasingPairs(LinalgOp op) {
return mapping;
}
+/// Bufferization of linalg.generic. Replace with a new linalg.generic that
+/// operates entirely on memrefs.
template <typename OpTy>
struct LinalgOpInterface
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
+ // Operand is read if it is used in the computation.
auto genericOp = cast<linalg::LinalgOp>(op);
return genericOp.payloadUsesValueFromOperand(&opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
+ // Operand is written to if it has an aliasing OpResult. For more details,
+ // see `computeAliasingPairs`.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
return static_cast<bool>(
bufferizableOp.getAliasingOpResult(opOperand, state));
@@ -156,6 +163,8 @@ struct LinalgOpInterface
getAliasingOpOperand(Operation *op, OpResult opResult,
const BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
+
+ // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
if (pairs[opOperand] == opResult)
@@ -166,6 +175,8 @@ struct LinalgOpInterface
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
+
+ // Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
return pairs[&opOperand];
}
@@ -207,22 +218,26 @@ struct InitTensorOpInterface
}
};
+/// Bufferization of linalg.tiled_loop. Replace with a new linalg.tiled_loop
+/// that operates entirely on memrefs.
struct TiledLoopOpInterface
: public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
linalg::TiledLoopOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
- // TiledLoop alone doesn't bufferize to a memory read, one of the uses of
- // its matching bbArg may.
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
+
+ // linalg.tiled_loop operands alone do not bufferize to a memory read, but
+ // one of the uses of their matching bbArgs may.
return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
- // TiledLoop alone doesn't bufferize to a memory write, one of the uses of
- // its matching bbArg may.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
+
+ // Only operands with an aliasing OpResult (i.e., output operands) bufferize
+ // to a memory write.
return static_cast<bool>(
bufferizableOp.getAliasingOpResult(opOperand, state));
}
@@ -230,6 +245,8 @@ struct TiledLoopOpInterface
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
+
+ // Output operands are tied to their corresponding OpResults.
return tiledLoopOp.getTiedOpResult(opOperand);
}
@@ -241,8 +258,8 @@ struct TiledLoopOpInterface
bool isWritable(Operation *op, Value value,
const BufferizationState &state) const {
- // Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed
- // inplace from the perspective of ops nested under:
+ // Interestingly, linalg::TiledLoopOp's bbArgs can **always** be viewed
+ // inplace from the perspective of nested ops:
// 1. Either the matching iter operand is not bufferized inplace and an
// alloc + optional copy makes the bbArg itself inplaceable.
// 2. Or the matching iter operand is bufferized inplace and bbArg just
@@ -268,10 +285,10 @@ struct TiledLoopOpInterface
int nextResultNum = 0;
for (Value value : tiledLoopOp.outputs()) {
if (value.getType().isa<TensorType>()) {
- Value buffer = state.getResultBuffer(
+ FailureOr<Value> buffer = state.getResultBuffer(
rewriter, tiledLoopOp->getResult(nextResultNum++));
- newOutputs.push_back(buffer);
- newResults.push_back(buffer);
+ newOutputs.push_back(*buffer);
+ newResults.push_back(*buffer);
} else {
newOutputs.push_back(value);
}
@@ -349,6 +366,8 @@ struct TiledLoopOpInterface
}
};
+/// Bufferization of linalg.yield. Bufferized as part of linalg.tiled_loop's
+/// bufferization.
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
linalg::YieldOp> {
@@ -407,13 +426,12 @@ struct LinalgOpInterfaceHelper<> {
/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single InitTensorOp.
-LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
- InitTensorEliminationStep::eliminateInitTensors(
- Operation *op, BufferizationState &state,
- BufferizationAliasInfo &aliasInfo,
- std::function<bool(OpOperand &)> anchorMatchFunc,
- std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
- SmallVector<Operation *> &newOps) {
+LogicalResult
+mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
+ eliminateInitTensors(Operation *op, BufferizationState &state,
+ BufferizationAliasInfo &aliasInfo,
+ AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
+ SmallVector<Operation *> &newOps) {
OpBuilder b(op->getContext());
WalkResult status = op->walk([&](Operation *op) {
@@ -506,6 +524,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
return eliminateInitTensors(
op, state, aliasInfo,
+ /*anchorMatchFunc=*/
[&](OpOperand &operand) {
auto insertSliceOp =
dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
@@ -516,6 +535,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
return false;
return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
},
+ /*rewriteFunc=*/
[](OpBuilder &b, Location loc, OpOperand &operand) {
auto insertSliceOp = cast<tensor::InsertSliceOp>(operand.getOwner());
auto extractOp = b.create<tensor::ExtractSliceOp>(
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 7d9a5648b128..5cac342296f8 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -5,6 +5,88 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+//
+// Module bufferization is an extension of Comprehensive Bufferize that
+// bufferizes function boundaries. It provides `BufferizableOpInterface`
+// implementations for FuncOp, CallOp and ReturnOp, along with a few helper
+// functions that control the order in which functions are bufferized.
+//
+// Three cases can occur during bufferization of FuncOps.
+//
+// i. inplaceable function arguments may be reused in place after the
+// function itself has been bufferized. This is encoded by IR resembling:
+//
+// ```
+// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
+// -> tensor<?xf32> {
+// %0 = bufferization.to_memref %A : memref<?xf32, #map>
+// // ... uses of %0
+// %res = bufferization.to_tensor %0 : memref<?xf32, #map>
+// return %res : tensor<?xf32>
+// }
+// ```
+//
+// this is the cue for the bufferization of the function foo (and calls
+// to it) may bufferize to `func @foo(%A: memref<?xf32, some_layout>)`.
+// To fully achieve bufferization, an additional analysis is needed to
+// determine whether function argument/operand pairs bufferize to a
+// single inplace buffer argument (i.e. functions may return tensors in
+// arbitrary order that may not match argument numbers).
+//
+// ii. results that don't map to an inplaceable function argument are
+// generally allocated. Since memref semantics wrt ownership of the
+// underlying memory region are not well-defined, comprehensive
+// bufferization chooses to perform allocations in a scoped fashion:
+// returning memrefs is always considered illegal.
+// Such scenarios are encoded by IR resembling:
+//
+// ```
+// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// func @foo(%A: tensor<?xf32> {linalg.inplaceable = true})
+// -> tensor<?xf32> {
+// %0 = bufferization.to_memref %A : memref<?xf32, #map>
+// %1 = memref.dim %0, %c0 : memref<?xf32, #map>
+// %2 = memref.alloc(%1) : memref<?xf32>
+// %3 = memref.cast %2 : memref<?xf32> to memref<?xf32, #map>
+// // ... uses of %3
+// memref.dealloc %2 : memref<?xf32, #map>
+// %res = bufferization.to_tensor %3 : memref<?xf32, #map>
+// return %res : tensor<?xf32>
+// }
+// ```
+//
+// this is the cue for the bufferization of the function foo (and calls
+// to it) that it must bufferize to `func @foo(%A: memref<?xf32,
+// some_layout>,
+// %B: memref<?xf32, some_layout>)` (i.e. make a cloned
+// allocation of the result tensor)
+// To fully achieve bufferization, the alloc/dealloc pair must be lifted
+// out of the function at each call site.
+//
+// iii. as an optimization over ii., it may be possible to reuse an argument
+// and only want to return a slice.
+// This may forego allocation by letting *all* callers decide whether to
+// pass a new *aliasing* memref function argument (i.e. a subview).
+// Without loss of generality, callers may agree to allocate a new buffer
+// to avoid this aliasing. Such scenarios are encoded by IR resembling:
+//
+// ```
+// #map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
+// func @foo(%arg0: tensor<?xf32> {linalg.inplaceable = true})
+// -> tensor<4xf32> {
+// %0 = bufferization.to_memref %arg0 : memref<?xf32, #map>
+// %1 = memref.subview %0[0] [4] [1] : memref<?xf32, #map> to
+// memref<4xf32, #map>
+// // ... inplace computes into %1
+// %3 = bufferization.to_tensor %1 : memref<4xf32, #map>
+// return %3 : tensor<4xf32>
+// }
+// ```
+//
+// Note: In the future, it may be worthwhile to design special bufferization
+// ops to encode the desired semantics at function boundaries for i., ii. and
+// iii.
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
@@ -161,7 +243,7 @@ static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
if (auto rankedTensorType = t.dyn_cast<RankedTensorType>())
return getDynamicMemRefType(rankedTensorType);
if (auto tensorType = t.dyn_cast<TensorType>())
- return getContiguousOrUnrankedMemRefType(tensorType);
+ return getUnrankedMemRefType(tensorType.getElementType());
return t;
};
auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite));
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 1c4185c8ffef..5983d421aaed 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -19,6 +19,8 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace scf_ext {
+/// Bufferization of scf.execute_region. Can be analyzed, but bufferization not
+/// fully implemented at the moment.
struct ExecuteRegionOpInterface
: public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
scf::ExecuteRegionOp> {
@@ -79,6 +81,7 @@ struct ExecuteRegionOpInterface
}
};
+/// Bufferization of scf.if. Replace with a new scf.if that yields memrefs.
struct IfOpInterface
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
SmallVector<OpOperand *>
@@ -212,6 +215,8 @@ struct IfOpInterface
}
};
+/// Bufferization of scf.for. Replace with a new scf.for that operates on
+/// memrefs.
struct ForOpInterface
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
scf::ForOp> {
@@ -292,7 +297,7 @@ struct ForOpInterface
// Construct a new scf.for op with memref instead of tensor values.
SmallVector<Value> initArgs =
convert(forOp.getInitArgs(), [&](Value val, int64_t index) {
- return state.getResultBuffer(rewriter, forOp->getOpResult(index));
+ return *state.getResultBuffer(rewriter, forOp->getOpResult(index));
});
auto newForOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
@@ -399,6 +404,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
return status;
}
+/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
+/// this is for analysis only.
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index e1cd933de3b8..6b8b8983972a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -51,29 +51,38 @@ struct CastOpInterface
const BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
- Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
- if (!resultBuffer)
- return failure();
- Type sourceType = resultBuffer.getType();
- auto rankedMemRefType = sourceType.dyn_cast<MemRefType>();
- auto unrankedMemRefType = sourceType.dyn_cast<UnrankedMemRefType>();
- assert(rankedMemRefType || unrankedMemRefType);
- Attribute memorySpace = rankedMemRefType
- ? rankedMemRefType.getMemorySpace()
- : unrankedMemRefType.getMemorySpace();
- TensorType tensorType = castOp.getResult().getType().cast<TensorType>();
- MemRefLayoutAttrInterface layout =
- rankedMemRefType && tensorType.isa<RankedTensorType>()
- ? rankedMemRefType.getLayout()
- : MemRefLayoutAttrInterface();
- Type memRefType = getContiguousOrUnrankedMemRefType(
- castOp.getResult().getType(), layout, memorySpace);
- replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, memRefType,
- resultBuffer);
+ // The result buffer still has the old (pre-cast) type.
+ FailureOr<Value> resultBuffer =
+ state.getResultBuffer(rewriter, castOp->getResult(0));
+ auto sourceMemRefType = resultBuffer->getType().cast<BaseMemRefType>();
+ Attribute memorySpace = sourceMemRefType.getMemorySpace();
+ TensorType resultTensorType =
+ castOp.getResult().getType().cast<TensorType>();
+ MemRefLayoutAttrInterface layout;
+
+ if (auto rankedMemRefType = sourceMemRefType.dyn_cast<MemRefType>())
+ if (resultTensorType.isa<RankedTensorType>())
+ layout = rankedMemRefType.getLayout();
+
+ // Compute the new memref type.
+ Type resultMemRefType;
+ if (auto rankedTensorType = resultTensorType.isa<RankedTensorType>()) {
+ resultMemRefType =
+ getContiguousMemRefType(resultTensorType, layout, memorySpace);
+ } else {
+ resultMemRefType =
+ getUnrankedMemRefType(resultTensorType.getElementType(), memorySpace);
+ }
+
+ // Replace the op with a memref.cast.
+ replaceOpWithNewBufferizedOp<memref::CastOp>(rewriter, op, resultMemRefType,
+ *resultBuffer);
+
return success();
}
};
+/// Bufferization of tensor.dim. Replace with memref.dim.
struct DimOpInterface
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
tensor::DimOp> {
@@ -95,14 +104,13 @@ struct DimOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
- if (!dimOp.source().getType().isa<RankedTensorType>())
- return dimOp.emitError("unranked tensor not supported");
Value v = state.lookupBuffer(rewriter, dimOp.source());
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
};
+/// Bufferization of tensor.extract_slice. Replace with memref.subview.
struct ExtractSliceOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
tensor::ExtractSliceOp> {
@@ -156,7 +164,7 @@ struct ExtractSliceOpInterface
loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(),
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
- /// If not inplaceable, copy.
+ // If not inplaceable, copy.
if (!inplace) {
// Do not copy if the copied data is never read.
if (state.isValueRead(extractSliceOp.result()))
@@ -169,6 +177,7 @@ struct ExtractSliceOpInterface
}
};
+/// Bufferization of tensor.extract. Replace with memref.load.
struct ExtractOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
tensor::ExtractOp> {
@@ -197,6 +206,7 @@ struct ExtractOpInterface
}
};
+/// Bufferization of tensor.insert. Replace with memref.store.
struct InsertOpInterface
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
tensor::InsertOp> {
@@ -226,12 +236,11 @@ struct InsertOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
- Location loc = insertOp.getLoc();
- Value destMemref =
+ FailureOr<Value> destMemref =
state.getResultBuffer(rewriter, insertOp->getOpResult(0));
- rewriter.create<memref::StoreOp>(loc, insertOp.scalar(), destMemref,
- insertOp.indices());
- replaceOpWithBufferizedValues(rewriter, op, destMemref);
+ rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
+ *destMemref, insertOp.indices());
+ replaceOpWithBufferizedValues(rewriter, op, *destMemref);
return success();
}
@@ -276,6 +285,8 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
condition);
}
+/// Bufferization of tensor.insert_slice. Replace with a memory copy. Under
+/// certain circumstances, this op can also be a no-op.
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
@@ -391,13 +402,11 @@ struct InsertSliceOpInterface
Location loc = insertSliceOp.getLoc();
// When bufferizing out-of-place, `getResultBuffer` allocates.
- Value dstMemref =
+ FailureOr<Value> dstMemref =
state.getResultBuffer(rewriter, insertSliceOp->getResult(0));
- if (!dstMemref)
- return failure();
// Take a subview of the dst.
- auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+ auto dstMemrefType = dstMemref->getType().cast<MemRefType>();
auto subviewMemRefType =
memref::SubViewOp::inferRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstMemrefType,
@@ -405,15 +414,15 @@ struct InsertSliceOpInterface
insertSliceOp.getMixedStrides())
.cast<MemRefType>();
Value subView = rewriter.create<memref::SubViewOp>(
- loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
+ loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(),
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
- state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
+ state.createMemCpy(rewriter, loc, srcMemref, subView);
- replaceOpWithBufferizedValues(rewriter, op, dstMemref);
+ replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
return success();
}
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index d4c57617b004..3c8d6a9c96e5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -17,6 +17,8 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace vector_ext {
+/// Bufferization of vector.transfer_read. Replaced with a new
+/// vector.transfer_read that operates on a memref.
struct TransferReadOpInterface
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
vector::TransferReadOp> {
@@ -55,6 +57,8 @@ struct TransferReadOpInterface
}
};
+/// Bufferization of vector.transfer_write. Replace with a new
+/// vector.transfer_write that operates on a memref.
struct TransferWriteOpInterface
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
@@ -94,13 +98,12 @@ struct TransferWriteOpInterface
// Create a new transfer_write on buffer that doesn't have a return value.
// Leave the previous transfer_write to dead code as it still has uses at
// this point.
- Value resultBuffer = state.getResultBuffer(rewriter, op->getResult(0));
- if (!resultBuffer)
- return failure();
+ FailureOr<Value> resultBuffer =
+ state.getResultBuffer(rewriter, op->getResult(0));
rewriter.create<vector::TransferWriteOp>(
- writeOp.getLoc(), writeOp.vector(), resultBuffer, writeOp.indices(),
+ writeOp.getLoc(), writeOp.vector(), *resultBuffer, writeOp.indices(),
writeOp.permutation_mapAttr(), writeOp.in_boundsAttr());
- replaceOpWithBufferizedValues(rewriter, op, resultBuffer);
+ replaceOpWithBufferizedValues(rewriter, op, *resultBuffer);
return success();
}
More information about the Mlir-commits
mailing list