[Mlir-commits] [mlir] e3889b3 - [mlir][Linalg] Replace DenseSet by UnionFind in ComprehensiveBufferize - NFC
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Sep 15 03:36:00 PDT 2021
Author: Nicolas Vasilache
Date: 2021-09-15T10:35:54Z
New Revision: e3889b30590a8d61eec42b08bbe3708322c15205
URL: https://github.com/llvm/llvm-project/commit/e3889b30590a8d61eec42b08bbe3708322c15205
DIFF: https://github.com/llvm/llvm-project/commit/e3889b30590a8d61eec42b08bbe3708322c15205.diff
LOG: [mlir][Linalg] Replace DenseSet by UnionFind in ComprehensiveBufferize - NFC
AliasInfo can now use union-find for a much more efficient implementation.
This brings no functional changes but large performance gains on more complex examples.
Differential Revision: https://reviews.llvm.org/D109819
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index b165ec752491..1ca179ce8ec3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -780,18 +780,25 @@ class BufferizationAliasInfo {
void dumpEquivalences() const { printEquivalences(llvm::errs()); }
private:
- /// Check that aliasInfo for `v` exists and return a reference to it.
- DenseSet<Value> &getAliasInfoRef(Value v);
-
- const DenseSet<Value> &getAliasInfoRef(Value v) const {
- return const_cast<BufferizationAliasInfo *>(this)->getAliasInfoRef(v);
- }
-
- /// Union all the aliasing sets of all aliases of v1 and v2.
- bool mergeAliases(Value v1, Value v2);
+ /// llvm::EquivalenceClasses wants comparable elements because it uses
+ /// std::set as the underlying impl.
+ /// ValueWrapper wraps Value and uses pointer comparison on the defining op.
+ /// This is a poor man's comparison but it's not like UnionFind needs ordering
+ /// anyway ..
+ struct ValueWrapper {
+ ValueWrapper(Value val) : v(val) {}
+ operator Value() const { return v; }
+ bool operator<(const ValueWrapper &wrap) const {
+ return v.getImpl() < wrap.v.getImpl();
+ }
+ bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; }
+ Value v;
+ };
- /// Iteratively merge alias sets until a fixed-point.
- void mergeAliasesToFixedPoint();
+ using EquivalenceClassRangeType = llvm::iterator_range<
+ llvm::EquivalenceClasses<ValueWrapper>::member_iterator>;
+ /// Check that aliasInfo for `v` exists and return a reference to it.
+ EquivalenceClassRangeType getAliases(Value v) const;
/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
/// equivalent operand / result and same offset/sizes/strides specification).
@@ -849,24 +856,10 @@ class BufferizationAliasInfo {
OpOperand &aliasingWrite,
const DominanceInfo &domInfo) const;
- /// EquivalenceClasses wants comparable elements because it uses std::set.
- /// ValueWrapper wraps Value and uses pointer comparison on the defining op.
- /// This is a poor man's comparison but it's not like UnionFind needs ordering
- /// anyway ..
- struct ValueWrapper {
- ValueWrapper(Value val) : v(val) {}
- operator Value() const { return v; }
- bool operator<(const ValueWrapper &wrap) const {
- return v.getImpl() < wrap.v.getImpl();
- }
- bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; }
- Value v;
- };
-
/// Auxiliary structure to store all the values a given value aliases with.
/// These are the conservative cases that can further decompose into
/// "equivalent" buffer relationships.
- DenseMap<Value, DenseSet<Value>> aliasInfo;
+ llvm::EquivalenceClasses<ValueWrapper> aliasInfo;
/// Auxiliary structure to store all the equivalent buffer classes.
llvm::EquivalenceClasses<ValueWrapper> equivalentInfo;
@@ -889,19 +882,15 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
/// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
/// beginning the alias and equivalence sets only contain `v` itself.
void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
- DenseSet<Value> selfSet;
- selfSet.insert(v);
- aliasInfo.try_emplace(v, selfSet);
+ aliasInfo.insert(v);
equivalentInfo.insert(v);
}
/// Insert an info entry for `newValue` and merge its alias set with that of
/// `alias`.
void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
- assert(aliasInfo.find(alias) != aliasInfo.end() && "Missing alias entry");
createAliasInfoEntry(newValue);
- mergeAliases(newValue, alias);
- mergeAliasesToFixedPoint();
+ aliasInfo.unionSets(newValue, alias);
}
/// Insert an info entry for `newValue` and merge its alias set with that of
@@ -920,7 +909,7 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
LDBG("----Start aliasesNonWriteableBuffer\n");
LDBG("-------for -> #" << operand.getOperandNumber() << ": "
<< printOperationInfo(operand.getOwner()) << '\n');
- for (Value v : getAliasInfoRef(operand.get())) {
+ for (Value v : getAliases(operand.get())) {
LDBG("-----------examine: " << printValueInfo(v) << '\n');
if (auto bbArg = v.dyn_cast<BlockArgument>()) {
if (getInPlace(bbArg) == InPlaceSpec::True) {
@@ -948,7 +937,7 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
LDBG("----Start aliasesInPlaceWrite\n");
LDBG("-------for : " << printValueInfo(value) << '\n');
- for (Value v : getAliasInfoRef(value)) {
+ for (Value v : getAliases(value)) {
for (auto &use : v.getUses()) {
if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) {
LDBG("-----------wants to bufferize to inPlace write: "
@@ -967,8 +956,7 @@ void BufferizationAliasInfo::bufferizeInPlace(OpResult result,
OpOperand &operand,
BufferRelation bufferRelation) {
setInPlaceOpResult(result, InPlaceSpec::True);
- if (mergeAliases(result, operand.get()))
- mergeAliasesToFixedPoint();
+ aliasInfo.unionSets(result, operand.get());
// Dump the updated alias analysis.
LLVM_DEBUG(dumpAliases());
if (bufferRelation == BufferRelation::Equivalent)
@@ -1009,7 +997,7 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
// opToBufferize is not yet inplace, we want to determine if it can be inplace
// so we also consider all its write uses, not just the inplace ones.
DenseSet<OpOperand *> usesWrite;
- for (Value vWrite : getAliasInfoRef(root)) {
+ for (Value vWrite : getAliases(root)) {
for (auto &uWrite : vWrite.getUses()) {
if (!bufferizesToMemoryWrite(uWrite))
continue;
@@ -1018,7 +1006,7 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
usesWrite.insert(&uWrite);
}
}
- for (Value vWrite : getAliasInfoRef(result))
+ for (Value vWrite : getAliases(result))
for (auto &uWrite : vWrite.getUses())
if (bufferizesToMemoryWrite(uWrite, InPlaceSpec::True))
usesWrite.insert(&uWrite);
@@ -1027,12 +1015,12 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
// opToBufferize is not yet inplace, we want to determine if it can be inplace
// so we also consider all read uses of its result.
DenseSet<OpOperand *> usesRead;
- auto &aliasListRead = getAliasInfoRef(root);
+ auto aliasListRead = getAliases(root);
for (Value vRead : aliasListRead)
for (auto &uRead : vRead.getUses())
if (bufferizesToMemoryRead(uRead))
usesRead.insert(&uRead);
- for (Value vRead : getAliasInfoRef(result))
+ for (Value vRead : getAliases(result))
for (auto &uRead : vRead.getUses())
if (bufferizesToMemoryRead(uRead))
usesRead.insert(&uRead);
@@ -1116,16 +1104,21 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
}
void BufferizationAliasInfo::printAliases(raw_ostream &os) const {
- os << "\n/========================== AliasInfo "
- "==========================\n";
- for (auto it : aliasInfo) {
- os << "|\n| -- source: " << printValueInfo(it.getFirst(), /*prefix=*/false)
+ os << "\n/===================== AliasInfo =====================\n";
+ for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) {
+ if (!it->isLeader())
+ continue;
+ Value leader = it->getData();
+ os << "|\n| -- leader: " << printValueInfo(leader, /*prefix=*/false)
<< '\n';
- for (auto v : it.getSecond())
- os << "| ---- target: " << printValueInfo(v, /*prefix=*/false) << '\n';
+ for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
+ mit != meit; ++mit) {
+ Value v = static_cast<Value>(*mit);
+ os << "| ---- equivalent member: " << printValueInfo(v, /*prefix=*/false)
+ << '\n';
+ }
}
- os << "|\n\\====================== End AliasInfo "
- "======================\n\n";
+ os << "\n/===================== End AliasInfo =====================\n\n";
}
void BufferizationAliasInfo::printEquivalences(raw_ostream &os) const {
@@ -1148,37 +1141,16 @@ void BufferizationAliasInfo::printEquivalences(raw_ostream &os) const {
os << "|\n\\***************** End Equivalent Buffers *****************\n\n";
}
-DenseSet<Value> &BufferizationAliasInfo::getAliasInfoRef(Value v) {
- auto it = aliasInfo.find(v);
- if (it == aliasInfo.end())
- llvm_unreachable("Missing alias");
- return it->getSecond();
-}
-
-/// Union all the aliasing sets of all aliases of v1 and v2.
-bool BufferizationAliasInfo::mergeAliases(Value v1, Value v2) {
- // Avoid invalidation of iterators by pre unioning the aliases for v1 and v2.
- bool changed = set_union(getAliasInfoRef(v1), getAliasInfoRef(v2)) ||
- set_union(getAliasInfoRef(v2), getAliasInfoRef(v1));
- for (auto v : getAliasInfoRef(v1))
- if (v != v1)
- changed |= set_union(getAliasInfoRef(v), getAliasInfoRef(v2));
- for (auto v : getAliasInfoRef(v2))
- if (v != v2)
- changed |= set_union(getAliasInfoRef(v), getAliasInfoRef(v1));
- return changed;
-}
-
-/// Iteratively merge alias sets until a fixed-point.
-void BufferizationAliasInfo::mergeAliasesToFixedPoint() {
- while (true) {
- bool changed = false;
- for (auto it : aliasInfo)
- for (auto v : it.getSecond())
- changed |= mergeAliases(it.getFirst(), v);
- if (!changed)
- break;
+BufferizationAliasInfo::EquivalenceClassRangeType
+BufferizationAliasInfo::getAliases(Value v) const {
+ DenseSet<Value> res;
+ auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
+ for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
+ mit != meit; ++mit) {
+ res.insert(static_cast<Value>(*mit));
}
+ return BufferizationAliasInfo::EquivalenceClassRangeType(
+ aliasInfo.member_begin(it), aliasInfo.member_end());
}
/// This is one particular type of relationship between ops on tensors that
More information about the Mlir-commits
mailing list