[llvm] 56c72c7 - [ORC] Add a public unsafe-operations helper for SymbolStringPtr.

Lang Hames via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 27 09:49:05 PST 2023


Author: Lang Hames
Date: 2023-11-27T09:48:56-08:00
New Revision: 56c72c7f339cd6ff780a0d6aa1a0ac8bfc1487aa

URL: https://github.com/llvm/llvm-project/commit/56c72c7f339cd6ff780a0d6aa1a0ac8bfc1487aa
DIFF: https://github.com/llvm/llvm-project/commit/56c72c7f339cd6ff780a0d6aa1a0ac8bfc1487aa.diff

LOG: [ORC] Add a public unsafe-operations helper for SymbolStringPtr.

SymbolStringPoolEntryUnsafe provides unsafe access to SymbolStringPtr objects,
allowing clients to manually retain and release pool entries, or consume or
create SymbolStringPtr instances without affecting an entry's ref-count. This
can be useful when writing C APIs that need to handle SymbolStringPtrs.

As part of this patch the LLVM-C API implementation is updated to use the new
utility, rather than the old, private OrcV2CAPIHelper utility.

Added: 
    

Modified: 
    llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h
    llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
    llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h
index 497e29da98bd59d..f47956a65f2e784 100644
--- a/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h
+++ b/llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h
@@ -32,6 +32,7 @@ class NonOwningSymbolStringPtr;
 class SymbolStringPool {
   friend class SymbolStringPoolTest;
   friend class SymbolStringPtrBase;
+  friend class SymbolStringPoolEntryUnsafe;
 
   // Implemented in DebugUtils.h.
   friend raw_ostream &operator<<(raw_ostream &OS, const SymbolStringPool &SSP);
@@ -134,8 +135,8 @@ class SymbolStringPtrBase {
 
 /// Pointer to a pooled string representing a symbol name.
 class SymbolStringPtr : public SymbolStringPtrBase {
-  friend class OrcV2CAPIHelper;
   friend class SymbolStringPool;
+  friend class SymbolStringPoolEntryUnsafe;
   friend struct DenseMapInfo<SymbolStringPtr>;
 
 public:
@@ -189,6 +190,47 @@ class SymbolStringPtr : public SymbolStringPtrBase {
   }
 };
 
+/// Provides unsafe access to ownership operations on SymbolStringPtr.
+/// This class can be used to manage SymbolStringPtr instances from C.
+class SymbolStringPoolEntryUnsafe {
+public:
+  using PoolEntry = SymbolStringPool::PoolMapEntry;
+
+  SymbolStringPoolEntryUnsafe(PoolEntry *E) : E(E) {}
+
+  /// Create an unsafe pool entry ref without changing the ref-count.
+  static SymbolStringPoolEntryUnsafe from(const SymbolStringPtr &S) {
+    return S.S;
+  }
+
+  /// Consumes the given SymbolStringPtr without releasing the pool entry.
+  static SymbolStringPoolEntryUnsafe take(SymbolStringPtr &&S) {
+    PoolEntry *E = nullptr;
+    std::swap(E, S.S);
+    return E;
+  }
+
+  PoolEntry *rawPtr() { return E; }
+
+  /// Creates a SymbolStringPtr for this entry, with the SymbolStringPtr
+  /// retaining the entry as usual.
+  SymbolStringPtr copyToSymbolStringPtr() { return SymbolStringPtr(E); }
+
+  /// Creates a SymbolStringPtr for this entry *without* performing a retain
+  /// operation during construction.
+  SymbolStringPtr moveToSymbolStringPtr() {
+    SymbolStringPtr S;
+    std::swap(S.S, E);
+    return S;
+  }
+
+  void retain() { ++E->getValue(); }
+  void release() { --E->getValue(); }
+
+private:
+  PoolEntry *E = nullptr;
+};
+
 /// Non-owning SymbolStringPool entry pointer. Instances are comparable with
 /// SymbolStringPtr instances and guaranteed to have the same hash, but do not
 /// affect the ref-count of the pooled string (and are therefore cheaper to

diff  --git a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
index a73aec6d98c64c9..72314cceedf33f6 100644
--- a/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
+++ b/llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp
@@ -27,42 +27,6 @@ class InProgressLookupState;
 
 class OrcV2CAPIHelper {
 public:
-  using PoolEntry = SymbolStringPtr::PoolEntry;
-  using PoolEntryPtr = SymbolStringPtr::PoolEntryPtr;
-
-  // Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
-  static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S) {
-    PoolEntryPtr Result = nullptr;
-    std::swap(Result, S.S);
-    return Result;
-  }
-
-  // Move from a PoolEntryPtr to a SymbolStringPtr (no change in ref count).
-  static SymbolStringPtr moveToSymbolStringPtr(PoolEntryPtr P) {
-    SymbolStringPtr S;
-    S.S = P;
-    return S;
-  }
-
-  // Copy a pool entry to a SymbolStringPtr (increments ref count).
-  static SymbolStringPtr copyToSymbolStringPtr(PoolEntryPtr P) {
-    return SymbolStringPtr(P);
-  }
-
-  static PoolEntryPtr getRawPoolEntryPtr(const SymbolStringPtr &S) {
-    return S.S;
-  }
-
-  static void retainPoolEntry(PoolEntryPtr P) {
-    SymbolStringPtr S(P);
-    S.S = nullptr;
-  }
-
-  static void releasePoolEntry(PoolEntryPtr P) {
-    SymbolStringPtr S;
-    S.S = P;
-  }
-
   static InProgressLookupState *extractLookupState(LookupState &LS) {
     return LS.IPLS.release();
   }
@@ -75,10 +39,16 @@ class OrcV2CAPIHelper {
 } // namespace orc
 } // namespace llvm
 
+inline LLVMOrcSymbolStringPoolEntryRef wrap(SymbolStringPoolEntryUnsafe E) {
+  return reinterpret_cast<LLVMOrcSymbolStringPoolEntryRef>(E.rawPtr());
+}
+
+inline SymbolStringPoolEntryUnsafe unwrap(LLVMOrcSymbolStringPoolEntryRef E) {
+  return reinterpret_cast<SymbolStringPoolEntryUnsafe::PoolEntry *>(E);
+}
+
 DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionSession, LLVMOrcExecutionSessionRef)
 DEFINE_SIMPLE_CONVERSION_FUNCTIONS(SymbolStringPool, LLVMOrcSymbolStringPoolRef)
-DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcV2CAPIHelper::PoolEntry,
-                                   LLVMOrcSymbolStringPoolEntryRef)
 DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationUnit,
                                    LLVMOrcMaterializationUnitRef)
 DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationResponsibility,
@@ -136,7 +106,7 @@ class OrcCAPIMaterializationUnit : public llvm::orc::MaterializationUnit {
 
 private:
   void discard(const JITDylib &JD, const SymbolStringPtr &Name) override {
-    Discard(Ctx, wrap(&JD), wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
+    Discard(Ctx, wrap(&JD), wrap(SymbolStringPoolEntryUnsafe::from(Name)));
   }
 
   std::string Name;
@@ -184,7 +154,7 @@ static SymbolMap toSymbolMap(LLVMOrcCSymbolMapPairs Syms, size_t NumPairs) {
   SymbolMap SM;
   for (size_t I = 0; I != NumPairs; ++I) {
     JITSymbolFlags Flags = toJITSymbolFlags(Syms[I].Sym.Flags);
-    SM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] = {
+    SM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] = {
         ExecutorAddr(Syms[I].Sym.Address), Flags};
   }
   return SM;
@@ -199,7 +169,7 @@ toSymbolDependenceMap(LLVMOrcCDependenceMapPairs Pairs, size_t NumPairs) {
 
     for (size_t J = 0; J != Pairs[I].Names.Length; ++J) {
       auto Sym = Pairs[I].Names.Symbols[J];
-      Names.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Sym)));
+      Names.insert(unwrap(Sym).moveToSymbolStringPtr());
     }
     SDM[JD] = Names;
   }
@@ -309,7 +279,7 @@ class CAPIDefinitionGenerator final : public DefinitionGenerator {
     CLookupSet.reserve(LookupSet.size());
     for (auto &KV : LookupSet) {
       LLVMOrcSymbolStringPoolEntryRef Name =
-          ::wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first));
+          ::wrap(SymbolStringPoolEntryUnsafe::from(KV.first));
       LLVMOrcSymbolLookupFlags SLF = fromSymbolLookupFlags(KV.second);
       CLookupSet.push_back({Name, SLF});
     }
@@ -353,8 +323,7 @@ void LLVMOrcSymbolStringPoolClearDeadEntries(LLVMOrcSymbolStringPoolRef SSP) {
 
 LLVMOrcSymbolStringPoolEntryRef
 LLVMOrcExecutionSessionIntern(LLVMOrcExecutionSessionRef ES, const char *Name) {
-  return wrap(
-      OrcV2CAPIHelper::moveFromSymbolStringPtr(unwrap(ES)->intern(Name)));
+  return wrap(SymbolStringPoolEntryUnsafe::take(unwrap(ES)->intern(Name)));
 }
 
 void LLVMOrcExecutionSessionLookup(
@@ -374,7 +343,7 @@ void LLVMOrcExecutionSessionLookup(
 
   SymbolLookupSet SLS;
   for (size_t I = 0; I != SymbolsSize; ++I)
-    SLS.add(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I].Name)),
+    SLS.add(unwrap(Symbols[I].Name).moveToSymbolStringPtr(),
             toSymbolLookupFlags(Symbols[I].LookupFlags));
 
   unwrap(ES)->lookup(
@@ -384,7 +353,7 @@ void LLVMOrcExecutionSessionLookup(
           SmallVector<LLVMOrcCSymbolMapPair> CResult;
           for (auto &KV : *Result)
             CResult.push_back(LLVMOrcCSymbolMapPair{
-                wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first)),
+                wrap(SymbolStringPoolEntryUnsafe::from(KV.first)),
                 fromExecutorSymbolDef(KV.second)});
           HandleResult(LLVMErrorSuccess, CResult.data(), CResult.size(), Ctx);
         } else
@@ -394,15 +363,15 @@ void LLVMOrcExecutionSessionLookup(
 }
 
 void LLVMOrcRetainSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) {
-  OrcV2CAPIHelper::retainPoolEntry(unwrap(S));
+  unwrap(S).retain();
 }
 
 void LLVMOrcReleaseSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) {
-  OrcV2CAPIHelper::releasePoolEntry(unwrap(S));
+  unwrap(S).release();
 }
 
 const char *LLVMOrcSymbolStringPoolEntryStr(LLVMOrcSymbolStringPoolEntryRef S) {
-  return unwrap(S)->getKey().data();
+  return unwrap(S).rawPtr()->getKey().data();
 }
 
 LLVMOrcResourceTrackerRef
@@ -452,10 +421,10 @@ LLVMOrcMaterializationUnitRef LLVMOrcCreateCustomMaterializationUnit(
     LLVMOrcMaterializationUnitDestroyFunction Destroy) {
   SymbolFlagsMap SFM;
   for (size_t I = 0; I != NumSyms; ++I)
-    SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] =
+    SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] =
         toJITSymbolFlags(Syms[I].Flags);
 
-  auto IS = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(InitSym));
+  auto IS = unwrap(InitSym).moveToSymbolStringPtr();
 
   return wrap(new OrcCAPIMaterializationUnit(
       Name, std::move(SFM), std::move(IS), Ctx, Materialize, Discard, Destroy));
@@ -476,9 +445,8 @@ LLVMOrcMaterializationUnitRef LLVMOrcLazyReexports(
   for (size_t I = 0; I != NumPairs; ++I) {
     auto pair = CallableAliases[I];
     JITSymbolFlags Flags = toJITSymbolFlags(pair.Entry.Flags);
-    SymbolStringPtr Name =
-        OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Entry.Name));
-    SAM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Name))] =
+    SymbolStringPtr Name = unwrap(pair.Entry.Name).moveToSymbolStringPtr();
+    SAM[unwrap(pair.Name).moveToSymbolStringPtr()] =
         SymbolAliasMapEntry(Name, Flags);
   }
 
@@ -511,7 +479,7 @@ LLVMOrcCSymbolFlagsMapPairs LLVMOrcMaterializationResponsibilityGetSymbols(
       safe_malloc(Symbols.size() * sizeof(LLVMOrcCSymbolFlagsMapPair)));
   size_t I = 0;
   for (auto const &pair : Symbols) {
-    auto Name = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(pair.first));
+    auto Name = wrap(SymbolStringPoolEntryUnsafe::from(pair.first));
     auto Flags = pair.second;
     Result[I] = {Name, fromJITSymbolFlags(Flags)};
     I++;
@@ -528,7 +496,7 @@ LLVMOrcSymbolStringPoolEntryRef
 LLVMOrcMaterializationResponsibilityGetInitializerSymbol(
     LLVMOrcMaterializationResponsibilityRef MR) {
   auto Sym = unwrap(MR)->getInitializerSymbol();
-  return wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Sym));
+  return wrap(SymbolStringPoolEntryUnsafe::from(Sym));
 }
 
 LLVMOrcSymbolStringPoolEntryRef *
@@ -541,7 +509,7 @@ LLVMOrcMaterializationResponsibilityGetRequestedSymbols(
           Symbols.size() * sizeof(LLVMOrcSymbolStringPoolEntryRef)));
   size_t I = 0;
   for (auto &Name : Symbols) {
-    Result[I] = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name));
+    Result[I] = wrap(SymbolStringPoolEntryUnsafe::from(Name));
     I++;
   }
   *NumSymbols = Symbols.size();
@@ -569,7 +537,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDefineMaterializing(
     LLVMOrcCSymbolFlagsMapPairs Syms, size_t NumSyms) {
   SymbolFlagsMap SFM;
   for (size_t I = 0; I != NumSyms; ++I)
-    SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] =
+    SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] =
         toJITSymbolFlags(Syms[I].Flags);
 
   return wrap(unwrap(MR)->defineMaterializing(std::move(SFM)));
@@ -588,7 +556,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDelegate(
     LLVMOrcMaterializationResponsibilityRef *Result) {
   SymbolNameSet Syms;
   for (size_t I = 0; I != NumSymbols; I++) {
-    Syms.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I])));
+    Syms.insert(unwrap(Symbols[I]).moveToSymbolStringPtr());
   }
   auto OtherMR = unwrap(MR)->delegate(Syms);
 
@@ -605,7 +573,7 @@ void LLVMOrcMaterializationResponsibilityAddDependencies(
     LLVMOrcCDependenceMapPairs Dependencies, size_t NumPairs) {
 
   SymbolDependenceMap SDM = toSymbolDependenceMap(Dependencies, NumPairs);
-  auto Sym = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Name));
+  auto Sym = unwrap(Name).moveToSymbolStringPtr();
   unwrap(MR)->addDependencies(Sym, SDM);
 }
 
@@ -698,7 +666,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
   DynamicLibrarySearchGenerator::SymbolPredicate Pred;
   if (Filter)
     Pred = [=](const SymbolStringPtr &Name) -> bool {
-      return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
+      return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name)));
     };
 
   auto ProcessSymsGenerator =
@@ -724,7 +692,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForPath(
   DynamicLibrarySearchGenerator::SymbolPredicate Pred;
   if (Filter)
     Pred = [=](const SymbolStringPtr &Name) -> bool {
-      return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
+      return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name)));
     };
 
   auto LibrarySymsGenerator =
@@ -992,7 +960,7 @@ char LLVMOrcLLJITGetGlobalPrefix(LLVMOrcLLJITRef J) {
 
 LLVMOrcSymbolStringPoolEntryRef
 LLVMOrcLLJITMangleAndIntern(LLVMOrcLLJITRef J, const char *UnmangledName) {
-  return wrap(OrcV2CAPIHelper::moveFromSymbolStringPtr(
+  return wrap(SymbolStringPoolEntryUnsafe::take(
       unwrap(J)->mangleAndIntern(UnmangledName)));
 }
 

diff  --git a/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp b/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp
index fc864ab1131b2a5..cd1cecd3244d638 100644
--- a/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp
+++ b/llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp
@@ -142,4 +142,42 @@ TEST_F(SymbolStringPoolTest, NonOwningPointerRefCounts) {
         << "Copy-assignment of NonOwningSymbolStringPtr changed ref-count";
   }
 }
+
+TEST_F(SymbolStringPoolTest, SymbolStringPoolEntryUnsafe) {
+
+  auto A = SP.intern("a");
+  EXPECT_EQ(getRefCount(A), 1U);
+
+  {
+    // Try creating an unsafe pool entry ref from the given SymbolStringPtr.
+    // This should not affect the ref-count.
+    auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A);
+    EXPECT_EQ(getRefCount(A), 1U);
+
+    // Create a new SymbolStringPtr from the unsafe ref. This should increment
+    // the ref-count.
+    auto ACopy = AUnsafe.copyToSymbolStringPtr();
+    EXPECT_EQ(getRefCount(A), 2U);
+  }
+
+  {
+    // Create a copy of the original string. Move it into an unsafe ref, and
+    // then move it back. None of these operations should affect the ref-count.
+    auto ACopy = A;
+    EXPECT_EQ(getRefCount(A), 2U);
+    auto AUnsafe = SymbolStringPoolEntryUnsafe::take(std::move(ACopy));
+    EXPECT_EQ(getRefCount(A), 2U);
+    ACopy = AUnsafe.moveToSymbolStringPtr();
+    EXPECT_EQ(getRefCount(A), 2U);
+  }
+
+  // Test manual retain / release.
+  auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A);
+  EXPECT_EQ(getRefCount(A), 1U);
+  AUnsafe.retain();
+  EXPECT_EQ(getRefCount(A), 2U);
+  AUnsafe.release();
+  EXPECT_EQ(getRefCount(A), 1U);
+}
+
 } // namespace


        


More information about the llvm-commits mailing list