[Mlir-commits] [mlir] 2975407 - [mlir][linalg][bufferize][NFC] Pass BufferizationState as const reference
Matthias Springer
llvmlistbot at llvm.org
Thu Jan 6 07:22:56 PST 2022
Author: Matthias Springer
Date: 2022-01-07T00:18:46+09:00
New Revision: 2975407bd41c0a416832a879fb4137f5c90385ba
URL: https://github.com/llvm/llvm-project/commit/2975407bd41c0a416832a879fb4137f5c90385ba
DIFF: https://github.com/llvm/llvm-project/commit/2975407bd41c0a416832a879fb4137f5c90385ba.diff
LOG: [mlir][linalg][bufferize][NFC] Pass BufferizationState as const reference
This is mostly for documentation purposes: Passing the object as a const reference signifies that analysis decisions cannot be changed after the analysis.
Differential Revision: https://reviews.llvm.org/D116742
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 3ec15fb301988..0d51677cab9ea 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -304,29 +304,29 @@ class BufferizationState {
/// Determine which OpOperand* will alias with `result` if the op is
/// bufferized in place. Return an empty vector if the op is not bufferizable.
- SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
+ SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
/// Determine which OpResult will alias with `opOperand` if the op is
/// bufferized in place. Return an empty OpResult if the op is not
/// bufferizable.
- OpResult getAliasingOpResult(OpOperand &opOperand);
+ OpResult getAliasingOpResult(OpOperand &opOperand) const;
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if
/// the op is not bufferizable.
- bool bufferizesToMemoryRead(OpOperand &opOperand);
+ bool bufferizesToMemoryRead(OpOperand &opOperand) const;
/// Return true if `opOperand` bufferizes to a memory write. Return true` if
/// the op is not bufferizable.
- bool bufferizesToMemoryWrite(OpOperand &opOperand);
+ bool bufferizesToMemoryWrite(OpOperand &opOperand) const;
/// Return true if `opOperand` does neither read nor write but bufferizes to
/// an alias. Return false if the op is not bufferizable.
- bool bufferizesToAliasOnly(OpOperand &opOperand);
+ bool bufferizesToAliasOnly(OpOperand &opOperand) const;
/// Return true if the given value is read by an op that bufferizes to a
/// memory read. Also takes into account ops that create an alias but do not
/// read by themselves (e.g., ExtractSliceOp).
- bool isValueRead(Value value);
+ bool isValueRead(Value value) const;
/// Starting from `value`, follow the use-def chain in reverse, always
/// selecting the aliasing OpOperands. Find and return Values for which
@@ -351,9 +351,8 @@ class BufferizationState {
/// In the above example, Values with a star satisfy the condition. When
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
- llvm::SetVector<Value>
- findValueInReverseUseDefChain(Value value,
- llvm::function_ref<bool(Value)> condition);
+ llvm::SetVector<Value> findValueInReverseUseDefChain(
+ Value value, llvm::function_ref<bool(Value)> condition) const;
/// Find the Value of the last preceding write of a given Value.
///
@@ -363,33 +362,34 @@ class BufferizationState {
///
/// Note: When reaching an end of the reverse SSA use-def chain, that value
/// is returned regardless of whether it is a memory write or not.
- Value findLastPrecedingWrite(Value value);
+ Value findLastPrecedingWrite(Value value) const;
/// Creates a memref allocation.
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ArrayRef<Value> dynShape);
+ ArrayRef<Value> dynShape) const;
/// Creates an alloc-dealloc pair. This function may perform additional
/// optimizations such as buffer allocation hoisting.
Value createAllocDeallocPair(OpBuilder &builder, Location loc,
- Value shapedValue);
+ Value shapedValue) const;
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
- void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer);
+ void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const;
/// Creates a memcpy between two given buffers.
- void createMemCpy(OpBuilder &b, Location loc, Value from, Value to);
+ void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
- void replaceOp(RewriterBase &rewriter, Operation *op, ValueRange values);
+ void replaceOp(RewriterBase &rewriter, Operation *op,
+ ValueRange values) const;
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
/// values.
template <typename OpTy, typename... Args>
OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
- Args &&...args) {
+ Args &&...args) const {
Operation *newOp =
rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
replaceOp(rewriter, op, newOp->getResults());
@@ -398,7 +398,7 @@ class BufferizationState {
/// Lookup the memref buffer that is associated to the given tensor value.
/// Asserts if no buffer is associated.
- Value lookupBuffer(RewriterBase &rewriter, Value tensor);
+ Value lookupBuffer(RewriterBase &rewriter, Value tensor) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpResult opResult) const;
@@ -406,10 +406,19 @@ class BufferizationState {
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
- Value getResultBuffer(RewriterBase &rewriter, OpResult result);
+ Value getResultBuffer(RewriterBase &rewriter, OpResult result) const;
/// Return dialect-specific bufferization state.
- template <typename StateT> StateT &getDialectState(StringRef name) {
+ template <typename StateT>
+ Optional<const StateT *> getDialectState(StringRef name) const {
+ auto it = dialectState.find(name);
+ if (it == dialectState.end())
+ return None;
+ return static_cast<const StateT *>(it->getSecond().get());
+ }
+
+ /// Return dialect-specific bufferization state or create one if none exists.
+ template <typename StateT> StateT &getOrCreateDialectState(StringRef name) {
// Create state if it does not exist yet.
if (!dialectState.count(name))
dialectState[name] = std::make_unique<StateT>();
@@ -419,15 +428,10 @@ class BufferizationState {
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const { return options; }
-private:
- friend LogicalResult
- runComprehensiveBufferize(Operation *op, const BufferizationOptions &options,
- BufferizationState &state);
-
- friend LogicalResult
- runComprehensiveBufferize(ModuleOp moduleOp,
- std::unique_ptr<BufferizationOptions> options);
+ /// Return a reference to the BufferizationAliasInfo.
+ BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
+private:
/// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
/// functions and `runComprehensiveBufferize` may access this object.
BufferizationAliasInfo aliasInfo;
@@ -441,17 +445,17 @@ class BufferizationState {
/// Bufferize all ops in the given region.
LogicalResult bufferize(RewriterBase &rewriter, Region *region,
- BufferizationState &state);
+ const BufferizationState &state);
/// Bufferize all ops in the given block.
LogicalResult bufferize(RewriterBase &rewriter, Block *block,
- BufferizationState &state);
+ const BufferizationState &state);
/// Bufferize the given op. If the op has no tensor OpOperands/OpResults, this
/// function returns immediately. Otherwise, it calls the `bufferize` interface
/// method of `BufferizableOpInterface`.
LogicalResult bufferize(RewriterBase &rewriter, Operation *op,
- BufferizationState &state);
+ const BufferizationState &state);
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
@@ -492,38 +496,39 @@ struct AllocationHoistingBarrierOnly
: public BufferizableOpInterface::ExternalModel<
AllocationHoistingBarrierOnly<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return {};
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::None;
}
- bool isWritable(Operation *op, Value value, BufferizationState &state) const {
+ bool isWritable(Operation *op, Value value,
+ const BufferizationState &state) const {
return false;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
if (any_of(op->getOperandTypes(), isaTensor) ||
any_of(op->getResultTypes(), isaTensor))
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 56c6b848c5f3f..3bf0420381a63 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -33,7 +33,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"bufferizesToMemoryRead",
/*args=*/(ins "OpOperand &":$opOperand,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -62,7 +62,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"bufferizesToMemoryWrite",
/*args=*/(ins "OpOperand &":$opOperand,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -85,7 +85,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"isMemoryWrite",
/*args=*/(ins "OpResult":$opResult,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
@@ -116,7 +116,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"mustBufferizeInPlace",
/*args=*/(ins "OpResult":$opResult,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
@@ -131,7 +131,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"OpResult",
/*methodName=*/"getAliasingOpResult",
/*args=*/(ins "OpOperand &":$opOperand,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -155,7 +155,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"SmallVector<OpOperand *>",
/*methodName=*/"getAliasingOpOperand",
/*args=*/(ins "OpResult":$opResult,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opResult.getType().isa<TensorType>() &&
@@ -188,7 +188,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"bufferRelation",
/*args=*/(ins "OpResult":$opResult,
"const BufferizationAliasInfo &":$aliasInfo,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpResults
@@ -210,7 +210,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "RewriterBase &":$rewriter,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
@@ -236,7 +236,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"isWritable",
/*args=*/(ins "Value":$value,
- "BufferizationState &":$state),
+ "const BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return value.isa<OpResult>();
@@ -275,7 +275,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"isNotConflicting",
/*args=*/(ins "OpOperand *":$uRead,
"OpOperand *":$uWrite,
- "BufferizationState &":$state,
+ "const BufferizationState &":$state,
"const BufferizationAliasInfo &":$aliasInfo),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -292,7 +292,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
///
/// Examples of such ops are `tensor.extract_slice` and `tensor.cast`.
bool bufferizesToAliasOnly(OpOperand &opOperand,
- BufferizationState &state) {
+ const BufferizationState &state) {
auto bufferizableOp =
cast<BufferizableOpInterface>(getOperation());
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
index e8d0fa984bb03..8474c127b1206 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp
@@ -24,7 +24,7 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
"not a constant ranked tensor");
@@ -40,7 +40,8 @@ struct ConstantOpInterface
return success();
}
- bool isWritable(Operation *op, Value value, BufferizationState &state) const {
+ bool isWritable(Operation *op, Value value,
+ const BufferizationState &state) const {
// Memory locations returned by memref::GetGlobalOp may not be written to.
assert(value.isa<OpResult>());
return false;
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 816fafb0691ca..9ee674fe4ff73 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -199,7 +199,7 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *>
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
- OpResult result) {
+ OpResult result) const {
if (Operation *op = result.getDefiningOp())
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.getAliasingOpOperand(result, *this);
@@ -210,7 +210,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
/// in place. Return an empty OpResult if the op is not bufferizable.
OpResult
mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
- OpOperand &opOperand) {
+ OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.getAliasingOpResult(opOperand, *this);
@@ -220,7 +220,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
/// op is not bufferizable.
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
- bufferizesToMemoryRead(OpOperand &opOperand) {
+ bufferizesToMemoryRead(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
@@ -233,7 +233,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
/// Return true if `opOperand` bufferizes to a memory write. Return
/// `true` if the op is not bufferizable.
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
- bufferizesToMemoryWrite(OpOperand &opOperand) {
+ bufferizesToMemoryWrite(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
@@ -246,7 +246,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
/// Return true if `opOperand` does neither read nor write but bufferizes to an
/// alias. Return false if the op is not bufferizable.
bool mlir::linalg::comprehensive_bufferize::BufferizationState::
- bufferizesToAliasOnly(OpOperand &opOperand) {
+ bufferizesToAliasOnly(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
@@ -260,7 +260,7 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
- Value value) {
+ Value value) const {
SmallVector<OpOperand *> workingSet;
for (OpOperand &use : value.getUses())
workingSet.push_back(&use);
@@ -282,10 +282,9 @@ bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
// the aliasing OpOperands. Find and return Values for which `condition`
// evaluates to true. OpOperands of such matching Values are not traversed any
// further.
-llvm::SetVector<Value>
-mlir::linalg::comprehensive_bufferize::BufferizationState::
- findValueInReverseUseDefChain(Value value,
- llvm::function_ref<bool(Value)> condition) {
+llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
+ BufferizationState::findValueInReverseUseDefChain(
+ Value value, llvm::function_ref<bool(Value)> condition) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -312,7 +311,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::
// Find the Value of the last preceding write of a given Value.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
- findLastPrecedingWrite(Value value) {
+ findLastPrecedingWrite(Value value) const {
SetVector<Value> result =
findValueInReverseUseDefChain(value, [&](Value value) {
Operation *op = value.getDefiningOp();
@@ -360,7 +359,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
- getResultBuffer(RewriterBase &rewriter, OpResult result) {
+ getResultBuffer(RewriterBase &rewriter, OpResult result) const {
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
@@ -424,7 +423,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
}
void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
- RewriterBase &rewriter, Operation *op, ValueRange values) {
+ RewriterBase &rewriter, Operation *op, ValueRange values) const {
OpBuilder::InsertionGuard g(rewriter);
// Replace all OpResults with the given values.
@@ -454,7 +453,7 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
}
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
- RewriterBase &rewriter, Region *region, BufferizationState &state) {
+ RewriterBase &rewriter, Region *region, const BufferizationState &state) {
for (Block &block : *region)
if (failed(bufferize(rewriter, &block, state)))
return failure();
@@ -462,7 +461,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
}
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
- RewriterBase &rewriter, Block *block, BufferizationState &state) {
+ RewriterBase &rewriter, Block *block, const BufferizationState &state) {
// Ops may get deleted during the traversal, so do not iterate over `block`
// directly.
SmallVector<Operation *> ops;
@@ -476,7 +475,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
}
LogicalResult mlir::linalg::comprehensive_bufferize::bufferize(
- RewriterBase &rewriter, Operation *op, BufferizationState &state) {
+ RewriterBase &rewriter, Operation *op, const BufferizationState &state) {
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
@@ -592,7 +591,8 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
- createAllocDeallocPair(OpBuilder &b, Location loc, Value shapedValue) {
+ createAllocDeallocPair(OpBuilder &b, Location loc,
+ Value shapedValue) const {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
@@ -621,19 +621,20 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
/// Create a memref allocation.
Optional<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
- OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape) {
+ OpBuilder &b, Location loc, MemRefType type,
+ ArrayRef<Value> dynShape) const {
return options.allocationFns->allocationFn(b, loc, type, dynShape);
}
/// Create a memref deallocation.
void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
- OpBuilder &b, Location loc, Value allocatedBuffer) {
+ OpBuilder &b, Location loc, Value allocatedBuffer) const {
return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
}
/// Create a memory copy between two memref buffers.
void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
- OpBuilder &b, Location loc, Value from, Value to) {
+ OpBuilder &b, Location loc, Value from, Value to) const {
return options.allocationFns->memCpyFn(b, loc, from, to);
}
@@ -649,7 +650,7 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
}
Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
- RewriterBase &rewriter, Value tensor) {
+ RewriterBase &rewriter, Value tensor) const {
assert(tensor.getType().isa<TensorType>() && "unexpected non-tensor type");
// Replace "%t = to_tensor %m" with %m.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index eab925f02420a..05bac6fa132e3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -40,18 +40,18 @@ struct ToMemrefOpInterface
: public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
bufferization::ToMemrefOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// It is unknown whether the resulting MemRef will be read or not.
return true;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
// Fold to_memref(to_tensor(x)) to x.
@@ -86,11 +86,12 @@ struct ToTensorOpInterface
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
bufferization::ToTensorOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return success();
}
- bool isWritable(Operation *op, Value value, BufferizationState &state) const {
+ bool isWritable(Operation *op, Value value,
+ const BufferizationState &state) const {
// It is unknown whether the MemRef operand is writable or not.
return false;
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 1d7ebfa39988b..b912d1ea34f50 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -661,7 +661,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
IRRewriter rewriter(op->getContext());
DominanceInfo domInfo(op);
- BufferizationAliasInfo &aliasInfo = state.aliasInfo;
+ BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
return failure();
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 536664a6dfb70..482131a7fec52 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -24,7 +24,7 @@ namespace {
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
- BufferizationState &state) {
+ const BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
@@ -142,13 +142,13 @@ struct LinalgOpInterface
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
return genericOp.payloadUsesValueFromOperand(&opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
return static_cast<bool>(
bufferizableOp.getAliasingOpResult(opOperand, state));
@@ -156,7 +156,7 @@ struct LinalgOpInterface
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
@@ -166,7 +166,7 @@ struct LinalgOpInterface
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
return pairs[&opOperand];
@@ -174,12 +174,12 @@ struct LinalgOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
}
};
@@ -188,13 +188,13 @@ struct InitTensorOpInterface
: public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
linalg::InitTensorOp> {
bool isMemoryWrite(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// InitTensorOps allocate but do not write.
return false;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto initTensorOp = cast<linalg::InitTensorOp>(op);
// The InitTensorOp may have been eliminated.
@@ -212,7 +212,7 @@ struct TiledLoopOpInterface
: public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
linalg::TiledLoopOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of
// its matching bbArg may.
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
@@ -220,7 +220,7 @@ struct TiledLoopOpInterface
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// TiledLoop alone doesn't bufferize to a memory write, one of the uses of
// its matching bbArg may.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
@@ -229,18 +229,19 @@ struct TiledLoopOpInterface
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
return tiledLoopOp.getTiedOpResult(opOperand);
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
- bool isWritable(Operation *op, Value value, BufferizationState &state) const {
+ bool isWritable(Operation *op, Value value,
+ const BufferizationState &state) const {
// Interestingly, linalg::TiledLoopOp's bbArg can **always** be viewed
// inplace from the perspective of ops nested under:
// 1. Either the matching iter operand is not bufferized inplace and an
@@ -253,7 +254,7 @@ struct TiledLoopOpInterface
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
// Compute new inputs, outputs and results.
@@ -355,22 +356,22 @@ struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
linalg::YieldOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto yieldOp = cast<linalg::YieldOp>(op);
if (!yieldOp->getParentOfType<TiledLoopOp>())
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 171b47b6447c4..434458bf56d1b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -34,9 +34,20 @@ struct ModuleBufferizationState : public DialectBufferizationState {
};
} // namespace
+/// Get ModuleBufferizationState.
+static const ModuleBufferizationState &
+getModuleBufferizationState(const BufferizationState &state) {
+ Optional<const ModuleBufferizationState *> maybeState =
+ state.getDialectState<ModuleBufferizationState>(
+ StandardOpsDialect::getDialectNamespace());
+ assert(maybeState.hasValue() && "ModuleBufferizationState does not exist");
+ return **maybeState;
+}
+
+/// Get or create ModuleBufferizationState.
static ModuleBufferizationState &
getModuleBufferizationState(BufferizationState &state) {
- return state.getDialectState<ModuleBufferizationState>(
+ return state.getOrCreateDialectState<ModuleBufferizationState>(
StandardOpsDialect::getDialectNamespace());
}
@@ -471,19 +482,25 @@ namespace std_ext {
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
/// specified return value (if any).
static Optional<int64_t>
-getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
+getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
int64_t returnValIdx) {
- if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
+ if (!state.equivalentFuncArgs.count(funcOp))
+ // No equivalence info stores for funcOp.
+ return None;
+
+ const DenseMap<int64_t, int64_t> &equivFuncArgs =
+ state.equivalentFuncArgs.lookup(funcOp);
+ if (!equivFuncArgs.count(returnValIdx))
// Return value has no equivalent bbArg.
return None;
- return state.equivalentFuncArgs[funcOp][returnValIdx];
+ return equivFuncArgs.lookup(returnValIdx);
}
struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
// of the matching bbArg may. It is the responsibility of the caller to
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
@@ -492,7 +509,7 @@ struct CallOpInterface
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// CallOpInterface is special, it needs to wait for the callee to be
// bufferized and needs to inspect the BufferAliasInfo object. It can't
// make a proper determination by itself and needs to be conservative.
@@ -503,14 +520,15 @@ struct CallOpInterface
/// marked inplaceable. For now, it is the responsibility of the `callOp`
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op);
unsigned numResults = callOp.getNumResults();
unsigned numOperands = callOp->getNumOperands();
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected CallOp to a FuncOp");
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ const ModuleBufferizationState &moduleState =
+ getModuleBufferizationState(state);
// Result types of the bufferized CallOp.
SmallVector<Type> resultTypes;
@@ -626,22 +644,22 @@ struct ReturnOpInterface
: public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
ReturnOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto returnOp = cast<ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
"only support FuncOp parent for ReturnOp");
@@ -662,7 +680,7 @@ struct ReturnOpInterface
struct FuncOpInterface
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto funcOp = cast<FuncOp>(op);
// Bufferize function body.
@@ -670,11 +688,13 @@ struct FuncOpInterface
}
/// Return `true` if the given function argument is writable.
- bool isWritable(Operation *op, Value value, BufferizationState &state) const {
+ bool isWritable(Operation *op, Value value,
+ const BufferizationState &state) const {
auto funcOp = cast<FuncOp>(op);
BlockArgument bbArg = value.dyn_cast<BlockArgument>();
assert(bbArg && "expected BlockArgument");
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ const ModuleBufferizationState &moduleState =
+ getModuleBufferizationState(state);
// In a first approximation:
// =========================
@@ -720,8 +740,9 @@ static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
}
/// Annotate the IR with the result of the analysis. For testing/debugging only.
-static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
- BufferizationState &state) {
+static void
+annotateOpsWithBufferizationMarkers(FuncOp funcOp,
+ const BufferizationState &state) {
auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<TensorType>())
@@ -733,7 +754,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
IRRewriter rewriter(moduleOp.getContext());
BufferizationState state(moduleOp, *options);
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- BufferizationAliasInfo &aliasInfo = state.aliasInfo;
+ BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
moduleState.callerMap)))
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index d008607ed4c1a..156ac160ff3e4 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -24,7 +24,7 @@ struct ExecuteRegionOpInterface
scf::ExecuteRegionOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
// any SSA value that is in scope. To allow for use-def chain traversal
// through ExecuteRegionOps in the analysis, the corresponding yield value
@@ -41,7 +41,7 @@ struct ExecuteRegionOpInterface
}
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// ExecuteRegionOp results always bufferize in-place. Since they have no
// OpOperands, they are mostly ignored by the analysis once alias sets are
// set up.
@@ -51,7 +51,7 @@ struct ExecuteRegionOpInterface
// TODO: For better bufferization results, this could return `true` only if
// there is a memory write in the region.
bool isMemoryWrite(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// Similar to scf.if, results of this op are always considered memory writes
// in the analysis. This is a useful pattern for all ops that have tensor
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
@@ -61,7 +61,7 @@ struct ExecuteRegionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// TODO: Add bufferization support when needed. scf.execute_region should be
// bufferized similar to scf.if.
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
@@ -76,7 +76,7 @@ struct ExecuteRegionOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
};
@@ -85,7 +85,7 @@ struct IfOpInterface
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
// value that is in scope. To allow for use-def chain traversal through
// IfOps in the analysis, both corresponding yield values from the then/else
@@ -102,7 +102,7 @@ struct IfOpInterface
// allowed at the moment, we should never encounter scf.ifs that yield
// unmodified tensors. Such scf.yield ops could just fold away.
bool isMemoryWrite(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// IfOp results are always considered memory writes in the analysis. This
// design decision simplifies the analysis considerably. E.g., consider the
// following test case:
@@ -129,14 +129,14 @@ struct IfOpInterface
}
bool mustBufferizeInPlace(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// IfOp results always bufferize in-place. Since they have no OpOperands,
// they are mostly ignored by the analysis once alias sets are set up.
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto ifOp = cast<scf::IfOp>(op);
// Compute new types of the bufferized scf.if op.
@@ -209,7 +209,7 @@ struct IfOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// IfOp results are equivalent to their corresponding yield values if both
// yield values are equivalent to each other.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
@@ -226,7 +226,7 @@ struct ForOpInterface
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
scf::ForOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
// its matching bbArg may.
auto forOp = cast<scf::ForOp>(op);
@@ -234,7 +234,7 @@ struct ForOpInterface
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// Tensor iter_args of scf::ForOps are always considered as a write. This is
// to simplify the analysis.
// TODO: Consider doing sth. like isValueWritten.
@@ -242,7 +242,7 @@ struct ForOpInterface
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto forOp = cast<scf::ForOp>(op);
if (!opOperand.get().getType().isa<RankedTensorType>())
return OpResult();
@@ -251,7 +251,7 @@ struct ForOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
auto forOp = cast<scf::ForOp>(op);
@@ -263,7 +263,8 @@ struct ForOpInterface
return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
}
- bool isWritable(Operation *op, Value value, BufferizationState &state) const {
+ bool isWritable(Operation *op, Value value,
+ const BufferizationState &state) const {
// Interestingly, scf::ForOp's bbArg can **always** be viewed
// inplace from the perspective of ops nested under:
// 1. Either the matching iter operand is not bufferized inplace and an
@@ -274,7 +275,7 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = &forOp.getLoopBody().front();
@@ -416,22 +417,22 @@ struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
yieldOp->getParentOp()))
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 9ee1d23d5d8af..550e585e4736a 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -27,28 +27,28 @@ struct CastOpInterface
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
tensor::CastOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return op->getResult(0);
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
Value resultBuffer = state.getResultBuffer(rewriter, castOp->getResult(0));
@@ -78,22 +78,22 @@ struct DimOpInterface
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
tensor::DimOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
if (!dimOp.source().getType().isa<RankedTensorType>())
return dimOp.emitError("unranked tensor not supported");
@@ -107,17 +107,17 @@ struct ExtractSliceOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
tensor::ExtractSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return &opOperand == &op->getOpOperand(0) /*source*/
? op->getResult(0)
: OpResult();
@@ -125,12 +125,12 @@ struct ExtractSliceOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
Value srcMemref = state.lookupBuffer(rewriter, extractSliceOp.source());
@@ -173,22 +173,22 @@ struct ExtractOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
tensor::ExtractOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref = state.lookupBuffer(rewriter, extractOp.tensor());
state.replaceOpWithNewOp<memref::LoadOp>(rewriter, op, srcMemref,
@@ -201,17 +201,17 @@ struct InsertOpInterface
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
tensor::InsertOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
"expected dest OpOperand");
return op->getOpResult(0);
@@ -219,12 +219,12 @@ struct InsertOpInterface
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return {&op->getOpOperand(1) /*dest*/};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
Location loc = insertOp.getLoc();
Value destMemref =
@@ -237,7 +237,7 @@ struct InsertOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
};
@@ -263,8 +263,8 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state, Value value,
- InsertSliceOp insertOp) {
+ const BufferizationState &state,
+ Value value, InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
@@ -280,17 +280,17 @@ struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/
? op->getResult(0)
: OpResult();
@@ -298,12 +298,13 @@ struct InsertSliceOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
- OpOperand *uConflictingWrite, BufferizationState &state,
+ OpOperand *uConflictingWrite,
+ const BufferizationState &state,
const BufferizationAliasInfo &aliasInfo) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@@ -380,7 +381,7 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index c8d335e66bc8c..0d66e8879563c 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -21,26 +21,26 @@ struct TransferReadOpInterface
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
vector::TransferReadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return false;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return OpResult();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(readOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
@@ -60,21 +60,21 @@ struct TransferWriteOpInterface
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return op->getOpResult(0);
@@ -82,12 +82,12 @@ struct TransferWriteOpInterface
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(writeOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
More information about the Mlir-commits
mailing list