[Mlir-commits] [mlir] cf2d374 - [mlir][bufferize][NFC] Merge AnalysisState and BufferizationAliasInfo
Matthias Springer
llvmlistbot at llvm.org
Wed Feb 8 00:16:44 PST 2023
Author: Matthias Springer
Date: 2023-02-08T09:12:09+01:00
New Revision: cf2d374e990e4784c7f2bf3bd66c76bb00843a11
URL: https://github.com/llvm/llvm-project/commit/cf2d374e990e4784c7f2bf3bd66c76bb00843a11
DIFF: https://github.com/llvm/llvm-project/commit/cf2d374e990e4784c7f2bf3bd66c76bb00843a11.diff
LOG: [mlir][bufferize][NFC] Merge AnalysisState and BufferizationAliasInfo
There is no longer a need to keep the two separate. This is in preparation of reusing the same AnalysisState for tensor.empty elimination and One-Shot Bufferize (to address performance bottlenecks).
Differential Revision: https://reviews.llvm.org/D143313
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 63a8b389e0d5f..3dc045d93bf4e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -17,7 +17,6 @@ namespace mlir {
namespace bufferization {
struct OneShotBufferizationOptions;
-class BufferizationAliasInfo;
struct BufferizationStatistics;
class OneShotAnalysisState;
@@ -40,108 +39,11 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
llvm::ArrayRef<std::string> noAnalysisFuncFilter;
};
-/// The BufferizationAliasInfo class maintains a list of buffer aliases and
-/// equivalence classes to support bufferization.
-class BufferizationAliasInfo {
-public:
- explicit BufferizationAliasInfo(Operation *rootOp);
-
- // BufferizationAliasInfo should be passed as a reference.
- BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
-
- /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
- /// beginning the alias and equivalence sets only contain `v` itself.
- void createAliasInfoEntry(Value v);
-
- /// Insert an info entry for `newValue` and merge its alias set with that of
- /// `alias`.
- void insertNewBufferAlias(Value newValue, Value alias);
-
- /// Insert an info entry for `newValue` and merge its alias set with that of
- /// `alias`. Additionally, merge their equivalence classes.
- void insertNewBufferEquivalence(Value newValue, Value alias);
-
- /// Set the inPlace bufferization spec to true.
- /// Merge result's and operand's aliasing sets and iterate to a fixed point.
- void bufferizeInPlace(OpOperand &operand, AnalysisState &state);
-
- /// Set the inPlace bufferization spec to false.
- void bufferizeOutOfPlace(OpOperand &operand);
-
- /// Return true if `v1` and `v2` may bufferize to aliasing buffers.
- bool areAliasingBufferizedValues(Value v1, Value v2) const {
- return aliasInfo.isEquivalent(v1, v2);
- }
-
- /// Return true if `v1` and `v2` bufferize to equivalent buffers.
- bool areEquivalentBufferizedValues(Value v1, Value v2) const {
- return equivalentInfo.isEquivalent(v1, v2);
- }
-
- /// Union the alias sets of `v1` and `v2`.
- void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
-
- /// Union the equivalence classes of `v1` and `v2`.
- void unionEquivalenceClasses(Value v1, Value v2) {
- equivalentInfo.unionSets(v1, v2);
- }
-
- /// Apply `fun` to all the members of the equivalence class of `v`.
- void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
-
- /// Apply `fun` to all aliases of `v`.
- void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
-
- /// Mark a value as in-place bufferized.
- void markInPlace(OpOperand &o) { inplaceBufferized.insert(&o); }
-
- /// Return `true` if a value was marked as in-place bufferized.
- bool isInPlace(OpOperand &opOperand) const;
-
- int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; }
- int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; }
-
-private:
- /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
- /// uses pointer comparison on the defining op. This is a poor man's
- /// comparison but it's not like UnionFind needs ordering anyway.
- struct ValueComparator {
- bool operator()(const Value &lhs, const Value &rhs) const {
- return lhs.getImpl() < rhs.getImpl();
- }
- };
-
- using EquivalenceClassRangeType = llvm::iterator_range<
- llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
- /// Check that aliasInfo for `v` exists and return a reference to it.
- EquivalenceClassRangeType getAliases(Value v) const;
-
- /// Set of all OpResults that were decided to bufferize in-place.
- llvm::DenseSet<OpOperand *> inplaceBufferized;
-
- /// Auxiliary structure to store all the values a given value may alias with.
- /// Alias information is "may be" conservative: In the presence of branches, a
- /// value may alias with one of multiple other values. The concrete aliasing
- /// value may not even be known at compile time. All such values are
- /// considered to be aliases.
- llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
-
- /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
- /// buffer information is "must be" conservative: Only if two values are
- /// guaranteed to be equivalent at runtime, they said to be equivalent. It is
- /// possible that, in the presence of branches, it cannot be determined
- /// statically if two values are equivalent. In that case, the values are
- /// considered to be not equivalent.
- llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
-
- // Bufferization statistics.
- int64_t statNumTensorOutOfPlace = 0;
- int64_t statNumTensorInPlace = 0;
-};
-
/// State for analysis-enabled bufferization. This class keeps track of alias
-/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
-/// in-place.
+/// sets, equivalence sets, in-place OpOperands and other things.
+///
+/// Note: Modifying the IR generally invalidates the result of the analysis.
+/// Adding new operations is safe if they are analyzed subsequently.
class OneShotAnalysisState : public AnalysisState {
public:
OneShotAnalysisState(Operation *op,
@@ -161,11 +63,11 @@ class OneShotAnalysisState : public AnalysisState {
AnalysisState::getOptions());
}
- /// Return a reference to the BufferizationAliasInfo.
- BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
+ /// Apply `fun` to all the members of the equivalence class of `v`.
+ void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
- /// Return `true` if the given OpResult has been decided to bufferize inplace.
- bool isInPlace(OpOperand &opOperand) const override;
+ /// Apply `fun` to all aliases of `v`.
+ void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
/// Return true if `v1` and `v2` bufferize to equivalent buffers.
bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
@@ -173,12 +75,16 @@ class OneShotAnalysisState : public AnalysisState {
/// Return true if `v1` and `v2` may bufferize to aliasing buffers.
bool areAliasingBufferizedValues(Value v1, Value v2) const override;
- /// Return `true` if the given tensor has undefined contents.
- bool hasUndefinedContents(OpOperand *opOperand) const override;
+ /// Mark the given OpOperand as in-place and merge the results' and operand's
+ /// aliasing sets.
+ void bufferizeInPlace(OpOperand &operand);
- /// Return true if the given tensor (or an aliasing tensor) is yielded from
- /// the containing block. Also include all aliasing tensors in the same block.
- bool isTensorYielded(Value tensor) const override;
+ /// Mark the given OpOperand as out-of-place.
+ void bufferizeOutOfPlace(OpOperand &operand);
+
+ /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
+ /// beginning the alias and equivalence sets only contain `v` itself.
+ void createAliasInfoEntry(Value v);
/// Find all tensor values in the given operation that have undefined contents
/// and store them in `undefinedTensorUses`.
@@ -188,6 +94,19 @@ class OneShotAnalysisState : public AnalysisState {
/// `yieldedTensors`. Also include all aliasing tensors in the same block.
void gatherYieldedTensors(Operation *op);
+ int64_t getStatNumTensorOutOfPlace() const { return statNumTensorOutOfPlace; }
+ int64_t getStatNumTensorInPlace() const { return statNumTensorInPlace; }
+
+ /// Return `true` if the given tensor has undefined contents.
+ bool hasUndefinedContents(OpOperand *opOperand) const override;
+
+ /// Return `true` if the given OpResult has been decided to bufferize inplace.
+ bool isInPlace(OpOperand &opOperand) const override;
+
+ /// Return true if the given tensor (or an aliasing tensor) is yielded from
+ /// the containing block. Also include all aliasing tensors in the same block.
+ bool isTensorYielded(Value tensor) const override;
+
/// Return true if the buffer of the given tensor value is written to. Must
/// not be called for values inside not yet analyzed functions.
bool isValueWritten(Value value) const;
@@ -195,6 +114,12 @@ class OneShotAnalysisState : public AnalysisState {
/// Return true if the buffer of the given tensor value is writable.
bool isWritable(Value value) const;
+ /// Union the alias sets of `v1` and `v2`.
+ void unionAliasSets(Value v1, Value v2);
+
+ /// Union the equivalence classes of `v1` and `v2`.
+ void unionEquivalenceClasses(Value v1, Value v2);
+
/// Base class for OneShotAnalysisState extensions that allow
/// OneShotAnalysisState to contain user-specified information in the state
/// object. Clients are expected to derive this class, add the desired fields,
@@ -279,9 +204,41 @@ class OneShotAnalysisState : public AnalysisState {
}
private:
- /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
- /// functions and `runOneShotBufferize` may access this object.
- BufferizationAliasInfo aliasInfo;
+ /// llvm::EquivalenceClasses wants comparable elements. This comparator uses
+ /// pointer comparison on the defining op. This is a poor man's comparison
+ /// but it's not like UnionFind needs ordering anyway.
+ struct ValueComparator {
+ bool operator()(const Value &lhs, const Value &rhs) const {
+ return lhs.getImpl() < rhs.getImpl();
+ }
+ };
+
+ using EquivalenceClassRangeType = llvm::iterator_range<
+ llvm::EquivalenceClasses<Value, ValueComparator>::member_iterator>;
+ /// Check that aliasInfo for `v` exists and return a reference to it.
+ EquivalenceClassRangeType getAliases(Value v) const;
+
+ /// Set of all OpResults that were decided to bufferize in-place.
+ llvm::DenseSet<OpOperand *> inplaceBufferized;
+
+ /// Auxiliary structure to store all the values a given value may alias with.
+ /// Alias information is "may be" conservative: In the presence of branches, a
+ /// value may alias with one of multiple other values. The concrete aliasing
+ /// value may not even be known at compile time. All such values are
+ /// considered to be aliases.
+ llvm::EquivalenceClasses<Value, ValueComparator> aliasInfo;
+
+ /// Auxiliary structure to store all the equivalent buffer classes. Equivalent
+ /// buffer information is "must be" conservative: Only if two values are
+ /// guaranteed to be equivalent at runtime, they said to be equivalent. It is
+ /// possible that, in the presence of branches, it cannot be determined
+ /// statically if two values are equivalent. In that case, the values are
+ /// considered to be not equivalent.
+ llvm::EquivalenceClasses<Value, ValueComparator> equivalentInfo;
+
+ // Bufferization statistics.
+ int64_t statNumTensorOutOfPlace = 0;
+ int64_t statNumTensorInPlace = 0;
/// A set of all tensors (and maybe aliasing tensors) that yielded from a
/// block.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 3cd84c529be61..4231b85363451 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -71,7 +71,7 @@ static bool isaTensor(Type t) { return t.isa<TensorType>(); }
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
// These are for testing and debugging only. Bufferization information is stored
-// in BufferizationAliasInfo. When run with `testAnalysisOnly`, the IR is
+// in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
// annotated with the results of the analysis, so that they can be checked in
// tests.
//===----------------------------------------------------------------------===//
@@ -98,11 +98,14 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
}
//===----------------------------------------------------------------------===//
-// BufferizationAliasInfo
+// OneShotAnalysisState
//===----------------------------------------------------------------------===//
-BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
- rootOp->walk([&](Operation *op) {
+OneShotAnalysisState::OneShotAnalysisState(
+ Operation *op, const OneShotBufferizationOptions &options)
+ : AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
+ // Set up alias sets.
+ op->walk([&](Operation *op) {
for (Value v : op->getResults())
if (v.getType().isa<TensorType>())
createAliasInfoEntry(v);
@@ -112,55 +115,20 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
if (bbArg.getType().isa<TensorType>())
createAliasInfoEntry(bbArg);
});
-}
-
-/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
-/// beginning the alias and equivalence sets only contain `v` itself.
-void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
- aliasInfo.insert(v);
- equivalentInfo.insert(v);
-}
-
-/// Insert an info entry for `newValue` and merge its alias set with that of
-/// `alias`.
-void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
- createAliasInfoEntry(newValue);
- aliasInfo.unionSets(newValue, alias);
-}
-
-/// Insert an info entry for `newValue` and merge its alias set with that of
-/// `alias`. Additionally, merge their equivalence classes.
-void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
- Value alias) {
- insertNewBufferAlias(newValue, alias);
- equivalentInfo.unionSets(newValue, alias);
-}
-
-/// Return `true` if a value was marked as in-place bufferized.
-bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
- return inplaceBufferized.contains(&operand);
-}
-
-/// Set the inPlace bufferization spec to true.
-void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
- AnalysisState &state) {
- if (inplaceBufferized.contains(&operand))
- return;
- markInPlace(operand);
- for (OpResult result : state.getAliasingOpResults(operand))
- aliasInfo.unionSets(result, operand.get());
- ++statNumTensorInPlace;
-}
-/// Set the inPlace bufferization spec to false.
-void BufferizationAliasInfo::bufferizeOutOfPlace(OpOperand &operand) {
- assert(!inplaceBufferized.contains(&operand) &&
- "OpOperand was already decided to bufferize inplace");
- ++statNumTensorOutOfPlace;
+ // Mark OpOperands in-place that must bufferize in-place.
+ op->walk([&](BufferizableOpInterface bufferizableOp) {
+ if (!options.isOpAllowed(bufferizableOp))
+ return WalkResult::skip();
+ for (OpOperand &opOperand : bufferizableOp->getOpOperands())
+ if (opOperand.get().getType().isa<TensorType>())
+ if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
+ bufferizeInPlace(opOperand);
+ return WalkResult::advance();
+ });
}
-/// Apply `fun` to all the members of the equivalence class of `v`.
-void BufferizationAliasInfo::applyOnEquivalenceClass(
+void OneShotAnalysisState::applyOnEquivalenceClass(
Value v, function_ref<void(Value)> fun) const {
auto leaderIt = equivalentInfo.findLeader(v);
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
@@ -169,66 +137,48 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
}
}
-/// Apply `fun` to all aliases of `v`.
-void BufferizationAliasInfo::applyOnAliases(
- Value v, function_ref<void(Value)> fun) const {
+void OneShotAnalysisState::applyOnAliases(Value v,
+ function_ref<void(Value)> fun) const {
auto leaderIt = aliasInfo.findLeader(v);
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
fun(*mit);
}
}
-BufferizationAliasInfo::EquivalenceClassRangeType
-BufferizationAliasInfo::getAliases(Value v) const {
- DenseSet<Value> res;
- auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
- for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
- mit != meit; ++mit) {
- res.insert(static_cast<Value>(*mit));
- }
- return BufferizationAliasInfo::EquivalenceClassRangeType(
- aliasInfo.member_begin(it), aliasInfo.member_end());
+bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
+ Value v2) const {
+ return equivalentInfo.isEquivalent(v1, v2);
}
-//===----------------------------------------------------------------------===//
-// OneShotAnalysisState
-//===----------------------------------------------------------------------===//
-
-OneShotAnalysisState::OneShotAnalysisState(
- Operation *op, const OneShotBufferizationOptions &options)
- : AnalysisState(options, TypeID::get<OneShotAnalysisState>()),
- aliasInfo(op) {
- // Set up alias sets for OpResults that must bufferize in-place. This should
- // be done before making any other bufferization decisions.
- op->walk([&](BufferizableOpInterface bufferizableOp) {
- if (!options.isOpAllowed(bufferizableOp))
- return WalkResult::skip();
- for (OpOperand &opOperand : bufferizableOp->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
- if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
- aliasInfo.bufferizeInPlace(opOperand, *this);
- return WalkResult::advance();
- });
+bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
+ Value v2) const {
+ return aliasInfo.isEquivalent(v1, v2);
}
-bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
- return aliasInfo.isInPlace(opOperand);
+void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
+ if (inplaceBufferized.contains(&operand))
+ return;
+ inplaceBufferized.insert(&operand);
+ for (OpResult result : getAliasingOpResults(operand))
+ aliasInfo.unionSets(result, operand.get());
+ ++statNumTensorInPlace;
}
-bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
- Value v2) const {
- return aliasInfo.areEquivalentBufferizedValues(v1, v2);
+void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
+ assert(!inplaceBufferized.contains(&operand) &&
+ "OpOperand was already decided to bufferize inplace");
+ ++statNumTensorOutOfPlace;
}
-bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
- Value v2) const {
- return aliasInfo.areAliasingBufferizedValues(v1, v2);
+void OneShotAnalysisState::createAliasInfoEntry(Value v) {
+ aliasInfo.insert(v);
+ equivalentInfo.insert(v);
}
// Gather yielded tensors in `yieldedTensors` by querying all aliases. This is
// to ensure that such information is available during bufferization time.
-// Alias information can no longer be queried through BufferizationAliasInfo
-// once we have started modifying the IR.
+// Alias information can no longer be queried once we have started modifying
+// the IR.
void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
op->walk([&](Operation *returnOp) {
if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp))
@@ -242,7 +192,7 @@ void OneShotAnalysisState::gatherYieldedTensors(Operation *op) {
// Add all aliases of the returned value. But only the ones that are in
// the same block.
- aliasInfo.applyOnAliases(returnVal, [&](Value v) {
+ applyOnAliases(returnVal, [&](Value v) {
if (auto bbArg = v.dyn_cast<BlockArgument>()) {
if (bbArg.getOwner()->getParentOp() == returnOp->getParentOp())
yieldedTensors.insert(bbArg);
@@ -285,13 +235,17 @@ bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return undefinedTensorUses.contains(opOperand);
}
+bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
+ return inplaceBufferized.contains(&opOperand);
+}
+
bool OneShotAnalysisState::isTensorYielded(Value tensor) const {
return yieldedTensors.contains(tensor);
}
bool OneShotAnalysisState::isValueWritten(Value value) const {
bool isWritten = false;
- aliasInfo.applyOnAliases(value, [&](Value val) {
+ applyOnAliases(value, [&](Value val) {
for (OpOperand &use : val.getUses())
if (isInPlace(use) && bufferizesToMemoryWrite(use))
isWritten = true;
@@ -314,6 +268,14 @@ bool OneShotAnalysisState::isWritable(Value value) const {
return false;
}
+void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
+ aliasInfo.unionSets(v1, v2);
+}
+
+void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
+ equivalentInfo.unionSets(v1, v2);
+}
+
OneShotAnalysisState::Extension::~Extension() = default;
//===----------------------------------------------------------------------===//
@@ -322,13 +284,12 @@ OneShotAnalysisState::Extension::~Extension() = default;
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
- const BufferizationAliasInfo &aliasInfo,
- const AnalysisState &state) {
+ const OneShotAnalysisState &state) {
// OpOperands that do not bufferize to a memory write do not write in-place.
if (!state.bufferizesToMemoryWrite(opOperand))
return false;
// Check current bufferization decisions.
- return aliasInfo.isInPlace(opOperand);
+ return state.isInPlace(opOperand);
}
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
@@ -489,10 +450,11 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
/// the result of a definition W1. But because of bufferization decisions, R
/// actually reads another definition W2.
-static bool hasReadAfterWriteInterference(
- const DenseSet<OpOperand *> &usesRead,
- const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
- AnalysisState &state, const BufferizationAliasInfo &aliasInfo) {
+static bool
+hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
+ const DenseSet<OpOperand *> &usesWrite,
+ const DominanceInfo &domInfo,
+ OneShotAnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
for (OpOperand *uRead : usesRead) {
@@ -654,21 +616,19 @@ static bool hasReadAfterWriteInterference(
// Helper function to iterate on aliases of `root` and capture the writes.
static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
- const BufferizationAliasInfo &aliasInfo,
- const AnalysisState &state) {
- aliasInfo.applyOnAliases(root, [&](Value alias) {
+ const OneShotAnalysisState &state) {
+ state.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Inplace write to a value that aliases root.
- if (isInplaceMemoryWrite(use, aliasInfo, state))
+ if (isInplaceMemoryWrite(use, state))
res.insert(&use);
});
}
// Helper function to iterate on aliases of `root` and capture the reads.
static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
- const BufferizationAliasInfo &aliasInfo,
- const AnalysisState &state) {
- aliasInfo.applyOnAliases(root, [&](Value alias) {
+ const OneShotAnalysisState &state) {
+ state.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses()) {
// Read of a value that aliases root.
if (state.bufferizesToMemoryRead(use)) {
@@ -731,22 +691,20 @@ static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
/// OpResult. In that case, only the consistency of bufferization decisions
/// involving aliases of the given OpOperand are checked.
static bool wouldCreateReadAfterWriteInterference(
- OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
- const BufferizationAliasInfo &aliasInfo,
- bool checkConsistencyOnly = false) {
+ OpOperand &operand, const DominanceInfo &domInfo,
+ OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
// Collect reads and writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesRead, usesWrite;
- getAliasingReads(usesRead, operand.get(), aliasInfo, state);
- getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
+ getAliasingReads(usesRead, operand.get(), state);
+ getAliasingInplaceWrites(usesWrite, operand.get(), state);
for (OpResult result : state.getAliasingOpResults(operand)) {
- getAliasingReads(usesRead, result, aliasInfo, state);
- getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
+ getAliasingReads(usesRead, result, state);
+ getAliasingInplaceWrites(usesWrite, result, state);
}
if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
- return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
- aliasInfo);
+ return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
}
/// Annotate IR with details about the detected non-writability conflict.
@@ -773,7 +731,6 @@ static void annotateNonWritableTensor(Value value) {
/// materialized in `aliasInfo` yet.
static bool
hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
- const BufferizationAliasInfo &aliasInfo,
const OneShotAnalysisState &state) {
SmallVector<Value> worklist;
worklist.push_back(value);
@@ -794,7 +751,7 @@ hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
AliasingOpOperandList aliasingOpOperands =
state.getAliasingOpOperands(opResult);
for (OpOperand *opOperand : aliasingOpOperands)
- if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
+ if (state.isInPlace(*opOperand) || currentOpOperand == opOperand)
worklist.push_back(opOperand->get());
}
return false;
@@ -802,14 +759,15 @@ hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
/// Return true if bufferizing `operand` inplace would create a write to a
/// non-writable buffer.
-static bool wouldCreateWriteToNonWritableBuffer(
- OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
- OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
+static bool
+wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
+ OneShotAnalysisState &state,
+ bool checkConsistencyOnly = false) {
// Collect writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesWrite;
- getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
+ getAliasingInplaceWrites(usesWrite, operand.get(), state);
for (OpResult result : state.getAliasingOpResults(operand)) {
- getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
+ getAliasingInplaceWrites(usesWrite, result, state);
}
if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
@@ -818,8 +776,7 @@ static bool wouldCreateWriteToNonWritableBuffer(
// alias), check if there is a non-writable tensor in the reverse SSA use-def
// chain.
for (OpOperand *uWrite : usesWrite) {
- if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
- aliasInfo, state)) {
+ if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) {
LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
return true;
}
@@ -833,22 +790,22 @@ static bool wouldCreateWriteToNonWritableBuffer(
//===----------------------------------------------------------------------===//
/// Determine if `operand` can be bufferized in-place.
-static LogicalResult bufferizableInPlaceAnalysisImpl(
- OpOperand &operand, BufferizationAliasInfo &aliasInfo,
- OneShotAnalysisState &state, const DominanceInfo &domInfo) {
+static LogicalResult
+bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
+ const DominanceInfo &domInfo) {
LLVM_DEBUG(
llvm::dbgs() << "//===-------------------------------------------===//\n"
<< "Analyzing operand #" << operand.getOperandNumber()
<< " of " << *operand.getOwner() << "\n");
bool foundInterference =
- wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
- wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
+ wouldCreateWriteToNonWritableBuffer(operand, state) ||
+ wouldCreateReadAfterWriteInterference(operand, domInfo, state);
if (foundInterference)
- aliasInfo.bufferizeOutOfPlace(operand);
+ state.bufferizeOutOfPlace(operand);
else
- aliasInfo.bufferizeInPlace(operand, state);
+ state.bufferizeInPlace(operand);
LLVM_DEBUG(llvm::dbgs()
<< "//===-------------------------------------------===//\n");
@@ -874,7 +831,6 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
/// An analysis is required to ensure inplace bufferization would not result in
/// RaW dependence violations.
static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
- BufferizationAliasInfo &aliasInfo,
OneShotAnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
@@ -890,8 +846,7 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
auto analyzeOp = [&](Operation *op) {
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>())
- if (failed(bufferizableInPlaceAnalysisImpl(opOperand, aliasInfo, state,
- domInfo)))
+ if (failed(bufferizableInPlaceAnalysisImpl(opOperand, state, domInfo)))
return failure();
return success();
};
@@ -924,7 +879,6 @@ static bool hasTensorSemantics(Operation *op) {
/// Analyze all ops that are contained in `op`.
static LogicalResult inPlaceAnalysis(Operation *op,
- BufferizationAliasInfo &aliasInfo,
OneShotAnalysisState &state,
const DominanceInfo &domInfo,
unsigned analysisFuzzerSeed = 0) {
@@ -937,13 +891,12 @@ static LogicalResult inPlaceAnalysis(Operation *op,
ops.push_back(op);
});
- return inPlaceAnalysis(ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
+ return inPlaceAnalysis(ops, state, domInfo, analysisFuzzerSeed);
}
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
- BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
+ OneShotAnalysisState &state) {
for (Operation *op : ops)
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
for (OpResult opResult : op->getOpResults())
@@ -953,14 +906,12 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
if (state.isInPlace(*opOperand))
if (bufferizableOp.bufferRelation(opResult, state) ==
BufferRelation::Equivalent)
- aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
+ state.unionEquivalenceClasses(opResult, opOperand->get());
}
/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
/// in `op`.
-static void equivalenceAnalysis(Operation *op,
- BufferizationAliasInfo &aliasInfo,
- AnalysisState &state) {
+static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
// Traverse ops in PostOrder: Nested ops first, then enclosing ops.
SmallVector<Operation *> ops;
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
@@ -970,14 +921,13 @@ static void equivalenceAnalysis(Operation *op,
ops.push_back(op);
});
- equivalenceAnalysis(ops, aliasInfo, state);
+ equivalenceAnalysis(ops, state);
}
/// Assert that the current bufferization decisions are consistent.
-static LogicalResult
-checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
- AnalysisState &state,
- const BufferizationAliasInfo &aliasInfo) {
+static LogicalResult checkAliasInfoConsistency(Operation *op,
+ const DominanceInfo &domInfo,
+ OneShotAnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
@@ -995,7 +945,7 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
for (OpOperand &opOperand : op->getOpOperands()) {
if (opOperand.get().getType().isa<TensorType>()) {
if (wouldCreateReadAfterWriteInterference(
- opOperand, domInfo, state, aliasInfo,
+ opOperand, domInfo, state,
/*checkConsistencyOnly=*/true)) {
// This error can happen if certain "mustBufferizeInPlace" interface
// methods are implemented incorrectly, such that the IR already has
@@ -1015,13 +965,12 @@ checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
/// Annotate the IR with the result of the analysis. For testing/debugging only.
static void
annotateOpsWithBufferizationMarkers(Operation *op,
- const BufferizationAliasInfo &aliasInfo,
- const BufferizationOptions &options) {
+ const OneShotAnalysisState &state) {
// Add __inplace_operands_attr__.
op->walk([&](Operation *op) {
for (OpOperand &opOperand : op->getOpOperands())
if (opOperand.get().getType().isa<TensorType>())
- setInPlaceOpOperand(opOperand, aliasInfo.isInPlace(opOperand));
+ setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
});
}
@@ -1056,12 +1005,12 @@ annotateOpsWithBufferizationMarkers(Operation *op,
// TODO: Remove buffer deallocation from One-Shot Bufferize and fix the buffer
// deallocation pass.
static LogicalResult assertNoAllocsReturned(Operation *op,
- const BufferizationOptions &options,
- BufferizationAliasInfo &aliasInfo) {
+ const OneShotAnalysisState &state) {
LogicalResult status = success();
DominanceInfo domInfo(op);
op->walk([&](Operation *returnOp) {
- if (!isRegionReturnLike(returnOp) || !options.isOpAllowed(returnOp))
+ if (!isRegionReturnLike(returnOp) ||
+ !state.getOptions().isOpAllowed(returnOp))
return WalkResult::advance();
for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
@@ -1071,7 +1020,7 @@ static LogicalResult assertNoAllocsReturned(Operation *op,
continue;
bool foundEquivValue = false;
- aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
+ state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
Operation *definingOp = bbArg.getOwner()->getParentOp();
if (definingOp->isProperAncestor(returnOp))
@@ -1105,27 +1054,25 @@ LogicalResult bufferization::analyzeOp(Operation *op,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
DominanceInfo domInfo(op);
- BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
const OneShotBufferizationOptions &options = state.getOptions();
- if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
+ if (failed(checkAliasInfoConsistency(op, domInfo, state)))
return failure();
// If the analysis fails, just return.
- if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
- options.analysisFuzzerSeed)))
+ if (failed(inPlaceAnalysis(op, state, domInfo, options.analysisFuzzerSeed)))
return failure();
if (statistics) {
- statistics->numTensorInPlace = aliasInfo.getStatNumTensorInPlace();
- statistics->numTensorOutOfPlace = aliasInfo.getStatNumTensorOutOfPlace();
+ statistics->numTensorInPlace = state.getStatNumTensorInPlace();
+ statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
}
- equivalenceAnalysis(op, aliasInfo, state);
+ equivalenceAnalysis(op, state);
bool failedAnalysis = false;
if (!options.allowReturnAllocs)
- failedAnalysis |= failed(assertNoAllocsReturned(op, options, aliasInfo));
+ failedAnalysis |= failed(assertNoAllocsReturned(op, state));
// Gather some extra analysis data.
state.gatherYieldedTensors(op);
@@ -1142,7 +1089,7 @@ LogicalResult bufferization::analyzeOp(Operation *op,
// Annotate operations if we only want to report the analysis.
if (options.testAnalysisOnly)
- annotateOpsWithBufferizationMarkers(op, aliasInfo, options);
+ annotateOpsWithBufferizationMarkers(op, state);
return success(!failedAnalysis);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 943efe8e18b70..9562ac5920f58 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -250,7 +250,6 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
/// analyzed.
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(func::FuncOp funcOp,
- BufferizationAliasInfo &aliasInfo,
OneShotAnalysisState &state,
FuncAnalysisState &funcState) {
funcOp->walk([&](func::CallOp callOp) {
@@ -268,7 +267,7 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
continue;
Value returnVal = callOp.getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
- aliasInfo.unionEquivalenceClasses(returnVal, argVal);
+ state.unionEquivalenceClasses(returnVal, argVal);
}
return WalkResult::advance();
@@ -365,7 +364,6 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
assert(state.getOptions().bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
- BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
@@ -385,7 +383,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
funcState.startFunctionAnalysis(funcOp);
// Gather equivalence info for CallOps.
- equivalenceAnalysis(funcOp, aliasInfo, state, funcState);
+ equivalenceAnalysis(funcOp, state, funcState);
// Analyze funcOp.
if (failed(analyzeOp(funcOp, state, statistics)))
More information about the Mlir-commits
mailing list