[Mlir-commits] [llvm] [mlir] [ADT] Make DenseMap/DenseSet more resilient agains OOM situations (PR #107251)

Marc Auberer llvmlistbot at llvm.org
Wed Sep 11 12:17:11 PDT 2024


https://github.com/marcauberer updated https://github.com/llvm/llvm-project/pull/107251

>From 6c9cc94a92432f3428801ba37e68500c7c0c98d5 Mon Sep 17 00:00:00 2001
From: Marc Auberer <marc.auberer at sap.com>
Date: Wed, 4 Sep 2024 14:07:01 +0000
Subject: [PATCH] [ADT] Make DenseMap/DenseSet more resilient agains OOM
 situations

---
 llvm/include/llvm/ADT/DenseMap.h      | 435 +++++++++++++++++++-------
 llvm/include/llvm/ADT/DenseSet.h      |  68 +++-
 llvm/include/llvm/TextAPI/SymbolSet.h |  14 +-
 mlir/include/mlir/Support/LLVM.h      |  10 +-
 4 files changed, 389 insertions(+), 138 deletions(-)

diff --git a/llvm/include/llvm/ADT/DenseMap.h b/llvm/include/llvm/ADT/DenseMap.h
index fb267cf5cee1c6..6cc857f906e910 100644
--- a/llvm/include/llvm/ADT/DenseMap.h
+++ b/llvm/include/llvm/ADT/DenseMap.h
@@ -16,6 +16,7 @@
 
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/EpochTracker.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/Support/AlignOf.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/MathExtras.h"
@@ -36,8 +37,6 @@ namespace llvm {
 
 namespace detail {
 
-// We extend a pair to allow users to override the bucket type with their own
-// implementation without requiring two members.
 template <typename KeyT, typename ValueT>
 struct DenseMapPair : public std::pair<KeyT, ValueT> {
   using std::pair<KeyT, ValueT>::pair;
@@ -48,16 +47,72 @@ struct DenseMapPair : public std::pair<KeyT, ValueT> {
   const ValueT &getSecond() const { return std::pair<KeyT, ValueT>::second; }
 };
 
+// OOM safe DenseMap bucket
+// flags to indicate if key and value are constructed
+// this is necessary since assigning empty key or tombstone key throws for some
+// key types
+
+template <typename KeyT, typename ValueT, typename Enable = void>
+struct DenseMapPairImpl;
+
+// part I: helpers to decide which specialization of DenseMapPairImpl to use
+template <typename KeyT, typename ValueT>
+using EnableIfTriviallyDestructibleMap =
+    typename std::enable_if_t<std::is_trivially_destructible_v<KeyT> &&
+                              std::is_trivially_destructible_v<ValueT>>;
+template <typename KeyT, typename ValueT>
+using EnableIfNotTriviallyDestructibleMap =
+    typename std::enable_if_t<!std::is_trivially_destructible_v<KeyT> ||
+                              !std::is_trivially_destructible_v<ValueT>>;
+
+// OOM safe DenseMap bucket
+// part II: specialization for trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+                        EnableIfTriviallyDestructibleMap<KeyT, ValueT>>
+    : public DenseMapPair<KeyT, ValueT> {
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+// OOM safe DenseMap bucket
+// part III: specialization for non-trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+                        EnableIfNotTriviallyDestructibleMap<KeyT, ValueT>>
+    : DenseMapPair<KeyT, ValueT> {
+private:
+  bool m_keyConstructed;
+  bool m_valueConstructed;
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool constructed) {
+    m_keyConstructed = constructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const {
+    return m_keyConstructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool constructed) {
+    m_valueConstructed = constructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const {
+    return m_valueConstructed;
+  }
+};
+
 } // end namespace detail
 
 template <typename KeyT, typename ValueT,
           typename KeyInfoT = DenseMapInfo<KeyT>,
-          typename Bucket = llvm::detail::DenseMapPair<KeyT, ValueT>,
+          typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+          typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>,
           bool IsConst = false>
 class DenseMapIterator;
 
 template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
-          typename BucketT>
+          typename BucketT, typename BucketBaseT>
 class DenseMapBase : public DebugEpochBase {
   template <typename T>
   using const_arg_type_t = typename const_pointer_or_const_ref<T>::type;
@@ -66,11 +121,12 @@ class DenseMapBase : public DebugEpochBase {
   using size_type = unsigned;
   using key_type = KeyT;
   using mapped_type = ValueT;
-  using value_type = BucketT;
+  using value_type = BucketBaseT;
 
-  using iterator = DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT>;
+  using iterator =
+      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
   using const_iterator =
-      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, true>;
+      DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT, true>;
 
   inline iterator begin() {
     // When the map is empty, avoid the overhead of advancing/retreating past
@@ -120,24 +176,20 @@ class DenseMapBase : public DebugEpochBase {
     }
 
     const KeyT EmptyKey = getEmptyKey();
-    if (std::is_trivially_destructible<ValueT>::value) {
+    if constexpr (std::is_trivially_destructible_v<ValueT>) {
       // Use a simpler loop when values don't need destruction.
       for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P)
         P->getFirst() = EmptyKey;
     } else {
-      const KeyT TombstoneKey = getTombstoneKey();
-      unsigned NumEntries = getNumEntries();
       for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
         if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey)) {
-          if (!KeyInfoT::isEqual(P->getFirst(), TombstoneKey)) {
+          if (P->isValueConstructed()) {
             P->getSecond().~ValueT();
-            --NumEntries;
+            P->setValueConstructed(false);
           }
           P->getFirst() = EmptyKey;
         }
       }
-      assert(NumEntries == 0 && "Node count imbalance!");
-      (void)NumEntries;
     }
     setNumEntries(0);
     setNumTombstones(0);
@@ -340,17 +392,22 @@ class DenseMapBase : public DebugEpochBase {
       return false; // not in map.
 
     TheBucket->getSecond().~ValueT();
+    TheBucket->setValueConstructed(false);
     TheBucket->getFirst() = getTombstoneKey();
     decrementNumEntries();
     incrementNumTombstones();
     return true;
   }
   void erase(iterator I) {
-    BucketT *TheBucket = &*I;
-    TheBucket->getSecond().~ValueT();
-    TheBucket->getFirst() = getTombstoneKey();
-    decrementNumEntries();
-    incrementNumTombstones();
+    BucketT *TheBucket = static_cast<BucketT *>(&*I);
+    // Iterator can point to nullptr in case of memory malfunctions
+    if (TheBucket != nullptr) {
+      TheBucket->getSecond().~ValueT();
+      TheBucket->setValueConstructed(false);
+      TheBucket->getFirst() = getTombstoneKey();
+      decrementNumEntries();
+      incrementNumTombstones();
+    }
   }
 
   LLVM_DEPRECATED("Use [Key] instead", "[Key]")
@@ -403,15 +460,24 @@ class DenseMapBase : public DebugEpochBase {
   DenseMapBase() = default;
 
   void destroyAll() {
-    if (getNumBuckets() == 0) // Nothing to do.
-      return;
-
-    const KeyT EmptyKey = getEmptyKey(), TombstoneKey = getTombstoneKey();
     for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
-      if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey) &&
-          !KeyInfoT::isEqual(P->getFirst(), TombstoneKey))
+      if (P->isValueConstructed()) {
         P->getSecond().~ValueT();
-      P->getFirst().~KeyT();
+        P->setValueConstructed(false);
+      }
+      if (P->isKeyConstructed()) {
+        P->getFirst().~KeyT();
+        P->setKeyConstructed(false);
+      }
+    }
+  }
+
+  void initUnitialized() {
+    BucketT *B = getBuckets();
+    BucketT *E = getBucketsEnd();
+    for (; B != E; ++B) {
+      B->setKeyConstructed(false);
+      B->setValueConstructed(false);
     }
   }
 
@@ -422,8 +488,16 @@ class DenseMapBase : public DebugEpochBase {
     assert((getNumBuckets() & (getNumBuckets() - 1)) == 0 &&
            "# initial buckets must be a power of two!");
     const KeyT EmptyKey = getEmptyKey();
-    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B)
+#ifndef NDEBUG
+    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
+      assert(!B->isKeyConstructed());
+      assert(!B->isValueConstructed());
+    }
+#endif
+    for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
       ::new (&B->getFirst()) KeyT(EmptyKey);
+      B->setKeyConstructed(true);
+    }
   }
 
   /// Returns the number of buckets to allocate to ensure that the DenseMap can
@@ -453,18 +527,23 @@ class DenseMapBase : public DebugEpochBase {
         assert(!FoundVal && "Key already in new map?");
         DestBucket->getFirst() = std::move(B->getFirst());
         ::new (&DestBucket->getSecond()) ValueT(std::move(B->getSecond()));
+        DestBucket->setValueConstructed(true);
         incrementNumEntries();
-
-        // Free the value.
+      }
+      if (B->isValueConstructed()) {
         B->getSecond().~ValueT();
+        B->setValueConstructed(false);
+      }
+      if (B->isKeyConstructed()) {
+        B->getFirst().~KeyT();
+        B->setKeyConstructed(false);
       }
-      B->getFirst().~KeyT();
     }
   }
 
   template <typename OtherBaseT>
-  void copyFrom(
-      const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT> &other) {
+  void copyFrom(const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT,
+                                   BucketBaseT> &other) {
     assert(&other != this);
     assert(getNumBuckets() == other.getNumBuckets());
 
@@ -479,10 +558,13 @@ class DenseMapBase : public DebugEpochBase {
       for (size_t i = 0; i < getNumBuckets(); ++i) {
         ::new (&getBuckets()[i].getFirst())
             KeyT(other.getBuckets()[i].getFirst());
+        getBuckets()[i].setKeyConstructed(true);
         if (!KeyInfoT::isEqual(getBuckets()[i].getFirst(), getEmptyKey()) &&
-            !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey()))
+            !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey())) {
           ::new (&getBuckets()[i].getSecond())
               ValueT(other.getBuckets()[i].getSecond());
+          getBuckets()[i].setValueConstructed(true);
+        }
       }
   }
 
@@ -570,26 +652,25 @@ class DenseMapBase : public DebugEpochBase {
   template <typename KeyArg, typename... ValueArgs>
   BucketT *InsertIntoBucket(BucketT *TheBucket, KeyArg &&Key,
                             ValueArgs &&...Values) {
-    TheBucket = InsertIntoBucketImpl(Key, Key, TheBucket);
-
+    TheBucket = InsertIntoBucketImpl(Key, TheBucket);
     TheBucket->getFirst() = std::forward<KeyArg>(Key);
     ::new (&TheBucket->getSecond()) ValueT(std::forward<ValueArgs>(Values)...);
+    TheBucket->setValueConstructed(true);
     return TheBucket;
   }
 
   template <typename LookupKeyT>
   BucketT *InsertIntoBucketWithLookup(BucketT *TheBucket, KeyT &&Key,
                                       ValueT &&Value, LookupKeyT &Lookup) {
-    TheBucket = InsertIntoBucketImpl(Key, Lookup, TheBucket);
-
+    TheBucket = InsertIntoBucketImpl(Lookup, TheBucket);
     TheBucket->getFirst() = std::move(Key);
     ::new (&TheBucket->getSecond()) ValueT(std::move(Value));
+    TheBucket->setValueConstructed(true);
     return TheBucket;
   }
 
   template <typename LookupKeyT>
-  BucketT *InsertIntoBucketImpl(const KeyT &Key, const LookupKeyT &Lookup,
-                                BucketT *TheBucket) {
+  BucketT *InsertIntoBucketImpl(const LookupKeyT &Lookup, BucketT *TheBucket) {
     incrementEpoch();
 
     // If the load of the hash table is more than 3/4, or if fewer than 1/8 of
@@ -709,6 +790,48 @@ class DenseMapBase : public DebugEpochBase {
     }
   }
 
+protected:
+  // helper class to guarantee deallocation of buffer
+  class ReleaseOldBuffer {
+    BucketT *m_buckets;
+    unsigned m_numBuckets;
+
+  public:
+    ReleaseOldBuffer(BucketT *buckets, unsigned numBuckets)
+        : m_buckets(buckets), m_numBuckets(numBuckets) {}
+    ~ReleaseOldBuffer() {
+#ifndef NDEBUG
+      memset((void *)m_buckets, 0x5a, sizeof(BucketT) * m_numBuckets);
+#endif
+      // Free the old table.
+      size_t const alignment{alignof(BucketT)};
+      deallocate_buffer(static_cast<void *>(m_buckets),
+                        sizeof(BucketT) * m_numBuckets, alignment);
+    }
+  };
+
+  // helper class to guarantee destruction of bucket content
+  class ReleaseOldBuckets {
+    BucketT *m_buckets;
+    unsigned m_numBuckets;
+
+  public:
+    ReleaseOldBuckets(BucketT *buckets, unsigned numBuckets)
+        : m_buckets(buckets), m_numBuckets(numBuckets) {}
+    ~ReleaseOldBuckets() {
+      for (BucketT *B = m_buckets, *E = m_buckets + m_numBuckets; B != E; ++B) {
+        if (B->isValueConstructed()) {
+          B->getSecond().~ValueT();
+          B->setValueConstructed(false);
+        }
+        if (B->isKeyConstructed()) {
+          B->getFirst().~KeyT();
+          B->setKeyConstructed(false);
+        }
+      }
+    }
+  };
+
 public:
   /// Return the approximate size (in bytes) of the actual map.
   /// This is just the raw memory used by DenseMap.
@@ -724,10 +847,11 @@ class DenseMapBase : public DebugEpochBase {
 /// Equivalent to N calls to RHS.find and N value comparisons. Amortized
 /// complexity is linear, worst case is O(N^2) (if every hash collides).
 template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
-          typename BucketT>
-bool operator==(
-    const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT> &LHS,
-    const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT> &RHS) {
+          typename BucketT, typename BucketBaseT>
+bool operator==(const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT,
+                                   BucketBaseT> &LHS,
+                const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT,
+                                   BucketBaseT> &RHS) {
   if (LHS.size() != RHS.size())
     return false;
 
@@ -744,23 +868,28 @@ bool operator==(
 ///
 /// Equivalent to !(LHS == RHS). See operator== for performance notes.
 template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
-          typename BucketT>
-bool operator!=(
-    const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT> &LHS,
-    const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT> &RHS) {
+          typename BucketT, typename BucketBaseT>
+bool operator!=(const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT,
+                                   BucketBaseT> &LHS,
+                const DenseMapBase<DerivedT, KeyT, ValueT, KeyInfoT, BucketT,
+                                   BucketBaseT> &RHS) {
   return !(LHS == RHS);
 }
 
 template <typename KeyT, typename ValueT,
           typename KeyInfoT = DenseMapInfo<KeyT>,
-          typename BucketT = llvm::detail::DenseMapPair<KeyT, ValueT>>
-class DenseMap : public DenseMapBase<DenseMap<KeyT, ValueT, KeyInfoT, BucketT>,
-                                     KeyT, ValueT, KeyInfoT, BucketT> {
-  friend class DenseMapBase<DenseMap, KeyT, ValueT, KeyInfoT, BucketT>;
+          typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+          typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>>
+class DenseMap : public DenseMapBase<
+                     DenseMap<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>,
+                     KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT> {
+  friend class DenseMapBase<DenseMap, KeyT, ValueT, KeyInfoT, BucketT,
+                            BucketBaseT>;
 
   // Lift some types from the dependent base class into this class for
   // simplicity of referring to them.
-  using BaseT = DenseMapBase<DenseMap, KeyT, ValueT, KeyInfoT, BucketT>;
+  using BaseT =
+      DenseMapBase<DenseMap, KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
 
   BucketT *Buckets;
   unsigned NumEntries;
@@ -773,28 +902,45 @@ class DenseMap : public DenseMapBase<DenseMap<KeyT, ValueT, KeyInfoT, BucketT>,
   explicit DenseMap(unsigned InitialReserve = 0) { init(InitialReserve); }
 
   DenseMap(const DenseMap &other) : BaseT() {
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
     init(0);
     copyFrom(other);
+    CtorGuard.release();
   }
 
   DenseMap(DenseMap &&other) : BaseT() {
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
     init(0);
     swap(other);
+    CtorGuard.release();
   }
 
   template <typename InputIt> DenseMap(const InputIt &I, const InputIt &E) {
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
     init(std::distance(I, E));
     this->insert(I, E);
+    CtorGuard.release();
   }
 
   DenseMap(std::initializer_list<typename BaseT::value_type> Vals) {
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
     init(Vals.size());
     this->insert(Vals.begin(), Vals.end());
+    CtorGuard.release();
   }
 
-  ~DenseMap() {
+  ~DenseMap() { destroy(); }
+
+  void destroy() {
     this->destroyAll();
+    deallocateBuckets();
+  }
+
+  void deallocateBuckets() {
     deallocate_buffer(Buckets, sizeof(BucketT) * NumBuckets, alignof(BucketT));
+    Buckets = nullptr;
+    NumBuckets = 0;
+    NumTombstones = 0;
   }
 
   void swap(DenseMap &RHS) {
@@ -813,16 +959,14 @@ class DenseMap : public DenseMapBase<DenseMap<KeyT, ValueT, KeyInfoT, BucketT>,
   }
 
   DenseMap &operator=(DenseMap &&other) {
-    this->destroyAll();
-    deallocate_buffer(Buckets, sizeof(BucketT) * NumBuckets, alignof(BucketT));
+    destroy();
     init(0);
     swap(other);
     return *this;
   }
 
   void copyFrom(const DenseMap &other) {
-    this->destroyAll();
-    deallocate_buffer(Buckets, sizeof(BucketT) * NumBuckets, alignof(BucketT));
+    destroy();
     if (allocateBuckets(other.NumBuckets)) {
       this->BaseT::copyFrom(other);
     } else {
@@ -852,16 +996,16 @@ class DenseMap : public DenseMapBase<DenseMap<KeyT, ValueT, KeyInfoT, BucketT>,
       this->BaseT::initEmpty();
       return;
     }
-
+    // ensure that old buckets are always released, in normal case as well as in
+    // case of an exception Insert all the old elements.
+    typename BaseT::ReleaseOldBuffer ReleaseOldBuffer(OldBuckets,
+                                                      OldNumBuckets);
+    typename BaseT::ReleaseOldBuckets ReleaseOldBuckets(OldBuckets,
+                                                        OldNumBuckets);
     this->moveFromOldBuckets(OldBuckets, OldBuckets + OldNumBuckets);
-
-    // Free the old table.
-    deallocate_buffer(OldBuckets, sizeof(BucketT) * OldNumBuckets,
-                      alignof(BucketT));
   }
 
   void shrink_and_clear() {
-    unsigned OldNumBuckets = NumBuckets;
     unsigned OldNumEntries = NumEntries;
     this->destroyAll();
 
@@ -874,8 +1018,7 @@ class DenseMap : public DenseMapBase<DenseMap<KeyT, ValueT, KeyInfoT, BucketT>,
       return;
     }
 
-    deallocate_buffer(Buckets, sizeof(BucketT) * OldNumBuckets,
-                      alignof(BucketT));
+    deallocateBuckets();
     init(NewNumBuckets);
   }
 
@@ -893,30 +1036,36 @@ class DenseMap : public DenseMapBase<DenseMap<KeyT, ValueT, KeyInfoT, BucketT>,
   unsigned getNumBuckets() const { return NumBuckets; }
 
   bool allocateBuckets(unsigned Num) {
-    NumBuckets = Num;
-    if (NumBuckets == 0) {
+    if (Num == 0) {
       Buckets = nullptr;
+      NumBuckets = 0;
       return false;
     }
 
     Buckets = static_cast<BucketT *>(
-        allocate_buffer(sizeof(BucketT) * NumBuckets, alignof(BucketT)));
+        allocate_buffer(sizeof(BucketT) * Num, alignof(BucketT)));
+    // update NumBuckets only if reallocation is successful
+    NumBuckets = Num;
+    this->BaseT::initUnitialized();
     return true;
   }
 };
 
 template <typename KeyT, typename ValueT, unsigned InlineBuckets = 4,
           typename KeyInfoT = DenseMapInfo<KeyT>,
-          typename BucketT = llvm::detail::DenseMapPair<KeyT, ValueT>>
+          typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+          typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>>
 class SmallDenseMap
-    : public DenseMapBase<
-          SmallDenseMap<KeyT, ValueT, InlineBuckets, KeyInfoT, BucketT>, KeyT,
-          ValueT, KeyInfoT, BucketT> {
-  friend class DenseMapBase<SmallDenseMap, KeyT, ValueT, KeyInfoT, BucketT>;
+    : public DenseMapBase<SmallDenseMap<KeyT, ValueT, InlineBuckets, KeyInfoT,
+                                        BucketT, BucketBaseT>,
+                          KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT> {
+  friend class DenseMapBase<SmallDenseMap, KeyT, ValueT, KeyInfoT, BucketT,
+                            BucketBaseT>;
 
   // Lift some types from the dependent base class into this class for
   // simplicity of referring to them.
-  using BaseT = DenseMapBase<SmallDenseMap, KeyT, ValueT, KeyInfoT, BucketT>;
+  using BaseT =
+      DenseMapBase<SmallDenseMap, KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
 
   static_assert(isPowerOf2_64(InlineBuckets),
                 "InlineBuckets must be a power of 2.");
@@ -932,35 +1081,46 @@ class SmallDenseMap
 
   /// A "union" of an inline bucket array and the struct representing
   /// a large bucket. This union will be discriminated by the 'Small' bit.
-  AlignedCharArrayUnion<BucketT[InlineBuckets], LargeRep> storage;
+  AlignedCharArrayUnion<BucketT[InlineBuckets]> storage;
+  LargeRep largeRep;
 
 public:
   explicit SmallDenseMap(unsigned NumInitBuckets = 0) {
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
     if (NumInitBuckets > InlineBuckets)
       NumInitBuckets = llvm::bit_ceil(NumInitBuckets);
-    init(NumInitBuckets);
+    initFromCtor(NumInitBuckets);
+    CtorGuard.release();
   }
 
   SmallDenseMap(const SmallDenseMap &other) : BaseT() {
-    init(0);
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
+    initFromCtor(0);
     copyFrom(other);
+    CtorGuard.release();
   }
 
   SmallDenseMap(SmallDenseMap &&other) : BaseT() {
-    init(0);
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
+    initFromCtor(0);
     swap(other);
+    CtorGuard.release();
   }
 
   template <typename InputIt>
   SmallDenseMap(const InputIt &I, const InputIt &E) {
-    init(NextPowerOf2(std::distance(I, E)));
+    auto CtorGuard = llvm::make_scope_exit([&] { destroy(); });
+    initFromCtor(NextPowerOf2(std::distance(I, E)));
     this->insert(I, E);
+    CtorGuard.release();
   }
 
   SmallDenseMap(std::initializer_list<typename BaseT::value_type> Vals)
       : SmallDenseMap(Vals.begin(), Vals.end()) {}
 
-  ~SmallDenseMap() {
+  ~SmallDenseMap() { destroy(); }
+
+  void destroy() {
     this->destroyAll();
     deallocateBuckets();
   }
@@ -994,10 +1154,14 @@ class SmallDenseMap
         std::swap(LHSB->getFirst(), RHSB->getFirst());
         if (hasLHSValue) {
           ::new (&RHSB->getSecond()) ValueT(std::move(LHSB->getSecond()));
+          RHSB->setValueConstructed(true);
           LHSB->getSecond().~ValueT();
+          LHSB->setValueConstructed(false);
         } else if (hasRHSValue) {
           ::new (&LHSB->getSecond()) ValueT(std::move(RHSB->getSecond()));
+          LHSB->setValueConstructed(true);
           RHSB->getSecond().~ValueT();
+          RHSB->setValueConstructed(false);
         }
       }
       return;
@@ -1023,18 +1187,22 @@ class SmallDenseMap
       BucketT *NewB = &LargeSide.getInlineBuckets()[i],
               *OldB = &SmallSide.getInlineBuckets()[i];
       ::new (&NewB->getFirst()) KeyT(std::move(OldB->getFirst()));
+      NewB->setKeyConstructed(true);
       OldB->getFirst().~KeyT();
+      OldB->setKeyConstructed(false);
       if (!KeyInfoT::isEqual(NewB->getFirst(), EmptyKey) &&
           !KeyInfoT::isEqual(NewB->getFirst(), TombstoneKey)) {
         ::new (&NewB->getSecond()) ValueT(std::move(OldB->getSecond()));
+        NewB->setValueConstructed(true);
         OldB->getSecond().~ValueT();
+        OldB->setValueConstructed(false);
       }
     }
 
     // The hard part of moving the small buckets across is done, just move
     // the TmpRep into its new home.
+    ::new (SmallSide.getLargeRep()) LargeRep(std::move(TmpRep));
     SmallSide.Small = false;
-    new (SmallSide.getLargeRep()) LargeRep(std::move(TmpRep));
   }
 
   SmallDenseMap &operator=(const SmallDenseMap &other) {
@@ -1044,33 +1212,34 @@ class SmallDenseMap
   }
 
   SmallDenseMap &operator=(SmallDenseMap &&other) {
-    this->destroyAll();
-    deallocateBuckets();
+    destroy();
     init(0);
     swap(other);
     return *this;
   }
 
   void copyFrom(const SmallDenseMap &other) {
-    this->destroyAll();
-    deallocateBuckets();
-    Small = true;
+    destroy();
     if (other.getNumBuckets() > InlineBuckets) {
-      Small = false;
-      new (getLargeRep()) LargeRep(allocateBuckets(other.getNumBuckets()));
+      allocateBuckets(other.getNumBuckets());
     }
     this->BaseT::copyFrom(other);
   }
 
-  void init(unsigned InitBuckets) {
+  void initFromCtor(unsigned InitBuckets) {
     Small = true;
+    this->BaseT::initUnitialized();
+    init(InitBuckets);
+  }
+
+  void init(unsigned InitBuckets) {
     if (InitBuckets > InlineBuckets) {
-      Small = false;
-      new (getLargeRep()) LargeRep(allocateBuckets(InitBuckets));
+      allocateBuckets(InitBuckets);
     }
     this->BaseT::initEmpty();
   }
 
+public:
   void grow(unsigned AtLeast) {
     if (AtLeast > InlineBuckets)
       AtLeast = std::max<unsigned>(64, NextPowerOf2(AtLeast - 1));
@@ -1085,44 +1254,61 @@ class SmallDenseMap
       // temporary storage. Have the loop move the TmpEnd forward as it goes.
       const KeyT EmptyKey = this->getEmptyKey();
       const KeyT TombstoneKey = this->getTombstoneKey();
+
+      // ensure that old buckets are always released, in normal case as well as
+      // in case of an exception
+      typename BaseT::ReleaseOldBuckets ReleaseOldBuckets(TmpBegin,
+                                                          InlineBuckets);
+
+      for (BucketT *P = TmpBegin, *E = P + InlineBuckets; P != E; ++P) {
+        P->setKeyConstructed(false);
+        P->setValueConstructed(false);
+      }
+
       for (BucketT *P = getBuckets(), *E = P + InlineBuckets; P != E; ++P) {
         if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey) &&
             !KeyInfoT::isEqual(P->getFirst(), TombstoneKey)) {
           assert(size_t(TmpEnd - TmpBegin) < InlineBuckets &&
                  "Too many inline buckets!");
           ::new (&TmpEnd->getFirst()) KeyT(std::move(P->getFirst()));
+          TmpEnd->setKeyConstructed(true);
           ::new (&TmpEnd->getSecond()) ValueT(std::move(P->getSecond()));
+          TmpEnd->setValueConstructed(true);
           ++TmpEnd;
+        }
+        if (P->isValueConstructed()) {
           P->getSecond().~ValueT();
+          P->setValueConstructed(false);
+        }
+        if (P->isKeyConstructed()) {
+          P->getFirst().~KeyT();
+          P->setKeyConstructed(false);
         }
-        P->getFirst().~KeyT();
       }
 
       // AtLeast == InlineBuckets can happen if there are many tombstones,
       // and grow() is used to remove them. Usually we always switch to the
       // large rep here.
       if (AtLeast > InlineBuckets) {
-        Small = false;
-        new (getLargeRep()) LargeRep(allocateBuckets(AtLeast));
+        allocateBuckets(AtLeast);
       }
       this->moveFromOldBuckets(TmpBegin, TmpEnd);
       return;
     }
-
     LargeRep OldRep = std::move(*getLargeRep());
     getLargeRep()->~LargeRep();
-    if (AtLeast <= InlineBuckets) {
-      Small = true;
-    } else {
-      new (getLargeRep()) LargeRep(allocateBuckets(AtLeast));
+    Small = true;
+    // ensure that old buckets are always released, in normal case as well as in
+    // case of an exception
+    typename BaseT::ReleaseOldBuffer ReleaseOldBuffer(OldRep.Buckets,
+                                                      OldRep.NumBuckets);
+    typename BaseT::ReleaseOldBuckets ReleaseOldBuckets(OldRep.Buckets,
+                                                        OldRep.NumBuckets);
+    if (AtLeast > InlineBuckets) {
+      allocateBuckets(AtLeast);
     }
-
     this->moveFromOldBuckets(OldRep.Buckets,
                              OldRep.Buckets + OldRep.NumBuckets);
-
-    // Free the old table.
-    deallocate_buffer(OldRep.Buckets, sizeof(BucketT) * OldRep.NumBuckets,
-                      alignof(BucketT));
   }
 
   void shrink_and_clear() {
@@ -1173,9 +1359,8 @@ class SmallDenseMap
   }
 
   const LargeRep *getLargeRep() const {
-    assert(!Small);
     // Note, same rule about aliasing as with getInlineBuckets.
-    return reinterpret_cast<const LargeRep *>(&storage);
+    return reinterpret_cast<const LargeRep *>(&largeRep);
   }
 
   LargeRep *getLargeRep() {
@@ -1197,6 +1382,7 @@ class SmallDenseMap
   }
 
   void deallocateBuckets() {
+    NumEntries = 0;
     if (Small)
       return;
 
@@ -1204,38 +1390,45 @@ class SmallDenseMap
                       sizeof(BucketT) * getLargeRep()->NumBuckets,
                       alignof(BucketT));
     getLargeRep()->~LargeRep();
+    Small = true;
   }
 
-  LargeRep allocateBuckets(unsigned Num) {
+  void allocateBuckets(unsigned Num) {
     assert(Num > InlineBuckets && "Must allocate more buckets than are inline");
-    LargeRep Rep = {static_cast<BucketT *>(allocate_buffer(
-                        sizeof(BucketT) * Num, alignof(BucketT))),
-                    Num};
-    return Rep;
+    largeRep = {static_cast<BucketT *>(
+                    allocate_buffer(sizeof(BucketT) * Num, alignof(BucketT))),
+                Num};
+    Small = false;
+    this->BaseT::initUnitialized();
   }
 };
 
-template <typename KeyT, typename ValueT, typename KeyInfoT, typename Bucket,
-          bool IsConst>
+template <typename KeyT, typename ValueT, typename KeyInfoT, typename BucketT,
+          typename BucketBaseT, bool IsConst>
 class DenseMapIterator : DebugEpochBase::HandleBase {
-  friend class DenseMapIterator<KeyT, ValueT, KeyInfoT, Bucket, true>;
-  friend class DenseMapIterator<KeyT, ValueT, KeyInfoT, Bucket, false>;
+  friend class DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT,
+                                true>;
+  friend class DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT,
+                                false>;
 
 public:
   using difference_type = ptrdiff_t;
-  using value_type = std::conditional_t<IsConst, const Bucket, Bucket>;
+  using inc_type = std::conditional_t<IsConst, const BucketT, BucketT>;
+  using inc_pointer = inc_type *;
+  using value_type =
+      std::conditional_t<IsConst, const BucketBaseT, BucketBaseT>;
   using pointer = value_type *;
   using reference = value_type &;
   using iterator_category = std::forward_iterator_tag;
 
 private:
-  pointer Ptr = nullptr;
-  pointer End = nullptr;
+  inc_pointer Ptr = nullptr;
+  inc_pointer End = nullptr;
 
 public:
   DenseMapIterator() = default;
 
-  DenseMapIterator(pointer Pos, pointer E, const DebugEpochBase &Epoch,
+  DenseMapIterator(inc_pointer Pos, inc_pointer E, const DebugEpochBase &Epoch,
                    bool NoAdvance = false)
       : DebugEpochBase::HandleBase(&Epoch), Ptr(Pos), End(E) {
     assert(isHandleInSync() && "invalid construction!");
@@ -1254,8 +1447,8 @@ class DenseMapIterator : DebugEpochBase::HandleBase {
   // constructor.
   template <bool IsConstSrc,
             typename = std::enable_if_t<!IsConstSrc && IsConst>>
-  DenseMapIterator(
-      const DenseMapIterator<KeyT, ValueT, KeyInfoT, Bucket, IsConstSrc> &I)
+  DenseMapIterator(const DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT,
+                                          BucketBaseT, IsConstSrc> &I)
       : DebugEpochBase::HandleBase(I), Ptr(I.Ptr), End(I.End) {}
 
   reference operator*() const {
diff --git a/llvm/include/llvm/ADT/DenseSet.h b/llvm/include/llvm/ADT/DenseSet.h
index a307bd8c9198b9..b9a1b0d5087646 100644
--- a/llvm/include/llvm/ADT/DenseSet.h
+++ b/llvm/include/llvm/ADT/DenseSet.h
@@ -27,11 +27,25 @@ namespace llvm {
 
 namespace detail {
 
+// flag to indicate if key is constructed
+// this is necessary since assigning empty key or tombstone key throws for some
+// key types
+
+// Helpers to decide which specialization of DenseSetPair to use
+template <typename KeyT>
+using EnableIfTriviallyDestructibleSet =
+    typename std::enable_if_t<std::is_trivially_destructible_v<KeyT>>;
+template <typename KeyT>
+using EnableIfNotTriviallyDestructibleSet =
+    typename std::enable_if_t<!std::is_trivially_destructible_v<KeyT>>;
+
 struct DenseSetEmpty {};
 
 // Use the empty base class trick so we can create a DenseMap where the buckets
 // contain only a single item.
+
 template <typename KeyT> class DenseSetPair : public DenseSetEmpty {
+private:
   KeyT key;
 
 public:
@@ -39,6 +53,43 @@ template <typename KeyT> class DenseSetPair : public DenseSetEmpty {
   const KeyT &getFirst() const { return key; }
   DenseSetEmpty &getSecond() { return *this; }
   const DenseSetEmpty &getSecond() const { return *this; }
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+template <typename KeyT, typename Enabled = void> class DenseSetPairImpl;
+
+// First specialization for trivially destructible types
+template <typename KeyT>
+class DenseSetPairImpl<KeyT, EnableIfTriviallyDestructibleSet<KeyT>>
+    : public DenseSetPair<KeyT> {
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+// Second specialization for non-trivially destructible types
+template <typename KeyT>
+class DenseSetPairImpl<KeyT, EnableIfNotTriviallyDestructibleSet<KeyT>>
+    : public DenseSetPair<KeyT> {
+public:
+  bool KeyConstructed;
+
+public:
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool Constructed) {
+    KeyConstructed = Constructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const {
+    return KeyConstructed;
+  }
+  LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+  LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
 };
 
 /// Base class for DenseSet and DenseSmallSet.
@@ -52,8 +103,6 @@ template <typename KeyT> class DenseSetPair : public DenseSetEmpty {
 /// DenseMapInfo "concept".
 template <typename ValueT, typename MapTy, typename ValueInfoT>
 class DenseSetImpl {
-  static_assert(sizeof(typename MapTy::value_type) == sizeof(ValueT),
-                "DenseMap buckets unexpectedly large!");
   MapTy TheMap;
 
   template <typename T>
@@ -274,13 +323,14 @@ template <typename ValueT, typename ValueInfoT = DenseMapInfo<ValueT>>
 class DenseSet : public detail::DenseSetImpl<
                      ValueT,
                      DenseMap<ValueT, detail::DenseSetEmpty, ValueInfoT,
+                              detail::DenseSetPairImpl<ValueT>,
                               detail::DenseSetPair<ValueT>>,
                      ValueInfoT> {
-  using BaseT =
-      detail::DenseSetImpl<ValueT,
-                           DenseMap<ValueT, detail::DenseSetEmpty, ValueInfoT,
-                                    detail::DenseSetPair<ValueT>>,
-                           ValueInfoT>;
+  using BaseT = detail::DenseSetImpl<
+      ValueT,
+      DenseMap<ValueT, detail::DenseSetEmpty, ValueInfoT,
+               detail::DenseSetPairImpl<ValueT>, detail::DenseSetPair<ValueT>>,
+      ValueInfoT>;
 
 public:
   using BaseT::BaseT;
@@ -294,11 +344,13 @@ class SmallDenseSet
     : public detail::DenseSetImpl<
           ValueT,
           SmallDenseMap<ValueT, detail::DenseSetEmpty, InlineBuckets,
-                        ValueInfoT, detail::DenseSetPair<ValueT>>,
+                        ValueInfoT, detail::DenseSetPairImpl<ValueT>,
+                        detail::DenseSetPair<ValueT>>,
           ValueInfoT> {
   using BaseT = detail::DenseSetImpl<
       ValueT,
       SmallDenseMap<ValueT, detail::DenseSetEmpty, InlineBuckets, ValueInfoT,
+                    detail::DenseSetPairImpl<ValueT>,
                     detail::DenseSetPair<ValueT>>,
       ValueInfoT>;
 
diff --git a/llvm/include/llvm/TextAPI/SymbolSet.h b/llvm/include/llvm/TextAPI/SymbolSet.h
index 6ccabb90772087..e97b63fc5946f5 100644
--- a/llvm/include/llvm/TextAPI/SymbolSet.h
+++ b/llvm/include/llvm/TextAPI/SymbolSet.h
@@ -48,11 +48,12 @@ template <> struct DenseMapInfo<SymbolsMapKey> {
   }
 };
 
-template <typename DerivedT, typename KeyInfoT, typename BucketT>
+template <typename DerivedT, typename KeyInfoT, typename BucketT,
+          typename BucketBaseT>
 bool operator==(const DenseMapBase<DerivedT, SymbolsMapKey, MachO::Symbol *,
-                                   KeyInfoT, BucketT> &LHS,
+                                   KeyInfoT, BucketT, BucketBaseT> &LHS,
                 const DenseMapBase<DerivedT, SymbolsMapKey, MachO::Symbol *,
-                                   KeyInfoT, BucketT> &RHS) {
+                                   KeyInfoT, BucketT, BucketBaseT> &RHS) {
   if (LHS.size() != RHS.size())
     return false;
   for (const auto &KV : LHS) {
@@ -63,11 +64,12 @@ bool operator==(const DenseMapBase<DerivedT, SymbolsMapKey, MachO::Symbol *,
   return true;
 }
 
-template <typename DerivedT, typename KeyInfoT, typename BucketT>
+template <typename DerivedT, typename KeyInfoT, typename BucketT,
+          typename BucketBaseT>
 bool operator!=(const DenseMapBase<DerivedT, SymbolsMapKey, MachO::Symbol *,
-                                   KeyInfoT, BucketT> &LHS,
+                                   KeyInfoT, BucketT, BucketBaseT> &LHS,
                 const DenseMapBase<DerivedT, SymbolsMapKey, MachO::Symbol *,
-                                   KeyInfoT, BucketT> &RHS) {
+                                   KeyInfoT, BucketT, BucketBaseT> &RHS) {
   return !(LHS == RHS);
 }
 
diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h
index 020c0fba726c84..8e8eb9c8f99ca1 100644
--- a/mlir/include/mlir/Support/LLVM.h
+++ b/mlir/include/mlir/Support/LLVM.h
@@ -50,8 +50,11 @@ class BitVector;
 namespace detail {
 template <typename KeyT, typename ValueT>
 struct DenseMapPair;
+template <typename KeyT, typename ValueT, typename Enable>
+struct DenseMapPairImpl;
 } // namespace detail
-template <typename KeyT, typename ValueT, typename KeyInfoT, typename BucketT>
+template <typename KeyT, typename ValueT, typename KeyInfoT, typename BucketT,
+          typename BucketBaseT>
 class DenseMap;
 template <typename T, typename Enable>
 struct DenseMapInfo;
@@ -122,8 +125,9 @@ template <typename T, typename Enable = void>
 using DenseMapInfo = llvm::DenseMapInfo<T, Enable>;
 template <typename KeyT, typename ValueT,
           typename KeyInfoT = DenseMapInfo<KeyT>,
-          typename BucketT = llvm::detail::DenseMapPair<KeyT, ValueT>>
-using DenseMap = llvm::DenseMap<KeyT, ValueT, KeyInfoT, BucketT>;
+          typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT, void>,
+          typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>>
+using DenseMap = llvm::DenseMap<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
 template <typename ValueT, typename ValueInfoT = DenseMapInfo<ValueT>>
 using DenseSet = llvm::DenseSet<ValueT, ValueInfoT>;
 template <typename T, typename Vector = llvm::SmallVector<T, 0>,



More information about the Mlir-commits mailing list