[llvm] [offload][SYCL] Add SYCL Module splitting (PR #119713)
Joseph Huber via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 10 09:15:54 PDT 2025
================
@@ -0,0 +1,511 @@
+//===-------- SYCLSplitModule.cpp - Split a module into call graphs -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// See comments in the header.
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/SYCLSplitModule.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Bitcode/BitcodeWriterPass.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PassManagerImpl.h"
+#include "llvm/IRPrinter/IRPrintingPasses.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Transforms/IPO/GlobalDCE.h"
+#include "llvm/Transforms/IPO/StripDeadPrototypes.h"
+#include "llvm/Transforms/IPO/StripSymbols.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/SYCLUtils.h"
+
+#include <map>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "sycl-split-module"
+
+static bool isKernel(const Function &F) {
+ return F.getCallingConv() == CallingConv::SPIR_KERNEL ||
+ F.getCallingConv() == CallingConv::AMDGPU_KERNEL;
+}
+
+static bool isEntryPoint(const Function &F) {
+ // Skip declarations, if any: they should not be included into a vector of
+ // entry points groups or otherwise we will end up with incorrectly generated
+ // list of symbols.
+ if (F.isDeclaration())
+ return false;
+
+ // Kernels are always considered to be entry points
+ return isKernel(F);
+}
+
+namespace {
+
+// A vector that contains all entry point functions in a split module.
+using EntryPointSet = SetVector<const Function *>;
+
+/// Represents a named group entry points.
+struct EntryPointGroup {
+ std::string GroupName;
+ EntryPointSet Functions;
+
+ EntryPointGroup() = default;
+ EntryPointGroup(const EntryPointGroup &) = default;
+ EntryPointGroup &operator=(const EntryPointGroup &) = default;
+ EntryPointGroup(EntryPointGroup &&) = default;
+ EntryPointGroup &operator=(EntryPointGroup &&) = default;
+
+ EntryPointGroup(StringRef GroupName,
+ EntryPointSet Functions = EntryPointSet())
+ : GroupName(GroupName), Functions(std::move(Functions)) {}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ LLVM_DUMP_METHOD void dump() const {
+ constexpr size_t INDENT = 4;
+ dbgs().indent(INDENT) << "ENTRY POINTS"
+ << " " << GroupName << " {\n";
+ for (const Function *F : Functions)
+ dbgs().indent(INDENT) << " " << F->getName() << "\n";
+
+ dbgs().indent(INDENT) << "}\n";
+ }
+#endif
+};
+
+/// Annotates an llvm::Module with information necessary to perform and track
+/// the result of device code (llvm::Module instances) splitting:
+/// - entry points group from the module.
+class ModuleDesc {
+ std::unique_ptr<Module> M;
+ EntryPointGroup EntryPoints;
+
+public:
+ ModuleDesc() = delete;
+ ModuleDesc(const ModuleDesc &) = delete;
+ ModuleDesc &operator=(const ModuleDesc &) = delete;
+ ModuleDesc(ModuleDesc &&) = default;
+ ModuleDesc &operator=(ModuleDesc &&) = default;
+
+ ModuleDesc(std::unique_ptr<Module> M,
+ EntryPointGroup EntryPoints = EntryPointGroup())
+ : M(std::move(M)), EntryPoints(std::move(EntryPoints)) {
+ assert(this->M && "Module should be non-null");
+ }
+
+ const EntryPointSet &entries() const { return EntryPoints.Functions; }
+ const EntryPointGroup &getEntryPointGroup() const { return EntryPoints; }
+ EntryPointSet &entries() { return EntryPoints.Functions; }
+ Module &getModule() { return *M; }
+ const Module &getModule() const { return *M; }
+
+ // Cleans up module IR - removes dead globals, debug info etc.
+ void cleanup() {
+ ModuleAnalysisManager MAM;
+ MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+ ModulePassManager MPM;
+ MPM.addPass(GlobalDCEPass()); // Delete unreachable globals.
+ MPM.addPass(StripDeadDebugInfoPass()); // Remove dead debug info.
+ MPM.addPass(StripDeadPrototypesPass()); // Remove dead func decls.
+ MPM.run(*M, MAM);
+ }
+
+ std::string makeSymbolTable() const {
+ SmallString<128> ST;
----------------
jhuber6 wrote:
I would suggest
```
SmallString<0> Data;
raw_svector_ostream OS(Data);
OS << F->getName() << "\n";
```
And just return the `SmallString<0>` unless you really need the `std::string`.
https://github.com/llvm/llvm-project/pull/119713
More information about the llvm-commits
mailing list