[Mlir-commits] [mlir] 9597b16 - [mlir][bufferize][NFC] Split BufferizationState into AnalysisState/BufferizationState
Matthias Springer
llvmlistbot at llvm.org
Tue Mar 15 01:39:55 PDT 2022
Author: Matthias Springer
Date: 2022-03-15T17:35:47+09:00
New Revision: 9597b16aa91b5efba7457c7c7885fbb82647eb24
URL: https://github.com/llvm/llvm-project/commit/9597b16aa91b5efba7457c7c7885fbb82647eb24
DIFF: https://github.com/llvm/llvm-project/commit/9597b16aa91b5efba7457c7c7885fbb82647eb24.diff
LOG: [mlir][bufferize][NFC] Split BufferizationState into AnalysisState/BufferizationState
Differential Revision: https://reviews.llvm.org/D121361
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index fd556b2cebf07..c7cfc7241a509 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -26,11 +26,11 @@ class DominanceInfo;
namespace bufferization {
+class AnalysisState;
class BufferizableOpInterface;
-class BufferizationState;
-struct DialectBufferizationState;
+struct DialectAnalysisState;
-/// Options for ComprehensiveBufferize.
+/// Options for BufferizableOpInterface-based bufferization.
struct BufferizationOptions {
/// Allocator function: Generate a memref allocation with the given type,
/// dynamic extents and alignment.
@@ -43,11 +43,11 @@ struct BufferizationOptions {
/// Memcpy function: Generate a memcpy between two buffers.
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
- /// Initializer function for bufferization state.
- using BufferizationStateInitFn = std::function<void(BufferizationState &)>;
- /// Initializer function for dialect-specific bufferization state.
+ /// Initializer function for analysis state.
+ using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
+ /// Initializer function for dialect-specific analysis state.
using DialectStateInitFn =
- std::function<std::unique_ptr<DialectBufferizationState>()>;
+ std::function<std::unique_ptr<DialectAnalysisState>()>;
/// An op filter entry. Filters can be used to specify which ops should be
/// processed by the bufferization.
@@ -232,12 +232,12 @@ struct BufferizationOptions {
/// DENY-filtered and have at least one matching ALLOW filter are processed.
SmallVector<OpFilterEntry> opFilter;
- /// Initializer functions for bufferization state. These can be used to
- /// initialize dialect-specific bufferization state.
- SmallVector<BufferizationStateInitFn> stateInitializers;
+ /// Initializer functions for analysis state. These can be used to
+ /// initialize dialect-specific analysis state.
+ SmallVector<AnalysisStateInitFn> stateInitializers;
- /// Add a bufferization state initializer that initializes the specified
- /// dialect-specific bufferization state.
+ /// Add a analysis state initializer that initializes the specified
+ /// dialect-specific analysis state.
void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
private:
@@ -265,21 +265,21 @@ enum class BufferRelation {
/// Return `true` if the given value is a BlockArgument of a FuncOp.
bool isFunctionArgument(Value value);
-/// Dialect-specific bufferization state. Analysis/bufferization information
+/// Dialect-specific analysis state. Analysis/bufferization information
/// that is specific to ops from a certain dialect can be stored in derived
/// variants of this struct.
-struct DialectBufferizationState {
- DialectBufferizationState() = default;
+struct DialectAnalysisState {
+ DialectAnalysisState() = default;
- virtual ~DialectBufferizationState() = default;
+ virtual ~DialectAnalysisState() = default;
// Copying state is forbidden. Always pass as reference.
- DialectBufferizationState(const DialectBufferizationState &) = delete;
+ DialectAnalysisState(const DialectAnalysisState &) = delete;
};
-/// BufferizationState provides a variety of helper functions for dealing with
-/// tensor values and memref buffers.
-class BufferizationState {
+/// AnalysisState provides a variety of helper functions for dealing with
+/// tensor values.
+class AnalysisState {
public:
/// Determine which OpOperand* will alias with `result` if the op is
/// bufferized in place. Return an empty vector if the op is not bufferizable.
@@ -348,15 +348,7 @@ class BufferizationState {
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
- /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
- /// a new buffer and copy over data from the existing buffer if out-of-place
- /// bufferization was decided.
- FailureOr<Value>
- getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
- bool forceInPlace = false,
- Optional<Operation *> customCopyInsertionPoint = None) const;
-
- /// Return dialect-specific bufferization state.
+ /// Return dialect-specific analysis state.
template <typename StateT>
Optional<const StateT *> getDialectState(StringRef name) const {
auto it = dialectState.find(name);
@@ -365,7 +357,7 @@ class BufferizationState {
return static_cast<const StateT *>(it->getSecond().get());
}
- /// Return dialect-specific bufferization state or create one if none exists.
+ /// Return dialect-specific analysis state or create one if none exists.
template <typename StateT>
StateT &getOrCreateDialectState(StringRef name) {
// Create state if it does not exist yet.
@@ -375,7 +367,7 @@ class BufferizationState {
}
void insertDialectState(StringRef name,
- std::unique_ptr<DialectBufferizationState> state) {
+ std::unique_ptr<DialectAnalysisState> state) {
assert(!dialectState.count(name) && "dialect state already initialized");
dialectState[name] = std::move(state);
}
@@ -384,31 +376,31 @@ class BufferizationState {
const BufferizationOptions &getOptions() const { return options; }
protected:
- explicit BufferizationState(const BufferizationOptions &options);
+ explicit AnalysisState(const BufferizationOptions &options);
- // BufferizationState should be passed as a reference.
- BufferizationState(const BufferizationState &) = delete;
+ // AnalysisState should be passed as a reference.
+ AnalysisState(const AnalysisState &) = delete;
- ~BufferizationState() = default;
+ ~AnalysisState() = default;
private:
- /// Dialect-specific bufferization state.
- DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
+ /// Dialect-specific analysis state.
+ DenseMap<StringRef, std::unique_ptr<DialectAnalysisState>> dialectState;
/// A reference to current bufferization options.
const BufferizationOptions &options;
};
-/// This a "no analysis, always copy" BufferizationState. In the absence of an
+/// This a "no analysis, always copy" AnalysisState. In the absence of an
/// analysis, a buffer must be copied each time it is written to. Therefore, all
/// OpOperands that bufferize to a memory write must bufferize out-of-place.
-class AlwaysCopyBufferizationState : public BufferizationState {
+class AlwaysCopyAnalysisState : public AnalysisState {
public:
- explicit AlwaysCopyBufferizationState(const BufferizationOptions &options);
+ explicit AlwaysCopyAnalysisState(const BufferizationOptions &options);
- AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
+ AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete;
- virtual ~AlwaysCopyBufferizationState() = default;
+ virtual ~AlwaysCopyAnalysisState() = default;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const override;
@@ -417,6 +409,35 @@ class AlwaysCopyBufferizationState : public BufferizationState {
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
};
+/// BufferizationState provides helper functions for performing bufferization
+/// rewrites and handling memref buffers.
+struct BufferizationState {
+ BufferizationState(const AnalysisState &analysisState)
+ : analysisState(analysisState) {}
+
+ /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
+ /// a new buffer and copy over data from the existing buffer if out-of-place
+ /// bufferization was decided.
+ FailureOr<Value>
+ getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
+ bool forceInPlace = false,
+ Optional<Operation *> customCopyInsertionPoint = None) const;
+
+ /// Return a reference to the BufferizationOptions.
+ const BufferizationOptions &getOptions() const {
+ return analysisState.getOptions();
+ }
+
+ const AnalysisState &getAnalysisState() const { return analysisState; }
+
+protected:
+ // BufferizationState should be passed as a reference.
+ BufferizationState(const BufferizationState &) = delete;
+
+private:
+ const AnalysisState &analysisState;
+};
+
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
@@ -503,39 +524,38 @@ struct AllocationHoistingBarrierOnly
: public BufferizableOpInterface::ExternalModel<
AllocationHoistingBarrierOnly<OpTy>, OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return {};
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::None;
}
bool isWritable(Operation *op, Value value,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
return failure();
}
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index ac26d327d2e31..37bab9531c316 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -33,7 +33,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"bufferizesToMemoryRead",
/*args=*/(ins "OpOperand &":$opOperand,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -62,7 +62,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"bufferizesToMemoryWrite",
/*args=*/(ins "OpOperand &":$opOperand,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -85,7 +85,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"isMemoryWrite",
/*args=*/(ins "OpResult":$opResult,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
@@ -112,7 +112,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"mustBufferizeInPlace",
/*args=*/(ins "OpOperand &":$opOperand,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
@@ -127,7 +127,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"SmallVector<OpResult>",
/*methodName=*/"getAliasingOpResult",
/*args=*/(ins "OpOperand &":$opOperand,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
@@ -151,7 +151,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"SmallVector<OpOperand *>",
/*methodName=*/"getAliasingOpOperand",
/*args=*/(ins "OpResult":$opResult,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
assert(opResult.getType().isa<TensorType>() &&
@@ -185,7 +185,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"BufferRelation",
/*methodName=*/"bufferRelation",
/*args=*/(ins "OpResult":$opResult,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpResults
@@ -220,7 +220,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "RewriterBase &":$rewriter,
- "const BufferizationState &":$state),
+ "BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
@@ -246,7 +246,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"bool",
/*methodName=*/"isWritable",
/*args=*/(ins "Value":$value,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return value.isa<OpResult>();
@@ -285,7 +285,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"isNotConflicting",
/*args=*/(ins "OpOperand *":$uRead,
"OpOperand *":$uWrite,
- "const BufferizationState &":$state),
+ "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return false;
@@ -302,7 +302,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
}],
/*retType=*/"LogicalResult",
/*methodName=*/"verifyAnalysis",
- /*args=*/(ins "const BufferizationState &":$state),
+ /*args=*/(ins "const AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return success();
@@ -318,7 +318,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
///
/// Examples of such ops are `tensor.extract_slice` and `tensor.cast`.
bool bufferizesToAliasOnly(OpOperand &opOperand,
- const BufferizationState &state) {
+ const AnalysisState &state) {
auto bufferizableOp =
cast<BufferizableOpInterface>(getOperation());
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 559f9b4380813..5f1a132c7d48c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -125,7 +125,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// results as not writable enforces a buffer copy and has the same effect.
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
// to_tensor cannot be bufferized. However, other ops that are using
// to_tensor's result will eventually be bufferized. At that point, they
// will start using to_tensor's memref operand. Once all users of
@@ -136,7 +136,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
return failure();
}
- bool isWritable(Value value, const BufferizationState &state) const {
+ bool isWritable(Value value, const AnalysisState &state) const {
// It is unknown whether the memref operand is writable or not.
return false;
}
@@ -194,30 +194,30 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
// but such IR may no longer be analyzable by One-Shot analysis.
bool bufferizesToMemoryRead(OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// It is unknown whether the resulting memref will be read or not.
return true;
}
bool bufferizesToMemoryWrite(OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// It is unknown whether the resulting MemRef will be written or not.
return true;
}
bool mustBufferizeInPlace(OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// ToMemrefOps always bufferize inplace.
return true;
}
SmallVector<OpResult> getAliasingOpResult(
- OpOperand &opOperand, const BufferizationState &state) const {
+ OpOperand &opOperand, const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationState &state);
+ BufferizationState &state);
}];
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index e2ede4b63d2f3..13b48d09d33ff 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -28,7 +28,8 @@
namespace mlir {
namespace bufferization {
-class BufferizationState;
+class AnalysisState;
+struct BufferizationState;
struct BufferizationOptions;
/// A helper type converter class that automatically populates the relevant
@@ -67,7 +68,14 @@ void populateEliminateBufferizeMaterializationsPatterns(
/// layouts after transformations. Combinations of memref.cast +
/// canonicalization are responsible for clean ups.
// TODO: Extract `options` from `state` and pass as separate argument.
-LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
+LogicalResult bufferizeOp(Operation *op, const AnalysisState &analysisState);
+
+/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
+/// Reuse an existing `BufferizationState`.
+///
+/// Note: This function overload is useful for extending the bufferization.
+LogicalResult bufferizeOp(Operation *op,
+ BufferizationState &bufferizationState);
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
/// Buffers are duplicated and copied before any tensor use that bufferizes to
@@ -77,11 +85,6 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
/// can be used to implement partial bufferization passes.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
-/// Populate the pattern set with a pattern that bufferizes ops that implement
-/// `BufferizableOpInterface`.
-void populateBufferizationPattern(const BufferizationState &state,
- RewritePatternSet &patterns);
-
BufferizationOptions getPartialBufferizationOptions();
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 8641bc1702712..de555988dd549 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -16,9 +16,9 @@
namespace mlir {
namespace bufferization {
-class AnalysisBufferizationState;
+struct OneShotBufferizationOptions;
class BufferizationAliasInfo;
-struct AnalysisBufferizationOptions;
+class OneShotAnalysisState;
/// PostAnalysisStepFns can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used to
@@ -26,14 +26,14 @@ struct AnalysisBufferizationOptions;
/// must keep `aliasInfo` consistent. Newly created operations and operations
/// that should be re-analyzed must be added to `newOps`.
using PostAnalysisStepFn = std::function<LogicalResult(
- Operation *, BufferizationState &, BufferizationAliasInfo &,
+ Operation *, AnalysisState &, BufferizationAliasInfo &,
SmallVector<Operation *> &)>;
using PostAnalysisStepList = SmallVector<PostAnalysisStepFn>;
/// Options for analysis-enabled bufferization.
-struct AnalysisBufferizationOptions : public BufferizationOptions {
- AnalysisBufferizationOptions() = default;
+struct OneShotBufferizationOptions : public BufferizationOptions {
+ OneShotBufferizationOptions() = default;
/// Register a "post analysis" step. Such steps are executed after the
/// analysis, but before bufferization.
@@ -68,7 +68,7 @@ class BufferizationAliasInfo {
/// Set the inPlace bufferization spec to true.
/// Merge result's and operand's aliasing sets and iterate to a fixed point.
- void bufferizeInPlace(OpOperand &operand, BufferizationState &state);
+ void bufferizeInPlace(OpOperand &operand, AnalysisState &state);
/// Set the inPlace bufferization spec to false.
void bufferizeOutOfPlace(OpOperand &operand);
@@ -135,14 +135,14 @@ class BufferizationAliasInfo {
/// State for analysis-enabled bufferization. This class keeps track of alias
/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
/// in-place.
-class AnalysisBufferizationState : public BufferizationState {
+class OneShotAnalysisState : public AnalysisState {
public:
- AnalysisBufferizationState(Operation *op,
- const AnalysisBufferizationOptions &options);
+ OneShotAnalysisState(Operation *op,
+ const OneShotBufferizationOptions &options);
- AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
+ OneShotAnalysisState(const OneShotAnalysisState &) = delete;
- virtual ~AnalysisBufferizationState() = default;
+ virtual ~OneShotAnalysisState() = default;
/// Return a reference to the BufferizationAliasInfo.
BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
@@ -161,11 +161,11 @@ class AnalysisBufferizationState : public BufferizationState {
/// Analyze `op` and its nested ops. Bufferization decisions are stored in
/// `state`.
-LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
+LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state);
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
LogicalResult runOneShotBufferize(Operation *op,
- const AnalysisBufferizationOptions &options);
+ const OneShotBufferizationOptions &options);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 1aa40da14bcdb..bff0b197544af 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -5,7 +5,7 @@
namespace mlir {
namespace bufferization {
-struct AnalysisBufferizationOptions;
+struct OneShotBufferizationOptions;
//===----------------------------------------------------------------------===//
// Passes
@@ -37,7 +37,7 @@ std::unique_ptr<Pass> createOneShotBufferizePass();
/// Create a pass that bufferizes all ops that implement BufferizableOpInterface
/// with One-Shot Bufferize and the specified bufferization options.
std::unique_ptr<Pass>
-createOneShotBufferizePass(const AnalysisBufferizationOptions &options);
+createOneShotBufferizePass(const OneShotBufferizationOptions &options);
/// Creates a pass that promotes heap-based allocations to stack-based ones.
/// Only buffers smaller than the provided size are promoted.
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index d6e50072317b4..9248b9bbb247e 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -18,7 +18,7 @@ struct LogicalResult;
class ModuleOp;
namespace bufferization {
-struct AnalysisBufferizationOptions;
+struct OneShotBufferizationOptions;
} // namespace bufferization
namespace linalg {
@@ -29,7 +29,7 @@ namespace comprehensive_bufferize {
/// analyzes and bufferizes FuncOps one-by-one with One-Shot Bufferize.
LogicalResult
runModuleBufferize(ModuleOp moduleOp,
- bufferization::AnalysisBufferizationOptions options);
+ bufferization::OneShotBufferizationOptions options);
namespace std_ext {
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 3f8719b0782b5..ac9b3b2ace240 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -19,7 +19,7 @@
namespace mlir {
namespace bufferization {
-struct AnalysisBufferizationOptions;
+struct OneShotBufferizationOptions;
} // namespace bufferization
std::unique_ptr<Pass> createConvertElementwiseToLinalgPass();
@@ -64,7 +64,7 @@ std::unique_ptr<OperationPass<FuncOp>> createConvertLinalgToAffineLoopsPass();
/// with the 'inplaceable' attribute.
std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass(
- const bufferization::AnalysisBufferizationOptions &options);
+ const bufferization::OneShotBufferizationOptions &options);
/// Create a pass to convert Linalg operations which work on tensors to use
/// buffers instead.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
index 13d6f189721be..64c3232e86372 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
@@ -36,7 +36,7 @@ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
/// This analysis can be skipped with `skipAnalysis`.
LogicalResult
-eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
+eliminateInitTensors(Operation *op, bufferization::AnalysisState &state,
bufferization::BufferizationAliasInfo &aliasInfo,
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
SmallVector<Operation *> &newOps);
@@ -45,7 +45,7 @@ eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
/// (and some other conditions are met).
LogicalResult insertSliceAnchoredInitTensorEliminationStep(
- Operation *op, bufferization::BufferizationState &state,
+ Operation *op, bufferization::AnalysisState &state,
bufferization::BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps);
diff --git a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
index 08c6ca2ee0d29..d0f758e6b4e9e 100644
--- a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
@@ -9,16 +9,9 @@
#ifndef MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
#define MLIR_DIALECT_SCF_BUFFERIZABLEOPINTERFACEIMPL_H
-#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
-
namespace mlir {
class DialectRegistry;
-namespace bufferization {
-class BufferizationState;
-class BufferizationAliasInfo;
-} // namespace bufferization
-
namespace scf {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
} // namespace scf
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 3a01b397abef7..2b08300bb7127 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -23,7 +23,7 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
// Only ranked tensors are supported.
@@ -49,7 +49,7 @@ struct ConstantOpInterface
}
bool isWritable(Operation *op, Value value,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Memory locations returned by memref::GetGlobalOp may not be written to.
assert(value.isa<OpResult>());
return false;
@@ -60,28 +60,27 @@ struct IndexCastOpInterface
: public BufferizableOpInterface::ExternalModel<IndexCastOpInterface,
arith::IndexCastOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {op->getResult(0)};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
@@ -106,30 +105,29 @@ struct SelectOpInterface
: public BufferizableOpInterface::ExternalModel<SelectOpInterface,
arith::SelectOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {op->getOpResult(0) /*result*/};
}
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return {&op->getOpOperand(1) /*true_value*/,
&op->getOpOperand(2) /*false_value*/};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto selectOp = cast<arith::SelectOp>(op);
// `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
@@ -147,7 +145,7 @@ struct SelectOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::None;
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 7fd538cf4d7fa..d2b8f1de5628f 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -67,7 +67,7 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::addDialectStateInitializer(
StringRef name, const DialectStateInitFn &fn) {
stateInitializers.push_back(
- [=](BufferizationState &state) { state.insertDialectState(name, fn()); });
+ [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
}
//===----------------------------------------------------------------------===//
@@ -85,7 +85,7 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *>
-BufferizationState::getAliasingOpOperand(OpResult result) const {
+AnalysisState::getAliasingOpOperand(OpResult result) const {
if (Operation *op = result.getDefiningOp())
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
return bufferizableOp.getAliasingOpOperand(result, *this);
@@ -95,7 +95,7 @@ BufferizationState::getAliasingOpOperand(OpResult result) const {
/// Determine which OpResult will alias with `opOperand` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpResult>
-BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
+AnalysisState::getAliasingOpResult(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.getAliasingOpResult(opOperand, *this);
@@ -104,7 +104,7 @@ BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
/// op is not bufferizable.
-bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const {
+bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryRead(opOperand, *this);
@@ -116,7 +116,7 @@ bool BufferizationState::bufferizesToMemoryRead(OpOperand &opOperand) const {
/// Return true if `opOperand` bufferizes to a memory write. Return
/// `true` if the op is not bufferizable.
-bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
+bool AnalysisState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToMemoryWrite(opOperand, *this);
@@ -128,7 +128,7 @@ bool BufferizationState::bufferizesToMemoryWrite(OpOperand &opOperand) const {
/// Return true if `opOperand` does neither read nor write but bufferizes to an
/// alias. Return false if the op is not bufferizable.
-bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const {
+bool AnalysisState::bufferizesToAliasOnly(OpOperand &opOperand) const {
if (auto bufferizableOp =
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
return bufferizableOp.bufferizesToAliasOnly(opOperand, *this);
@@ -141,7 +141,7 @@ bool BufferizationState::bufferizesToAliasOnly(OpOperand &opOperand) const {
/// Return true if the given value is read by an op that bufferizes to a memory
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
-bool BufferizationState::isValueRead(Value value) const {
+bool AnalysisState::isValueRead(Value value) const {
assert(value.getType().isa<TensorType>() && "expected TensorType");
SmallVector<OpOperand *> workingSet;
for (OpOperand &use : value.getUses())
@@ -165,7 +165,7 @@ bool BufferizationState::isValueRead(Value value) const {
// the aliasing OpOperands. Find and return Values for which `condition`
// evaluates to true. OpOperands of such matching Values are not traversed any
// further.
-llvm::SetVector<Value> BufferizationState::findValueInReverseUseDefChain(
+llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition) const {
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
@@ -193,7 +193,7 @@ llvm::SetVector<Value> BufferizationState::findValueInReverseUseDefChain(
// Find the Values of the last preceding write of a given Value.
llvm::SetVector<Value>
-BufferizationState::findLastPrecedingWrite(Value value) const {
+AnalysisState::findLastPrecedingWrite(Value value) const {
return findValueInReverseUseDefChain(value, [&](Value value) {
Operation *op = value.getDefiningOp();
if (!op)
@@ -205,9 +205,9 @@ BufferizationState::findLastPrecedingWrite(Value value) const {
});
}
-BufferizationState::BufferizationState(const BufferizationOptions &options)
+AnalysisState::AnalysisState(const BufferizationOptions &options)
: options(options) {
- for (const BufferizationOptions::BufferizationStateInitFn &fn :
+ for (const BufferizationOptions::AnalysisStateInitFn &fn :
options.stateInitializers)
fn(*this);
}
@@ -246,13 +246,14 @@ Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
FailureOr<Value> BufferizationState::getBuffer(
RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
Optional<Operation *> customCopyInsertionPoint) const {
+ const BufferizationOptions &options = analysisState.getOptions();
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = opOperand.getOwner();
Location loc = op->getLoc();
Value operand = opOperand.get();
Value operandBuffer = lookupBuffer(rewriter, operand, options);
- if (forceInPlace || isInPlace(opOperand))
+ if (forceInPlace || analysisState.isInPlace(opOperand))
return operandBuffer;
// Bufferizing out-of-place: Allocate a new buffer.
@@ -269,22 +270,26 @@ FailureOr<Value> BufferizationState::getBuffer(
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
// use-def chain, it returns that value, regardless of whether it is a
// memory write or not.
- SetVector<Value> lastWrites = findLastPrecedingWrite(operand);
+ SetVector<Value> lastWrites = analysisState.findLastPrecedingWrite(operand);
if (llvm::none_of(lastWrites, [&](Value lastWrite) {
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
- *this);
+ analysisState);
return true;
}))
return resultBuffer;
// Do not copy if the copied data is never read.
- SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
- if (!aliasingOpResults.empty() && !bufferizesToMemoryRead(opOperand) &&
- llvm::none_of(aliasingOpResults,
- [&](OpResult opResult) { return isValueRead(opResult); }))
+ SmallVector<OpResult> aliasingOpResults =
+ analysisState.getAliasingOpResult(opOperand);
+ if (!aliasingOpResults.empty() &&
+ !analysisState.bufferizesToMemoryRead(opOperand) &&
+ llvm::none_of(aliasingOpResults, [&](OpResult opResult) {
+ return analysisState.isValueRead(opResult);
+ }))
return resultBuffer;
// Do not copy if this op does not read the data, but writes it.
- if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))
+ if (analysisState.bufferizesToMemoryWrite(opOperand) &&
+ !analysisState.bufferizesToMemoryRead(opOperand))
return resultBuffer;
if (customCopyInsertionPoint) {
@@ -330,20 +335,20 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
rewriter.replaceOp(op, replacements);
}
-AlwaysCopyBufferizationState::AlwaysCopyBufferizationState(
+AlwaysCopyAnalysisState::AlwaysCopyAnalysisState(
const BufferizationOptions &options)
- : BufferizationState(options) {}
+ : AnalysisState(options) {}
/// Return `true` if the given OpResult has been decided to bufferize inplace.
-bool AlwaysCopyBufferizationState::isInPlace(OpOperand &opOperand) const {
+bool AlwaysCopyAnalysisState::isInPlace(OpOperand &opOperand) const {
// OpOperands that bufferize to a memory write are out-of-place, i.e., an
// alloc and copy is inserted.
return !bufferizesToMemoryWrite(opOperand);
}
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
-bool AlwaysCopyBufferizationState::areEquivalentBufferizedValues(
- Value v1, Value v2) const {
+bool AlwaysCopyAnalysisState::areEquivalentBufferizedValues(Value v1,
+ Value v2) const {
// There is no analysis, so we do not know if the values are equivalent. The
// conservative answer is "false".
return false;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 8b07b1f97b835..87f1d480ec340 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -349,7 +349,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
- const BufferizationState &state) {
+ BufferizationState &state) {
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
return foldToMemrefToTensorPair(rewriter, *this);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 510d1e13ecc72..ba6cd37751119 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -153,7 +153,7 @@ struct OneShotBufferizePass
: public OneShotBufferizeBase<OneShotBufferizePass> {
OneShotBufferizePass() : OneShotBufferizeBase<OneShotBufferizePass>() {}
- explicit OneShotBufferizePass(const AnalysisBufferizationOptions &options)
+ explicit OneShotBufferizePass(const OneShotBufferizationOptions &options)
: options(options) {}
void getDependentDialects(DialectRegistry ®istry) const override {
@@ -161,7 +161,7 @@ struct OneShotBufferizePass
}
void runOnOperation() override {
- AnalysisBufferizationOptions opt;
+ OneShotBufferizationOptions opt;
if (!options) {
// Make new bufferization options if none were provided when creating the
// pass.
@@ -209,7 +209,7 @@ struct OneShotBufferizePass
}
private:
- llvm::Optional<AnalysisBufferizationOptions> options;
+ llvm::Optional<OneShotBufferizationOptions> options;
};
} // namespace
@@ -218,7 +218,7 @@ std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass() {
}
std::unique_ptr<Pass> mlir::bufferization::createOneShotBufferizePass(
- const AnalysisBufferizationOptions &options) {
+ const OneShotBufferizationOptions &options) {
return std::make_unique<OneShotBufferizePass>(options);
}
@@ -243,23 +243,25 @@ static bool hasTensorSemantics(Operation *op) {
/// Rewrite pattern that bufferizes bufferizable ops.
struct BufferizationPattern
: public OpInterfaceRewritePattern<BufferizableOpInterface> {
- BufferizationPattern(MLIRContext *context, const BufferizationState &state,
+ BufferizationPattern(MLIRContext *context, BufferizationState &state,
PatternBenefit benefit = 1)
: OpInterfaceRewritePattern<BufferizableOpInterface>(context, benefit),
- state(state) {}
+ state(&state) {}
LogicalResult matchAndRewrite(BufferizableOpInterface bufferizableOp,
PatternRewriter &rewriter) const override {
+ const BufferizationOptions &options = state->getOptions();
+
// No tensors => no buffers.
if (!hasTensorSemantics(bufferizableOp.getOperation()))
return failure();
- if (!state.getOptions().isOpAllowed(bufferizableOp.getOperation()))
+ if (!options.isOpAllowed(bufferizableOp.getOperation()))
return failure();
- return bufferizableOp.bufferize(rewriter, state);
+ return bufferizableOp.bufferize(rewriter, *state);
}
private:
- const BufferizationState &state;
+ BufferizationState *const state;
};
/// Check the result of bufferization. Return an error if an op was not
@@ -298,10 +300,17 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
}
LogicalResult bufferization::bufferizeOp(Operation *op,
- const BufferizationState &state) {
+ const AnalysisState &analysisState) {
+ BufferizationState bufferizationState(analysisState);
+ return bufferizeOp(op, bufferizationState);
+}
+
+LogicalResult
+bufferization::bufferizeOp(Operation *op,
+ BufferizationState &bufferizationState) {
// Bufferize the op and its nested ops.
RewritePatternSet patterns(op->getContext());
- populateBufferizationPattern(state, patterns);
+ patterns.add<BufferizationPattern>(patterns.getContext(), bufferizationState);
// Bufferize ops top-to-bottom. When creating a new op, we should ideally
// know the exact memref type of all operands. Otherwise, we have to use a
@@ -323,21 +332,21 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return failure();
- return checkBufferizationResult(op, state.getOptions());
+ return checkBufferizationResult(op, bufferizationState.getOptions());
}
namespace {
-/// This a "no analysis, always copy" BufferizationState. In the absence of an
+/// This a "no analysis, always copy" AnalysisState. In the absence of an
/// analysis, a buffer must be copied each time it is written to. Therefore, all
/// OpOperands that bufferize to a memory write must bufferize out-of-place.
-class AlwaysCopyBufferizationState : public BufferizationState {
+class AlwaysCopyAnalysisState : public AnalysisState {
public:
- AlwaysCopyBufferizationState(const BufferizationOptions &options)
- : BufferizationState(options) {}
+ AlwaysCopyAnalysisState(const BufferizationOptions &options)
+ : AnalysisState(options) {}
- AlwaysCopyBufferizationState(const AlwaysCopyBufferizationState &) = delete;
+ AlwaysCopyAnalysisState(const AlwaysCopyAnalysisState &) = delete;
- virtual ~AlwaysCopyBufferizationState() = default;
+ virtual ~AlwaysCopyAnalysisState() = default;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const override {
@@ -357,15 +366,10 @@ class AlwaysCopyBufferizationState : public BufferizationState {
LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options) {
- AlwaysCopyBufferizationState state(options);
+ AlwaysCopyAnalysisState state(options);
return bufferizeOp(op, state);
}
-void bufferization::populateBufferizationPattern(
- const BufferizationState &state, RewritePatternSet &patterns) {
- patterns.add<BufferizationPattern>(patterns.getContext(), state);
-}
-
BufferizationOptions bufferization::getPartialBufferizationOptions() {
BufferizationOptions options;
options.allowReturnMemref = true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index efe7e36956c74..706072d7b9c10 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -26,7 +26,7 @@
// ops) and then bufferizes it.
//
// Inplace bufferization decisions are passed from the analysis to the
-// bufferization phase via `BufferizationState` and `BufferizationAliasInfo`.
+// bufferization phase via `AnalysisState` and `BufferizationAliasInfo`.
// They can be printed for debugging purposes with `testAnalysisOnly`.
//
// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
@@ -138,7 +138,7 @@ bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
/// Set the inPlace bufferization spec to true.
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
- BufferizationState &state) {
+ AnalysisState &state) {
markInPlace(operand);
for (OpResult result : state.getAliasingOpResult(operand))
aliasInfo.unionSets(result, operand.get());
@@ -182,12 +182,12 @@ BufferizationAliasInfo::getAliases(Value v) const {
}
//===----------------------------------------------------------------------===//
-// AnalysisBufferizationState
+// OneShotAnalysisState
//===----------------------------------------------------------------------===//
-AnalysisBufferizationState::AnalysisBufferizationState(
- Operation *op, const AnalysisBufferizationOptions &options)
- : BufferizationState(options), aliasInfo(op) {
+OneShotAnalysisState::OneShotAnalysisState(
+ Operation *op, const OneShotBufferizationOptions &options)
+ : AnalysisState(options), aliasInfo(op) {
// Set up alias sets for OpResults that must bufferize in-place. This should
// be done before making any other bufferization decisions.
op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -206,12 +206,12 @@ AnalysisBufferizationState::AnalysisBufferizationState(
});
}
-bool AnalysisBufferizationState::isInPlace(OpOperand &opOperand) const {
+bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
return aliasInfo.isInPlace(opOperand);
}
-bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1,
- Value v2) const {
+bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
+ Value v2) const {
return aliasInfo.areEquivalentBufferizedValues(v1, v2);
}
@@ -222,7 +222,7 @@ bool AnalysisBufferizationState::areEquivalentBufferizedValues(Value v1,
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) {
+ AnalysisState &state) {
// OpOperands that do not bufferize to a memory write do not write in-place.
if (!state.bufferizesToMemoryWrite(opOperand))
return false;
@@ -234,7 +234,7 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
/// is not writable.
static bool aliasesNonWritableBuffer(Value value,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) {
+ AnalysisState &state) {
bool foundNonWritableBuffer = false;
aliasInfo.applyOnAliases(value, [&](Value v) {
// Query BufferizableOpInterface to see if the value is writable.
@@ -260,7 +260,7 @@ static bool aliasesNonWritableBuffer(Value value,
/// to some buffer write.
static bool aliasesInPlaceWrite(Value value,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) {
+ AnalysisState &state) {
bool foundInplaceWrite = false;
aliasInfo.applyOnAliases(value, [&](Value v) {
for (auto &use : v.getUses()) {
@@ -331,7 +331,7 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
static bool hasReadAfterWriteInterference(
const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
- BufferizationState &state, const BufferizationAliasInfo &aliasInfo) {
+ AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
const BufferizationOptions &options = state.getOptions();
for (OpOperand *uRead : usesRead) {
@@ -452,7 +452,7 @@ static bool hasReadAfterWriteInterference(
/// OpResult. In that case, only the consistency of bufferization decisions
/// involving aliases of the given OpOperand are checked.
static bool wouldCreateReadAfterWriteInterference(
- OpOperand &operand, const DominanceInfo &domInfo, BufferizationState &state,
+ OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
const BufferizationAliasInfo &aliasInfo,
bool checkConsistencyOnly = false) {
// Helper function to iterate on aliases of `root` and capture the reads.
@@ -495,7 +495,7 @@ static bool wouldCreateReadAfterWriteInterference(
static bool
wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) {
+ AnalysisState &state) {
// Certain buffers are not writeable:
// 1. A function bbArg that is not inplaceable or
// 2. A constant op.
@@ -520,8 +520,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
/// Determine if `operand` can be bufferized in-place.
static LogicalResult bufferizableInPlaceAnalysisImpl(
- OpOperand &operand, BufferizationAliasInfo &aliasInfo,
- BufferizationState &state, const DominanceInfo &domInfo) {
+ OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state,
+ const DominanceInfo &domInfo) {
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
@@ -554,7 +554,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
/// RaW dependence violations.
static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
- BufferizationState &state,
+ AnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
if (analysisFuzzerSeed) {
@@ -587,7 +587,7 @@ static bool hasTensorSemantics(Operation *op) {
/// Analyze all ops that are contained in `op`.
static LogicalResult inPlaceAnalysis(Operation *op,
BufferizationAliasInfo &aliasInfo,
- BufferizationState &state,
+ AnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
// Collect ops so we can build our own reverse traversal.
@@ -605,7 +605,7 @@ static LogicalResult inPlaceAnalysis(Operation *op,
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) {
+ AnalysisState &state) {
for (Operation *op : ops)
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
for (OpResult opResult : op->getOpResults())
@@ -622,7 +622,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
/// in `op`.
static void equivalenceAnalysis(Operation *op,
BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) {
+ AnalysisState &state) {
// Traverse ops in PostOrder: Nested ops first, then enclosing ops.
SmallVector<Operation *> ops;
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
@@ -638,7 +638,7 @@ static void equivalenceAnalysis(Operation *op,
/// Assert that the current bufferization decisions are consistent.
static LogicalResult
checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
- BufferizationState &state,
+ AnalysisState &state,
const BufferizationAliasInfo &aliasInfo) {
const BufferizationOptions &options = state.getOptions();
Operation *inconsistentOp = nullptr;
@@ -668,7 +668,7 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
static void
annotateOpsWithBufferizationMarkers(Operation *op,
const BufferizationAliasInfo &aliasInfo,
- BufferizationState &state) {
+ AnalysisState &state) {
op->walk([&](Operation *op) {
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
for (OpOperand &opOperand : op->getOpOperands())
@@ -701,7 +701,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
// for aliasing values because the analysis is a maybe-alias analysis and we
// need a must-alias analysis here.
static LogicalResult
-assertDestinationPassingStyle(Operation *op, BufferizationState &state,
+assertDestinationPassingStyle(Operation *op, AnalysisState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
@@ -748,11 +748,11 @@ assertDestinationPassingStyle(Operation *op, BufferizationState &state,
}
LogicalResult bufferization::analyzeOp(Operation *op,
- AnalysisBufferizationState &state) {
+ OneShotAnalysisState &state) {
DominanceInfo domInfo(op);
BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
const auto &options =
- static_cast<const AnalysisBufferizationOptions &>(state.getOptions());
+ static_cast<const OneShotBufferizationOptions &>(state.getOptions());
if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
return failure();
@@ -796,9 +796,10 @@ LogicalResult bufferization::analyzeOp(Operation *op,
return success(!failedAnalysis);
}
-LogicalResult bufferization::runOneShotBufferize(
- Operation *op, const AnalysisBufferizationOptions &options) {
- AnalysisBufferizationState state(op, options);
+LogicalResult
+bufferization::runOneShotBufferize(Operation *op,
+ const OneShotBufferizationOptions &options) {
+ OneShotAnalysisState state(op, options);
if (failed(analyzeOp(op, state)))
return failure();
if (options.testAnalysisOnly)
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index a8abea24d819a..85c425623eacc 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -17,7 +17,7 @@
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered through PostAnalysisStepFns and stored in
-// `ModuleBufferizationState`.
+// `ModuleAnalysisState`.
//
// * `equivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
// tensor return value (if any).
@@ -90,9 +90,9 @@ namespace {
/// The state of analysis of a FuncOp.
enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed };
-/// Extra bufferization state that is required for bufferization of function
+/// Extra analysis state that is required for bufferization of function
/// boundaries.
-struct ModuleBufferizationState : public DialectBufferizationState {
+struct ModuleAnalysisState : public DialectAnalysisState {
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
@@ -117,28 +117,26 @@ struct ModuleBufferizationState : public DialectBufferizationState {
};
} // namespace
-/// Get ModuleBufferizationState.
-static const ModuleBufferizationState &
-getModuleBufferizationState(const BufferizationState &state) {
- Optional<const ModuleBufferizationState *> maybeState =
- state.getDialectState<ModuleBufferizationState>(
+/// Get ModuleAnalysisState.
+static const ModuleAnalysisState &
+getModuleAnalysisState(const AnalysisState &state) {
+ Optional<const ModuleAnalysisState *> maybeState =
+ state.getDialectState<ModuleAnalysisState>(
func::FuncDialect::getDialectNamespace());
- assert(maybeState.hasValue() && "ModuleBufferizationState does not exist");
+ assert(maybeState.hasValue() && "ModuleAnalysisState does not exist");
return **maybeState;
}
-/// Get or create ModuleBufferizationState.
-static ModuleBufferizationState &
-getModuleBufferizationState(BufferizationState &state) {
- return state.getOrCreateDialectState<ModuleBufferizationState>(
+/// Get or create ModuleAnalysisState.
+static ModuleAnalysisState &getModuleAnalysisState(AnalysisState &state) {
+ return state.getOrCreateDialectState<ModuleAnalysisState>(
func::FuncDialect::getDialectNamespace());
}
/// Return the state (phase) of analysis of the FuncOp.
-static FuncOpAnalysisState
-getFuncOpAnalysisState(const BufferizationState &state, FuncOp funcOp) {
- const ModuleBufferizationState &moduleState =
- getModuleBufferizationState(state);
+static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
+ FuncOp funcOp) {
+ const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
auto it = moduleState.analyzedFuncOps.find(funcOp);
if (it == moduleState.analyzedFuncOps.end())
return FuncOpAnalysisState::NotAnalyzed;
@@ -183,12 +181,12 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
}
/// Store function BlockArguments that are equivalent to a returned value in
-/// ModuleBufferizationState.
+/// ModuleAnalysisState.
static LogicalResult
-equivalentFuncOpBBArgsAnalysis(Operation *op, BufferizationState &state,
+equivalentFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
// Support only single return-terminated block in the function.
auto funcOp = cast<FuncOp>(op);
@@ -213,7 +211,7 @@ equivalentFuncOpBBArgsAnalysis(Operation *op, BufferizationState &state,
/// Return true if the buffer of the given tensor value is written to. Must not
/// be called for values inside not yet analyzed functions. (Post-analysis
/// steps do not have to be run yet, i.e., "in progress" is also OK.)
-static bool isValueWritten(Value value, const BufferizationState &state,
+static bool isValueWritten(Value value, const AnalysisState &state,
const BufferizationAliasInfo &aliasInfo) {
#ifndef NDEBUG
assert(value.getType().isa<TensorType>() && "expected TensorType");
@@ -259,10 +257,10 @@ static void annotateFuncArgAccess(FuncOp funcOp, BlockArgument bbArg,
/// PostAnalysisStepFn is run on a function with unknown ops, it will
/// conservatively assume that such ops bufferize to a read + write.
static LogicalResult
-funcOpBbArgReadWriteAnalysis(Operation *op, BufferizationState &state,
+funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
auto funcOp = cast<FuncOp>(op);
// If the function has no body, conservatively assume that all args are
@@ -349,7 +347,7 @@ getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes,
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(FuncOp funcOp,
BufferizationAliasInfo &aliasInfo,
- ModuleBufferizationState &moduleState) {
+ ModuleAnalysisState &moduleState) {
funcOp->walk([&](func::CallOp callOp) {
FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called FuncOp");
@@ -387,7 +385,8 @@ static void equivalenceAnalysis(FuncOp funcOp,
static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
RewriterBase &rewriter,
BufferizationState &state) {
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+ const ModuleAnalysisState &moduleState =
+ getModuleAnalysisState(state.getAnalysisState());
// If nothing to do then we are done.
if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) &&
@@ -439,8 +438,9 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
}
// If return operand is equivalent to some bbArg, no need to return it.
- if (moduleState.equivalentFuncArgs[funcOp].count(
- returnOperand.getOperandNumber()))
+ auto funcOpIt = moduleState.equivalentFuncArgs.find(funcOp);
+ if (funcOpIt != moduleState.equivalentFuncArgs.end() &&
+ funcOpIt->second.count(returnOperand.getOperandNumber()))
continue;
// Cast values at the call site if necessary.
@@ -674,7 +674,7 @@ namespace std_ext {
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
/// specified return value (if any).
static Optional<int64_t>
-getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleBufferizationState &state,
+getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleAnalysisState &state,
int64_t returnValIdx) {
auto funcOpIt = state.equivalentFuncArgs.find(funcOp);
if (funcOpIt == state.equivalentFuncArgs.end())
@@ -693,13 +693,12 @@ struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface,
func::CallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleBufferizationState &moduleState =
- getModuleBufferizationState(state);
+ const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is read.
return true;
@@ -709,13 +708,12 @@ struct CallOpInterface
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleBufferizationState &moduleState =
- getModuleBufferizationState(state);
+ const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed)
// FuncOp not analyzed yet. Assume that OpOperand is written.
return true;
@@ -724,14 +722,12 @@ struct CallOpInterface
funcOp.getArgument(opOperand.getOperandNumber()));
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleBufferizationState &moduleState =
- getModuleBufferizationState(state);
+ const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
SmallVector<OpResult> result;
for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
@@ -746,12 +742,11 @@ struct CallOpInterface
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleBufferizationState &moduleState =
- getModuleBufferizationState(state);
+ const ModuleAnalysisState &moduleState = getModuleAnalysisState(state);
// TODO: We should be looking for aliasing block arguments here. The current
// condition is actually stronger than neccesary. Once we check for aliasing
@@ -766,7 +761,7 @@ struct CallOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
@@ -774,14 +769,14 @@ struct CallOpInterface
/// marked inplaceable. For now, it is the responsibility of the `callOp`
/// bufferization to allow FuncOp that are inplaceable to write inPlace.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
unsigned numResults = callOp.getNumResults();
unsigned numOperands = callOp->getNumOperands();
FuncOp funcOp = getCalledFunction(callOp);
assert(funcOp && "expected CallOp to a FuncOp");
- const ModuleBufferizationState &moduleState =
- getModuleBufferizationState(state);
+ const ModuleAnalysisState &moduleState =
+ getModuleAnalysisState(state.getAnalysisState());
// Result types of the bufferized CallOp.
SmallVector<Type> resultTypes;
@@ -906,23 +901,22 @@ struct ReturnOpInterface
: public BufferizableOpInterface::ExternalModel<ReturnOpInterface,
func::ReturnOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
#ifndef NDEBUG
auto returnOp = cast<func::ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -935,13 +929,13 @@ struct ReturnOpInterface
struct FuncOpInterface
: public BufferizableOpInterface::ExternalModel<FuncOpInterface, FuncOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
return failure();
}
/// Return `true` if the given function argument is writable.
bool isWritable(Operation *op, Value value,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
auto funcOp = cast<FuncOp>(op);
BlockArgument bbArg = value.dyn_cast<BlockArgument>();
assert(bbArg && "expected BlockArgument");
@@ -982,9 +976,8 @@ static void setInPlaceFuncArgument(BlockArgument bbArg, bool inPlace) {
}
/// Annotate the IR with the result of the analysis. For testing/debugging only.
-static void
-annotateOpsWithBufferizationMarkers(FuncOp funcOp,
- const BufferizationState &state) {
+static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
+ const AnalysisState &state) {
auto bufferizableOp = cast<BufferizableOpInterface>(funcOp.getOperation());
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<TensorType>())
@@ -992,11 +985,12 @@ annotateOpsWithBufferizationMarkers(FuncOp funcOp,
}
LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
- ModuleOp moduleOp, AnalysisBufferizationOptions options) {
+ ModuleOp moduleOp, OneShotBufferizationOptions options) {
IRRewriter rewriter(moduleOp.getContext());
- AnalysisBufferizationState state(moduleOp, options);
- ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
- BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
+ OneShotAnalysisState analysisState(moduleOp, options);
+ BufferizationState bufferizationState(analysisState);
+ ModuleAnalysisState &moduleState = getModuleAnalysisState(analysisState);
+ BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo();
if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps,
moduleState.callerMap)))
@@ -1016,7 +1010,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::InProgress;
// Analyze funcOp.
- if (failed(analyzeOp(funcOp, state)))
+ if (failed(analyzeOp(funcOp, analysisState)))
return failure();
// Gather equivalence info for CallOps.
@@ -1028,7 +1022,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
// Add annotations to function arguments.
if (options.testAnalysisOnly)
- annotateOpsWithBufferizationMarkers(funcOp, state);
+ annotateOpsWithBufferizationMarkers(funcOp, analysisState);
}
if (options.testAnalysisOnly)
@@ -1040,7 +1034,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
if (funcOp.getBody().empty())
continue;
- if (failed(bufferizeOp(funcOp, state)))
+ if (failed(bufferizeOp(funcOp, bufferizationState)))
return failure();
}
@@ -1048,7 +1042,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
for (FuncOp funcOp : moduleState.orderedFuncOps) {
// Note: It would be good to apply cleanups here but we cannot as aliasInfo
// would be invalidated.
- if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, state)))
+ if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, bufferizationState)))
return failure();
if (!options.allowReturnMemref &&
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index b9d1fc2c29f71..9153bce176fa0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -25,7 +25,7 @@ namespace {
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
- const BufferizationState &state) {
+ BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
@@ -56,7 +56,7 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
SmallVector<OpOperand *> aliasingOpOperands =
- state.getAliasingOpOperand(opResult);
+ state.getAnalysisState().getAliasingOpOperand(opResult);
assert(aliasingOpOperands.size() == 1 && "expected 1 OpOperand");
FailureOr<Value> resultBuffer =
state.getBuffer(rewriter, *aliasingOpOperands.front());
@@ -156,14 +156,14 @@ struct LinalgOpInterface
: public BufferizableOpInterface::ExternalModel<LinalgOpInterface<OpTy>,
OpTy> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Operand is read if it is used in the computation.
auto genericOp = cast<linalg::LinalgOp>(op);
return genericOp.payloadUsesValueFromOperand(&opOperand);
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Operand is written to if it has an aliasing OpResult.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
@@ -171,7 +171,7 @@ struct LinalgOpInterface
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
// By default, the i-th OpResult may alias with the i-th "out" tensor.
@@ -188,9 +188,8 @@ struct LinalgOpInterface
return {};
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
auto genericOp = cast<linalg::LinalgOp>(op);
// By default, the i-th "out" tensor may alias with the i-th OpResult.
@@ -209,12 +208,12 @@ struct LinalgOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
}
};
@@ -223,13 +222,13 @@ struct InitTensorOpInterface
: public BufferizableOpInterface::ExternalModel<InitTensorOpInterface,
linalg::InitTensorOp> {
bool isMemoryWrite(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// InitTensorOps allocate but do not write.
return false;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto initTensorOp = cast<linalg::InitTensorOp>(op);
// The InitTensorOp may have been eliminated.
@@ -345,7 +344,7 @@ findValidInsertionPoint(Operation *initTensorOp,
/// chain, starting from the OpOperand and always following the aliasing
/// OpOperand, that eventually ends at a single InitTensorOp.
LogicalResult mlir::linalg::eliminateInitTensors(
- Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
+ Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo,
AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
SmallVector<Operation *> &newOps) {
OpBuilder b(op->getContext());
@@ -447,7 +446,7 @@ LogicalResult mlir::linalg::eliminateInitTensors(
/// Note that the newly inserted ExtractSliceOp may have to bufferize
/// out-of-place due to RaW conflicts.
LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
- Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
+ Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
return eliminateInitTensors(
op, state, aliasInfo,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index f1b1d97f58576..a7e86e11fbc3a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -40,7 +40,7 @@ struct LinalgComprehensiveModuleBufferize
const LinalgComprehensiveModuleBufferize &p) = default;
explicit LinalgComprehensiveModuleBufferize(
- const AnalysisBufferizationOptions &options)
+ const OneShotBufferizationOptions &options)
: options(options) {}
void runOnOperation() override;
@@ -61,7 +61,7 @@ struct LinalgComprehensiveModuleBufferize
}
private:
- llvm::Optional<AnalysisBufferizationOptions> options;
+ llvm::Optional<OneShotBufferizationOptions> options;
};
} // namespace
@@ -81,7 +81,7 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
}
void LinalgComprehensiveModuleBufferize::runOnOperation() {
- AnalysisBufferizationOptions opt;
+ OneShotBufferizationOptions opt;
if (!options) {
// Make new bufferization options if none were provided when creating the
// pass.
@@ -129,6 +129,6 @@ std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
}
std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
- const AnalysisBufferizationOptions &options) {
+ const OneShotBufferizationOptions &options) {
return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index d4dd3489841c6..b15d3460fa105 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -40,7 +40,7 @@ struct ExecuteRegionOpInterface
scf::ExecuteRegionOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// ExecuteRegionOps do not have tensor OpOperands. The yielded value can be
// any SSA value that is in scope. To allow for use-def chain traversal
// through ExecuteRegionOps in the analysis, the corresponding yield value
@@ -60,7 +60,7 @@ struct ExecuteRegionOpInterface
// TODO: For better bufferization results, this could return `true` only if
// there is a memory write in the region.
bool isMemoryWrite(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Similar to scf.if, results of this op are always considered memory writes
// in the analysis. This is a useful pattern for all ops that have tensor
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
@@ -70,7 +70,7 @@ struct ExecuteRegionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
// Compute new result types.
@@ -125,7 +125,7 @@ struct ExecuteRegionOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
};
@@ -135,7 +135,7 @@ struct IfOpInterface
: public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// IfOps do not have tensor OpOperands. The yielded value can be any SSA
// value that is in scope. To allow for use-def chain traversal through
// IfOps in the analysis, both corresponding yield values from the then/else
@@ -152,7 +152,7 @@ struct IfOpInterface
// allowed at the moment, we should never encounter scf.ifs that yield
// unmodified tensors. Such scf.yield ops could just fold away.
bool isMemoryWrite(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// IfOp results are always considered memory writes in the analysis. This
// design decision simplifies the analysis considerably. E.g., consider the
// following test case:
@@ -179,7 +179,7 @@ struct IfOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto ifOp = cast<scf::IfOp>(op);
// Compute new types of the bufferized scf.if op.
@@ -244,7 +244,7 @@ struct IfOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// IfOp results are equivalent to their corresponding yield values if both
// yield values are equivalent to each other.
auto bufferizableOp = cast<BufferizableOpInterface>(op);
@@ -263,7 +263,7 @@ struct ForOpInterface
: public BufferizableOpInterface::ExternalModel<ForOpInterface,
scf::ForOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of
// its matching bbArg may.
auto forOp = cast<scf::ForOp>(op);
@@ -271,16 +271,15 @@ struct ForOpInterface
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Tensor iter_args of scf::ForOps are always considered as a write. This is
// to simplify the analysis.
// TODO: Consider doing sth. like isValueWritten.
return true;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
if (!opOperand.get().getType().isa<RankedTensorType>())
return {};
@@ -288,7 +287,7 @@ struct ForOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// ForOp results are equivalent to their corresponding init_args if the
// corresponding iter_args and yield values are equivalent.
auto forOp = cast<scf::ForOp>(op);
@@ -301,7 +300,7 @@ struct ForOpInterface
}
bool isWritable(Operation *op, Value value,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Interestingly, scf::ForOp's bbArg can **always** be viewed
// inplace from the perspective of ops nested under:
// 1. Either the matching iter operand is not bufferized inplace and an
@@ -312,7 +311,7 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = &forOp.getLoopBody().front();
@@ -391,7 +390,7 @@ struct ForOpInterface
/// scf.for op is currently assumed to alias with the i-th iter_arg (in the
/// absence of conflicts).
LogicalResult verifyAnalysis(Operation *op,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
auto forOp = cast<scf::ForOp>(op);
auto yieldOp =
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
@@ -424,18 +423,17 @@ struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
scf::YieldOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
if (isa<scf::IfOp>(op->getParentOp()))
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
@@ -444,7 +442,7 @@ struct YieldOpInterface
}
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
// may be generated inside the block. We should not return/yield allocations
// when possible.
@@ -452,7 +450,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(
yieldOp->getParentOp()))
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 9e21e0b062023..ca102a2dbddd9 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -29,7 +29,7 @@ struct AssumingOpInterface
shape::AssumingOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// AssumingOps do not have tensor OpOperands. The yielded value can be any
// SSA value that is in scope. To allow for use-def chain traversal through
// AssumingOps in the analysis, the corresponding yield value is considered
@@ -49,7 +49,7 @@ struct AssumingOpInterface
// TODO: For better bufferization results, this could return `true` only if
// there is a memory write in the region.
bool isMemoryWrite(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Similar to scf.if, results of this op are always considered memory writes
// in the analysis. This is a useful pattern for all ops that have tensor
// OpResults but no tensor OpOperands. By default, `isMemoryWrite` is
@@ -59,7 +59,7 @@ struct AssumingOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto assumingOp = cast<shape::AssumingOp>(op);
// Compute new result types.
@@ -115,7 +115,7 @@ struct AssumingOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
};
@@ -126,25 +126,24 @@ struct AssumingYieldOpInterface
: public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface,
shape::AssumingOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
assert(isa<shape::AssumingOp>(op->getParentOp()) &&
"expected that parent is an AssumingOp");
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
}
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
// may be generated inside the block. We should not return/yield allocations
// when possible.
@@ -152,7 +151,7 @@ struct AssumingYieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
// Op is bufferized as part of AssumingOp.
return failure();
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ad5485de63317..f068010a50895 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,28 +26,27 @@ struct CastOpInterface
: public BufferizableOpInterface::ExternalModel<CastOpInterface,
tensor::CastOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {op->getResult(0)};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
// The result buffer still has the old (pre-cast) type.
@@ -85,30 +84,29 @@ struct CollapseShapeOpInterface
: public BufferizableOpInterface::ExternalModel<CollapseShapeOpInterface,
tensor::CollapseShapeOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
if (&opOperand == &op->getOpOperand(0) /*src*/)
return {op->getOpResult(0)};
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
Value buffer =
*state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
@@ -125,23 +123,22 @@ struct DimOpInterface
: public BufferizableOpInterface::ExternalModel<DimOpInterface,
tensor::DimOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
Value v = *state.getBuffer(rewriter, dimOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
@@ -154,30 +151,29 @@ struct ExpandShapeOpInterface
: public BufferizableOpInterface::ExternalModel<ExpandShapeOpInterface,
tensor::ExpandShapeOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
if (&opOperand == &op->getOpOperand(0) /*src*/)
return {op->getOpResult(0)};
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
Value buffer =
*state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/);
@@ -194,30 +190,29 @@ struct ExtractSliceOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractSliceOpInterface,
tensor::ExtractSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
if (&opOperand == &op->getOpOperand(0) /*source*/)
return {op->getOpResult(0)};
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
Value srcMemref =
@@ -228,7 +223,8 @@ struct ExtractSliceOpInterface
extractSliceOp.result().getType().cast<RankedTensorType>();
// If not inplaceable, alloc.
- bool inplace = state.isInPlace(extractSliceOp->getOpOperand(0));
+ bool inplace =
+ state.getAnalysisState().isInPlace(extractSliceOp->getOpOperand(0));
Value alloc;
if (!inplace) {
FailureOr<Value> allocOrFailure =
@@ -264,7 +260,7 @@ struct ExtractSliceOpInterface
// If not inplaceable, copy.
if (!inplace) {
// Do not copy if the copied data is never read.
- if (state.isValueRead(extractSliceOp.result()))
+ if (state.getAnalysisState().isValueRead(extractSliceOp.result()))
if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
alloc, state.getOptions())))
return failure();
@@ -281,23 +277,22 @@ struct ExtractOpInterface
: public BufferizableOpInterface::ExternalModel<ExtractOpInterface,
tensor::ExtractOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
Value srcMemref =
*state.getBuffer(rewriter, extractOp->getOpOperand(0) /*tensor*/);
@@ -334,7 +329,7 @@ struct FromElementsOpInterface
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
tensor::FromElementsOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
// Allocate a buffer for the result.
@@ -387,7 +382,7 @@ struct GenerateOpInterface
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
tensor::GenerateOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto generateOp = cast<tensor::GenerateOp>(op);
// Allocate memory.
@@ -446,18 +441,17 @@ struct InsertOpInterface
: public BufferizableOpInterface::ExternalModel<InsertOpInterface,
tensor::InsertOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
assert(&opOperand == &op->getOpOperand(1) /*dest*/ &&
"expected dest OpOperand");
return {op->getOpResult(0)};
@@ -465,12 +459,12 @@ struct InsertOpInterface
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return {&op->getOpOperand(1) /*dest*/};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref =
state.getBuffer(rewriter, insertOp->getOpOperand(1) /*dest*/);
@@ -483,7 +477,7 @@ struct InsertOpInterface
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
};
@@ -494,7 +488,7 @@ struct InsertOpInterface
/// This is one particular type of relationship between ops on tensors that
/// reduce to an equivalence on buffers. This should be generalized and
/// exposed as interfaces on the proper types.
-static bool areEquivalentExtractSliceOps(const BufferizationState &state,
+static bool areEquivalentExtractSliceOps(const AnalysisState &state,
ExtractSliceOp st, InsertSliceOp sti) {
if (!st || !sti)
return false;
@@ -508,8 +502,8 @@ static bool areEquivalentExtractSliceOps(const BufferizationState &state,
/// Return true if `value` is originating from an ExtractSliceOp that matches
/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationState &state,
- Value value, InsertSliceOp insertOp) {
+static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
+ InsertSliceOp insertOp) {
auto condition = [&](Value val) {
if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
@@ -527,31 +521,30 @@ struct InsertSliceOpInterface
: public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
tensor::InsertSliceOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return &opOperand == &op->getOpOperand(1) /*dest*/;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
if (&opOperand == &op->getOpOperand(1) /*dest*/)
return {op->getResult(0)};
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
bool isNotConflicting(Operation *op, OpOperand *uRead,
OpOperand *uConflictingWrite,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@@ -626,7 +619,7 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
@@ -683,23 +676,22 @@ struct RankOpInterface
: public BufferizableOpInterface::ExternalModel<RankOpInterface,
tensor::RankOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto rankOp = cast<tensor::RankOp>(op);
Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/);
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index eecb7bc42eaa7..48008934e46b6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -27,27 +27,26 @@ struct TransferReadOpInterface
: public BufferizableOpInterface::ExternalModel<TransferReadOpInterface,
vector::TransferReadOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
assert(opOperand.get().getType().isa<RankedTensorType>() &&
"only tensor types expected");
return false;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
return {};
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(readOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
@@ -69,34 +68,33 @@ struct TransferWriteOpInterface
: public BufferizableOpInterface::ExternalModel<TransferWriteOpInterface,
vector::TransferWriteOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return true;
}
- SmallVector<OpResult>
- getAliasingOpResult(Operation *op, OpOperand &opOperand,
- const BufferizationState &state) const {
+ SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
+ const AnalysisState &state) const {
assert(opOperand.get().getType().isa<TensorType>() &&
"only tensor types expected");
return {op->getOpResult(0)};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
- const BufferizationState &state) const {
+ const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationState &state) const {
+ BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(writeOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
More information about the Mlir-commits
mailing list