[llvm] [mlir] [ADT] Make DenseMap/DenseSet more resilient agains OOM situations (PR #107251)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 10 07:53:49 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-adt
Author: Marc Auberer (marcauberer)
<details>
<summary>Changes</summary>
This patch hardens DenseMap and DenseSet against OOM scenarios, outlined in [this RFC](https://discourse.llvm.org/t/rfc-malfunction-safe-densemap-denseset/81036). Feel free to comment any comments/suggestions/objections there.
As discussed in the RFC, we would eventually like to have malfunction tests to ensure malfunction-safety. This is still in internal discussion on our side. Therefore we would like to contribute the improvements with the existing functional tests for now.
---
Patch is 53.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107251.diff
4 Files Affected:
- (modified) llvm/include/llvm/ADT/DenseMap.h (+378-220)
- (modified) llvm/include/llvm/ADT/DenseSet.h (+91-30)
- (modified) llvm/include/llvm/TextAPI/SymbolSet.h (+8-6)
- (modified) mlir/include/mlir/Support/LLVM.h (+7-3)
``````````diff
diff --git a/llvm/include/llvm/ADT/DenseMap.h b/llvm/include/llvm/ADT/DenseMap.h
index 00290c9dd0a585..6cc857f906e910 100644
--- a/llvm/include/llvm/ADT/DenseMap.h
+++ b/llvm/include/llvm/ADT/DenseMap.h
@@ -16,6 +16,7 @@
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/EpochTracker.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/AlignOf.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/MathExtras.h"
@@ -36,8 +37,6 @@ namespace llvm {
namespace detail {
-// We extend a pair to allow users to override the bucket type with their own
-// implementation without requiring two members.
template <typename KeyT, typename ValueT>
struct DenseMapPair : public std::pair<KeyT, ValueT> {
using std::pair<KeyT, ValueT>::pair;
@@ -48,16 +47,72 @@ struct DenseMapPair : public std::pair<KeyT, ValueT> {
const ValueT &getSecond() const { return std::pair<KeyT, ValueT>::second; }
};
+// OOM safe DenseMap bucket
+// flags to indicate if key and value are constructed
+// this is necessary since assigning empty key or tombstone key throws for some
+// key types
+
+template <typename KeyT, typename ValueT, typename Enable = void>
+struct DenseMapPairImpl;
+
+// part I: helpers to decide which specialization of DenseMapPairImpl to use
+template <typename KeyT, typename ValueT>
+using EnableIfTriviallyDestructibleMap =
+ typename std::enable_if_t<std::is_trivially_destructible_v<KeyT> &&
+ std::is_trivially_destructible_v<ValueT>>;
+template <typename KeyT, typename ValueT>
+using EnableIfNotTriviallyDestructibleMap =
+ typename std::enable_if_t<!std::is_trivially_destructible_v<KeyT> ||
+ !std::is_trivially_destructible_v<ValueT>>;
+
+// OOM safe DenseMap bucket
+// part II: specialization for trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+ EnableIfTriviallyDestructibleMap<KeyT, ValueT>>
+ : public DenseMapPair<KeyT, ValueT> {
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool) {}
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const { return false; }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool) {}
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const { return false; }
+};
+
+// OOM safe DenseMap bucket
+// part III: specialization for non-trivially destructible types
+template <typename KeyT, typename ValueT>
+struct DenseMapPairImpl<KeyT, ValueT,
+ EnableIfNotTriviallyDestructibleMap<KeyT, ValueT>>
+ : DenseMapPair<KeyT, ValueT> {
+private:
+ bool m_keyConstructed;
+ bool m_valueConstructed;
+
+public:
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setKeyConstructed(bool constructed) {
+ m_keyConstructed = constructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isKeyConstructed() const {
+ return m_keyConstructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE void setValueConstructed(bool constructed) {
+ m_valueConstructed = constructed;
+ }
+ LLVM_ATTRIBUTE_ALWAYS_INLINE bool isValueConstructed() const {
+ return m_valueConstructed;
+ }
+};
+
} // end namespace detail
template <typename KeyT, typename ValueT,
typename KeyInfoT = DenseMapInfo<KeyT>,
- typename Bucket = llvm::detail::DenseMapPair<KeyT, ValueT>,
+ typename BucketT = llvm::detail::DenseMapPairImpl<KeyT, ValueT>,
+ typename BucketBaseT = llvm::detail::DenseMapPair<KeyT, ValueT>,
bool IsConst = false>
class DenseMapIterator;
template <typename DerivedT, typename KeyT, typename ValueT, typename KeyInfoT,
- typename BucketT>
+ typename BucketT, typename BucketBaseT>
class DenseMapBase : public DebugEpochBase {
template <typename T>
using const_arg_type_t = typename const_pointer_or_const_ref<T>::type;
@@ -66,11 +121,12 @@ class DenseMapBase : public DebugEpochBase {
using size_type = unsigned;
using key_type = KeyT;
using mapped_type = ValueT;
- using value_type = BucketT;
+ using value_type = BucketBaseT;
- using iterator = DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT>;
+ using iterator =
+ DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT>;
using const_iterator =
- DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, true>;
+ DenseMapIterator<KeyT, ValueT, KeyInfoT, BucketT, BucketBaseT, true>;
inline iterator begin() {
// When the map is empty, avoid the overhead of advancing/retreating past
@@ -109,7 +165,8 @@ class DenseMapBase : public DebugEpochBase {
void clear() {
incrementEpoch();
- if (getNumEntries() == 0 && getNumTombstones() == 0) return;
+ if (getNumEntries() == 0 && getNumTombstones() == 0)
+ return;
// If the capacity of the array is huge, and the # elements used is small,
// shrink the array.
@@ -119,24 +176,20 @@ class DenseMapBase : public DebugEpochBase {
}
const KeyT EmptyKey = getEmptyKey();
- if (std::is_trivially_destructible<ValueT>::value) {
+ if constexpr (std::is_trivially_destructible_v<ValueT>) {
// Use a simpler loop when values don't need destruction.
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P)
P->getFirst() = EmptyKey;
} else {
- const KeyT TombstoneKey = getTombstoneKey();
- unsigned NumEntries = getNumEntries();
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey)) {
- if (!KeyInfoT::isEqual(P->getFirst(), TombstoneKey)) {
+ if (P->isValueConstructed()) {
P->getSecond().~ValueT();
- --NumEntries;
+ P->setValueConstructed(false);
}
P->getFirst() = EmptyKey;
}
}
- assert(NumEntries == 0 && "Node count imbalance!");
- (void)NumEntries;
}
setNumEntries(0);
setNumTombstones(0);
@@ -172,15 +225,14 @@ class DenseMapBase : public DebugEpochBase {
/// The DenseMapInfo is responsible for supplying methods
/// getHashValue(LookupKeyT) and isEqual(LookupKeyT, KeyT) for each key
/// type used.
- template<class LookupKeyT>
- iterator find_as(const LookupKeyT &Val) {
+ template <class LookupKeyT> iterator find_as(const LookupKeyT &Val) {
if (BucketT *Bucket = doFind(Val))
return makeIterator(
Bucket, shouldReverseIterate<KeyT>() ? getBuckets() : getBucketsEnd(),
*this, true);
return end();
}
- template<class LookupKeyT>
+ template <class LookupKeyT>
const_iterator find_as(const LookupKeyT &Val) const {
if (const BucketT *Bucket = doFind(Val))
return makeConstIterator(
@@ -223,7 +275,7 @@ class DenseMapBase : public DebugEpochBase {
// The value is constructed in-place if the key is not in the map, otherwise
// it is not moved.
template <typename... Ts>
- std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&... Args) {
+ std::pair<iterator, bool> try_emplace(KeyT &&Key, Ts &&...Args) {
BucketT *TheBucket;
if (LookupBucketFor(Key, TheBucket))
return std::make_pair(makeIterator(TheBucket,
@@ -248,7 +300,7 @@ class DenseMapBase : public DebugEpochBase {
// The value is constructed in-place if the key is not in the map, otherwise
// it is not moved.
template <typename... Ts>
- std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&... Args) {
+ std::pair<iterator, bool> try_emplace(const KeyT &Key, Ts &&...Args) {
BucketT *TheBucket;
if (LookupBucketFor(Key, TheBucket))
return std::make_pair(makeIterator(TheBucket,
@@ -297,8 +349,7 @@ class DenseMapBase : public DebugEpochBase {
}
/// insert - Range insertion of pairs.
- template<typename InputIt>
- void insert(InputIt I, InputIt E) {
+ template <typename InputIt> void insert(InputIt I, InputIt E) {
for (; I != E; ++I)
insert(*I);
}
@@ -341,17 +392,22 @@ class DenseMapBase : public DebugEpochBase {
return false; // not in map.
TheBucket->getSecond().~ValueT();
+ TheBucket->setValueConstructed(false);
TheBucket->getFirst() = getTombstoneKey();
decrementNumEntries();
incrementNumTombstones();
return true;
}
void erase(iterator I) {
- BucketT *TheBucket = &*I;
- TheBucket->getSecond().~ValueT();
- TheBucket->getFirst() = getTombstoneKey();
- decrementNumEntries();
- incrementNumTombstones();
+ BucketT *TheBucket = static_cast<BucketT *>(&*I);
+ // Iterator can point to nullptr in case of memory malfunctions
+ if (TheBucket != nullptr) {
+ TheBucket->getSecond().~ValueT();
+ TheBucket->setValueConstructed(false);
+ TheBucket->getFirst() = getTombstoneKey();
+ decrementNumEntries();
+ incrementNumTombstones();
+ }
}
LLVM_DEPRECATED("Use [Key] instead", "[Key]")
@@ -404,15 +460,24 @@ class DenseMapBase : public DebugEpochBase {
DenseMapBase() = default;
void destroyAll() {
- if (getNumBuckets() == 0) // Nothing to do.
- return;
-
- const KeyT EmptyKey = getEmptyKey(), TombstoneKey = getTombstoneKey();
for (BucketT *P = getBuckets(), *E = getBucketsEnd(); P != E; ++P) {
- if (!KeyInfoT::isEqual(P->getFirst(), EmptyKey) &&
- !KeyInfoT::isEqual(P->getFirst(), TombstoneKey))
+ if (P->isValueConstructed()) {
P->getSecond().~ValueT();
- P->getFirst().~KeyT();
+ P->setValueConstructed(false);
+ }
+ if (P->isKeyConstructed()) {
+ P->getFirst().~KeyT();
+ P->setKeyConstructed(false);
+ }
+ }
+ }
+
+ void initUnitialized() {
+ BucketT *B = getBuckets();
+ BucketT *E = getBucketsEnd();
+ for (; B != E; ++B) {
+ B->setKeyConstructed(false);
+ B->setValueConstructed(false);
}
}
@@ -420,11 +485,19 @@ class DenseMapBase : public DebugEpochBase {
setNumEntries(0);
setNumTombstones(0);
- assert((getNumBuckets() & (getNumBuckets()-1)) == 0 &&
+ assert((getNumBuckets() & (getNumBuckets() - 1)) == 0 &&
"# initial buckets must be a power of two!");
const KeyT EmptyKey = getEmptyKey();
- for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B)
+#ifndef NDEBUG
+ for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
+ assert(!B->isKeyConstructed());
+ assert(!B->isValueConstructed());
+ }
+#endif
+ for (BucketT *B = getBuckets(), *E = getBucketsEnd(); B != E; ++B) {
::new (&B->getFirst()) KeyT(EmptyKey);
+ B->setKeyConstructed(true);
+ }
}
/// Returns the number of buckets to allocate to ensure that the DenseMap can
@@ -454,18 +527,23 @@ class DenseMapBase : public DebugEpochBase {
assert(!FoundVal && "Key already in new map?");
DestBucket->getFirst() = std::move(B->getFirst());
::new (&DestBucket->getSecond()) ValueT(std::move(B->getSecond()));
+ DestBucket->setValueConstructed(true);
incrementNumEntries();
-
- // Free the value.
+ }
+ if (B->isValueConstructed()) {
B->getSecond().~ValueT();
+ B->setValueConstructed(false);
+ }
+ if (B->isKeyConstructed()) {
+ B->getFirst().~KeyT();
+ B->setKeyConstructed(false);
}
- B->getFirst().~KeyT();
}
}
template <typename OtherBaseT>
- void copyFrom(
- const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT> &other) {
+ void copyFrom(const DenseMapBase<OtherBaseT, KeyT, ValueT, KeyInfoT, BucketT,
+ BucketBaseT> &other) {
assert(&other != this);
assert(getNumBuckets() == other.getNumBuckets());
@@ -480,10 +558,13 @@ class DenseMapBase : public DebugEpochBase {
for (size_t i = 0; i < getNumBuckets(); ++i) {
::new (&getBuckets()[i].getFirst())
KeyT(other.getBuckets()[i].getFirst());
+ getBuckets()[i].setKeyConstructed(true);
if (!KeyInfoT::isEqual(getBuckets()[i].getFirst(), getEmptyKey()) &&
- !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey()))
+ !KeyInfoT::isEqual(getBuckets()[i].getFirst(), getTombstoneKey())) {
::new (&getBuckets()[i].getSecond())
ValueT(other.getBuckets()[i].getSecond());
+ getBuckets()[i].setValueConstructed(true);
+ }
}
}
@@ -491,7 +572,7 @@ class DenseMapBase : public DebugEpochBase {
return KeyInfoT::getHashValue(Val);
}
- template<typename LookupKeyT>
+ template <typename LookupKeyT>
static unsigned getHashValue(const LookupKeyT &Val) {
return KeyInfoT::getHashValue(Val);
}
@@ -502,14 +583,11 @@ class DenseMapBase : public DebugEpochBase {
return KeyInfoT::getEmptyKey();
}
- static const KeyT getTombstoneKey() {
- return KeyInfoT::getTombstoneKey();
- }
+ static const KeyT getTombstoneKey() { return KeyInfoT::getTombstoneKey(); }
private:
- iterator makeIterator(BucketT *P, BucketT *E,
- DebugEpochBase &Epoch,
- bool NoAdvance=false) {
+ iterator makeIterator(BucketT *P, BucketT *E, DebugEpochBase &Epoch,
+ bool NoAdvance = false) {
if (shouldReverseIterate<KeyT>()) {
BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
return iterator(B, E, Epoch, NoAdvance);
@@ -519,7 +597,7 @@ class DenseMapBase : public DebugEpochBase {
const_iterator makeConstIterator(const BucketT *P, const BucketT *E,
const DebugEpochBase &Epoch,
- const bool NoAdvance=false) const {
+ const bool NoAdvance = false) const {
if (shouldReverseIterate<KeyT>()) {
const BucketT *B = P == getBucketsEnd() ? getBuckets() : P + 1;
return const_iterator(B, E, Epoch, NoAdvance);
@@ -535,13 +613,9 @@ class DenseMapBase : public DebugEpochBase {
static_cast<DerivedT *>(this)->setNumEntries(Num);
}
- void incrementNumEntries() {
- setNumEntries(getNumEntries() + 1);
- }
+ void incrementNumEntries() { setNumEntries(getNumEntries() + 1); }
- void decrementNumEntries() {
- setNumEntries(getNumEntries() - 1);
- }
+ void decrementNumEntries() { setNumEntries(getNumEntries() - 1); }
unsigned getNumTombstones() const {
return static_cast<const DerivedT *>(this)->getNumTombstones();
@@ -551,65 +625,52 @@ class DenseMapBase : public DebugEpochBase {
static_cast<DerivedT *>(this)->setNumTombstones(Num);
}
- void incrementNumTombstones() {
- setNumTombstones(getNumTombstones() + 1);
- }
+ void incrementNumTombstones() { setNumTombstones(getNumTombstones() + 1); }
- void decrementNumTombstones() {
- setNumTombstones(getNumTombstones() - 1);
- }
+ void decrementNumTombstones() { setNumTombstones(getNumTombstones() - 1); }
const BucketT *getBuckets() const {
return static_cast<const DerivedT *>(this)->getBuckets();
}
- BucketT *getBuckets() {
- return static_cast<DerivedT *>(this)->getBuckets();
- }
+ BucketT *getBuckets() { return static_cast<DerivedT *>(this)->getBuckets(); }
unsigned getNumBuckets() const {
return static_cast<const DerivedT *>(this)->getNumBuckets();
}
- BucketT *getBucketsEnd() {
- return getBuckets() + getNumBuckets();
- }
+ BucketT *getBucketsEnd() { return getBuckets() + getNumBuckets(); }
const BucketT *getBucketsEnd() const {
return getBuckets() + getNumBuckets();
}
- void grow(unsigned AtLeast) {
- static_cast<DerivedT *>(this)->grow(AtLeast);
- }
+ void grow(unsigned AtLeast) { static_cast<DerivedT *>(this)->grow(AtLeast); }
- void shrink_and_clear() {
- static_cast<DerivedT *>(this)->shrink_and_clear();
- }
+ void shrink_and_clear() { static_cast<DerivedT *>(this)->shrink_and_clear(); }
template <typename KeyArg, typename... ValueArgs>
BucketT *InsertIntoBucket(BucketT *TheBucket, KeyArg &&Key,
- ValueArgs &&... Values) {
- TheBucket = InsertIntoBucketImpl(Key, Key, TheBucket);
-
+ ValueArgs &&...Values) {
+ TheBucket = InsertIntoBucketImpl(Key, TheBucket);
TheBucket->getFirst() = std::forward<KeyArg>(Key);
::new (&TheBucket->getSecond()) ValueT(std::forward<ValueArgs>(Values)...);
+ TheBucket->setValueConstructed(true);
return TheBucket;
}
template <typename LookupKeyT>
BucketT *InsertIntoBucketWithLookup(BucketT *TheBucket, KeyT &&Key,
ValueT &&Value, LookupKeyT &Lookup) {
- TheBucket = InsertIntoBucketImpl(Key, Lookup, TheBucket);
-
+ TheBucket = InsertIntoBucketImpl(Lookup, TheBucket);
TheBucket->getFirst() = std::move(Key);
::new (&TheBucket->getSecond()) ValueT(std::move(Value));
+ TheBucket->setValueConstructed(true);
return TheBucket;
}
template <typename LookupKeyT>
- BucketT *InsertIntoBucketImpl(const KeyT &Key, const LookupKeyT &Lookup,
- BucketT *TheBucket) {
+ BucketT *InsertIntoBucketImpl(const LookupKeyT &Lookup, BucketT *TheBucket) {
incrementEpoch();
// If the load of the hash table is more than 3/4, or if fewer than 1/8 of
@@ -627,8 +688,9 @@ class DenseMapBase : public DebugEpochBase {
this->grow(NumBuckets * 2);
LookupBucketFor(Lookup, TheBucket);
NumBuckets = getNumBuckets();
- } else if (LLVM_UNLIKELY(NumBuckets-(NewNumEntries+getNumTombstones()) <=
- NumBuckets/8)) {
+ } else if (LLVM_UNLIKELY(NumBuckets -
+ (NewNumEntries + getNumTombstones()) <=
+ NumBuckets / 8)) {
this->grow(NumBuckets);
LookupBucketFor(Lookup, TheBucket);
}
@@ -696,7 +758,7 @@ class DenseMapBase : public DebugEpochBase {
!KeyInfoT::isEqual(Val, TombstoneKey) &&
"Empty/Tombstone value shouldn't be inserted into map!");
- unsigned BucketNo = getHashValue(Val) & (NumBuckets-1);
+ unsigned BucketNo = getHashValue(Val) & (NumBuckets - 1);
unsigned ProbeAmt = 1;
while (true) {
BucketT *ThisBucket = BucketsPtr + BucketNo;
@@ -719,23 +781,63 @@ class DenseMapBase : public DebugEpochBase {
// prefer to return it than something that would require more probing.
if (KeyInfoT::isEqual(ThisBucket->getFirst(), TombstoneKey) &&
!FoundTombstone)
- FoundTombstone = ThisBucket; // Remember the first tombstone found.
+ FoundTombstone = ThisBucket; // Remember the first tombstone found.
// Otherwise, it's a hash collision or a tombstone, continue quadratic
// probing.
BucketNo += ProbeAmt++;
- BucketNo &= (NumBuckets-1);
+ BucketNo &= (NumBuckets - 1);
}
}
+protected:
+ // helper class to guarantee deallocation of buffer
+ class ReleaseOldBuffer {
+ BucketT *m_buckets;
+ unsigned m_numBuckets;
+
+ public:
+ ReleaseOldBuffer(BucketT *buckets, unsigned numBuckets)
+ : m_buckets(buckets), m_numBuckets(numBuckets) {}
+ ~ReleaseOldBuffer() {
+#ifndef NDEBUG
+ memset((void *)m_buckets, 0x5a, sizeof(BucketT) * m_numBuckets);
+#endif
+ // Free the old table.
+ size_t const alignment{alignof(BucketT)};
+ deallocate_buffer(static_cast<void *>(m_buckets),
+ sizeof(BucketT) * m_numBuckets, alignment);
+ }
+ };
+
+ // helper class to guarantee destruction of bucket content
+ class ReleaseOldBuckets {
+ BucketT *m_buckets;
+ unsigned m_numBuckets;
+
+ public:
+ ReleaseOldBuckets(BucketT *buckets, unsigned numBuckets)
+ : m_buckets(buckets), m_numBuckets(numBuckets) {}
+ ~ReleaseOldBuckets() {
+ for (BucketT *B = m_buckets, *E = m_buckets + m_numBuckets; B != E; ++B) {
+ if (B->isValueConstructed()) {
+ B->getSecond().~ValueT();
+ B->setValueConstructed(false);
+ }
+ if (B->isKeyConstructed()) {
+ B->getFirst().~KeyT();
+ B->setKeyConstructed(false);
+ }
+ }
+ }
+ };
+
public:
/// Return the approximate size (in bytes) of the actual map.
/// This is just the raw memory used by DenseMap.
/// If entries are pointers to objects, the size of the referenced objects
/// are not included.
- size_t getMemorySize() const {
- return getNumBuckets() * sizeof(BucketT);
- }
+ size...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/107251
More information about the llvm-commits
mailing list