[llvm] 753a219 - [ImmutableSet] Use IntrusiveRefCntPtr to eliminate some manual refcounting

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Thu May 21 10:14:34 PDT 2020


Author: Benjamin Kramer
Date: 2020-05-21T19:10:22+02:00
New Revision: 753a21928413f8a7e76978cb1354e09150e114e0

URL: https://github.com/llvm/llvm-project/commit/753a21928413f8a7e76978cb1354e09150e114e0
DIFF: https://github.com/llvm/llvm-project/commit/753a21928413f8a7e76978cb1354e09150e114e0.diff

LOG: [ImmutableSet] Use IntrusiveRefCntPtr to eliminate some manual refcounting

Still not ideal as the refcounting leaks to users, but better than
before. NFCI.

Added: 
    

Modified: 
    llvm/include/llvm/ADT/ImmutableMap.h
    llvm/include/llvm/ADT/ImmutableSet.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ADT/ImmutableMap.h b/llvm/include/llvm/ADT/ImmutableMap.h
index 86fd7fefaec3..30689d2274a8 100644
--- a/llvm/include/llvm/ADT/ImmutableMap.h
+++ b/llvm/include/llvm/ADT/ImmutableMap.h
@@ -70,33 +70,14 @@ class ImmutableMap {
   using TreeTy = ImutAVLTree<ValInfo>;
 
 protected:
-  TreeTy* Root;
+  IntrusiveRefCntPtr<TreeTy> Root;
 
 public:
   /// Constructs a map from a pointer to a tree root.  In general one
   /// should use a Factory object to create maps instead of directly
   /// invoking the constructor, but there are cases where make this
   /// constructor public is useful.
-  explicit ImmutableMap(const TreeTy* R) : Root(const_cast<TreeTy*>(R)) {
-    if (Root) { Root->retain(); }
-  }
-
-  ImmutableMap(const ImmutableMap &X) : Root(X.Root) {
-    if (Root) { Root->retain(); }
-  }
-
-  ~ImmutableMap() {
-    if (Root) { Root->release(); }
-  }
-
-  ImmutableMap &operator=(const ImmutableMap &X) {
-    if (Root != X.Root) {
-      if (X.Root) { X.Root->retain(); }
-      if (Root) { Root->release(); }
-      Root = X.Root;
-    }
-    return *this;
-  }
+  explicit ImmutableMap(const TreeTy *R) : Root(const_cast<TreeTy *>(R)) {}
 
   class Factory {
     typename TreeTy::Factory F;
@@ -115,12 +96,12 @@ class ImmutableMap {
 
     LLVM_NODISCARD ImmutableMap add(ImmutableMap Old, key_type_ref K,
                                     data_type_ref D) {
-      TreeTy *T = F.add(Old.Root, std::pair<key_type,data_type>(K,D));
+      TreeTy *T = F.add(Old.Root.get(), std::pair<key_type, data_type>(K, D));
       return ImmutableMap(Canonicalize ? F.getCanonicalTree(T): T);
     }
 
     LLVM_NODISCARD ImmutableMap remove(ImmutableMap Old, key_type_ref K) {
-      TreeTy *T = F.remove(Old.Root,K);
+      TreeTy *T = F.remove(Old.Root.get(), K);
       return ImmutableMap(Canonicalize ? F.getCanonicalTree(T): T);
     }
 
@@ -134,19 +115,20 @@ class ImmutableMap {
   }
 
   bool operator==(const ImmutableMap &RHS) const {
-    return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
+    return Root && RHS.Root ? Root->isEqual(*RHS.Root.get()) : Root == RHS.Root;
   }
 
   bool operator!=(const ImmutableMap &RHS) const {
-    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
+    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root.get())
+                            : Root != RHS.Root;
   }
 
   TreeTy *getRoot() const {
     if (Root) { Root->retain(); }
-    return Root;
+    return Root.get();
   }
 
-  TreeTy *getRootWithoutRetain() const { return Root; }
+  TreeTy *getRootWithoutRetain() const { return Root.get(); }
 
   void manualRetain() {
     if (Root) Root->retain();
@@ -217,7 +199,7 @@ class ImmutableMap {
     data_type_ref getData() const { return (*this)->second; }
   };
 
-  iterator begin() const { return iterator(Root); }
+  iterator begin() const { return iterator(Root.get()); }
   iterator end() const { return iterator(); }
 
   data_type* lookup(key_type_ref K) const {
@@ -243,7 +225,7 @@ class ImmutableMap {
   unsigned getHeight() const { return Root ? Root->getHeight() : 0; }
 
   static inline void Profile(FoldingSetNodeID& ID, const ImmutableMap& M) {
-    ID.AddPointer(M.Root);
+    ID.AddPointer(M.Root.get());
   }
 
   inline void Profile(FoldingSetNodeID& ID) const {
@@ -266,7 +248,7 @@ class ImmutableMapRef {
   using FactoryTy = typename TreeTy::Factory;
 
 protected:
-  TreeTy *Root;
+  IntrusiveRefCntPtr<TreeTy> Root;
   FactoryTy *Factory;
 
 public:
@@ -274,44 +256,12 @@ class ImmutableMapRef {
   /// should use a Factory object to create maps instead of directly
   /// invoking the constructor, but there are cases where make this
   /// constructor public is useful.
-  explicit ImmutableMapRef(const TreeTy *R, FactoryTy *F)
-      : Root(const_cast<TreeTy *>(R)), Factory(F) {
-    if (Root) {
-      Root->retain();
-    }
-  }
+  ImmutableMapRef(const TreeTy *R, FactoryTy *F)
+      : Root(const_cast<TreeTy *>(R)), Factory(F) {}
 
-  explicit ImmutableMapRef(const ImmutableMap<KeyT, ValT> &X,
-                           typename ImmutableMap<KeyT, ValT>::Factory &F)
-    : Root(X.getRootWithoutRetain()),
-      Factory(F.getTreeFactory()) {
-    if (Root) { Root->retain(); }
-  }
-
-  ImmutableMapRef(const ImmutableMapRef &X) : Root(X.Root), Factory(X.Factory) {
-    if (Root) {
-      Root->retain();
-    }
-  }
-
-  ~ImmutableMapRef() {
-    if (Root)
-      Root->release();
-  }
-
-  ImmutableMapRef &operator=(const ImmutableMapRef &X) {
-    if (Root != X.Root) {
-      if (X.Root)
-        X.Root->retain();
-
-      if (Root)
-        Root->release();
-
-      Root = X.Root;
-      Factory = X.Factory;
-    }
-    return *this;
-  }
+  ImmutableMapRef(const ImmutableMap<KeyT, ValT> &X,
+                  typename ImmutableMap<KeyT, ValT>::Factory &F)
+      : Root(X.getRootWithoutRetain()), Factory(F.getTreeFactory()) {}
 
   static inline ImmutableMapRef getEmptyMap(FactoryTy *F) {
     return ImmutableMapRef(0, F);
@@ -326,12 +276,13 @@ class ImmutableMapRef {
   }
 
   ImmutableMapRef add(key_type_ref K, data_type_ref D) const {
-    TreeTy *NewT = Factory->add(Root, std::pair<key_type, data_type>(K, D));
+    TreeTy *NewT =
+        Factory->add(Root.get(), std::pair<key_type, data_type>(K, D));
     return ImmutableMapRef(NewT, Factory);
   }
 
   ImmutableMapRef remove(key_type_ref K) const {
-    TreeTy *NewT = Factory->remove(Root, K);
+    TreeTy *NewT = Factory->remove(Root.get(), K);
     return ImmutableMapRef(NewT, Factory);
   }
 
@@ -340,15 +291,16 @@ class ImmutableMapRef {
   }
 
   ImmutableMap<KeyT, ValT> asImmutableMap() const {
-    return ImmutableMap<KeyT, ValT>(Factory->getCanonicalTree(Root));
+    return ImmutableMap<KeyT, ValT>(Factory->getCanonicalTree(Root.get()));
   }
 
   bool operator==(const ImmutableMapRef &RHS) const {
-    return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
+    return Root && RHS.Root ? Root->isEqual(*RHS.Root.get()) : Root == RHS.Root;
   }
 
   bool operator!=(const ImmutableMapRef &RHS) const {
-    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
+    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root.get())
+                            : Root != RHS.Root;
   }
 
   bool isEmpty() const { return !Root; }
@@ -377,7 +329,7 @@ class ImmutableMapRef {
     data_type_ref getData() const { return (*this)->second; }
   };
 
-  iterator begin() const { return iterator(Root); }
+  iterator begin() const { return iterator(Root.get()); }
   iterator end() const { return iterator(); }
 
   data_type *lookup(key_type_ref K) const {

diff  --git a/llvm/include/llvm/ADT/ImmutableSet.h b/llvm/include/llvm/ADT/ImmutableSet.h
index bab98f4b187a..f19913f8dcdd 100644
--- a/llvm/include/llvm/ADT/ImmutableSet.h
+++ b/llvm/include/llvm/ADT/ImmutableSet.h
@@ -15,6 +15,7 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/FoldingSet.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/Support/Allocator.h"
@@ -357,6 +358,12 @@ class ImutAVLTree {
   }
 };
 
+template <typename ImutInfo>
+struct IntrusiveRefCntPtrInfo<ImutAVLTree<ImutInfo>> {
+  static void retain(ImutAVLTree<ImutInfo> *Tree) { Tree->retain(); }
+  static void release(ImutAVLTree<ImutInfo> *Tree) { Tree->release(); }
+};
+
 //===----------------------------------------------------------------------===//
 // Immutable AVL-Tree Factory class.
 //===----------------------------------------------------------------------===//
@@ -961,33 +968,14 @@ class ImmutableSet {
   using TreeTy = ImutAVLTree<ValInfo>;
 
 private:
-  TreeTy *Root;
+  IntrusiveRefCntPtr<TreeTy> Root;
 
 public:
   /// Constructs a set from a pointer to a tree root.  In general one
   /// should use a Factory object to create sets instead of directly
   /// invoking the constructor, but there are cases where make this
   /// constructor public is useful.
-  explicit ImmutableSet(TreeTy* R) : Root(R) {
-    if (Root) { Root->retain(); }
-  }
-
-  ImmutableSet(const ImmutableSet &X) : Root(X.Root) {
-    if (Root) { Root->retain(); }
-  }
-
-  ~ImmutableSet() {
-    if (Root) { Root->release(); }
-  }
-
-  ImmutableSet &operator=(const ImmutableSet &X) {
-    if (Root != X.Root) {
-      if (X.Root) { X.Root->retain(); }
-      if (Root) { Root->release(); }
-      Root = X.Root;
-    }
-    return *this;
-  }
+  explicit ImmutableSet(TreeTy *R) : Root(R) {}
 
   class Factory {
     typename TreeTy::Factory F;
@@ -1016,7 +1004,7 @@ class ImmutableSet {
     ///  The memory allocated to represent the set is released when the
     ///  factory object that created the set is destroyed.
     LLVM_NODISCARD ImmutableSet add(ImmutableSet Old, value_type_ref V) {
-      TreeTy *NewT = F.add(Old.Root, V);
+      TreeTy *NewT = F.add(Old.Root.get(), V);
       return ImmutableSet(Canonicalize ? F.getCanonicalTree(NewT) : NewT);
     }
 
@@ -1028,7 +1016,7 @@ class ImmutableSet {
     ///  The memory allocated to represent the set is released when the
     ///  factory object that created the set is destroyed.
     LLVM_NODISCARD ImmutableSet remove(ImmutableSet Old, value_type_ref V) {
-      TreeTy *NewT = F.remove(Old.Root, V);
+      TreeTy *NewT = F.remove(Old.Root.get(), V);
       return ImmutableSet(Canonicalize ? F.getCanonicalTree(NewT) : NewT);
     }
 
@@ -1047,21 +1035,20 @@ class ImmutableSet {
   }
 
   bool operator==(const ImmutableSet &RHS) const {
-    return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
+    return Root && RHS.Root ? Root->isEqual(*RHS.Root.get()) : Root == RHS.Root;
   }
 
   bool operator!=(const ImmutableSet &RHS) const {
-    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
+    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root.get())
+                            : Root != RHS.Root;
   }
 
   TreeTy *getRoot() {
     if (Root) { Root->retain(); }
-    return Root;
+    return Root.get();
   }
 
-  TreeTy *getRootWithoutRetain() const {
-    return Root;
-  }
+  TreeTy *getRootWithoutRetain() const { return Root.get(); }
 
   /// isEmpty - Return true if the set contains no elements.
   bool isEmpty() const { return !Root; }
@@ -1082,7 +1069,7 @@ class ImmutableSet {
 
   using iterator = ImutAVLValueIterator<ImmutableSet>;
 
-  iterator begin() const { return iterator(Root); }
+  iterator begin() const { return iterator(Root.get()); }
   iterator end() const { return iterator(); }
 
   //===--------------------------------------------------===//
@@ -1092,7 +1079,7 @@ class ImmutableSet {
   unsigned getHeight() const { return Root ? Root->getHeight() : 0; }
 
   static void Profile(FoldingSetNodeID &ID, const ImmutableSet &S) {
-    ID.AddPointer(S.Root);
+    ID.AddPointer(S.Root.get());
   }
 
   void Profile(FoldingSetNodeID &ID) const { return Profile(ID, *this); }
@@ -1114,7 +1101,7 @@ class ImmutableSetRef {
   using FactoryTy = typename TreeTy::Factory;
 
 private:
-  TreeTy *Root;
+  IntrusiveRefCntPtr<TreeTy> Root;
   FactoryTy *Factory;
 
 public:
@@ -1122,42 +1109,18 @@ class ImmutableSetRef {
   /// should use a Factory object to create sets instead of directly
   /// invoking the constructor, but there are cases where make this
   /// constructor public is useful.
-  explicit ImmutableSetRef(TreeTy* R, FactoryTy *F)
-    : Root(R),
-      Factory(F) {
-    if (Root) { Root->retain(); }
-  }
-
-  ImmutableSetRef(const ImmutableSetRef &X)
-    : Root(X.Root),
-      Factory(X.Factory) {
-    if (Root) { Root->retain(); }
-  }
-
-  ~ImmutableSetRef() {
-    if (Root) { Root->release(); }
-  }
-
-  ImmutableSetRef &operator=(const ImmutableSetRef &X) {
-    if (Root != X.Root) {
-      if (X.Root) { X.Root->retain(); }
-      if (Root) { Root->release(); }
-      Root = X.Root;
-      Factory = X.Factory;
-    }
-    return *this;
-  }
+  ImmutableSetRef(TreeTy *R, FactoryTy *F) : Root(R), Factory(F) {}
 
   static ImmutableSetRef getEmptySet(FactoryTy *F) {
     return ImmutableSetRef(0, F);
   }
 
   ImmutableSetRef add(value_type_ref V) {
-    return ImmutableSetRef(Factory->add(Root, V), Factory);
+    return ImmutableSetRef(Factory->add(Root.get(), V), Factory);
   }
 
   ImmutableSetRef remove(value_type_ref V) {
-    return ImmutableSetRef(Factory->remove(Root, V), Factory);
+    return ImmutableSetRef(Factory->remove(Root.get(), V), Factory);
   }
 
   /// Returns true if the set contains the specified value.
@@ -1166,20 +1129,19 @@ class ImmutableSetRef {
   }
 
   ImmutableSet<ValT> asImmutableSet(bool canonicalize = true) const {
-    return ImmutableSet<ValT>(canonicalize ?
-                              Factory->getCanonicalTree(Root) : Root);
+    return ImmutableSet<ValT>(
+        canonicalize ? Factory->getCanonicalTree(Root.get()) : Root.get());
   }
 
-  TreeTy *getRootWithoutRetain() const {
-    return Root;
-  }
+  TreeTy *getRootWithoutRetain() const { return Root.get(); }
 
   bool operator==(const ImmutableSetRef &RHS) const {
-    return Root && RHS.Root ? Root->isEqual(*RHS.Root) : Root == RHS.Root;
+    return Root && RHS.Root ? Root->isEqual(*RHS.Root.get()) : Root == RHS.Root;
   }
 
   bool operator!=(const ImmutableSetRef &RHS) const {
-    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root) : Root != RHS.Root;
+    return Root && RHS.Root ? Root->isNotEqual(*RHS.Root.get())
+                            : Root != RHS.Root;
   }
 
   /// isEmpty - Return true if the set contains no elements.
@@ -1195,7 +1157,7 @@ class ImmutableSetRef {
 
   using iterator = ImutAVLValueIterator<ImmutableSetRef>;
 
-  iterator begin() const { return iterator(Root); }
+  iterator begin() const { return iterator(Root.get()); }
   iterator end() const { return iterator(); }
 
   //===--------------------------------------------------===//
@@ -1205,7 +1167,7 @@ class ImmutableSetRef {
   unsigned getHeight() const { return Root ? Root->getHeight() : 0; }
 
   static void Profile(FoldingSetNodeID &ID, const ImmutableSetRef &S) {
-    ID.AddPointer(S.Root);
+    ID.AddPointer(S.Root.get());
   }
 
   void Profile(FoldingSetNodeID &ID) const { return Profile(ID, *this); }


        


More information about the llvm-commits mailing list