[clang] Support VFE in thinLTO (PR #69735)

Manman Ren via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 20 10:49:28 PDT 2023


https://github.com/manman-ren updated https://github.com/llvm/llvm-project/pull/69735

>From 1adce63e663203f858de86cfa231527ee2284505 Mon Sep 17 00:00:00 2001
From: Manman Ren <mren at fb.com>
Date: Thu, 5 Oct 2023 10:56:53 -0700
Subject: [PATCH] Support VFE in thinLTO

We add a run of GlobalDCEPass with ImportSummary. When ImportSummary is true, we remove
virtual functions in vtables with VCallVisibility not Public. In this run, the regular
GlobalDCEPass::AddVirtualFunctionDependencies will be bypassed, we will use
ImportSummary to decide which virtual functions to remove.

Discussion points:
1> FuncsWithNonVtableRef: this is currently in ModuleSummaryIndex, does it
   make sense to be part of FunctionSummary?
   std::set<GlobalValue::GUID> FuncsWithNonVtableRef;
2> ComputeDependencies is copied from GlobalDCE to ModuleSummaryAnalysis, for the former,
   ConstantDependenciesCache is a member variable
3> match from summary to Function in IR
   Resolution is saved in std::set<GlobalValue::GUID> VFuncsToBeRemoved
   Use F->getGUID() or GlobalValue::getGUID(F->getName()) when searching in VFuncsToBeRemoved

TODO:
1> support hybrid Regular/ThinLTO
2> support legacy thinLTO backend (ThinLTOCodeGenerator)
---
 clang/lib/Driver/ToolChains/Clang.cpp         |   4 +-
 clang/test/CodeGenCXX/vfe-thin.cpp            | 123 +++++++++
 .../Driver/virtual-function-elimination.cpp   |   3 +-
 llvm/include/llvm/AsmParser/LLParser.h        |   1 +
 llvm/include/llvm/AsmParser/LLToken.h         |   1 +
 llvm/include/llvm/Bitcode/LLVMBitCodes.h      |   1 +
 llvm/include/llvm/IR/ModuleSummaryIndex.h     |  22 ++
 llvm/include/llvm/Transforms/IPO/GlobalDCE.h  |  15 +-
 llvm/lib/Analysis/ModuleSummaryAnalysis.cpp   |  65 ++++-
 llvm/lib/AsmParser/LLParser.cpp               |  16 ++
 llvm/lib/Bitcode/Reader/BitcodeReader.cpp     |   5 +
 llvm/lib/Bitcode/Writer/BitcodeWriter.cpp     |  15 +
 llvm/lib/IR/AsmWriter.cpp                     |  13 +
 llvm/lib/LTO/LTO.cpp                          |   8 +
 llvm/lib/Passes/PassBuilderPipelines.cpp      |   6 +
 llvm/lib/Transforms/IPO/GlobalDCE.cpp         | 261 +++++++++++++++++-
 16 files changed, 551 insertions(+), 8 deletions(-)
 create mode 100644 clang/test/CodeGenCXX/vfe-thin.cpp

diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp
index c74f6ff447261dc..d59eb98eb2777e1 100644
--- a/clang/lib/Driver/ToolChains/Clang.cpp
+++ b/clang/lib/Driver/ToolChains/Clang.cpp
@@ -7378,10 +7378,10 @@ void Clang::ConstructJob(Compilation &C, const JobAction &JA,
   if (VirtualFunctionElimination) {
     // VFE requires full LTO (currently, this might be relaxed to allow ThinLTO
     // in the future).
-    if (LTOMode != LTOK_Full)
+    if (LTOMode != LTOK_Full && LTOMode != LTOK_Thin)
       D.Diag(diag::err_drv_argument_only_allowed_with)
           << "-fvirtual-function-elimination"
-          << "-flto=full";
+          << "-flto";
 
     CmdArgs.push_back("-fvirtual-function-elimination");
   }
diff --git a/clang/test/CodeGenCXX/vfe-thin.cpp b/clang/test/CodeGenCXX/vfe-thin.cpp
new file mode 100644
index 000000000000000..b05d196060353c1
--- /dev/null
+++ b/clang/test/CodeGenCXX/vfe-thin.cpp
@@ -0,0 +1,123 @@
+// REQUIRES: system-darwin
+// RUN: rm -rf %t_devirt && mkdir %t_devirt
+// RUN: rm -rf %t.dir
+// RUN: split-file %s %t.dir
+
+// RUN: %clang -flto=thin -target arm64-apple-ios16 -emit-llvm -c -fwhole-program-vtables -fvirtual-function-elimination -mllvm -enable-vfe-summary -o %t.main.bc %t.dir/vfe-main.cpp
+// RUN: llvm-dis %t.main.bc -o - | FileCheck %s --check-prefix=CHECK
+// RUN: %clang -flto=thin -target arm64-apple-ios16 -emit-llvm -c -fwhole-program-vtables -fvirtual-function-elimination -mllvm -enable-vfe-summary -o %t.input.bc %t.dir/vfe-input.cpp
+// RUN: llvm-dis %t.input.bc -o - | FileCheck %s --check-prefix=INPUT
+
+// RUN: llvm-lto2 run %t.main.bc %t.input.bc -save-temps -enable-vfe-on-thinlto -o %t.out \
+// RUN:   -r=%t.main.bc,__Z6test_1P1A,pl \
+// RUN:   -r=%t.main.bc,__Z4testv,pl \
+// RUN:   -r=%t.main.bc,__Znwm, \
+// RUN:   -r=%t.main.bc,__ZN1AC1Ev,pl \
+// RUN:   -r=%t.main.bc,___gxx_personality_v0, \
+// RUN:   -r=%t.main.bc,__ZN1AC2Ev,pl \
+// RUN:   -r=%t.main.bc,__ZN1BC2Ev,pl \
+// RUN:   -r=%t.main.bc,__ZN1BC1Ev,pl \
+// RUN:   -r=%t.main.bc,__ZN1A3fooEv,pl \
+// RUN:   -r=%t.main.bc,__ZN1A3barEv,pl \
+// RUN:   -r=%t.main.bc,__Z6test_2P1B,pl \
+// RUN:   -r=%t.main.bc,_main,plx \
+// RUN:   -r=%t.main.bc,__ZdlPv \
+// RUN:   -r=%t.main.bc,__ZTV1A,pl \
+// RUN:   -r=%t.main.bc,__ZTV1B, \
+// RUN:   -r=%t.main.bc,__ZTVN10__cxxabiv117__class_type_infoE, \
+// RUN:   -r=%t.main.bc,__ZTS1A,pl \
+// RUN:   -r=%t.main.bc,__ZTI1A,pl \
+// RUN:   -r=%t.input.bc,__ZN1CC2Ev,pl \
+// RUN:   -r=%t.input.bc,__ZN1CC1Ev,pl \
+// RUN:   -r=%t.input.bc,__Z6test_3P1C,pl \
+// RUN:   -r=%t.input.bc,__Z6test_4P1C,pl \
+// RUN:   -r=%t.input.bc,__Z6test_5P1CMS_FvvE,pl \
+// RUN:   -r=%t.input.bc,__ZTV1C,
+
+// RUN: llvm-dis %t.out.1.0.preopt.bc -o - | FileCheck %s --check-prefix=ORIGINAL
+// RUN: llvm-dis %t.out.1.4.opt.bc -o - | FileCheck %s --check-prefix=AFTERVFE
+// ORIGINAL: define{{.*}}_ZN1A3barEv
+// AFTERVFE-NOT: define{{.*}}_ZN1A3barEv
+
+//--- vfe-main.cpp
+// CHECK: @_ZTV1A = {{.*}}constant {{.*}}@_ZTI1A{{.*}}@_ZN1A3fooEv{{.*}}_ZN1A3barEv{{.*}}!type [[A16:![0-9]+]]
+// CHECK-NOT: @_ZTV1B = {{.*}}!type
+struct __attribute__((visibility("hidden"))) A {
+  A();
+  virtual void foo();
+  virtual void bar();
+};
+
+__attribute__((used)) void test_1(A *p) {
+  // CHECK-LABEL: define{{.*}} void @_Z6test_1P1A
+  // CHECK: [[VTABLE:%.+]] = load
+  // CHECK: @llvm.type.checked.load(ptr {{%.+}}, i32 0, metadata !"_ZTS1A")
+  p->foo();
+}
+
+__attribute__((used)) A *test() {
+  return new (A)();
+}
+
+struct __attribute__((visibility("hidden"))) B {
+  B();
+  virtual void foo();
+};
+
+A::A() {}
+B::B() {}
+void A::foo() {}
+void A::bar() {}
+void test_2(B *p) {
+  // CHECK-LABEL: define{{.*}} void @_Z6test_2P1B
+  // CHECK: [[VTABLE:%.+]] = load
+  // CHECK: @llvm.type.checked.load(ptr {{%.+}}, i32 0, metadata !"_ZTS1B")
+  p->foo();
+}
+
+// INPUT-NOT: @_ZTV1C = {{.*}}!type
+// INPUT-LABEL: define{{.*}} void @_Z6test_3P1C
+// INPUT: [[LOAD:%.+]] = {{.*}}call { ptr, i1 } @llvm.type.checked.load(ptr {{%.+}}, i32 0, metadata !"_ZTS1C")
+// INPUT: [[FN_PTR:%.+]] = extractvalue { ptr, i1 } [[LOAD]], 0
+// INPUT: call void [[FN_PTR]](
+
+// INPUT-LABEL: define{{.*}} void @_Z6test_4P1C
+// INPUT: [[LOAD:%.+]] = {{.*}}call { ptr, i1 } @llvm.type.checked.load(ptr {{%.+}}, i32 8, metadata !"_ZTS1C")
+// INPUT: [[FN_PTR:%.+]] = extractvalue { ptr, i1 } [[LOAD]], 0
+// INPUT: call void [[FN_PTR]](
+
+int main() {
+}
+
+// CHECK: [[BAR:\^[0-9]+]] = gv: (name: "_ZN1A3barEv", {{.*}}
+// CHECK: FuncsWithNonVtableRef
+// CHECK-NOT: [[BAR]]
+//--- vfe-input.cpp
+struct __attribute__((visibility("hidden"))) C {
+  C();
+  virtual void foo();
+  virtual void bar();
+};
+
+C::C() {}
+void test_3(C *p) {
+  // C has hidden visibility, so we generate type.checked.load to allow VFE.
+  p->foo();
+}
+
+void test_4(C *p) {
+  // When using type.checked.load, we pass the vtable offset to the intrinsic,
+  // rather than adding it to the pointer with a GEP.
+  p->bar();
+}
+
+void test_5(C *p, void (C::*q)(void)) {
+  // We also use type.checked.load for the virtual side of member function
+  // pointer calls. We use a GEP to calculate the address to load from and pass
+  // 0 as the offset to the intrinsic, because we know that the load must be
+  // from exactly the point marked by one of the function-type metadatas (in
+  // this case "_ZTSM1CFvvE.virtual"). If we passed the offset from the member
+  // function pointer to the intrinsic, this information would be lost. No
+  // codegen changes on the non-virtual side.
+  (p->*q)();
+}
diff --git a/clang/test/Driver/virtual-function-elimination.cpp b/clang/test/Driver/virtual-function-elimination.cpp
index b65fac2393b3ca6..493bf9399f96033 100644
--- a/clang/test/Driver/virtual-function-elimination.cpp
+++ b/clang/test/Driver/virtual-function-elimination.cpp
@@ -1,6 +1,5 @@
 // RUN: not %clang -target x86_64-unknown-linux -fvirtual-function-elimination -### %s 2>&1 | FileCheck --check-prefix=BAD-LTO %s
-// RUN: not %clang -target x86_64-unknown-linux -fvirtual-function-elimination -flto=thin -### %s 2>&1 | FileCheck --check-prefix=BAD-LTO %s
-// BAD-LTO: invalid argument '-fvirtual-function-elimination' only allowed with '-flto=full'
+// BAD-LTO: invalid argument '-fvirtual-function-elimination' only allowed with '-flto'
 
 // RUN: %clang -target x86_64-unknown-linux -fvirtual-function-elimination -flto -### %s 2>&1 | FileCheck --check-prefix=GOOD %s
 // RUN: %clang -target x86_64-unknown-linux -fvirtual-function-elimination -flto=full -### %s 2>&1 | FileCheck --check-prefix=GOOD %s
diff --git a/llvm/include/llvm/AsmParser/LLParser.h b/llvm/include/llvm/AsmParser/LLParser.h
index eca908a24aac7b2..22a446984893bda 100644
--- a/llvm/include/llvm/AsmParser/LLParser.h
+++ b/llvm/include/llvm/AsmParser/LLParser.h
@@ -362,6 +362,7 @@ namespace llvm {
     bool parseGVReference(ValueInfo &VI, unsigned &GVId);
     bool parseSummaryIndexFlags();
     bool parseBlockCount();
+    bool parseFuncsWithNonVtableRefEntry(unsigned ID);
     bool parseGVEntry(unsigned ID);
     bool parseFunctionSummary(std::string Name, GlobalValue::GUID, unsigned ID);
     bool parseVariableSummary(std::string Name, GlobalValue::GUID, unsigned ID);
diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h
index 673dc58ce6451e3..f707caf3d42c834 100644
--- a/llvm/include/llvm/AsmParser/LLToken.h
+++ b/llvm/include/llvm/AsmParser/LLToken.h
@@ -399,6 +399,7 @@ enum Kind {
   kw_args,
   kw_typeid,
   kw_typeidCompatibleVTable,
+  kw_FuncsWithNonVtableRef,
   kw_summary,
   kw_typeTestRes,
   kw_kind,
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 52e76356a892e45..68d96ed4b4fb4eb 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -319,6 +319,7 @@ enum GlobalValueSummarySymtabCodes {
   //  numver x version]
   FS_COMBINED_ALLOC_INFO = 29,
   FS_STACK_IDS = 30,
+  OBJC_FUNC_NON_VTABLE_REF = 31,
 };
 
 enum MetadataCodes {
diff --git a/llvm/include/llvm/IR/ModuleSummaryIndex.h b/llvm/include/llvm/IR/ModuleSummaryIndex.h
index cd02c71adddfc25..39d492fe6ef7827 100644
--- a/llvm/include/llvm/IR/ModuleSummaryIndex.h
+++ b/llvm/include/llvm/IR/ModuleSummaryIndex.h
@@ -1362,6 +1362,8 @@ class ModuleSummaryIndex {
   // Temporary map while building StackIds list. Clear when index is completely
   // built via releaseTemporaryMemory.
   std::map<uint64_t, unsigned> StackIdToIndex;
+  std::set<GlobalValue::GUID> FuncsWithNonVtableRef;
+  std::set<GlobalValue::GUID> VFuncsToBeRemoved; // no need to serialize
 
   // YAML I/O support.
   friend yaml::MappingTraits<ModuleSummaryIndex>;
@@ -1811,6 +1813,26 @@ class ModuleSummaryIndex {
     }
   }
 
+  bool isFuncWithNonVtableRef(GlobalValue::GUID fId) const {
+    return FuncsWithNonVtableRef.count(fId);
+  }
+  void addFuncWithNonVtableRef(GlobalValue::GUID fId) {
+    FuncsWithNonVtableRef.insert(fId);
+  }
+  const std::set<GlobalValue::GUID> &funcsWithNonVtableRef() const {
+    return FuncsWithNonVtableRef;
+  }
+
+  bool isFuncToBeRemoved(GlobalValue::GUID fId) const {
+    return VFuncsToBeRemoved.count(fId);
+  }
+  void addFuncToBeRemoved(GlobalValue::GUID fId) {
+    VFuncsToBeRemoved.insert(fId);
+  }
+  const std::set<GlobalValue::GUID> &funcsToBeRemoved() const {
+    return VFuncsToBeRemoved;
+  }
+
   /// Print to an output stream.
   void print(raw_ostream &OS, bool IsForDebug = false) const;
 
diff --git a/llvm/include/llvm/Transforms/IPO/GlobalDCE.h b/llvm/include/llvm/Transforms/IPO/GlobalDCE.h
index 92c30d4b54a2612..7b953ad45940526 100644
--- a/llvm/include/llvm/Transforms/IPO/GlobalDCE.h
+++ b/llvm/include/llvm/Transforms/IPO/GlobalDCE.h
@@ -35,7 +35,16 @@ class Value;
 /// Pass to remove unused function declarations.
 class GlobalDCEPass : public PassInfoMixin<GlobalDCEPass> {
 public:
-  GlobalDCEPass(bool InLTOPostLink = false) : InLTOPostLink(InLTOPostLink) {}
+  GlobalDCEPass(bool InLTOPostLink = false)
+      : InLTOPostLink(InLTOPostLink), ExportSummary(nullptr),
+        ImportSummary(nullptr) {}
+  GlobalDCEPass(ModuleSummaryIndex *ExportSummary,
+                const ModuleSummaryIndex *ImportSummary,
+                bool InLTOPostLink = false)
+      : InLTOPostLink(InLTOPostLink), ExportSummary(ExportSummary),
+        ImportSummary(ImportSummary) {
+    assert(!(ExportSummary && ImportSummary));
+  }
 
   PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
 
@@ -44,6 +53,8 @@ class GlobalDCEPass : public PassInfoMixin<GlobalDCEPass> {
 
 private:
   bool InLTOPostLink = false;
+  ModuleSummaryIndex *ExportSummary;
+  const ModuleSummaryIndex *ImportSummary;
 
   SmallPtrSet<GlobalValue*, 32> AliveGlobals;
 
@@ -78,6 +89,8 @@ class GlobalDCEPass : public PassInfoMixin<GlobalDCEPass> {
   void ComputeDependencies(Value *V, SmallPtrSetImpl<GlobalValue *> &U);
 };
 
+void runVFEOnIndex(ModuleSummaryIndex &Summary,
+                   function_ref<bool(GlobalValue::GUID)> isRetained);
 }
 
 #endif // LLVM_TRANSFORMS_IPO_GLOBALDCE_H
diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
index a88622efa12db8c..d5b3be2664c44c4 100644
--- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
+++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
@@ -54,6 +54,7 @@
 #include <algorithm>
 #include <cassert>
 #include <cstdint>
+#include <unordered_map>
 #include <vector>
 
 using namespace llvm;
@@ -80,6 +81,9 @@ static cl::opt<std::string> ModuleSummaryDotFile(
     "module-summary-dot-file", cl::Hidden, cl::value_desc("filename"),
     cl::desc("File to emit dot graph of new summary into"));
 
+static cl::opt<bool> EnableVFESummary("enable-vfe-summary", cl::init(false),
+                                      cl::Hidden, cl::ZeroOrMore,
+                                      cl::desc("enable VFE with ThinLTO"));
 extern cl::opt<bool> ScalePartialSampleProfileWorkingSetSize;
 
 // Walk through the operands of a given User via worklist iteration and populate
@@ -229,7 +233,11 @@ static void addIntrinsicToSummary(
     for (auto &Call : DevirtCalls)
       addVCallToSet(Call, Guid, TypeCheckedLoadVCalls,
                     TypeCheckedLoadConstVCalls);
-
+    if (EnableVFESummary && !HasNonCallUses &&
+        DevirtCalls.empty()) { // for VFE with thinLTO
+      auto *Offset = cast<ConstantInt>(CI->getArgOperand(1));
+      TypeCheckedLoadVCalls.insert({Guid, Offset->getZExtValue()});
+    }
     break;
   }
   default:
@@ -775,6 +783,58 @@ static void computeVariableSummary(ModuleSummaryIndex &Index,
   Index.addGlobalValueSummary(V, std::move(GVarSummary));
 }
 
+static void ComputeDependencies(
+    Value *V, SmallPtrSetImpl<const GlobalValue *> &Deps,
+    std::unordered_map<Constant *, SmallPtrSet<const GlobalValue *, 8>>
+        &ConstantDependenciesCache) {
+  if (auto *I = dyn_cast<Instruction>(V)) {
+    Function *Parent = I->getParent()->getParent();
+    Deps.insert(Parent);
+  } else if (auto *GV = dyn_cast<GlobalValue>(V)) {
+    Deps.insert(GV);
+  } else if (auto *CE = dyn_cast<Constant>(V)) {
+    // Avoid walking the whole tree of a big ConstantExprs multiple times.
+    auto Where = ConstantDependenciesCache.find(CE);
+    if (Where != ConstantDependenciesCache.end()) {
+      auto const &K = Where->second;
+      Deps.insert(K.begin(), K.end());
+    } else {
+      SmallPtrSetImpl<const GlobalValue *> &LocalDeps =
+          ConstantDependenciesCache[CE];
+      for (User *CEUser : CE->users())
+        ComputeDependencies(CEUser, LocalDeps, ConstantDependenciesCache);
+      Deps.insert(LocalDeps.begin(), LocalDeps.end());
+    }
+  }
+}
+
+static void updateNonVtableRef(const Module &M, ModuleSummaryIndex &Index) {
+  if (!EnableVFESummary)
+    return;
+  // If a function has a reference from a non-vtable, it is not safe
+  // to be eliminated by VFE.
+  std::unordered_map<Constant *, SmallPtrSet<const GlobalValue *, 8>>
+      ConstantDependenciesCache;
+  for (auto &F : M) {
+    SmallPtrSet<const GlobalValue *, 8> Deps;
+    for (const User *U : F.users())
+      ComputeDependencies(const_cast<User *>(U), Deps,
+                          ConstantDependenciesCache);
+    Deps.erase(&F); // Remove self-reference.
+    for (const GlobalValue *GVU : Deps) {
+      // Check if GVU is a vtable.
+      if (auto *gv = dyn_cast<GlobalVariable>(GVU)) {
+        if (gv->hasMetadata(LLVMContext::MD_type)) {
+          continue;
+        }
+      }
+      // Otherwise, add GUID for F to a list.
+      Index.addFuncWithNonVtableRef(F.getGUID());
+      break;
+    }
+  }
+}
+
 static void computeAliasSummary(ModuleSummaryIndex &Index, const GlobalAlias &A,
                                 DenseSet<GlobalValue::GUID> &CantBePromoted) {
   // Skip summary for indirect function aliases as summary for aliasee will not
@@ -933,6 +993,9 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex(
                            !LocalsUsed.empty() || HasLocalInlineAsmSymbol,
                            CantBePromoted, IsThinLTO, GetSSICallback);
   }
+  // If a function has a reference from a non-vtable, it is not safe
+  // to be eliminated by VFE.
+  updateNonVtableRef(M, Index);
 
   // Compute summaries for all variables defined in module, and save in the
   // index.
diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp
index 04eabc94cfc6abe..cdf083d707f1f8f 100644
--- a/llvm/lib/AsmParser/LLParser.cpp
+++ b/llvm/lib/AsmParser/LLParser.cpp
@@ -950,6 +950,9 @@ bool LLParser::parseSummaryEntry() {
   case lltok::kw_blockcount:
     result = parseBlockCount();
     break;
+  case lltok::kw_FuncsWithNonVtableRef:
+    result = parseFuncsWithNonVtableRefEntry(SummaryID);
+    break;
   default:
     result = error(Lex.getLoc(), "unexpected summary kind");
     break;
@@ -8203,6 +8206,19 @@ bool LLParser::parseTypeIdSummary(TypeIdSummary &TIS) {
 static ValueInfo EmptyVI =
     ValueInfo(false, (GlobalValueSummaryMapTy::value_type *)-8);
 
+/// FuncsWithNonVtableRefEntry
+///   ::= 'FuncsWithNonVtableRef' ':' '('
+///   ')'
+bool LLParser::parseFuncsWithNonVtableRefEntry(unsigned ID) {
+  assert(Lex.getKind() == lltok::kw_FuncsWithNonVtableRef);
+  Lex.Lex();
+  std::string Name;
+  if (parseToken(lltok::colon, "expected ':' here") ||
+      parseToken(lltok::lparen, "expected '(' here"))
+    return true;
+  return false;
+}
+
 /// TypeIdCompatibleVtableEntry
 ///   ::= 'typeidCompatibleVTable' ':' '(' 'name' ':' STRINGCONSTANT ','
 ///   TypeIdCompatibleVtableInfo
diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
index 1d1ec988a93d847..52b5926ee284ea1 100644
--- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
+++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp
@@ -7497,6 +7497,11 @@ Error ModuleSummaryIndexBitcodeReader::parseEntireSummary(unsigned ID) {
       LastSeenGUID = 0;
       break;
     }
+    case bitc::OBJC_FUNC_NON_VTABLE_REF: {
+      for (unsigned I = 0; I != Record.size(); I++)
+        TheIndex.addFuncWithNonVtableRef(Record[I]);
+      break;
+    }
     case bitc::FS_TYPE_TESTS:
       assert(PendingTypeTests.empty());
       llvm::append_range(PendingTypeTests, Record);
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index e991d055f33474b..22d7f9b639f0405 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -3906,6 +3906,19 @@ static void writeTypeIdSummaryRecord(SmallVector<uint64_t, 64> &NameVals,
                                       W.second);
 }
 
+static void writeFuncWithNonVTableRef(BitstreamWriter &Stream,
+                                      const std::set<uint64_t> &List,
+                                      SmallVectorImpl<uint64_t> &Record) {
+  if (List.empty())
+    return;
+
+  // For deterministic, we should sort the list.
+  for (auto l : List)
+    Record.push_back(l);
+  Stream.EmitRecord(bitc::OBJC_FUNC_NON_VTABLE_REF, Record);
+  Record.clear();
+}
+
 static void writeTypeIdCompatibleVtableSummaryRecord(
     SmallVector<uint64_t, 64> &NameVals, StringTableBuilder &StrtabBuilder,
     const std::string &Id, const TypeIdCompatibleVtableInfo &Summary,
@@ -4246,6 +4259,8 @@ void ModuleBitcodeWriterBase::writePerModuleGlobalValueSummary() {
     Stream.EmitRecord(bitc::FS_ALIAS, NameVals, FSAliasAbbrev);
     NameVals.clear();
   }
+  SmallVector<uint64_t, 64> Record;
+  writeFuncWithNonVTableRef(Stream, Index->funcsWithNonVtableRef(), Record);
 
   for (auto &S : Index->typeIdCompatibleVtableMap()) {
     writeTypeIdCompatibleVtableSummaryRecord(NameVals, StrtabBuilder, S.first,
diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp
index e190d82127908db..80610387860f68a 100644
--- a/llvm/lib/IR/AsmWriter.cpp
+++ b/llvm/lib/IR/AsmWriter.cpp
@@ -2941,6 +2941,19 @@ void AssemblyWriter::printModuleSummaryIndex() {
     Out << ") ; guid = " << GUID << "\n";
   }
 
+  // Print FuncsWithNonVtableRef.
+  // Need to find a case where this is set.
+  if (!TheIndex->funcsWithNonVtableRef().empty()) {
+    Out << "^" << NumSlots << " = FuncsWithNonVtableRef: (";
+    FieldSeparator FS;
+    for (auto FID : TheIndex->funcsWithNonVtableRef()) {
+      Out << FS;
+      Out << "^" << Machine.getGUIDSlot(FID);
+    }
+    Out << ")\n";
+    ++NumSlots;
+  }
+
   // Don't emit flags when it's not really needed (value is zero by default).
   if (TheIndex->getFlags()) {
     Out << "^" << NumSlots << " = flags: " << TheIndex->getFlags() << "\n";
diff --git a/llvm/lib/LTO/LTO.cpp b/llvm/lib/LTO/LTO.cpp
index 4a64aa4593d543c..5ce4f30215d4eac 100644
--- a/llvm/lib/LTO/LTO.cpp
+++ b/llvm/lib/LTO/LTO.cpp
@@ -51,6 +51,7 @@
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Target/TargetOptions.h"
 #include "llvm/Transforms/IPO.h"
+#include "llvm/Transforms/IPO/GlobalDCE.h"
 #include "llvm/Transforms/IPO/MemProfContextDisambiguation.h"
 #include "llvm/Transforms/IPO/WholeProgramDevirt.h"
 #include "llvm/Transforms/Utils/FunctionImportUtils.h"
@@ -75,6 +76,7 @@ cl::opt<bool> EnableLTOInternalization(
     "enable-lto-internalization", cl::init(true), cl::Hidden,
     cl::desc("Enable global value internalization in LTO"));
 }
+extern cl::opt<bool> EnableVFEOnThinLTO;
 
 /// Indicate we are linking with an allocator that supports hot/cold operator
 /// new interfaces.
@@ -1736,6 +1738,12 @@ Error LTO::runThinLTO(AddStreamFn AddStream, FileCache Cache,
   std::map<ValueInfo, std::vector<VTableSlotSummary>> LocalWPDTargetsMap;
   runWholeProgramDevirtOnIndex(ThinLTO.CombinedIndex, ExportedGUIDs,
                                LocalWPDTargetsMap);
+  if (EnableVFEOnThinLTO) {
+    auto isRetained = [&](GlobalValue::GUID CalleeGUID) {
+      return GUIDPreservedSymbols.count(CalleeGUID);
+    };
+    runVFEOnIndex(ThinLTO.CombinedIndex, isRetained);
+  }
 
   auto isPrevailing = [&](GlobalValue::GUID GUID, const GlobalValueSummary *S) {
     return ThinLTO.PrevailingModuleForGUID[GUID] == S->modulePath();
diff --git a/llvm/lib/Passes/PassBuilderPipelines.cpp b/llvm/lib/Passes/PassBuilderPipelines.cpp
index 78e0e6353056343..348a11bcd17631d 100644
--- a/llvm/lib/Passes/PassBuilderPipelines.cpp
+++ b/llvm/lib/Passes/PassBuilderPipelines.cpp
@@ -281,6 +281,9 @@ cl::opt<bool> EnableMemProfContextDisambiguation(
     cl::ZeroOrMore, cl::desc("Enable MemProf context disambiguation"));
 
 extern cl::opt<bool> EnableInferAlignmentPass;
+cl::opt<bool> EnableVFEOnThinLTO("enable-vfe-on-thinlto", cl::init(false),
+                                 cl::Hidden, cl::ZeroOrMore,
+                                 cl::desc("enable VFE with ThinLTO"));
 
 PipelineTuningOptions::PipelineTuningOptions() {
   LoopInterleaving = true;
@@ -1583,6 +1586,9 @@ ModulePassManager PassBuilder::buildThinLTODefaultPipeline(
     OptimizationLevel Level, const ModuleSummaryIndex *ImportSummary) {
   ModulePassManager MPM;
 
+  if (EnableVFEOnThinLTO && ImportSummary)
+    MPM.addPass(GlobalDCEPass(nullptr, ImportSummary));
+
   if (ImportSummary) {
     // For ThinLTO we must apply the context disambiguation decisions early, to
     // ensure we can correctly match the callsites to summary data.
diff --git a/llvm/lib/Transforms/IPO/GlobalDCE.cpp b/llvm/lib/Transforms/IPO/GlobalDCE.cpp
index e36d524d7667ab3..17da4ed7912b3b5 100644
--- a/llvm/lib/Transforms/IPO/GlobalDCE.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalDCE.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/TypeMetadataUtils.h"
+#include "llvm/Bitcode/BitcodeReader.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Module.h"
@@ -26,6 +27,11 @@
 #include "llvm/Transforms/Utils/CtorUtils.h"
 #include "llvm/Transforms/Utils/GlobalStatus.h"
 
+#include "llvm/IR/ModuleSummaryIndex.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/GlobPattern.h"
+#include "llvm/Support/MemoryBuffer.h"
+
 using namespace llvm;
 
 #define DEBUG_TYPE "globaldce"
@@ -34,12 +40,223 @@ static cl::opt<bool>
     ClEnableVFE("enable-vfe", cl::Hidden, cl::init(true),
                 cl::desc("Enable virtual function elimination"));
 
+static cl::opt<std::string> ClReadSummary(
+    "globaldce-read-summary",
+    cl::desc("Read summary from given bitcode before running pass"),
+    cl::Hidden);
+
 STATISTIC(NumAliases  , "Number of global aliases removed");
 STATISTIC(NumFunctions, "Number of functions removed");
 STATISTIC(NumIFuncs,    "Number of indirect functions removed");
 STATISTIC(NumVariables, "Number of global variables removed");
 STATISTIC(NumVFuncs,    "Number of virtual functions removed");
 
+namespace llvm {
+
+// Returning a representative summary for the vtable, also set isSafe.
+static const GlobalVarSummary *
+getVTableFuncsForTId(const TypeIdOffsetVtableInfo &P, bool &isSafe) {
+  // Find a representative copy of the vtable initializer.
+  const GlobalVarSummary *VS = nullptr;
+  bool LocalFound = false;
+  for (auto &S : P.VTableVI.getSummaryList()) {
+    if (GlobalValue::isLocalLinkage(S->linkage())) {
+      if (LocalFound) {
+        isSafe = false;
+        return nullptr;
+      }
+      LocalFound = true;
+    }
+    auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject());
+    // Ignore if vTableFuncs is empty and vtable is available_externally.
+    if (!CurVS->vTableFuncs().empty() ||
+        !GlobalValue::isAvailableExternallyLinkage(S->linkage())) {
+      VS = CurVS;
+      if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) {
+        isSafe = false;
+        return VS;
+      }
+    }
+  }
+
+  if (!VS) {
+    isSafe = false;
+    return nullptr;
+  }
+  if (!VS->isLive()) {
+    isSafe = true;
+    return nullptr;
+  }
+  isSafe = true;
+  return VS;
+}
+
+static void collectSafeVTables(
+    ModuleSummaryIndex &Summary,
+    DenseMap<GlobalValue::GUID, std::vector<StringRef>> &NameByGUID,
+    std::map<ValueInfo, std::vector<VirtFuncOffset>> &VFESafeVTablesAndFns) {
+  // Update VFESafeVTablesAndFns with information from summary.
+  for (auto &P : Summary.typeIdCompatibleVtableMap()) {
+    NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first);
+    LLVM_DEBUG(dbgs() << "TId " << GlobalValue::getGUID(P.first) << " "
+                      << P.first << "\n");
+  }
+  llvm::errs() << "VFEThinLTO number of TIds: " << NameByGUID.size() << "\n";
+
+  // VFESafeVTablesAndFns: map from VI for vTable to VI for vfunc
+  std::map<ValueInfo, std::set<GlobalValue::GUID>> vFuncSet;
+  unsigned numSafeVFuncs = 0;
+  // Collect stats for VTables (safe, public-visibility, other).
+  std::set<ValueInfo> vTablePublicVis;
+  std::set<ValueInfo> vTableOther;
+  for (auto &TidSummary : Summary.typeIdCompatibleVtableMap()) {
+    for (const TypeIdOffsetVtableInfo &P : TidSummary.second) {
+      LLVM_DEBUG(dbgs() << "TId-vTable " << TidSummary.first << " "
+                        << P.VTableVI.name() << " " << P.AddressPointOffset
+                        << "\n");
+      bool isSafe = false;
+      const GlobalVarSummary *VS = getVTableFuncsForTId(P, isSafe);
+      if (!isSafe && VS)
+        vTablePublicVis.insert(P.VTableVI);
+      if ((isSafe && !VS) || (!isSafe && !VS))
+        vTableOther.insert(P.VTableVI);
+      if (!isSafe || !VS) {
+        continue;
+      }
+
+      // Go through VS->vTableFuncs
+      for (auto VTP : VS->vTableFuncs()) {
+        if (vFuncSet.find(P.VTableVI) == vFuncSet.end() ||
+            !vFuncSet[P.VTableVI].count(VTP.FuncVI.getGUID())) {
+          VFESafeVTablesAndFns[P.VTableVI].push_back(VTP);
+          LLVM_DEBUG(dbgs()
+                     << "vTable " << P.VTableVI.name() << " "
+                     << VTP.FuncVI.name() << " " << VTP.VTableOffset << "\n");
+          ++numSafeVFuncs;
+        }
+        vFuncSet[P.VTableVI].insert(VTP.FuncVI.getGUID());
+      }
+    }
+  }
+  llvm::errs() << "VFEThinLTO number of vTables: " << vFuncSet.size() << " "
+               << vTablePublicVis.size() << " " << vTableOther.size() << "\n";
+  llvm::errs() << "VFEThinLTO numSafeVFuncs: " << numSafeVFuncs << "\n";
+}
+
+static void checkVTableLoadsIndex(
+    ModuleSummaryIndex &Summary,
+    DenseMap<GlobalValue::GUID, std::vector<StringRef>> &NameByGUID,
+    std::map<ValueInfo, std::vector<VirtFuncOffset>> &VFESafeVTablesAndFns,
+    std::map<ValueInfo, std::vector<ValueInfo>> &VFuncsAndCallers) {
+  // Go through Function summarys for intrinsics, also funcHasNonVTableRef to
+  // erase entries from VFESafeVTableAndFns.
+  for (auto &PI : Summary) {
+    for (auto &S : PI.second.SummaryList) {
+      auto *FS = dyn_cast<FunctionSummary>(S.get());
+      if (!FS)
+        continue;
+      // We should ignore Tid if there is a type.checked.load with Offset not
+      // ConstantInt. Currently ModuleSummaryAnalysis will update TypeTests.
+      for (GlobalValue::GUID G : FS->type_tests()) {
+        if (NameByGUID.find(G) == NameByGUID.end())
+          continue;
+        auto TidSummary =
+            Summary.getTypeIdCompatibleVtableSummary(NameByGUID[G][0]);
+        for (const TypeIdOffsetVtableInfo &P : *TidSummary) {
+          LLVM_DEBUG(dbgs() << "unsafe-vtable due to type_tests: "
+                            << P.VTableVI.name() << "\n");
+          VFESafeVTablesAndFns.erase(P.VTableVI);
+        }
+      }
+      // Go through vTableFuncs to find the potential callees
+      auto CheckVLoad = [&](GlobalValue::GUID TId, uint64_t Offset) {
+        if (!NameByGUID.count(TId)) {
+          return;
+        }
+        auto TidSummary =
+            Summary.getTypeIdCompatibleVtableSummary(NameByGUID[TId][0]);
+        for (const TypeIdOffsetVtableInfo &P : *TidSummary) {
+          uint64_t VTableOffset = P.AddressPointOffset;
+          bool isSafe = false;
+          const GlobalVarSummary *VS = getVTableFuncsForTId(P, isSafe);
+          if (!isSafe || !VS)
+            continue;
+          unsigned foundCnt = 0;
+          for (auto VTP : VS->vTableFuncs()) {
+            if (VTP.VTableOffset != VTableOffset + Offset)
+              continue;
+            assert(foundCnt == 0);
+            foundCnt = 1;
+            VFuncsAndCallers[VTP.FuncVI].push_back(Summary.getValueInfo(PI));
+          }
+          if (foundCnt == 0) {
+            // A vtable is unsafe if a given vtable for the typeId doesn't have
+            // the offset.
+            VFESafeVTablesAndFns.erase(P.VTableVI);
+            LLVM_DEBUG(dbgs() << "unsafe-vtable can't find offset "
+                              << P.VTableVI.name() << " "
+                              << (VTableOffset + Offset) << "\n");
+          }
+        }
+      };
+      for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
+        CheckVLoad(VF.GUID, VF.Offset);
+      }
+      for (const FunctionSummary::ConstVCall &VC :
+           FS->type_checked_load_const_vcalls()) {
+        CheckVLoad(VC.VFunc.GUID, VC.VFunc.Offset);
+      }
+    }
+  }
+}
+
+void runVFEOnIndex(ModuleSummaryIndex &Summary,
+                   function_ref<bool(GlobalValue::GUID)> isRetained) {
+  if (Summary.typeIdCompatibleVtableMap().empty())
+    return;
+
+  DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID;
+  std::map<ValueInfo, std::vector<VirtFuncOffset>> VFESafeVTablesAndFns;
+  collectSafeVTables(Summary, NameByGUID, VFESafeVTablesAndFns);
+
+  // Go through Function summarys for intrinsics, also funcHasNonVTableRef to
+  // erase entries from VFESafeVTableAndFns.
+  std::map<ValueInfo, std::vector<ValueInfo>> VFuncsAndCallers;
+  checkVTableLoadsIndex(Summary, NameByGUID, VFESafeVTablesAndFns,
+                        VFuncsAndCallers);
+  llvm::errs() << "VFEThinLTO number of vTables: "
+               << VFESafeVTablesAndFns.size() << "\n";
+
+  // Generate list of vfuncs that can be removed.
+  std::set<ValueInfo> candidateSet;
+  for (auto entry : VFESafeVTablesAndFns) {
+    for (auto vfunc : entry.second) {
+      if (Summary.isFuncWithNonVtableRef(vfunc.FuncVI.getGUID())) {
+        LLVM_DEBUG(dbgs() << "unsafe-vfunc with non-vtable ref "
+                          << vfunc.FuncVI.name() << "\n");
+        continue;
+      }
+      if (!candidateSet.count(vfunc.FuncVI)) {
+        candidateSet.insert(vfunc.FuncVI);
+      }
+    }
+  }
+
+  // It is possible to use a workList to recursively mark vfuncs as removable.
+  // For now, only remove vfuncs that are not in VFuncsAndCallers.
+  std::set<ValueInfo> removable;
+  for (auto vfunc : candidateSet) {
+    if (!VFuncsAndCallers.count(vfunc) && !isRetained(vfunc.getGUID()))
+      removable.insert(vfunc);
+  }
+
+  for (auto fVI : removable) {
+    Summary.addFuncToBeRemoved(fVI.getGUID());
+  }
+  llvm::errs() << "VFEThinLTO removable " << removable.size() << "\n";
+}
+} // end namespace llvm
+
 /// Returns true if F is effectively empty.
 static bool isEmptyFunction(Function *F) {
   // Skip external functions.
@@ -220,6 +437,8 @@ void GlobalDCEPass::ScanTypeCheckedLoadIntrinsics(Module &M) {
 }
 
 void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) {
+  if (ImportSummary)
+    return; // only use results from Index with ThinLTO
   if (!ClEnableVFE)
     return;
 
@@ -247,6 +466,23 @@ void GlobalDCEPass::AddVirtualFunctionDependencies(Module &M) {
 }
 
 PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
+  // Handle the command-line summary arguments. This code is for testing
+  // purposes only, so we handle errors directly.
+  std::unique_ptr<ModuleSummaryIndex> Summary =
+      std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false);
+  if (!ClReadSummary.empty()) {
+    ExitOnError ExitOnErr("-globaldce-read-summary: " + ClReadSummary + ": ");
+    auto ReadSummaryFile =
+        ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
+    if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr =
+            getModuleSummaryIndex(*ReadSummaryFile)) {
+      Summary = std::move(*SummaryOrErr);
+    }
+    ImportSummary = Summary.get();
+    auto isRetained = [&](GlobalValue::GUID CalleeGUID) { return false; };
+    runVFEOnIndex(*(const_cast<ModuleSummaryIndex *>(ImportSummary)),
+                  isRetained);
+  }
   bool Changed = false;
 
   // The algorithm first computes the set L of global variables that are
@@ -338,12 +574,33 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) {
 
   // The second pass drops the bodies of functions which are dead...
   std::vector<Function *> DeadFunctions;
-  for (Function &F : M)
+  std::set<Function *> DeadFunctionsSet;
+  auto funcRemovedInIndex = [&](Function *F) -> bool {
+    if (!ImportSummary)
+      return false;
+    auto &vfuncsRemoved = ImportSummary->funcsToBeRemoved();
+    // If function is internalized, its current GUID will be different
+    // from the GUID in funcsToBeRemoved. Treat it as a global and search
+    // again.
+    if (vfuncsRemoved.count(F->getGUID()))
+      return true;
+    if (vfuncsRemoved.count(GlobalValue::getGUID(F->getName())))
+      return true;
+    return false;
+  };
+  for (Function &F : M) {
     if (!AliveGlobals.count(&F)) {
-      DeadFunctions.push_back(&F);         // Keep track of dead globals
+      DeadFunctions.push_back(&F); // Keep track of dead globals
+      DeadFunctionsSet.insert(&F);
+      if (!F.isDeclaration())
+        F.deleteBody();
+    } else if (funcRemovedInIndex(&F)) {
+      DeadFunctions.push_back(&F); // Keep track of dead globals
+      DeadFunctionsSet.insert(&F);
       if (!F.isDeclaration())
         F.deleteBody();
     }
+  }
 
   // The third pass drops targets of aliases which are dead...
   std::vector<GlobalAlias*> DeadAliases;



More information about the cfe-commits mailing list