[Mlir-commits] [mlir] e3889b3 - [mlir][Linalg] Replace DenseSet by UnionFind in ComprehensiveBufferize - NFC

Nicolas Vasilache llvmlistbot at llvm.org
Wed Sep 15 03:36:00 PDT 2021


Author: Nicolas Vasilache
Date: 2021-09-15T10:35:54Z
New Revision: e3889b30590a8d61eec42b08bbe3708322c15205

URL: https://github.com/llvm/llvm-project/commit/e3889b30590a8d61eec42b08bbe3708322c15205
DIFF: https://github.com/llvm/llvm-project/commit/e3889b30590a8d61eec42b08bbe3708322c15205.diff

LOG: [mlir][Linalg] Replace DenseSet by UnionFind in ComprehensiveBufferize - NFC

AliasInfo can now use union-find for a much more efficient implementation.
This brings no functional changes but large performance gains on more complex examples.

Differential Revision: https://reviews.llvm.org/D109819

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index b165ec752491..1ca179ce8ec3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -780,18 +780,25 @@ class BufferizationAliasInfo {
   void dumpEquivalences() const { printEquivalences(llvm::errs()); }
 
 private:
-  /// Check that aliasInfo for `v` exists and return a reference to it.
-  DenseSet<Value> &getAliasInfoRef(Value v);
-
-  const DenseSet<Value> &getAliasInfoRef(Value v) const {
-    return const_cast<BufferizationAliasInfo *>(this)->getAliasInfoRef(v);
-  }
-
-  /// Union all the aliasing sets of all aliases of v1 and v2.
-  bool mergeAliases(Value v1, Value v2);
+  /// llvm::EquivalenceClasses wants comparable elements because it uses
+  /// std::set as the underlying impl.
+  /// ValueWrapper wraps Value and uses pointer comparison on the defining op.
+  /// This is a poor man's comparison but it's not like UnionFind needs ordering
+  /// anyway ..
+  struct ValueWrapper {
+    ValueWrapper(Value val) : v(val) {}
+    operator Value() const { return v; }
+    bool operator<(const ValueWrapper &wrap) const {
+      return v.getImpl() < wrap.v.getImpl();
+    }
+    bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; }
+    Value v;
+  };
 
-  /// Iteratively merge alias sets until a fixed-point.
-  void mergeAliasesToFixedPoint();
+  using EquivalenceClassRangeType = llvm::iterator_range<
+      llvm::EquivalenceClasses<ValueWrapper>::member_iterator>;
+  /// Check that aliasInfo for `v` exists and return a reference to it.
+  EquivalenceClassRangeType getAliases(Value v) const;
 
   /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
   /// equivalent operand / result and same offset/sizes/strides specification).
@@ -849,24 +856,10 @@ class BufferizationAliasInfo {
                                   OpOperand &aliasingWrite,
                                   const DominanceInfo &domInfo) const;
 
-  /// EquivalenceClasses wants comparable elements because it uses std::set.
-  /// ValueWrapper wraps Value and uses pointer comparison on the defining op.
-  /// This is a poor man's comparison but it's not like UnionFind needs ordering
-  /// anyway ..
-  struct ValueWrapper {
-    ValueWrapper(Value val) : v(val) {}
-    operator Value() const { return v; }
-    bool operator<(const ValueWrapper &wrap) const {
-      return v.getImpl() < wrap.v.getImpl();
-    }
-    bool operator==(const ValueWrapper &wrap) const { return v == wrap.v; }
-    Value v;
-  };
-
   /// Auxiliary structure to store all the values a given value aliases with.
   /// These are the conservative cases that can further decompose into
   /// "equivalent" buffer relationships.
-  DenseMap<Value, DenseSet<Value>> aliasInfo;
+  llvm::EquivalenceClasses<ValueWrapper> aliasInfo;
 
   /// Auxiliary structure to store all the equivalent buffer classes.
   llvm::EquivalenceClasses<ValueWrapper> equivalentInfo;
@@ -889,19 +882,15 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
 /// Add a new entry for `v` in the `aliasInfo` and `equivalentInfo`. In the
 /// beginning the alias and equivalence sets only contain `v` itself.
 void BufferizationAliasInfo::createAliasInfoEntry(Value v) {
-  DenseSet<Value> selfSet;
-  selfSet.insert(v);
-  aliasInfo.try_emplace(v, selfSet);
+  aliasInfo.insert(v);
   equivalentInfo.insert(v);
 }
 
 /// Insert an info entry for `newValue` and merge its alias set with that of
 /// `alias`.
 void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) {
-  assert(aliasInfo.find(alias) != aliasInfo.end() && "Missing alias entry");
   createAliasInfoEntry(newValue);
-  mergeAliases(newValue, alias);
-  mergeAliasesToFixedPoint();
+  aliasInfo.unionSets(newValue, alias);
 }
 
 /// Insert an info entry for `newValue` and merge its alias set with that of
@@ -920,7 +909,7 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
   LDBG("----Start aliasesNonWriteableBuffer\n");
   LDBG("-------for -> #" << operand.getOperandNumber() << ": "
                          << printOperationInfo(operand.getOwner()) << '\n');
-  for (Value v : getAliasInfoRef(operand.get())) {
+  for (Value v : getAliases(operand.get())) {
     LDBG("-----------examine: " << printValueInfo(v) << '\n');
     if (auto bbArg = v.dyn_cast<BlockArgument>()) {
       if (getInPlace(bbArg) == InPlaceSpec::True) {
@@ -948,7 +937,7 @@ bool BufferizationAliasInfo::aliasesNonWriteableBuffer(
 bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
   LDBG("----Start aliasesInPlaceWrite\n");
   LDBG("-------for : " << printValueInfo(value) << '\n');
-  for (Value v : getAliasInfoRef(value)) {
+  for (Value v : getAliases(value)) {
     for (auto &use : v.getUses()) {
       if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) {
         LDBG("-----------wants to bufferize to inPlace write: "
@@ -967,8 +956,7 @@ void BufferizationAliasInfo::bufferizeInPlace(OpResult result,
                                               OpOperand &operand,
                                               BufferRelation bufferRelation) {
   setInPlaceOpResult(result, InPlaceSpec::True);
-  if (mergeAliases(result, operand.get()))
-    mergeAliasesToFixedPoint();
+  aliasInfo.unionSets(result, operand.get());
   // Dump the updated alias analysis.
   LLVM_DEBUG(dumpAliases());
   if (bufferRelation == BufferRelation::Equivalent)
@@ -1009,7 +997,7 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
   // opToBufferize is not yet inplace, we want to determine if it can be inplace
   // so we also consider all its write uses, not just the inplace ones.
   DenseSet<OpOperand *> usesWrite;
-  for (Value vWrite : getAliasInfoRef(root)) {
+  for (Value vWrite : getAliases(root)) {
     for (auto &uWrite : vWrite.getUses()) {
       if (!bufferizesToMemoryWrite(uWrite))
         continue;
@@ -1018,7 +1006,7 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
         usesWrite.insert(&uWrite);
     }
   }
-  for (Value vWrite : getAliasInfoRef(result))
+  for (Value vWrite : getAliases(result))
     for (auto &uWrite : vWrite.getUses())
       if (bufferizesToMemoryWrite(uWrite, InPlaceSpec::True))
         usesWrite.insert(&uWrite);
@@ -1027,12 +1015,12 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
   // opToBufferize is not yet inplace, we want to determine if it can be inplace
   // so we also consider all read uses of its result.
   DenseSet<OpOperand *> usesRead;
-  auto &aliasListRead = getAliasInfoRef(root);
+  auto aliasListRead = getAliases(root);
   for (Value vRead : aliasListRead)
     for (auto &uRead : vRead.getUses())
       if (bufferizesToMemoryRead(uRead))
         usesRead.insert(&uRead);
-  for (Value vRead : getAliasInfoRef(result))
+  for (Value vRead : getAliases(result))
     for (auto &uRead : vRead.getUses())
       if (bufferizesToMemoryRead(uRead))
         usesRead.insert(&uRead);
@@ -1116,16 +1104,21 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
 }
 
 void BufferizationAliasInfo::printAliases(raw_ostream &os) const {
-  os << "\n/========================== AliasInfo "
-        "==========================\n";
-  for (auto it : aliasInfo) {
-    os << "|\n| -- source: " << printValueInfo(it.getFirst(), /*prefix=*/false)
+  os << "\n/===================== AliasInfo =====================\n";
+  for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) {
+    if (!it->isLeader())
+      continue;
+    Value leader = it->getData();
+    os << "|\n| -- leader: " << printValueInfo(leader, /*prefix=*/false)
        << '\n';
-    for (auto v : it.getSecond())
-      os << "| ---- target: " << printValueInfo(v, /*prefix=*/false) << '\n';
+    for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
+         mit != meit; ++mit) {
+      Value v = static_cast<Value>(*mit);
+      os << "| ---- equivalent member: " << printValueInfo(v, /*prefix=*/false)
+         << '\n';
+    }
   }
-  os << "|\n\\====================== End AliasInfo "
-        "======================\n\n";
+  os << "\n/===================== End AliasInfo =====================\n\n";
 }
 
 void BufferizationAliasInfo::printEquivalences(raw_ostream &os) const {
@@ -1148,37 +1141,16 @@ void BufferizationAliasInfo::printEquivalences(raw_ostream &os) const {
   os << "|\n\\***************** End Equivalent Buffers *****************\n\n";
 }
 
-DenseSet<Value> &BufferizationAliasInfo::getAliasInfoRef(Value v) {
-  auto it = aliasInfo.find(v);
-  if (it == aliasInfo.end())
-    llvm_unreachable("Missing alias");
-  return it->getSecond();
-}
-
-/// Union all the aliasing sets of all aliases of v1 and v2.
-bool BufferizationAliasInfo::mergeAliases(Value v1, Value v2) {
-  // Avoid invalidation of iterators by pre unioning the aliases for v1 and v2.
-  bool changed = set_union(getAliasInfoRef(v1), getAliasInfoRef(v2)) ||
-                 set_union(getAliasInfoRef(v2), getAliasInfoRef(v1));
-  for (auto v : getAliasInfoRef(v1))
-    if (v != v1)
-      changed |= set_union(getAliasInfoRef(v), getAliasInfoRef(v2));
-  for (auto v : getAliasInfoRef(v2))
-    if (v != v2)
-      changed |= set_union(getAliasInfoRef(v), getAliasInfoRef(v1));
-  return changed;
-}
-
-/// Iteratively merge alias sets until a fixed-point.
-void BufferizationAliasInfo::mergeAliasesToFixedPoint() {
-  while (true) {
-    bool changed = false;
-    for (auto it : aliasInfo)
-      for (auto v : it.getSecond())
-        changed |= mergeAliases(it.getFirst(), v);
-    if (!changed)
-      break;
+BufferizationAliasInfo::EquivalenceClassRangeType
+BufferizationAliasInfo::getAliases(Value v) const {
+  DenseSet<Value> res;
+  auto it = aliasInfo.findValue(aliasInfo.getLeaderValue(v));
+  for (auto mit = aliasInfo.member_begin(it), meit = aliasInfo.member_end();
+       mit != meit; ++mit) {
+    res.insert(static_cast<Value>(*mit));
   }
+  return BufferizationAliasInfo::EquivalenceClassRangeType(
+      aliasInfo.member_begin(it), aliasInfo.member_end());
 }
 
 /// This is one particular type of relationship between ops on tensors that


        


More information about the Mlir-commits mailing list