[llvm] 32808cf - [IR] Track users of comdats

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 6 00:14:07 PST 2022


Author: Nikita Popov
Date: 2022-01-06T09:13:58+01:00
New Revision: 32808cfb24b8d83a99223b7f797be1dbe5573c10

URL: https://github.com/llvm/llvm-project/commit/32808cfb24b8d83a99223b7f797be1dbe5573c10
DIFF: https://github.com/llvm/llvm-project/commit/32808cfb24b8d83a99223b7f797be1dbe5573c10.diff

LOG: [IR] Track users of comdats

Track all GlobalObjects that reference a given comdat, which allows
determining whether a function in a comdat is dead without scanning
the whole module.

In particular, this makes filterDeadComdatFunctions() have complexity
O(#DeadFunctions) rather than O(#SymbolsInModule), which addresses
half of the compile-time issue exposed by D115545.

Differential Revision: https://reviews.llvm.org/D115864

Added: 
    

Modified: 
    llvm/include/llvm/IR/Comdat.h
    llvm/include/llvm/IR/GlobalObject.h
    llvm/lib/IR/Comdat.cpp
    llvm/lib/IR/Globals.cpp
    llvm/lib/Transforms/Utils/ModuleUtils.cpp
    llvm/unittests/IR/ConstantsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Comdat.h b/llvm/include/llvm/IR/Comdat.h
index 01a047d36455c..1701802e69772 100644
--- a/llvm/include/llvm/IR/Comdat.h
+++ b/llvm/include/llvm/IR/Comdat.h
@@ -16,10 +16,12 @@
 #define LLVM_IR_COMDAT_H
 
 #include "llvm-c/Types.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/CBindingWrapping.h"
 
 namespace llvm {
 
+class GlobalObject;
 class raw_ostream;
 class StringRef;
 template <typename ValueTy> class StringMapEntry;
@@ -46,15 +48,21 @@ class Comdat {
   StringRef getName() const;
   void print(raw_ostream &OS, bool IsForDebug = false) const;
   void dump() const;
+  const SmallPtrSetImpl<GlobalObject *> &getUsers() const { return Users; }
 
 private:
   friend class Module;
+  friend class GlobalObject;
 
   Comdat();
+  void addUser(GlobalObject *GO);
+  void removeUser(GlobalObject *GO);
 
   // Points to the map in Module.
   StringMapEntry<Comdat> *Name = nullptr;
   SelectionKind SK = Any;
+  // Globals using this comdat.
+  SmallPtrSet<GlobalObject *, 2> Users;
 };
 
 // Create wrappers for C Binding types (see CBindingWrapping.h).

diff  --git a/llvm/include/llvm/IR/GlobalObject.h b/llvm/include/llvm/IR/GlobalObject.h
index e15cf718bb107..1f73c8540a4a2 100644
--- a/llvm/include/llvm/IR/GlobalObject.h
+++ b/llvm/include/llvm/IR/GlobalObject.h
@@ -48,6 +48,7 @@ class GlobalObject : public GlobalValue {
         ObjComdat(nullptr) {
     setGlobalValueSubClassData(0);
   }
+  ~GlobalObject();
 
   Comdat *ObjComdat;
   enum {
@@ -122,7 +123,7 @@ class GlobalObject : public GlobalValue {
   bool hasComdat() const { return getComdat() != nullptr; }
   const Comdat *getComdat() const { return ObjComdat; }
   Comdat *getComdat() { return ObjComdat; }
-  void setComdat(Comdat *C) { ObjComdat = C; }
+  void setComdat(Comdat *C);
 
   using Value::addMetadata;
   using Value::clearMetadata;

diff  --git a/llvm/lib/IR/Comdat.cpp b/llvm/lib/IR/Comdat.cpp
index 1a5d38d17bc0d..90d5c6e82e5c4 100644
--- a/llvm/lib/IR/Comdat.cpp
+++ b/llvm/lib/IR/Comdat.cpp
@@ -25,6 +25,10 @@ Comdat::Comdat() = default;
 
 StringRef Comdat::getName() const { return Name->first(); }
 
+void Comdat::addUser(GlobalObject *GO) { Users.insert(GO); }
+
+void Comdat::removeUser(GlobalObject *GO) { Users.erase(GO); }
+
 LLVMComdatRef LLVMGetOrInsertComdat(LLVMModuleRef M, const char *Name) {
   return wrap(unwrap(M)->getOrInsertComdat(Name));
 }

diff  --git a/llvm/lib/IR/Globals.cpp b/llvm/lib/IR/Globals.cpp
index b6bd25aa12341..99affa8a84e6d 100644
--- a/llvm/lib/IR/Globals.cpp
+++ b/llvm/lib/IR/Globals.cpp
@@ -95,6 +95,8 @@ void GlobalValue::eraseFromParent() {
   llvm_unreachable("not a global");
 }
 
+GlobalObject::~GlobalObject() { setComdat(nullptr); }
+
 bool GlobalValue::isInterposable() const {
   if (isInterposableLinkage(getLinkage()))
     return true;
@@ -182,6 +184,14 @@ const Comdat *GlobalValue::getComdat() const {
   return cast<GlobalObject>(this)->getComdat();
 }
 
+void GlobalObject::setComdat(Comdat *C) {
+  if (ObjComdat)
+    ObjComdat->removeUser(this);
+  ObjComdat = C;
+  if (C)
+    C->addUser(this);
+}
+
 StringRef GlobalValue::getPartition() const {
   if (!hasPartition())
     return "";

diff  --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
index bb5ff59cba4be..c8b9af3fd6db7 100644
--- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp
+++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp
@@ -179,65 +179,29 @@ llvm::getOrCreateSanitizerCtorAndInitFunctions(
 
 void llvm::filterDeadComdatFunctions(
     Module &M, SmallVectorImpl<Function *> &DeadComdatFunctions) {
-  // Build a map from the comdat to the number of entries in that comdat we
-  // think are dead. If this fully covers the comdat group, then the entire
-  // group is dead. If we find another entry in the comdat group though, we'll
-  // have to preserve the whole group.
-  SmallDenseMap<Comdat *, int, 16> ComdatEntriesCovered;
+  SmallPtrSet<Function *, 32> MaybeDeadFunctions;
+  SmallPtrSet<Comdat *, 32> MaybeDeadComdats;
   for (Function *F : DeadComdatFunctions) {
-    Comdat *C = F->getComdat();
-    assert(C && "Expected all input GVs to be in a comdat!");
-    ComdatEntriesCovered[C] += 1;
+    MaybeDeadFunctions.insert(F);
+    if (Comdat *C = F->getComdat())
+      MaybeDeadComdats.insert(C);
   }
 
-  auto CheckComdat = [&](Comdat &C) {
-    auto CI = ComdatEntriesCovered.find(&C);
-    if (CI == ComdatEntriesCovered.end())
-      return;
-
-    // If this could have been covered by a dead entry, just subtract one to
-    // account for it.
-    if (CI->second > 0) {
-      CI->second -= 1;
-      return;
-    }
-
-    // If we've already accounted for all the entries that were dead, the
-    // entire comdat is alive so remove it from the map.
-    ComdatEntriesCovered.erase(CI);
-  };
-
-  auto CheckAllComdats = [&] {
-    for (Function &F : M.functions())
-      if (Comdat *C = F.getComdat()) {
-        CheckComdat(*C);
-        if (ComdatEntriesCovered.empty())
-          return;
-      }
-    for (GlobalVariable &GV : M.globals())
-      if (Comdat *C = GV.getComdat()) {
-        CheckComdat(*C);
-        if (ComdatEntriesCovered.empty())
-          return;
-      }
-    for (GlobalAlias &GA : M.aliases())
-      if (Comdat *C = GA.getComdat()) {
-        CheckComdat(*C);
-        if (ComdatEntriesCovered.empty())
-          return;
-      }
-  };
-  CheckAllComdats();
-
-  if (ComdatEntriesCovered.empty()) {
-    DeadComdatFunctions.clear();
-    return;
+  // Find comdats for which all users are dead now.
+  SmallPtrSet<Comdat *, 32> DeadComdats;
+  for (Comdat *C : MaybeDeadComdats) {
+    auto IsUserDead = [&](GlobalObject *GO) {
+      auto *F = dyn_cast<Function>(GO);
+      return F && MaybeDeadFunctions.contains(F);
+    };
+    if (all_of(C->getUsers(), IsUserDead))
+      DeadComdats.insert(C);
   }
 
-  // Remove the entries that were not covering.
-  erase_if(DeadComdatFunctions, [&](GlobalValue *GV) {
-    return ComdatEntriesCovered.find(GV->getComdat()) ==
-           ComdatEntriesCovered.end();
+  // Only keep functions which have no comdat or a dead comdat.
+  erase_if(DeadComdatFunctions, [&](Function *F) {
+    Comdat *C = F->getComdat();
+    return C && !DeadComdats.contains(C);
   });
 }
 

diff  --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp
index 155383d5a0c63..faf8502b19df3 100644
--- a/llvm/unittests/IR/ConstantsTest.cpp
+++ b/llvm/unittests/IR/ConstantsTest.cpp
@@ -763,5 +763,32 @@ TEST(ConstantsTest, GetSplatValueRoundTrip) {
   }
 }
 
+TEST(ConstantsTest, ComdatUserTracking) {
+  LLVMContext Context;
+  Module M("MyModule", Context);
+
+  Comdat *C = M.getOrInsertComdat("comdat");
+  const SmallPtrSetImpl<GlobalObject *> &Users = C->getUsers();
+  EXPECT_TRUE(Users.size() == 0);
+
+  Type *Ty = Type::getInt8Ty(Context);
+  GlobalVariable *GV1 = cast<GlobalVariable>(M.getOrInsertGlobal("gv1", Ty));
+  GV1->setComdat(C);
+  EXPECT_TRUE(Users.size() == 1);
+  EXPECT_TRUE(Users.contains(GV1));
+
+  GlobalVariable *GV2 = cast<GlobalVariable>(M.getOrInsertGlobal("gv2", Ty));
+  GV2->setComdat(C);
+  EXPECT_TRUE(Users.size() == 2);
+  EXPECT_TRUE(Users.contains(GV2));
+
+  GV1->eraseFromParent();
+  EXPECT_TRUE(Users.size() == 1);
+  EXPECT_TRUE(Users.contains(GV2));
+
+  GV2->eraseFromParent();
+  EXPECT_TRUE(Users.size() == 0);
+}
+
 } // end anonymous namespace
 } // end namespace llvm


        


More information about the llvm-commits mailing list