[llvm] llvm module splitter (PR #121543)

weiwei chen via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 2 20:31:52 PST 2025


https://github.com/weiweichen created https://github.com/llvm/llvm-project/pull/121543

None

>From 9bc7b85383642995231afc4030457d25bc5f27b5 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Sat, 30 Nov 2024 21:18:37 -0500
Subject: [PATCH 1/3] Add files.

---
 llvm/include/llvm/Support/ModuleSplitter.h | 19 +++++++++++++++++++
 llvm/lib/Support/CMakeLists.txt            |  1 +
 llvm/lib/Support/ModuleSplitter.cpp        | 14 ++++++++++++++
 3 files changed, 34 insertions(+)
 create mode 100644 llvm/include/llvm/Support/ModuleSplitter.h
 create mode 100644 llvm/lib/Support/ModuleSplitter.cpp

diff --git a/llvm/include/llvm/Support/ModuleSplitter.h b/llvm/include/llvm/Support/ModuleSplitter.h
new file mode 100644
index 00000000000000..9f01bac925d88c
--- /dev/null
+++ b/llvm/include/llvm/Support/ModuleSplitter.h
@@ -0,0 +1,19 @@
+//===- ModuleSplitter.h - Module Splitter Functions -------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+/// \file
+///
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_SUPPORT_MODULESPLITTER_H
+#define LLVM_SUPPORT_MODULESPLITTER_H
+namespace llvm {
+
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt
index 2ecaea4b02bf61..4694e3102dd036 100644
--- a/llvm/lib/Support/CMakeLists.txt
+++ b/llvm/lib/Support/CMakeLists.txt
@@ -214,6 +214,7 @@ add_llvm_component_library(LLVMSupport
   MemoryBuffer.cpp
   MemoryBufferRef.cpp
   ModRef.cpp
+  ModuleSplitter.cpp
   MD5.cpp
   MSP430Attributes.cpp
   MSP430AttributeParser.cpp
diff --git a/llvm/lib/Support/ModuleSplitter.cpp b/llvm/lib/Support/ModuleSplitter.cpp
new file mode 100644
index 00000000000000..ea3a37656bcbc9
--- /dev/null
+++ b/llvm/lib/Support/ModuleSplitter.cpp
@@ -0,0 +1,14 @@
+//===--- ModuleSplitter.cpp - Module Splitter -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/ModRef.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace llvm;

>From 220c5788fd11bb4d0f69e5aaa04d854132dec65e Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Sat, 30 Nov 2024 21:19:00 -0500
Subject: [PATCH 2/3] Add ModuleSplitter.h

---
 llvm/include/llvm/Support/ModuleSplitter.h | 59 ++++++++++++++++++++++
 llvm/lib/Support/ModuleSplitter.cpp        |  3 +-
 2 files changed, 61 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Support/ModuleSplitter.h b/llvm/include/llvm/Support/ModuleSplitter.h
index 9f01bac925d88c..b09db1ee022b0f 100644
--- a/llvm/include/llvm/Support/ModuleSplitter.h
+++ b/llvm/include/llvm/Support/ModuleSplitter.h
@@ -12,8 +12,67 @@
 
 #ifndef LLVM_SUPPORT_MODULESPLITTER_H
 #define LLVM_SUPPORT_MODULESPLITTER_H
+
+#include "llvm/ADT/FunctionExtras.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/Error.h"
 namespace llvm {
 
+//===----------------------------------------------------------------------===//
+// LLVMModuleAndContext
+//===----------------------------------------------------------------------===//
+
+/// A pair of an LLVM module and the LLVM context that holds ownership of the
+/// objects. This is a useful class for parallelizing LLVM and managing
+/// ownership of LLVM instances.
+class LLVMModuleAndContext {
+public:
+  /// Expose the underlying LLVM context to create the module. This is the only
+  /// way to access the LLVM context to prevent accidental sharing.
+  Error create(
+      function_ref<ErrorOr<std::unique_ptr<llvm::Module>>(llvm::LLVMContext &)>
+          CreateModule);
+
+  llvm::Module &operator*() { return *Module; }
+  llvm::Module *operator->() { return Module.get(); }
+
+  void reset();
+
+private:
+  /// LLVM context stored in a unique pointer so that we can move this type.
+  std::unique_ptr<llvm::LLVMContext> CTX =
+      std::make_unique<llvm::LLVMContext>();
+  /// The paired LLVM module.
+  std::unique_ptr<llvm::Module> Module;
+};
+
+//===----------------------------------------------------------------------===//
+// Module Splitter
+//===----------------------------------------------------------------------===//
+
+using LLVMSplitProcessFn =
+    function_ref<void(llvm::unique_function<LLVMModuleAndContext()>,
+                      std::optional<int64_t>, unsigned)>;
+
+/// Helper to create a lambda that just forwards a preexisting Module.
+inline llvm::unique_function<LLVMModuleAndContext()>
+forwardModule(LLVMModuleAndContext &&Module) {
+  return [Module = std::move(Module)]() mutable { return std::move(Module); };
+}
+
+/// Support for splitting an LLVM module into multiple parts using anchored
+/// functions (e.g. exported functions), and pull in all dependency on the
+// call stack into one module.
+void splitPerAnchored(LLVMModuleAndContext Module,
+                      LLVMSplitProcessFn ProcessFn,
+                      llvm::SmallVectorImpl<llvm::Function>& Anchors);
+
+/// Support for splitting an LLVM module into multiple parts with each part
+/// contains only one function.
+void splitPerFunction(
+    LLVMModuleAndContext Module, LLVMSplitProcessFn ProcessFn);
+
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/Support/ModuleSplitter.cpp b/llvm/lib/Support/ModuleSplitter.cpp
index ea3a37656bcbc9..be85707386b0d4 100644
--- a/llvm/lib/Support/ModuleSplitter.cpp
+++ b/llvm/lib/Support/ModuleSplitter.cpp
@@ -8,7 +8,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "llvm/Support/ModRef.h"
+#include "llvm/Support/ModuleSplitter.h"
 #include "llvm/ADT/STLExtras.h"
 
+
 using namespace llvm;

>From 30c4c1d9aead7900298953e86b2b901fe2eafcb6 Mon Sep 17 00:00:00 2001
From: Weiwei Chen <weiwei.chen at modular.com>
Date: Sat, 30 Nov 2024 23:45:09 -0500
Subject: [PATCH 3/3] checkpoint.

---
 llvm/include/llvm/Support/ModuleSplitter.h |   6 +-
 llvm/lib/Support/ModuleSplitter.cpp        | 811 ++++++++++++++++++++-
 2 files changed, 813 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/Support/ModuleSplitter.h b/llvm/include/llvm/Support/ModuleSplitter.h
index b09db1ee022b0f..912d8edb7c189d 100644
--- a/llvm/include/llvm/Support/ModuleSplitter.h
+++ b/llvm/include/llvm/Support/ModuleSplitter.h
@@ -30,8 +30,8 @@ class LLVMModuleAndContext {
 public:
   /// Expose the underlying LLVM context to create the module. This is the only
   /// way to access the LLVM context to prevent accidental sharing.
-  Error create(
-      function_ref<ErrorOr<std::unique_ptr<llvm::Module>>(llvm::LLVMContext &)>
+  Expected<bool> create(
+      function_ref<Expected<std::unique_ptr<llvm::Module>>(llvm::LLVMContext &)>
           CreateModule);
 
   llvm::Module &operator*() { return *Module; }
@@ -41,7 +41,7 @@ class LLVMModuleAndContext {
 
 private:
   /// LLVM context stored in a unique pointer so that we can move this type.
-  std::unique_ptr<llvm::LLVMContext> CTX =
+  std::unique_ptr<llvm::LLVMContext> Ctx =
       std::make_unique<llvm::LLVMContext>();
   /// The paired LLVM module.
   std::unique_ptr<llvm::Module> Module;
diff --git a/llvm/lib/Support/ModuleSplitter.cpp b/llvm/lib/Support/ModuleSplitter.cpp
index be85707386b0d4..110062a6990b37 100644
--- a/llvm/lib/Support/ModuleSplitter.cpp
+++ b/llvm/lib/Support/ModuleSplitter.cpp
@@ -9,7 +9,816 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Support/ModuleSplitter.h"
-#include "llvm/ADT/STLExtras.h"
 
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Bitcode/BitcodeReader.h"
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/GlobalStatus.h"
+#include "llvm/Transforms/Utils/SplitModule.h"
+#include "llvm/Transforms/Utils/ValueMapper.h"
 
 using namespace llvm;
+#define DEBUG_TYPE "llvm-module-split"
+
+//===----------------------------------------------------------------------===//
+// LLVMModuleAndContext
+//===----------------------------------------------------------------------===//
+
+Expected<bool> LLVMModuleAndContext::create(
+    function_ref<Expected<std::unique_ptr<llvm::Module>>(llvm::LLVMContext &)>
+        CreateModule) {
+  assert(!Module && "already have a module");
+  auto ModuleOr = CreateModule(*Ctx);
+  if (Error Err = ModuleOr.takeError())
+    return Err;
+
+  Module = std::move(*ModuleOr);
+  return true;
+}
+
+void LLVMModuleAndContext::reset() {
+  Module.reset();
+  Ctx.reset();
+}
+
+//===----------------------------------------------------------------------===//
+// StringConstantTable
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Large strings are very inefficiently encoded in LLVM bitcode (each `char` is
+/// encoded as a `uint64_t`). The LLVM bitcode reader is also very inefficiently
+/// reads strings back, performing 3 ultimate copies of the data. This is made
+/// worse by the fact the `getLazyBitcodeModule` does not lazily parse constants
+/// from the LLVM bitcode. Thus, when per-function splitting a module with N
+/// functions and M large string constants, we form 3*M*N copies of the large
+/// strings.
+///
+/// This class is part of a workaround of this inefficiency. When processing a
+/// module for splitting, we track any string global constants and their indices
+/// in this table. If a module is going to be roundtripped through bitcode to be
+/// lazily loaded, we externalize the strings by setting the corresponding
+/// constants to `zeroinitializer` in the module before it is written to
+/// bitcode. As we materialize constants on the other side, we check for a
+/// materialized global variable that matches an entry in the string table and
+/// directly copy the data over into the new LLVM context.
+///
+/// We can generalize this optimization to other large data types as necessary.
+///
+/// This class is used in an `RCRef` to be shared across multiple threads.
+class StringConstantTable
+    : public ThreadSafeRefCountedBase<StringConstantTable> {
+  /// An entry in the string table consists of a global variable, its module
+  /// index, and the a reference to the string data. Because the string data is
+  /// owned by the original LLVM context, we have to ensure it stays alive.
+  struct Entry {
+    unsigned Idx;
+    const llvm::GlobalVariable *Var;
+    StringRef Value;
+  };
+
+public:
+  /// If `Value` denotes a string constant, record the data at index `GvIdx`.
+  void recordIfStringConstant(unsigned GvIdx, const llvm::GlobalValue &Value) {
+    auto Var = dyn_cast<llvm::GlobalVariable>(&Value);
+    if (Var && Var->isConstant() && Var->hasInternalLinkage()) {
+      auto *Init =
+          dyn_cast<llvm::ConstantDataSequential>(Var->getInitializer());
+      if (Init && Init->isCString())
+        StringConstants.push_back(Entry{GvIdx, Var, Init->getAsString()});
+    }
+  }
+
+  /// Before writing the main Module to bitcode, externalize large string
+  /// constants by stubbing out their values. Take ownership of the main Module
+  /// so the string data stays alive.
+  llvm::Module &externalizeStrings(LLVMModuleAndContext &&Module) {
+    MainModule = std::move(Module);
+    // Stub the initializers. The global variable is an internal constant, so it
+    // must have an initializer.
+    for (Entry &E : StringConstants) {
+      auto *Stub =
+          llvm::Constant::getNullValue(E.Var->getInitializer()->getType());
+      // `const_cast` is OK because we own the module now.
+      const_cast<llvm::GlobalVariable *>(E.Var)->setInitializer(Stub);
+    }
+    return *MainModule;
+  }
+
+  /// This is an iterator over the entries in the string table.
+  class Injector {
+    using const_iterator = std::vector<Entry>::const_iterator;
+
+  public:
+    /// Given a global variable in a materialized module and its index, if it is
+    /// a string constant found in the table, copy the data over into the new
+    /// LLVM context and set the initializer.
+    void materializeIfStringConstant(unsigned GvIdx,
+                                     llvm::GlobalVariable &Var) {
+      while (It != Et && It->Idx < GvIdx)
+        ++It;
+      if (It == Et || It->Idx != GvIdx)
+        return;
+      Var.setInitializer(llvm::ConstantDataArray::getString(
+          Var.getType()->getContext(), It->Value, /*AddNull=*/false));
+    }
+
+  private:
+    explicit Injector(const_iterator It, const_iterator Et) : It(It), Et(Et) {}
+
+    const_iterator It, Et;
+
+    friend class StringConstantTable;
+  };
+
+  Injector begin() const {
+    return Injector(StringConstants.begin(), StringConstants.end());
+  }
+
+private:
+  std::vector<Entry> StringConstants;
+  LLVMModuleAndContext MainModule;
+};
+
+//===----------------------------------------------------------------------===//
+// Module Splitter
+//===----------------------------------------------------------------------===//
+
+class LLVMModuleSplitterImpl {
+public:
+  explicit LLVMModuleSplitterImpl(LLVMModuleAndContext Module)
+      : MainModule(std::move(Module)) {}
+
+  /// Split the LLVM module into multiple modules using the provided process
+  /// function.
+  void split(LLVMSplitProcessFn ProcessFn,
+             llvm::SmallVectorImpl<llvm::Function> &Anchors);
+
+private:
+  struct ValueInfo {
+    /// The immediate global value dependencies of a value.
+    SmallVector<const llvm::GlobalValue *> Dependencies;
+    /// Map each global value to its index in the module. We will use this to
+    /// materialize global values from bitcode.
+    unsigned GvIdx;
+  };
+
+  struct TransitiveDeps {
+    /// The transitive dependencies.
+    llvm::MapVector<const llvm::GlobalValue *, unsigned> Deps;
+    /// True if computation is complete.
+    bool Complete = false;
+    /// The assigned module index.
+    std::optional<unsigned> MutIdx;
+  };
+
+  /// Collect the immediate global value dependencies of `Value`. `Orig` is the
+  /// original transitive value, which is not equal to `Value` when it is used
+  /// in a constant.
+  void collectImmediateDependencies(const llvm::Value *Value,
+                                    const llvm::GlobalValue *Orig);
+
+  /// The main LLVM module being split.
+  LLVMModuleAndContext MainModule;
+
+  /// The value info for each global value in the module.
+  llvm::DenseMap<const llvm::Value *, ValueInfo> Infos;
+
+  /// The transitive dependencies of each global value.
+  llvm::MapVector<const llvm::GlobalValue *, TransitiveDeps> TransDeps;
+
+  /// Users of split "anchors". These are global values where we don't want
+  /// their users to be split into different modules because it will cause the
+  /// symbol to be duplicated.
+  llvm::MapVector<const llvm::GlobalValue *, llvm::SetVector<TransitiveDeps *>>
+      SplitAnchorUsers;
+};
+} // namespace
+
+static LLVMModuleAndContext readAndMaterializeDependencies(
+    MemoryBuffer &Buf,
+    const llvm::MapVector<const llvm::GlobalValue *, unsigned> &Set,
+    const StringConstantTable &Strtab) {
+
+  // First, create a lazy module with an internal bitcode materializer.
+  // TODO: Not sure how to make lazy loading metadata work.
+  LLVMModuleAndContext Result;
+  {
+    (void)Result.create(
+        [&](llvm::LLVMContext &Ctx) -> Expected<std::unique_ptr<Module>> {
+          return llvm::cantFail(llvm::getLazyBitcodeModule(
+              llvm::MemoryBufferRef(Buf.getBuffer(), "<split-module>"), Ctx,
+              /*ShouldLazyLoadMetadata=*/false));
+        });
+    Result->setModuleInlineAsm("");
+  }
+
+  SmallVector<unsigned> SortIndices =
+      llvm::to_vector(llvm::make_second_range(Set));
+  llvm::sort(SortIndices, std::less<unsigned>());
+  auto IdxIt = SortIndices.begin();
+  auto IdxEnd = SortIndices.end();
+
+  // The global value indices go from globals, functions, then aliases. This
+  // mirrors the order in which global values are deleted by LLVM's GlobalDCE.
+  unsigned CurIdx = 0;
+  StringConstantTable::Injector It = Strtab.begin();
+  // We need to keep the IR "valid" for the verifier because `materializeAll`
+  // may invoke it. It doesn't matter since we're deleting the globals anyway.
+  for (llvm::GlobalVariable &Global : Result->globals()) {
+    if (IdxIt != IdxEnd && CurIdx == *IdxIt) {
+      ++IdxIt;
+      llvm::cantFail(Global.materialize());
+      It.materializeIfStringConstant(CurIdx, Global);
+    } else {
+      Global.setInitializer(nullptr);
+      Global.setComdat(nullptr);
+      Global.setLinkage(llvm::GlobalValue::ExternalLinkage);
+      // External link should not be DSOLocal anymore,
+      // otherwise position independent code generates
+      // `R_X86_64_PC32` instead of `R_X86_64_REX_GOTPCRELX`
+      // for these symbols and building shared library from
+      // a static archive of this module will error with an `fPIC` confusion.
+      Global.setDSOLocal(false);
+    }
+    ++CurIdx;
+  }
+  for (llvm::Function &Func : Result->functions()) {
+    if (IdxIt != IdxEnd && CurIdx == *IdxIt) {
+      ++IdxIt;
+      llvm::cantFail(Func.materialize());
+    } else {
+      Func.deleteBody();
+      Func.setComdat(nullptr);
+      Func.setLinkage(llvm::GlobalValue::ExternalLinkage);
+      // External link should not be DSOLocal anymore,
+      // otherwise position independent code generates
+      // `R_X86_64_PC32` instead of `R_X86_64_REX_GOTPCRELX`
+      // for these symbols and building shared library from
+      // a static archive of this module will error with an `fPIC` confusion.
+      // External link should not be DSOLocal anymore,
+      // otherwise position independent code generation get confused.
+      Func.setDSOLocal(false);
+    }
+    ++CurIdx;
+  }
+
+  // Finalize materialization of the module.
+  llvm::cantFail(Result->materializeAll());
+
+  // Now that the module is materialized, we can start deleting stuff. Just
+  // delete declarations with no uses.
+  for (llvm::GlobalVariable &Global :
+       llvm::make_early_inc_range(Result->globals())) {
+    if (Global.isDeclaration() && Global.use_empty())
+      Global.eraseFromParent();
+  }
+  for (llvm::Function &Func : llvm::make_early_inc_range(Result->functions())) {
+    if (Func.isDeclaration() && Func.use_empty())
+      Func.eraseFromParent();
+  }
+  return Result;
+}
+
+/// support for splitting an LLVM module into multiple parts using exported
+/// functions as anchors, and pull in all dependency on the call stack into one
+/// module.
+void splitPerAnchored(LLVMModuleAndContext Module, LLVMSplitProcessFn ProcessFn,
+                      llvm::SmallVectorImpl<llvm::Function> &Anchors) {
+  LLVMModuleSplitterImpl impl(std::move(Module));
+  impl.split(ProcessFn, Anchors);
+}
+
+void LLVMModuleSplitterImpl::split(
+    LLVMSplitProcessFn processFn,
+    llvm::SmallVectorImpl<llvm::Function> &Anchors) {
+  // The use-def list is sparse. Use it to build a sparse dependency graph
+  // between global values.
+  auto strtab = RCRef<StringConstantTable>::create();
+  unsigned gvIdx = 0;
+  auto computeDeps = [&](const llvm::GlobalValue &value) {
+    strtab->recordIfStringConstant(gvIdx, value);
+    infos[&value].gvIdx = gvIdx++;
+    collectImmediateDependencies(&value, &value);
+  };
+  // NOTE: The visitation of globals then functions has to line up with
+  // `readAndMaterializeDependencies`.
+  for (const llvm::GlobalVariable &global : mainModule->globals()) {
+    computeDeps(global);
+    if (!global.hasInternalLinkage() && !global.hasPrivateLinkage())
+      transitiveDeps[&global];
+  }
+  for (const llvm::Function &fn : mainModule->functions()) {
+    computeDeps(fn);
+    if (!fn.isDeclaration() && (fn.hasExternalLinkage() || fn.hasWeakLinkage()))
+      transitiveDeps[&fn];
+  }
+
+  // If there is only one (or fewer) exported functions, forward the main
+  // module.
+  if (transitiveDeps.size() <= 1)
+    return processFn(forwardModule(std::move(mainModule)), std::nullopt,
+                     /*numFunctionBase=*/0);
+
+  // Now for each export'd global value, compute the transitive set of
+  // dependencies using DFS.
+  SmallVector<const llvm::GlobalValue *> worklist;
+  for (auto &[value, deps] : transitiveDeps) {
+    worklist.clear();
+    worklist.push_back(value);
+    while (!worklist.empty()) {
+      const llvm::GlobalValue *it = worklist.pop_back_val();
+
+      auto [iter, inserted] = deps.deps.insert({it, -1});
+      if (!inserted) {
+        // Already visited.
+        continue;
+      }
+      // Pay the cost of the name lookup only on a miss.
+      const ValueInfo &info = infos.at(it);
+      iter->second = info.gvIdx;
+
+      // If this value depends on another value that is going to be split, we
+      // don't want to duplicate the symbol. Keep all the users together.
+      if (it != value) {
+        if (auto depIt = transitiveDeps.find(it);
+            depIt != transitiveDeps.end()) {
+          auto &users = splitAnchorUsers[it];
+          users.insert(&deps);
+          // Make sure to include the other value in its own user list.
+          users.insert(&depIt->second);
+          // We don't have to recurse since the subgraph will get processed.
+          continue;
+        }
+      }
+
+      // If this value depends on a mutable global, keep track of it. We have to
+      // put all users of a mutable global in the same module.
+      if (auto *global = dyn_cast<llvm::GlobalVariable>(it);
+          global && !global->isConstant())
+        splitAnchorUsers[global].insert(&deps);
+
+      // Recursive on dependencies.
+      llvm::append_range(worklist, info.dependencies);
+    }
+
+    deps.complete = true;
+  }
+
+  // For each mutable global, grab all the transitive users and put them in one
+  // module. If global A has user set A* and global B has user set B* where
+  // A* and B* have an empty intersection, all values in A* will be assigned 0
+  // and all values in B* will be assigned 1. If global C has user set C* that
+  // overlaps both A* and B*, it will overwrite both to 2.
+  SmallVector<SmallVector<TransitiveDeps *>> bucketing(splitAnchorUsers.size());
+  for (auto [curMutIdx, bucket, users] :
+       llvm::enumerate(bucketing, llvm::make_second_range(splitAnchorUsers))) {
+    for (TransitiveDeps *deps : users) {
+      if (deps->mutIdx && *deps->mutIdx != curMutIdx) {
+        auto &otherBucket = bucketing[*deps->mutIdx];
+        for (TransitiveDeps *other : otherBucket) {
+          bucket.push_back(other);
+          other->mutIdx = curMutIdx;
+        }
+        otherBucket.clear();
+        assert(*deps->mutIdx == curMutIdx);
+      } else {
+        bucket.push_back(deps);
+        deps->mutIdx = curMutIdx;
+      }
+    }
+  }
+
+  // Now that we have assigned buckets to each value, merge the transitive
+  // dependency sets of all values belonging to the same set.
+  SmallVector<llvm::MapVector<const llvm::GlobalValue *, unsigned>> buckets(
+      bucketing.size());
+  for (auto [deps, bucket] : llvm::zip(bucketing, buckets)) {
+    for (TransitiveDeps *dep : deps) {
+      for (auto &namedValue : dep->deps)
+        bucket.insert(namedValue);
+    }
+  }
+
+  SmallVector<llvm::MapVector<const llvm::GlobalValue *, unsigned> *>
+      setsToProcess;
+  setsToProcess.reserve(buckets.size() + transitiveDeps.size());
+
+  // Clone each mutable global bucket into its own module.
+  for (auto &bucket : buckets) {
+    if (bucket.empty())
+      continue;
+    setsToProcess.push_back(&bucket);
+  }
+
+  for (auto &[root, deps] : transitiveDeps) {
+    // Skip values included in another transitive dependency set and values
+    // included in mutable global sets.
+    if (!deps.mutIdx)
+      setsToProcess.push_back(&deps.deps);
+  }
+
+  if (setsToProcess.size() <= 1)
+    return processFn(forwardModule(std::move(mainModule)), std::nullopt,
+                     /*numFunctionBase=*/0);
+
+  // Sort the sets by to schedule the larger modules first.
+  llvm::sort(setsToProcess,
+             [](auto *lhs, auto *rhs) { return lhs->size() > rhs->size(); });
+
+  // Prepare to materialize slices of the module by first writing the main
+  // module as bitcode to a shared buffer.
+  auto buf = WriteableBuffer::get();
+  {
+    CompilerTimeTraceScope traceScope("writeMainModuleBitcode");
+    llvm::Module &module = strtab->externalizeStrings(std::move(mainModule));
+    llvm::WriteBitcodeToFile(module, *buf);
+  }
+
+  unsigned numFunctions = 0;
+  for (auto [idx, set] : llvm::enumerate(setsToProcess)) {
+    unsigned next = numFunctions + set->size();
+    auto makeModule = [set = std::move(*set), buf = BufferRef(buf.copy()),
+                       strtab = strtab.copy()]() mutable {
+      return readAndMaterializeDependencies(std::move(buf), set, *strtab,
+                                            /*ignoreFns=*/{});
+    };
+    processFn(std::move(makeModule), idx, numFunctions);
+    numFunctions = next;
+  }
+}
+
+void LLVMModuleSplitterImpl::collectImmediateDependencies(
+    const llvm::Value *value, const llvm::GlobalValue *orig) {
+  for (const llvm::Value *user : value->users()) {
+    // Recurse into pure constant users.
+    if (isa<llvm::Constant>(user) && !isa<llvm::GlobalValue>(user)) {
+      collectImmediateDependencies(user, orig);
+      continue;
+    }
+
+    if (auto *inst = dyn_cast<llvm::Instruction>(user)) {
+      const llvm::Function *func = inst->getParent()->getParent();
+      infos[func].dependencies.push_back(orig);
+    } else if (auto *globalVal = dyn_cast<llvm::GlobalValue>(user)) {
+      infos[globalVal].dependencies.push_back(orig);
+    } else {
+      llvm_unreachable("unexpected user of global value");
+    }
+  }
+}
+
+namespace {
+/// This class provides support for splitting an LLVM module into multiple
+/// parts.
+/// TODO: Clean up the splitters here (some code duplication) when we can move
+/// to per function llvm compilation.
+class LLVMModulePerFunctionSplitterImpl {
+public:
+  LLVMModulePerFunctionSplitterImpl(LLVMModuleAndContext module)
+      : mainModule(std::move(module)) {}
+
+  /// Split the LLVM module into multiple modules using the provided process
+  /// function.
+  void
+  split(LLVMSplitProcessFn processFn,
+        llvm::StringMap<llvm::GlobalValue::LinkageTypes> &symbolLinkageTypes,
+        unsigned numFunctionBase);
+
+private:
+  struct ValueInfo {
+    const llvm::Value *value = nullptr;
+    bool canBeSplit = true;
+    llvm::SmallPtrSet<const llvm::GlobalValue *, 4> dependencies;
+    llvm::SmallPtrSet<const llvm::GlobalValue *, 4> users;
+    /// Map each global value to its index in the module. We will use this to
+    /// materialize global values from bitcode.
+    unsigned gvIdx;
+    bool userEmpty = true;
+  };
+
+  /// Collect all of the immediate global value users of `value`.
+  void collectValueUsers(const llvm::GlobalValue *value);
+
+  /// Propagate use information through the module.
+  void propagateUseInfo();
+
+  /// The main LLVM module being split.
+  LLVMModuleAndContext mainModule;
+
+  /// The value info for each global value in the module.
+  llvm::MapVector<const llvm::GlobalValue *, ValueInfo> valueInfos;
+};
+} // namespace
+
+static void
+checkDuplicates(llvm::MapVector<const llvm::GlobalValue *, unsigned> &set,
+                llvm::StringSet<> &seenFns, llvm::StringSet<> &dupFns) {
+  for (auto [gv, _] : set) {
+    if (auto fn = dyn_cast<llvm::Function>(gv)) {
+      if (!seenFns.insert(fn->getName()).second) {
+        dupFns.insert(fn->getName());
+      }
+    }
+  }
+}
+
+/// support for splitting an LLVM module into multiple parts with each part
+/// contains only one function (with exception for coroutine related functions.)
+void KGEN::splitPerFunction(
+    LLVMModuleAndContext module, LLVMSplitProcessFn processFn,
+    llvm::StringMap<llvm::GlobalValue::LinkageTypes> &symbolLinkageTypes,
+    unsigned numFunctionBase) {
+  CompilerTimeTraceScope traceScope("splitPerFunction");
+  LLVMModulePerFunctionSplitterImpl impl(std::move(module));
+  impl.split(processFn, symbolLinkageTypes, numFunctionBase);
+}
+
+/// Split the LLVM module into multiple modules using the provided process
+/// function.
+void LLVMModulePerFunctionSplitterImpl::split(
+    LLVMSplitProcessFn processFn,
+    llvm::StringMap<llvm::GlobalValue::LinkageTypes> &symbolLinkageTypes,
+    unsigned numFunctionBase) {
+  // Compute the value info for each global in the module.
+  // NOTE: The visitation of globals then functions has to line up with
+  // `readAndMaterializeDependencies`.
+  auto strtab = RCRef<StringConstantTable>::create();
+  unsigned gvIdx = 0;
+  auto computeUsers = [&](const llvm::GlobalValue &value) {
+    strtab->recordIfStringConstant(gvIdx, value);
+    valueInfos[&value].gvIdx = gvIdx++;
+    collectValueUsers(&value);
+  };
+  llvm::for_each(mainModule->globals(), computeUsers);
+  llvm::for_each(mainModule->functions(), computeUsers);
+
+  // With use information collected, propagate it to the dependencies.
+  propagateUseInfo();
+
+  // Now we can split the module.
+  // We split the module per function and cloning any necessary dependencies:
+  // - For function dependencies, only clone the declaration unless its
+  //   coroutine related.
+  // - For other internal values, clone as is.
+  // This is much fine-grained splitting, which enables significantly higher
+  // levels of parallelism (and smaller generated artifacts).
+  // LLVM LTO style optimization may suffer a bit here since we don't have
+  // the full callstack present anymore in each cloned module.
+  llvm::DenseSet<const llvm::Value *> splitValues;
+  SmallVector<llvm::MapVector<const llvm::GlobalValue *, unsigned>>
+      setsToProcess;
+
+  // Hoist these collections to re-use memory allocations.
+  llvm::ValueToValueMapTy valueMap;
+  SmallPtrSet<const llvm::Value *, 4> splitDeps;
+  auto splitValue = [&](const llvm::GlobalValue *root) {
+    // If the function is already split, e.g. if it was a dependency of
+    // another function, skip it.
+    if (splitValues.count(root))
+      return;
+
+    auto &valueInfo = valueInfos[root];
+    valueMap.clear();
+    splitDeps.clear();
+    auto shouldSplit = [&](const llvm::GlobalValue *globalVal,
+                           const ValueInfo &info) {
+      // Only clone root and the declaration of its dependencies.
+      if (globalVal == root) {
+        splitDeps.insert(globalVal);
+        return true;
+      }
+
+      if ((info.canBeSplit || info.userEmpty) &&
+          isa_and_nonnull<llvm::Function>(globalVal))
+        return false;
+
+      if (valueInfo.dependencies.contains(globalVal)) {
+        splitDeps.insert(globalVal);
+        return true;
+      }
+
+      return false;
+    };
+
+    auto &set = setsToProcess.emplace_back();
+    for (auto &[globalVal, info] : valueInfos) {
+      if (shouldSplit(globalVal, info))
+        set.insert({globalVal, info.gvIdx});
+    }
+    if (set.empty())
+      setsToProcess.pop_back();
+
+    // Record the split values.
+    splitValues.insert(splitDeps.begin(), splitDeps.end());
+  };
+
+  [[maybe_unused]] int64_t count = 0;
+  SmallVector<const llvm::GlobalValue *> toSplit;
+  unsigned unnamedGlobal = numFunctionBase;
+  for (auto &global : mainModule->globals()) {
+    if (global.hasInternalLinkage() || global.hasPrivateLinkage()) {
+      if (!global.hasName()) {
+        // Give unnamed GlobalVariable a unique name so that MCLink will not get
+        // confused to name them while generating linked code since the IR
+        // values can be different in each splits (for X86 backend.)
+        // asan build inserts these unnamed GlobalVariables.
+        global.setName("__mojo_unnamed" + Twine(unnamedGlobal++));
+      }
+
+      symbolLinkageTypes.insert({global.getName().str(), global.getLinkage()});
+      global.setLinkage(llvm::GlobalValue::WeakAnyLinkage);
+      continue;
+    }
+
+    if (global.hasExternalLinkage())
+      continue;
+
+    // TODO: Add special handling for `llvm.global_ctors` and
+    // `llvm.global_dtors`, because otherwise they end up tying almost all
+    // symbols into the same split.
+    LLVM_DEBUG(llvm::dbgs()
+                   << (count++) << ": split global: " << global << "\n";);
+    toSplit.emplace_back(&global);
+  }
+
+  for (auto &fn : mainModule->functions()) {
+    if (fn.isDeclaration())
+      continue;
+
+    ValueInfo &info = valueInfos[&fn];
+    if (fn.hasInternalLinkage() || fn.hasPrivateLinkage()) {
+      // Avoid renaming when linking in MCLink.
+      symbolLinkageTypes.insert({fn.getName().str(), fn.getLinkage()});
+      fn.setLinkage(llvm::Function::LinkageTypes::WeakAnyLinkage);
+    }
+
+    if (info.canBeSplit || info.userEmpty) {
+      LLVM_DEBUG(llvm::dbgs()
+                     << (count++) << ": split fn: " << fn.getName() << "\n";);
+      toSplit.emplace_back(&fn);
+    }
+  }
+
+  // Run this now since we just changed the linkages.
+  for (const llvm::GlobalValue *value : toSplit)
+    splitValue(value);
+
+  if (setsToProcess.size() <= 1)
+    return processFn(forwardModule(std::move(mainModule)), std::nullopt,
+                     numFunctionBase);
+
+  auto duplicatedFns = std::move(mainModule.duplicatedFns);
+
+  // Prepare to materialize slices of the module by first writing the main
+  // module as bitcode to a shared buffer.
+  auto buf = WriteableBuffer::get();
+  {
+    CompilerTimeTraceScope traceScope("writeMainModuleBitcode");
+    llvm::Module &module = strtab->externalizeStrings(std::move(mainModule));
+    llvm::WriteBitcodeToFile(module, *buf);
+  }
+
+  unsigned numFunctions = numFunctionBase;
+  llvm::StringSet<> seenFns;
+  for (auto [idx, set] : llvm::enumerate(setsToProcess)) {
+    // Giving each function a unique ID across all splits for proper MC level
+    // linking and codegen into one object file where duplicated functions
+    // in each split will be deduplicated (with the linking).
+    llvm::StringSet<> currDuplicatedFns = duplicatedFns;
+    checkDuplicates(set, seenFns, currDuplicatedFns);
+
+    unsigned next = numFunctions + set.size();
+    auto makeModule = [set = std::move(set), buf = BufferRef(buf.copy()),
+                       strtab = strtab.copy(), currDuplicatedFns]() mutable {
+      return readAndMaterializeDependencies(std::move(buf), set, *strtab,
+                                            currDuplicatedFns);
+    };
+    processFn(std::move(makeModule), idx, numFunctions);
+    numFunctions = next;
+  }
+}
+
+/// Collect all of the immediate global value users of `value`.
+void LLVMModulePerFunctionSplitterImpl::collectValueUsers(
+    const llvm::GlobalValue *value) {
+  SmallVector<const llvm::User *> worklist(value->users());
+
+  while (!worklist.empty()) {
+    const llvm::User *userIt = worklist.pop_back_val();
+
+    // Recurse into pure constant users.
+    if (isa<llvm::Constant>(userIt) && !isa<llvm::GlobalValue>(userIt)) {
+      worklist.append(userIt->user_begin(), userIt->user_end());
+      continue;
+    }
+
+    if (const auto *inst = dyn_cast<llvm::Instruction>(userIt)) {
+      const llvm::Function *func = inst->getParent()->getParent();
+      valueInfos[value].users.insert(func);
+      valueInfos[func];
+    } else if (const auto *globalVal = dyn_cast<llvm::GlobalValue>(userIt)) {
+      valueInfos[value].users.insert(globalVal);
+      valueInfos[globalVal];
+    } else {
+      llvm_unreachable("unexpected user of global value");
+    }
+  }
+
+  // If the current value is a mutable global variable, then it can't be
+  // split.
+  if (auto *global = dyn_cast<llvm::GlobalVariable>(value))
+    valueInfos[value].canBeSplit = global->isConstant();
+}
+
+/// Propagate use information through the module.
+void LLVMModulePerFunctionSplitterImpl::propagateUseInfo() {
+  std::vector<ValueInfo *> worklist;
+
+  // Each value depends on itself. Seed the iteration with that.
+  for (auto &[value, info] : valueInfos) {
+    if (auto func = llvm::dyn_cast<llvm::Function>(value)) {
+      if (func->isDeclaration())
+        continue;
+    }
+
+    info.dependencies.insert(value);
+    info.value = value;
+    worklist.push_back(&info);
+    if (!info.canBeSplit) {
+      // If a value cannot be split, its users are also its dependencies.
+      llvm::set_union(info.dependencies, info.users);
+    }
+  }
+
+  while (!worklist.empty()) {
+    ValueInfo *info = worklist.back();
+    worklist.pop_back();
+
+    // Propagate the dependencies of this value to its users.
+    for (const llvm::GlobalValue *user : info->users) {
+      ValueInfo &userInfo = valueInfos.find(user)->second;
+      if (info == &userInfo)
+        continue;
+      bool changed = false;
+
+      // Merge dependency to user if current value is not a function that will
+      // be split into a separate module.
+      bool mergeToUserDep = true;
+      if (llvm::isa_and_nonnull<llvm::Function>(info->value)) {
+        mergeToUserDep = !info->canBeSplit;
+      }
+
+      // If there is a change, add the user info to the worklist.
+      if (mergeToUserDep) {
+        if (llvm::set_union(userInfo.dependencies, info->dependencies))
+          changed = true;
+      }
+
+      // If the value cannot be split, its users cannot be split either.
+      if (!info->canBeSplit && userInfo.canBeSplit) {
+        userInfo.canBeSplit = false;
+        changed = true;
+        // If a value cannot be split, its users are also its dependencies.
+        llvm::set_union(userInfo.dependencies, userInfo.users);
+      }
+
+      if (changed) {
+        userInfo.value = user;
+        worklist.push_back(&userInfo);
+      }
+    }
+
+    if (info->canBeSplit || isa_and_nonnull<llvm::GlobalValue>(info->value))
+      continue;
+
+    // If a value cannot be split, propagate its dependencies up to its
+    // dependencies.
+    for (const llvm::GlobalValue *dep : info->dependencies) {
+      ValueInfo &depInfo = valueInfos.find(dep)->second;
+      if (info == &depInfo)
+        continue;
+      if (llvm::set_union(depInfo.dependencies, info->dependencies)) {
+        depInfo.value = dep;
+        worklist.push_back(&depInfo);
+      }
+    }
+  }
+
+  for (auto &[value, info] : valueInfos) {
+    info.userEmpty = info.users.empty() ||
+                     (info.users.size() == 1 && info.users.contains(value));
+  }
+}



More information about the llvm-commits mailing list