[Mlir-commits] [mlir] ec8628b - [mlir][linalg][bufferize][NFC] Pass BufferizationState into all op interface methods
Matthias Springer
llvmlistbot at llvm.org
Wed Dec 15 18:50:14 PST 2021
Author: Matthias Springer
Date: 2021-12-16T11:45:13+09:00
New Revision: ec8628b1d615270e0e86a4efb71c9477dd95b195
URL: https://github.com/llvm/llvm-project/commit/ec8628b1d615270e0e86a4efb71c9477dd95b195
DIFF: https://github.com/llvm/llvm-project/commit/ec8628b1d615270e0e86a4efb71c9477dd95b195.diff
LOG: [mlir][linalg][bufferize][NFC] Pass BufferizationState into all op interface methods
This allows op interface implementations to make decisions based on dialect-specific bufferization state.
This is in preparation of fixing conflict detection of CallOps in ModuleBufferization.
Differential Revision: https://reviews.llvm.org/D115705
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 35955af49efa..891d59b61616 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -179,8 +179,7 @@ enum class BufferRelation {
/// equivalence classes to support bufferization.
class BufferizationAliasInfo {
public:
- explicit BufferizationAliasInfo(Operation *rootOp,
- const BufferizationOptions &options);
+ explicit BufferizationAliasInfo(Operation *rootOp);
// BufferizationAliasInfo should be passed as a reference.
BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
@@ -271,68 +270,6 @@ class BufferizationAliasInfo {
/// Return `true` if the given value is a BlockArgument of a FuncOp.
bool isFunctionArgument(Value value);
-/// Determine which OpOperand* will alias with `result` if the op is bufferized
-/// in place. Return an empty vector if the op is not bufferizable.
-SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
-
-/// Determine which OpResult will alias with `opOperand` if the op is bufferized
-/// in place. Return an empty OpResult if the op is not bufferizable.
-OpResult getAliasingOpResult(OpOperand &opOperand);
-
-/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
-/// op is not bufferizable.
-bool bufferizesToMemoryRead(OpOperand &opOperand);
-
-/// Return true if `opOperand` bufferizes to a memory write. Return
-/// `true` if the op is not bufferizable.
-bool bufferizesToMemoryWrite(OpOperand &opOperand);
-
-/// Return true if `opOperand` does neither read nor write but bufferizes to an
-/// alias. Return false if the op is not bufferizable.
-bool bufferizesToAliasOnly(OpOperand &opOperand);
-
-/// Return true if the given value is read by an op that bufferizes to a memory
-/// read. Also takes into account ops that create an alias but do not read by
-/// themselves (e.g., ExtractSliceOp).
-bool isValueRead(Value value);
-
-/// Starting from `value`, follow the use-def chain in reverse, always selecting
-/// the aliasing OpOperands. Find and return Values for which `condition`
-/// evaluates to true. OpOperands of such matching Values are not traversed any
-/// further.
-///
-/// When reaching the end of a chain (BlockArgument or Value without aliasing
-/// OpOperands), also return the last Value of that chain.
-///
-/// Example:
-///
-/// 8
-/// |
-/// 6* 7* +-----+----+
-/// | | | |
-/// 2* 3 4* 5
-/// | | | |
-/// +----------+----------+----------+
-/// |
-/// 1
-///
-/// In the above example, Values with a star satisfy the condition. When
-/// starting the traversal from Value 1, the resulting SetVector is:
-/// { 2, 7, 8, 5 }
-llvm::SetVector<Value>
-findValueInReverseUseDefChain(Value value, const BufferizationOptions &options,
- std::function<bool(Value)> condition);
-
-/// Find the Value of the last preceding write of a given Value.
-///
-/// Note: Unknown ops are handled conservatively and assumed to be writes.
-/// Furthermore, BlockArguments are also assumed to be writes. There is no
-/// analysis across block boundaries.
-///
-/// Note: When reaching an end of the reverse SSA use-def chain, that value
-/// is returned regardless of whether it is a memory write or not.
-Value findLastPrecedingWrite(Value value, const BufferizationOptions &options);
-
/// Dialect-specific bufferization state. Analysis/bufferization information
/// that is specific to ops from a certain dialect can be stored in derived
/// variants of this struct.
@@ -359,12 +296,74 @@ struct DialectBufferizationState {
/// * `replaceOp` replaces an op with new values.
class BufferizationState {
public:
- BufferizationState(Operation *op, const BufferizationOptions &options)
- : aliasInfo(op, options), options(options), builder(op->getContext()) {}
+ BufferizationState(Operation *op, const BufferizationOptions &options);
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
+ /// Determine which OpOperand* will alias with `result` if the op is
+ /// bufferized in place. Return an empty vector if the op is not bufferizable.
+ SmallVector<OpOperand *> getAliasingOpOperand(OpResult result);
+
+ /// Determine which OpResult will alias with `opOperand` if the op is
+ /// bufferized in place. Return an empty OpResult if the op is not
+ /// bufferizable.
+ OpResult getAliasingOpResult(OpOperand &opOperand);
+
+ /// Return true if `opOperand` bufferizes to a memory read. Return `true` if
+ /// the op is not bufferizable.
+ bool bufferizesToMemoryRead(OpOperand &opOperand);
+
+ /// Return true if `opOperand` bufferizes to a memory write. Return true` if
+ /// the op is not bufferizable.
+ bool bufferizesToMemoryWrite(OpOperand &opOperand);
+
+ /// Return true if `opOperand` does neither read nor write but bufferizes to
+ /// an alias. Return false if the op is not bufferizable.
+ bool bufferizesToAliasOnly(OpOperand &opOperand);
+
+ /// Return true if the given value is read by an op that bufferizes to a
+ /// memory read. Also takes into account ops that create an alias but do not
+ /// read by themselves (e.g., ExtractSliceOp).
+ bool isValueRead(Value value);
+
+ /// Starting from `value`, follow the use-def chain in reverse, always
+ /// selecting the aliasing OpOperands. Find and return Values for which
+ /// `condition` evaluates to true. OpOperands of such matching Values are not
+ /// traversed any further.
+ ///
+ /// When reaching the end of a chain (BlockArgument or Value without aliasing
+ /// OpOperands), also return the last Value of that chain.
+ ///
+ /// Example:
+ ///
+ /// 8
+ /// |
+ /// 6* 7* +-----+----+
+ /// | | | |
+ /// 2* 3 4* 5
+ /// | | | |
+ /// +----------+----------+----------+
+ /// |
+ /// 1
+ ///
+ /// In the above example, Values with a star satisfy the condition. When
+ /// starting the traversal from Value 1, the resulting SetVector is:
+ /// { 2, 7, 8, 5 }
+ llvm::SetVector<Value>
+ findValueInReverseUseDefChain(Value value,
+ std::function<bool(Value)> condition);
+
+ /// Find the Value of the last preceding write of a given Value.
+ ///
+ /// Note: Unknown ops are handled conservatively and assumed to be writes.
+ /// Furthermore, BlockArguments are also assumed to be writes. There is no
+ /// analysis across block boundaries.
+ ///
+ /// Note: When reaching an end of the reverse SSA use-def chain, that value
+ /// is returned regardless of whether it is a memory write or not.
+ Value findLastPrecedingWrite(Value value);
+
/// Creates a memref allocation.
Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape);
@@ -494,25 +493,30 @@ template <typename OpTy>
struct AllocationHoistingBarrierOnly
: public BufferizableOpInterface::ExternalModel<
AllocationHoistingBarrierOnly<OpTy>, OpTy> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
return {};
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::None;
}
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index a81b52d1433f..df9090972bed 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -32,7 +32,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"bool",
/*methodName=*/"bufferizesToMemoryRead",
- /*args=*/(ins "OpOperand &":$opOperand),
+ /*args=*/(ins "OpOperand &":$opOperand,
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -60,7 +61,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"bool",
/*methodName=*/"bufferizesToMemoryWrite",
- /*args=*/(ins "OpOperand &":$opOperand),
+ /*args=*/(ins "OpOperand &":$opOperand,
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -82,19 +84,21 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"bool",
/*methodName=*/"isMemoryWrite",
- /*args=*/(ins "OpResult":$opResult),
+ /*args=*/(ins "OpResult":$opResult,
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
cast<BufferizableOpInterface>($_op.getOperation());
SmallVector<OpOperand*> opOperands =
- bufferizableOp.getAliasingOpOperand(opResult);
+ bufferizableOp.getAliasingOpOperand(opResult, state);
if (opOperands.empty())
return true;
return llvm::any_of(
opOperands,
[&](OpOperand *operand) {
- return bufferizableOp.bufferizesToMemoryWrite(*operand);
+ return bufferizableOp.bufferizesToMemoryWrite(*operand,
+ state);
});
}]
>,
@@ -111,7 +115,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"bool",
/*methodName=*/"mustBufferizeInPlace",
- /*args=*/(ins "OpResult":$opResult),
+ /*args=*/(ins "OpResult":$opResult,
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
@@ -125,7 +130,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"OpResult",
/*methodName=*/"getAliasingOpResult",
- /*args=*/(ins "OpOperand &":$opOperand),
+ /*args=*/(ins "OpOperand &":$opOperand,
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -148,7 +154,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"SmallVector<OpOperand *>",
/*methodName=*/"getAliasingOpOperand",
- /*args=*/(ins "OpResult":$opResult),
+ /*args=*/(ins "OpResult":$opResult,
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opResult.getType().isa<TensorType>() &&
@@ -159,7 +166,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) {
if (!opOperand.get().getType().isa<TensorType>())
continue;
- if (bufferizableOp.getAliasingOpResult(opOperand) == opResult)
+ if (bufferizableOp.getAliasingOpResult(opOperand, state) ==
+ opResult)
result.push_back(&opOperand);
}
return result;
@@ -179,7 +187,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"BufferRelation",
/*methodName=*/"bufferRelation",
/*args=*/(ins "OpResult":$opResult,
- "const BufferizationAliasInfo &":$aliasInfo),
+ "const BufferizationAliasInfo &":$aliasInfo,
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpResults
@@ -282,13 +291,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/// be called on OpOperands that do not have a tensor type.
///
/// Examples of such ops are `tensor.extract_slice` and `tensor.cast`.
- bool bufferizesToAliasOnly(OpOperand &opOperand) {
+ bool bufferizesToAliasOnly(OpOperand &opOperand,
+ BufferizationState &state) {
auto bufferizableOp =
cast<BufferizableOpInterface>(getOperation());
- return !bufferizableOp.bufferizesToMemoryRead(opOperand)
- && !bufferizableOp.bufferizesToMemoryWrite(opOperand)
+ return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
+ && !bufferizableOp.bufferizesToMemoryWrite(opOperand, state)
&& static_cast<bool>(
- bufferizableOp.getAliasingOpResult(opOperand));
+ bufferizableOp.getAliasingOpResult(opOperand, state));
}
// TODO: The following two attributes should belong to the tensor dialect.
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index d6d3d28a1022..e2edc9d15267 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -78,8 +78,7 @@ BufferizationOptions::BufferizationOptions()
// BufferizationAliasInfo
//===----------------------------------------------------------------------===//
-BufferizationAliasInfo::BufferizationAliasInfo(
- Operation *rootOp, const BufferizationOptions &options) {
+BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
rootOp->walk([&](Operation *op) {
for (Value v : op->getResults())
if (v.getType().isa<TensorType>())
@@ -90,26 +89,6 @@ BufferizationAliasInfo::BufferizationAliasInfo(
if (bbArg.getType().isa<TensorType>())
createAliasInfoEntry(bbArg);
});
-
- // Set up alias sets for OpResults that must bufferize in-place. This should
- // be done before making any other bufferization decisions.
- rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
- if (!options.isOpAllowed(bufferizableOp))
- return WalkResult::skip();
- for (OpResult opResult : bufferizableOp->getOpResults()) {
- if (opResult.getType().isa<TensorType>())
- if (bufferizableOp.mustBufferizeInPlace(opResult)) {
- SmallVector<OpOperand *> operands =
- bufferizableOp.getAliasingOpOperand(opResult);
- assert(!operands.empty() &&
- "expected that OpResult has aliasing OpOperand");
- for (OpOperand *operand : operands)
- aliasInfo.unionSets(operand->get(), opResult);
- markInPlace(opResult);
- }
- }
- return WalkResult::advance();
- });
}
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
@@ -219,30 +198,32 @@ BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *>
-mlir::linalg::comprehensive_bufferize::getAliasingOpOperand(OpResult result) {
+mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpOperand(
+ OpResult result) {
if (Operation *op = result.getDefiningOp())
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
- return bufferizableOp.getAliasingOpOperand(result);
+ return bufferizableOp.getAliasingOpOperand(result, *this);
return {};
}
/// Determine which OpResult will alias with `opOperand` if the op is bufferized
/// in place. Return an empty OpResult if the op is not bufferizable.
-OpResult mlir::linalg::comprehensive_bufferize::getAliasingOpResult(
+OpResult
+mlir::linalg::comprehensive_bufferize::BufferizationState::getAliasingOpResult(
OpOperand &opOperand) {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
- return bufferizableOp.getAliasingOpResult(opOperand);
+ return bufferizableOp.getAliasingOpResult(opOperand, *this);
return OpResult();
}
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
/// op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryRead(
- OpOperand &opOperand) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::
+ bufferizesToMemoryRead(OpOperand &opOperand) {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
- return bufferizableOp.bufferizesToMemoryRead(opOperand);
+ return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
// Conservatively return true.
@@ -251,11 +232,11 @@ bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryRead(
/// Return true if `opOperand` bufferizes to a memory write. Return
/// `true` if the op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryWrite(
- OpOperand &opOperand) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::
+ bufferizesToMemoryWrite(OpOperand &opOperand) {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
- return bufferizableOp.bufferizesToMemoryWrite(opOperand);
+ return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
// Conservatively return true.
@@ -264,11 +245,11 @@ bool mlir::linalg::comprehensive_bufferize::bufferizesToMemoryWrite(
/// Return true if `opOperand` does neither read nor write but bufferizes to an
/// alias. Return false if the op is not bufferizable.
-bool mlir::linalg::comprehensive_bufferize::bufferizesToAliasOnly(
- OpOperand &opOperand) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::
+ bufferizesToAliasOnly(OpOperand &opOperand) {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
- return bufferizableOp.bufferizesToAliasOnly(opOperand);
+ return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
// Unknown op that returns a tensor. The inplace analysis does not support it.
// Conservatively return false.
@@ -278,7 +259,8 @@ bool mlir::linalg::comprehensive_bufferize::bufferizesToAliasOnly(
/// Return true if the given value is read by an op that bufferizes to a memory
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
-bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
+bool mlir::linalg::comprehensive_bufferize::BufferizationState::isValueRead(
+ Value value) {
SmallVector<OpOperand *> workingSet;
for (OpOperand &use : value.getUses())
workingSet.push_back(&use);
@@ -301,9 +283,9 @@ bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
// evaluates to true. OpOperands of such matching Values are not traversed any
// further.
llvm::SetVector<Value>
-mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
- Value value, const BufferizationOptions &options,
- std::function<bool(Value)> condition) {
+mlir::linalg::comprehensive_bufferize::BufferizationState::
+ findValueInReverseUseDefChain(Value value,
+ std::function<bool(Value)> condition) {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -329,17 +311,17 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
}
// Find the Value of the last preceding write of a given Value.
-Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
- Value value, const BufferizationOptions &options) {
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::
+ findLastPrecedingWrite(Value value) {
SetVector<Value> result =
- findValueInReverseUseDefChain(value, options, [&](Value value) {
+ findValueInReverseUseDefChain(value, [&](Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return true;
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
return true;
- return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
+ return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
});
// To simplify the analysis, `scf.if` ops are considered memory writes. There
@@ -350,6 +332,30 @@ Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
return result.front();
}
+mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
+ Operation *op, const BufferizationOptions &options)
+ : aliasInfo(op), options(options), builder(op->getContext()) {
+ // Set up alias sets for OpResults that must bufferize in-place. This should
+ // be done before making any other bufferization decisions.
+ op->walk([&](BufferizableOpInterface bufferizableOp) {
+ if (!options.isOpAllowed(bufferizableOp))
+ return WalkResult::skip();
+ for (OpResult opResult : bufferizableOp->getOpResults()) {
+ if (opResult.getType().isa<TensorType>())
+ if (bufferizableOp.mustBufferizeInPlace(opResult, *this)) {
+ SmallVector<OpOperand *> operands =
+ bufferizableOp.getAliasingOpOperand(opResult, *this);
+ assert(!operands.empty() &&
+ "expected that OpResult has aliasing OpOperand");
+ for (OpOperand *operand : operands)
+ aliasInfo.unionAliasSets(operand->get(), opResult);
+ aliasInfo.markInPlace(opResult);
+ }
+ }
+ return WalkResult::advance();
+ });
+}
+
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
@@ -394,9 +400,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
// use-def chain, it returns that value, regardless of whether it is a
// memory write or not.
- Value lastWrite = findLastPrecedingWrite(operand, options);
+ Value lastWrite = findLastPrecedingWrite(operand);
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
- if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>()))
+ if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
skipCopy = true;
// Do not copy if the copied data is never read.
if (!isValueRead(result))
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
index 97345350835a..3419a6aa4492 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -39,12 +39,14 @@ namespace bufferization_ext {
struct ToMemrefOpInterface
: public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
bufferization::ToMemrefOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
// It is unknown whether the resulting MemRef will be read or not.
return true;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index b9dde90e63ee..babbec5493ae 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -162,7 +162,8 @@ static void setInPlaceOpResult(OpResult opResult, bool inPlace) {
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
- const BufferizationAliasInfo &aliasInfo) {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) {
// The analysis does not know what happens to the result of a ToMemrefOp, so
// we assume that it is written to.
// TODO: This is a conservative implementation. This rule will have to be
@@ -170,11 +171,11 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
if (isa<bufferization::ToMemrefOp>(opOperand.getOwner()))
return true;
// OpOperands without an aliasing OpResult do not write.
- OpResult opResult = getAliasingOpResult(opOperand);
+ OpResult opResult = state.getAliasingOpResult(opOperand);
if (!opResult)
return false;
// OpOperands that do not bufferize to a memory write do not write in-place.
- if (!bufferizesToMemoryWrite(opOperand))
+ if (!state.bufferizesToMemoryWrite(opOperand))
return false;
// Check current bufferization decisions.
return aliasInfo.isInPlace(opResult);
@@ -209,11 +210,12 @@ static bool aliasesNonWritableBuffer(Value value,
/// Return true if the buffer to which `operand` would bufferize is equivalent
/// to some buffer write.
static bool aliasesInPlaceWrite(Value value,
- const BufferizationAliasInfo &aliasInfo) {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) {
bool foundInplaceWrite = false;
aliasInfo.applyOnAliases(value, [&](Value v) {
for (auto &use : v.getUses()) {
- if (isInplaceMemoryWrite(use, aliasInfo)) {
+ if (isInplaceMemoryWrite(use, aliasInfo, state)) {
foundInplaceWrite = true;
return;
}
@@ -295,7 +297,7 @@ static bool hasReadAfterWriteInterference(
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
// is %0. Note that operations that create an alias but do not write (such
// as ExtractSliceOp) are skipped.
- Value lastWrite = findLastPrecedingWrite(uRead->get(), options);
+ Value lastWrite = state.findLastPrecedingWrite(uRead->get());
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
@@ -352,7 +354,7 @@ static bool hasReadAfterWriteInterference(
// No conflict if the conflicting write and the last write are the same
// use.
- if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
+ if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
continue;
// All requirements are met. Conflict found!
@@ -402,7 +404,7 @@ bool wouldCreateReadAfterWriteInterference(
bool checkConsistencyOnly = false) {
#ifndef NDEBUG
if (result) {
- SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+ SmallVector<OpOperand *> opOperands = state.getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
"operand and result do not match");
} else {
@@ -416,7 +418,7 @@ bool wouldCreateReadAfterWriteInterference(
aliasInfo.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Read to a value that aliases root.
- if (bufferizesToMemoryRead(use))
+ if (state.bufferizesToMemoryRead(use))
res.insert(&use);
});
};
@@ -426,7 +428,7 @@ bool wouldCreateReadAfterWriteInterference(
aliasInfo.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Inplace write to a value that aliases root.
- if (isInplaceMemoryWrite(use, aliasInfo))
+ if (isInplaceMemoryWrite(use, aliasInfo, state))
res.insert(&use);
});
};
@@ -439,7 +441,7 @@ bool wouldCreateReadAfterWriteInterference(
getAliasingInplaceWrites(usesWrite, operand.get());
if (result)
getAliasingInplaceWrites(usesWrite, result);
- if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
+ if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
@@ -453,7 +455,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
BufferizationState &state) {
#ifndef NDEBUG
- SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
+ SmallVector<OpOperand *> opOperands = state.getAliasingOpOperand(opResult);
assert(llvm::find(opOperands, &opOperand) != opOperands.end() &&
"operand and result do not match");
#endif // NDEBUG
@@ -467,9 +469,9 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
return false;
// This is a problem only if the buffer is written to via some alias.
- bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo) ||
- aliasesInPlaceWrite(opOperand.get(), aliasInfo) ||
- bufferizesToMemoryWrite(opOperand);
+ bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo, state) ||
+ aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
+ state.bufferizesToMemoryWrite(opOperand);
if (!hasWrite)
return false;
@@ -485,7 +487,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo,
BufferizationState &state, const DominanceInfo &domInfo) {
#ifndef NDEBUG
- SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+ SmallVector<OpOperand *> opOperands = state.getAliasingOpOperand(result);
assert(llvm::find(opOperands, &operand) != opOperands.end() &&
"operand and result do not match");
#endif // NDEBUG
@@ -539,7 +541,8 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>())
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
- if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
+ if (OpResult opResult =
+ bufferizableOp.getAliasingOpResult(opOperand, state))
if (failed(bufferizableInPlaceAnalysisImpl(
opOperand, opResult, aliasInfo, state, domInfo)))
return failure();
@@ -569,16 +572,16 @@ static LogicalResult inPlaceAnalysis(Operation *op,
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
- const BufferizationOptions &options) {
+ BufferizationState &state) {
for (Operation *op : ops)
- if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+ if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
for (OpResult opResult : op->getOpResults())
if (opResult.getType().isa<TensorType>())
if (aliasInfo.isInPlace(opResult)) {
SmallVector<OpOperand *> opOperands =
- bufferizableOp.getAliasingOpOperand(opResult);
+ bufferizableOp.getAliasingOpOperand(opResult, state);
if (!opOperands.empty())
- if (bufferizableOp.bufferRelation(opResult, aliasInfo) ==
+ if (bufferizableOp.bufferRelation(opResult, aliasInfo, state) ==
BufferRelation::Equivalent)
for (OpOperand *opOperand : opOperands)
aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
@@ -589,7 +592,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
/// in `op`.
static void equivalenceAnalysis(Operation *op,
BufferizationAliasInfo &aliasInfo,
- const BufferizationOptions &options) {
+ BufferizationState &state) {
// Traverse ops in PostOrder: Nested ops first, then enclosing ops.
SmallVector<Operation *> ops;
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
@@ -599,7 +602,7 @@ static void equivalenceAnalysis(Operation *op,
ops.push_back(op);
});
- equivalenceAnalysis(ops, aliasInfo, options);
+ equivalenceAnalysis(ops, aliasInfo, state);
}
/// Assert that the current bufferization decisions are consistent.
@@ -613,7 +616,8 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
if (auto bufferizableOp = options.dynCastBufferizableOp(op))
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>()) {
- OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
+ OpResult opResult =
+ bufferizableOp.getAliasingOpResult(opOperand, state);
if (wouldCreateReadAfterWriteInterference(
opOperand, opResult, domInfo, state, aliasInfo,
/*checkConsistencyOnly=*/true)) {
@@ -669,7 +673,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
options.analysisFuzzerSeed)))
return failure();
- equivalenceAnalysis(op, aliasInfo, options);
+ equivalenceAnalysis(op, aliasInfo, state);
auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
@@ -679,7 +683,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
// Analyze ops that were created by the PostAnalysisStep.
if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
return failure();
- equivalenceAnalysis(newOps, aliasInfo, options);
+ equivalenceAnalysis(newOps, aliasInfo, state);
}
return success();
};
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 9984ae1ad122..158ad6a76343 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -140,18 +140,22 @@ template <typename OpTy>
struct LinalgOpInterface
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
OpTy> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
return genericOp.payloadUsesValueFromOperand(&opOperand);
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
- return static_cast<bool>(bufferizableOp.getAliasingOpResult(opOperand));
+ return static_cast<bool>(
+ bufferizableOp.getAliasingOpResult(opOperand, state));
}
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands())
@@ -160,14 +164,16 @@ struct LinalgOpInterface
return {};
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
return pairs[&opOperand];
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -180,7 +186,8 @@ struct LinalgOpInterface
struct InitTensorOpInterface
: public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
linalg::InitTensorOp> {
- bool isMemoryWrite(Operation *op, OpResult opResult) const {
+ bool isMemoryWrite(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
// InitTensorOps allocate but do not write.
return false;
}
@@ -203,27 +210,32 @@ struct InitTensorOpInterface
struct TiledLoopOpInterface
: public BufferizableOpInterface::ExternalModel<TiledLoopOpInterface,
linalg::TiledLoopOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
// TiledLoop alone doesn't bufferize to a memory read, one of the uses of
// its matching bbArg may.
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
- return isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
+ return state.isValueRead(tiledLoopOp.getTiedBlockArgument(opOperand));
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
// TiledLoop alone doesn't bufferize to a memory write, one of the uses of
// its matching bbArg may.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
- return static_cast<bool>(bufferizableOp.getAliasingOpResult(opOperand));
+ return static_cast<bool>(
+ bufferizableOp.getAliasingOpResult(opOperand, state));
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
return tiledLoopOp.getTiedOpResult(opOperand);
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -331,15 +343,18 @@ struct TiledLoopOpInterface
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
linalg::YieldOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
@@ -391,7 +406,6 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
SmallVector<Operation *> &newOps) {
OpBuilder b(op->getContext());
- const BufferizationOptions &options = state.getOptions();
WalkResult status = op->walk([&](Operation *op) {
for (OpOperand &operand : op->getOpOperands()) {
@@ -400,7 +414,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
continue;
SetVector<Value> maybeInitTensor =
- findValueInReverseUseDefChain(operand.get(), options, [&](Value val) {
+ state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
// Continue traversal until this function returns true.
OpResult opResult = val.dyn_cast<OpResult>();
if (!opResult)
@@ -410,7 +424,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
// Only equivalent tensors are supported at the moment.
// TODO: Support cases such as extract_slice(init_tensor).
SmallVector<OpOperand *> opOperands =
- getAliasingOpOperand(opResult);
+ state.getAliasingOpOperand(opResult);
if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
return aliasInfo.areEquivalentBufferizedValues(operand->get(),
opResult);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 0e391a9a4f04..49687ccacd3d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -490,7 +490,8 @@ namespace std_ext {
struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
// CallOpInterface alone doesn't bufferize to a memory read, one of the uses
// of the matching bbArg may. It is the responsibility of the caller to
// inspect bbArgs. In the absence of a BufferizationAliasInfo, we need to be
@@ -498,7 +499,8 @@ struct CallOpInterface
return true;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
// CallOpInterface is special, it needs to wait for the callee to be
// bufferized and needs to inspect the BufferAliasInfo object. It can't
// make a proper determination by itself and needs to be conservative.
@@ -618,15 +620,18 @@ struct CallOpInterface
struct ReturnOpInterface
: public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
ReturnOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index edded005a1ee..f3e4aa4d9c98 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -22,8 +22,9 @@ namespace scf_ext {
struct ExecuteRegionOpInterface
: public BufferizableOpInterface::ExternalModel<ExecuteRegionOpInterface,
scf::ExecuteRegionOp> {
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
// any SSA value that is in scope. To allow for use-def chain traversal
// through ExecuteRegionOps in the analysis, the corresponding yield value
@@ -39,7 +40,8 @@ struct ExecuteRegionOpInterface
return {&yieldOp->getOpOperand(resultNum)};
}
- bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+ bool mustBufferizeInPlace(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
// ExecuteRegionOp results always bufferize in-place. Since they have no
// OpOperands, they are mostly ignored by the analysis once alias sets are
// set up.
@@ -48,7 +50,8 @@ struct ExecuteRegionOpInterface
// TODO: For better bufferization results, this could return `true` only if
// there is a memory write in the region.
- bool isMemoryWrite(Operation *op, OpResult opResult) const {
+ bool isMemoryWrite(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
// Similar to scf.if, results of this op are always considered memory writes
// in the analysis. This is a useful pattern for all ops that have tensor
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
@@ -71,15 +74,17 @@ struct ExecuteRegionOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::Equivalent;
}
};
struct IfOpInterface
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
// value that is in scope. To allow for use-def chain traversal through
// IfOps in the analysis, both corresponding yield values from the then/else
@@ -95,7 +100,8 @@ struct IfOpInterface
// there is a memory write in one (or both) of the branches. Since this is not
// allowed at the moment, we should never encounter scf.ifs that yield
// unmodified tensors. Such scf.yield ops could just fold away.
- bool isMemoryWrite(Operation *op, OpResult opResult) const {
+ bool isMemoryWrite(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
// IfOp results are always considered memory writes in the analysis. This
// design decision simplifies the analysis considerably. E.g., consider the
// following test case:
@@ -121,7 +127,8 @@ struct IfOpInterface
return true;
}
- bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+ bool mustBufferizeInPlace(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
// IfOp results always bufferize in-place. Since they have no OpOperands,
// they are mostly ignored by the analysis once alias sets are set up.
return true;
@@ -203,12 +210,13 @@ struct IfOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
// IfOp results are equivalent to their corresponding yield values if both
// yield values are equivalent to each other.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
SmallVector<OpOperand *> yieldValues =
- bufferizableOp.getAliasingOpOperand(opResult);
+ bufferizableOp.getAliasingOpOperand(opResult, state);
assert(yieldValues.size() == 2 && "expected 2 yield values");
bool equivalentYields = aliasInfo.areEquivalentBufferizedValues(
yieldValues[0]->get(), yieldValues[1]->get());
@@ -219,21 +227,24 @@ struct IfOpInterface
struct ForOpInterface
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
scf::ForOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
// its matching bbArg may.
auto forOp = cast<scf::ForOp>(op);
- return isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
+ return state.isValueRead(forOp.getRegionIterArgForOpOperand(opOperand));
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
// Tensor iter_args of scf::ForOps are always considered as a write. This is
// to simplify the analysis.
// TODO: Consider doing sth. like isValueWritten.
return true;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
auto forOp = cast<scf::ForOp>(op);
if (!opOperand.get().getType().isa<RankedTensorType>())
return OpResult();
@@ -241,7 +252,8 @@ struct ForOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
auto forOp = cast<scf::ForOp>(op);
@@ -410,15 +422,18 @@ LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 7558d792facf..30ca9ed0a78b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -40,20 +40,24 @@ getTensorBufferizationState(BufferizationState &state) {
struct CastOpInterface
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
tensor::CastOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return op->getResult(0);
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -86,15 +90,18 @@ struct CastOpInterface
struct DimOpInterface
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
tensor::DimOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
@@ -112,22 +119,26 @@ struct DimOpInterface
struct ExtractSliceOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
tensor::ExtractSliceOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return &opOperand == &op->getOpOperand(0) /*source*/
? op->getResult(0)
: OpResult();
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::None;
}
@@ -160,7 +171,7 @@ struct ExtractSliceOpInterface
/// If not inplaceable, copy.
if (!inplace) {
// Do not copy if the copied data is never read.
- if (isValueRead(extractSliceOp.result()))
+ if (state.isValueRead(extractSliceOp.result()))
state.createMemCpy(b, extractSliceOp.getLoc(), subView, alloc);
subView = alloc;
}
@@ -173,15 +184,18 @@ struct ExtractSliceOpInterface
struct ExtractOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
tensor::ExtractOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
@@ -198,22 +212,26 @@ struct ExtractOpInterface
struct InsertOpInterface
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
tensor::InsertOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
"expected dest OpOperand");
return op->getOpResult(0);
}
- SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
- OpResult opResult) const {
+ SmallVector<OpOperand *>
+ getAliasingOpOperand(Operation *op, OpResult opResult,
+ BufferizationState &state) const {
return {&op->getOpOperand(1) /*dest*/};
}
@@ -229,7 +247,8 @@ struct InsertOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::Equivalent;
}
};
@@ -272,8 +291,8 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
- const BufferizationOptions &options,
- Value value, InsertSliceOp insertOp) {
+ BufferizationState &state, Value value,
+ InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
@@ -281,29 +300,33 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
return false;
};
- return llvm::all_of(findValueInReverseUseDefChain(value, options, condition),
+ return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
condition);
}
struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/
? op->getResult(0)
: OpResult();
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::Equivalent;
}
@@ -325,8 +348,8 @@ struct InsertSliceOpInterface
// TODO: Use insertSliceOp.getDestOpOperand etc. when available.
if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
- uConflictingWrite->get(), insertSliceOp))
+ hasMatchingExtractSliceOp(aliasInfo, state, uConflictingWrite->get(),
+ insertSliceOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@@ -343,7 +366,7 @@ struct InsertSliceOpInterface
if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
- hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(),
+ hasMatchingExtractSliceOp(aliasInfo, state, uRead->get(),
insertSliceOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
@@ -377,8 +400,8 @@ struct InsertSliceOpInterface
if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
aliasInfo.areEquivalentBufferizedValues(uRead->get(),
insertSliceOp.source()) &&
- hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
- insertSliceOp.source(), insertSliceOp))
+ hasMatchingExtractSliceOp(aliasInfo, state, insertSliceOp.source(),
+ insertSliceOp))
return true;
return false;
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 3ccfb5065ed2..50ceb5aa77c9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -20,19 +20,22 @@ namespace vector_ext {
struct TransferReadOpInterface
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
vector::TransferReadOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return false;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
return OpResult();
}
@@ -56,26 +59,30 @@ struct TransferReadOpInterface
struct TransferWriteOpInterface
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const {
+ bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
- OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ BufferizationState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return op->getOpResult(0);
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationAliasInfo &aliasInfo) const {
+ const BufferizationAliasInfo &aliasInfo,
+ BufferizationState &state) const {
return BufferRelation::Equivalent;
}
More information about the Mlir-commits
mailing list