[llvm] [Offload] Cache symbols in program (PR #148209)

Ross Brunton via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 15 02:10:57 PDT 2025


https://github.com/RossBrunton updated https://github.com/llvm/llvm-project/pull/148209

>From 88c9a1897a524dda957c4cd171fcda2ab55adeea Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Fri, 11 Jul 2025 11:52:22 +0100
Subject: [PATCH 1/3] [Offload] Cache symbols in program

When creating a new symbol, check that it already exists. If it does,
return that pointer rather than building a new symbol structure.
---
 offload/liboffload/src/OffloadImpl.cpp        | 34 +++++++++++++------
 .../OffloadAPI/symbol/olGetSymbol.cpp         | 18 ++++++++++
 2 files changed, 42 insertions(+), 10 deletions(-)

diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index af07a6786cfea..40ad1dd7ff617 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -85,16 +85,18 @@ struct ol_program_impl_t {
   plugin::DeviceImageTy *Image;
   std::unique_ptr<llvm::MemoryBuffer> ImageData;
   std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
+  std::mutex SymbolListMutex;
   __tgt_device_image DeviceImage;
 };
 
 struct ol_symbol_impl_t {
-  ol_symbol_impl_t(GenericKernelTy *Kernel)
-      : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
-  ol_symbol_impl_t(GlobalTy &&Global)
-      : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {}
+  ol_symbol_impl_t(const char *Name, GenericKernelTy *Kernel)
+      : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL), Name(Name) {}
+  ol_symbol_impl_t(const char *Name, GlobalTy &&Global)
+      : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
   std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
   ol_symbol_kind_t Kind;
+  const char *Name;
 };
 
 namespace llvm {
@@ -714,6 +716,18 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
                        ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
   auto &Device = Program->Image->getDevice();
 
+  std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
+
+  // If it already exists, return an existing handle
+  auto Check = std::find_if(
+      Program->Symbols.begin(), Program->Symbols.end(), [&](auto &Sym) {
+        return Sym->Kind == Kind && !std::strcmp(Sym->Name, Name);
+      });
+  if (Check != Program->Symbols.end()) {
+    *Symbol = Check->get();
+    return Error::success();
+  }
+
   switch (Kind) {
   case OL_SYMBOL_KIND_KERNEL: {
     auto KernelImpl = Device.constructKernel(Name);
@@ -723,10 +737,10 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
     if (auto Err = KernelImpl->init(Device, *Program->Image))
       return Err;
 
-    *Symbol =
-        Program->Symbols
-            .emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
-            .get();
+    *Symbol = Program->Symbols
+                  .emplace_back(std::make_unique<ol_symbol_impl_t>(
+                      KernelImpl->getName(), &*KernelImpl))
+                  .get();
     return Error::success();
   }
   case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
@@ -736,8 +750,8 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
       return Res;
 
     *Symbol = Program->Symbols
-                  .emplace_back(
-                      std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj)))
+                  .emplace_back(std::make_unique<ol_symbol_impl_t>(
+                      GlobalObj.getName().c_str(), std::move(GlobalObj)))
                   .get();
 
     return Error::success();
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp
index 5e87ab5b29621..1f496b9c6e1ae 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp
@@ -41,6 +41,14 @@ TEST_P(olGetSymbolKernelTest, Success) {
   ASSERT_NE(Kernel, nullptr);
 }
 
+TEST_P(olGetSymbolKernelTest, SuccessSamePtr) {
+  ol_symbol_handle_t KernelA = nullptr;
+  ol_symbol_handle_t KernelB = nullptr;
+  ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelA));
+  ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &KernelB));
+  ASSERT_EQ(KernelA, KernelB);
+}
+
 TEST_P(olGetSymbolKernelTest, InvalidNullProgram) {
   ol_symbol_handle_t Kernel = nullptr;
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
@@ -72,6 +80,16 @@ TEST_P(olGetSymbolGlobalTest, Success) {
   ASSERT_NE(Global, nullptr);
 }
 
+TEST_P(olGetSymbolGlobalTest, SuccessSamePtr) {
+  ol_symbol_handle_t GlobalA = nullptr;
+  ol_symbol_handle_t GlobalB = nullptr;
+  ASSERT_SUCCESS(
+      olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalA));
+  ASSERT_SUCCESS(
+      olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &GlobalB));
+  ASSERT_EQ(GlobalA, GlobalB);
+}
+
 TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) {
   ol_symbol_handle_t Global = nullptr;
   ASSERT_ERROR(

>From 9529b332306dc9276594c4bd97f03ffc927b01bb Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Mon, 14 Jul 2025 10:02:17 +0100
Subject: [PATCH 2/3] Some type shuffling

---
 offload/liboffload/src/OffloadImpl.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 40ad1dd7ff617..3a376e2d6aebb 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -84,7 +84,7 @@ struct ol_program_impl_t {
         DeviceImage(DeviceImage) {}
   plugin::DeviceImageTy *Image;
   std::unique_ptr<llvm::MemoryBuffer> ImageData;
-  std::vector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
+  llvm::SmallVector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
   std::mutex SymbolListMutex;
   __tgt_device_image DeviceImage;
 };
@@ -96,7 +96,7 @@ struct ol_symbol_impl_t {
       : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE), Name(Name) {}
   std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
   ol_symbol_kind_t Kind;
-  const char *Name;
+  llvm::StringRef Name;
 };
 
 namespace llvm {
@@ -719,10 +719,9 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
   std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
 
   // If it already exists, return an existing handle
-  auto Check = std::find_if(
-      Program->Symbols.begin(), Program->Symbols.end(), [&](auto &Sym) {
-        return Sym->Kind == Kind && !std::strcmp(Sym->Name, Name);
-      });
+  auto Check = llvm::find_if(Program->Symbols, [&](auto &Sym) {
+    return Sym->Kind == Kind && Sym->Name == Name;
+  });
   if (Check != Program->Symbols.end()) {
     *Symbol = Check->get();
     return Error::success();

>From 0fd6a69a42be217e86763409a8855571954153f7 Mon Sep 17 00:00:00 2001
From: Ross Brunton <ross at codeplay.com>
Date: Tue, 15 Jul 2025 10:10:33 +0100
Subject: [PATCH 3/3] Use (two) maps instead of vector

---
 offload/liboffload/src/OffloadImpl.cpp | 41 +++++++++++++++++---------
 1 file changed, 27 insertions(+), 14 deletions(-)

diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 3a376e2d6aebb..5466ee4e8b79f 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -84,9 +84,10 @@ struct ol_program_impl_t {
         DeviceImage(DeviceImage) {}
   plugin::DeviceImageTy *Image;
   std::unique_ptr<llvm::MemoryBuffer> ImageData;
-  llvm::SmallVector<std::unique_ptr<ol_symbol_impl_t>> Symbols;
   std::mutex SymbolListMutex;
   __tgt_device_image DeviceImage;
+  llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> KernelSymbols;
+  llvm::StringMap<std::unique_ptr<ol_symbol_impl_t>> GlobalSymbols;
 };
 
 struct ol_symbol_impl_t {
@@ -719,16 +720,20 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
   std::lock_guard<std::mutex> Lock{Program->SymbolListMutex};
 
   // If it already exists, return an existing handle
-  auto Check = llvm::find_if(Program->Symbols, [&](auto &Sym) {
-    return Sym->Kind == Kind && Sym->Name == Name;
-  });
-  if (Check != Program->Symbols.end()) {
-    *Symbol = Check->get();
-    return Error::success();
-  }
+  auto CheckCache = [&](StringMap<std::unique_ptr<ol_symbol_impl_t>> &Map)
+      -> std::optional<ol_symbol_handle_t> {
+    if (Map.contains(Name))
+      return Map[Name].get();
+    return std::nullopt;
+  };
 
   switch (Kind) {
   case OL_SYMBOL_KIND_KERNEL: {
+    if (auto Cache = CheckCache(Program->KernelSymbols)) {
+      *Symbol = *Cache;
+      return Plugin::success();
+    }
+
     auto KernelImpl = Device.constructKernel(Name);
     if (!KernelImpl)
       return KernelImpl.takeError();
@@ -736,21 +741,29 @@ Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
     if (auto Err = KernelImpl->init(Device, *Program->Image))
       return Err;
 
-    *Symbol = Program->Symbols
-                  .emplace_back(std::make_unique<ol_symbol_impl_t>(
-                      KernelImpl->getName(), &*KernelImpl))
+    *Symbol = Program->KernelSymbols
+                  .insert({Name, std::make_unique<ol_symbol_impl_t>(
+                                     KernelImpl->getName(), &*KernelImpl)})
+                  .first->getValue()
                   .get();
     return Error::success();
   }
   case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
+    if (auto Cache = CheckCache(Program->GlobalSymbols)) {
+      *Symbol = *Cache;
+      return Plugin::success();
+    }
+
     GlobalTy GlobalObj{Name};
     if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
             Device, *Program->Image, GlobalObj))
       return Res;
 
-    *Symbol = Program->Symbols
-                  .emplace_back(std::make_unique<ol_symbol_impl_t>(
-                      GlobalObj.getName().c_str(), std::move(GlobalObj)))
+    *Symbol = Program->GlobalSymbols
+                  .insert({Name, std::make_unique<ol_symbol_impl_t>(
+                                     GlobalObj.getName().c_str(),
+                                     std::move(GlobalObj))})
+                  .first->getValue()
                   .get();
 
     return Error::success();



More information about the llvm-commits mailing list