[llvm] Fix Arm64EC name mangling algorithm (PR #115567)

Daniel Paoliello via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 8 15:49:02 PST 2024


https://github.com/dpaoliello created https://github.com/llvm/llvm-project/pull/115567

Arm64EC uses a special name mangling mode that adds `$$h` between the symbol name and its type. In MSVC's name mangling `@` is used to separate the name and type BUT it is also used for other purposes, such as the separator between paths in a fully qualified name.

This change fixes the algorithm to:
* Ignore the last 4 characters of the mangled name - this may contain a `@@` if the symbol is a function returning a type with a qualified name.
* Search backwards for `@@` - this will find the `@` at the end of the symbol's name if it is qualified (which includes those in the global namespace) + the `@` used as the separator between the name and the type.
* If `@@` is not found, search for just `@` - this finds the `@` used as the separator between the name and the type for global symbols (such as `operator new()`) BUT assumes that their return types are not types with a qualified name.

Also fixed `isArm64ECMangledFunctionName` to search for `@$$h` since the `$$h` must always be after a `@`.

Fixes #115231

>From 5f96d58db9954a96b95151c288c05abec60a8503 Mon Sep 17 00:00:00 2001
From: "Daniel Paoliello (HE/HIM)" <danpao at microsoft.com>
Date: Fri, 8 Nov 2024 15:40:11 -0800
Subject: [PATCH] Fix Arm64EC name mangling algorithm

---
 llvm/include/llvm/IR/Mangler.h    |  2 +-
 llvm/lib/IR/Mangler.cpp           | 21 ++++++++---
 llvm/unittests/IR/ManglerTest.cpp | 62 +++++++++++++++++++++++++++++++
 3 files changed, 78 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/IR/Mangler.h b/llvm/include/llvm/IR/Mangler.h
index 3c3f0c6dce80fa..6c8ebf5f072f28 100644
--- a/llvm/include/llvm/IR/Mangler.h
+++ b/llvm/include/llvm/IR/Mangler.h
@@ -64,7 +64,7 @@ std::optional<std::string> getArm64ECDemangledFunctionName(StringRef Name);
 /// Check if an ARM64EC function name is mangled.
 bool inline isArm64ECMangledFunctionName(StringRef Name) {
   return Name[0] == '#' ||
-         (Name[0] == '?' && Name.find("$$h") != StringRef::npos);
+         (Name[0] == '?' && Name.find("@$$h") != StringRef::npos);
 }
 
 } // End llvm namespace
diff --git a/llvm/lib/IR/Mangler.cpp b/llvm/lib/IR/Mangler.cpp
index 15a4debf191a5b..12be14156f6656 100644
--- a/llvm/lib/IR/Mangler.cpp
+++ b/llvm/lib/IR/Mangler.cpp
@@ -302,14 +302,23 @@ std::optional<std::string> llvm::getArm64ECMangledFunctionName(StringRef Name) {
   // Insert the ARM64EC "$$h" tag after the mangled function name.
   if (Name.contains("$$h"))
     return std::nullopt;
-  size_t InsertIdx = Name.find("@@");
-  size_t ThreeAtSignsIdx = Name.find("@@@");
-  if (InsertIdx != std::string::npos && InsertIdx != ThreeAtSignsIdx) {
+
+  // The last 4 characters of the symbol type may contain a `@@` if the symbol
+  // is returning a qualified type. We don't want to insert `$$h` at that point.
+  auto TrimmedName = Name.drop_back(4);
+
+  // The last `@@` is the separation between the qualified name of the symbol
+  // and its type, which is where we want to insert `$$h`.
+  auto InsertIdx = TrimmedName.rfind("@@");
+  if (InsertIdx != StringRef::npos) {
     InsertIdx += 2;
   } else {
-    InsertIdx = Name.find("@");
-    if (InsertIdx != std::string::npos)
-      InsertIdx++;
+    // If there is no `@@`, then this is a global symbol (e.g., `operator new`)
+    // so look for a `@` instead (since we assume that it will not return a
+    // qualified type).
+    InsertIdx = TrimmedName.find_last_of('@');
+    assert(InsertIdx != StringRef::npos && "Invalid mangled name");
+    InsertIdx += 1;
   }
 
   return std::optional<std::string>(
diff --git a/llvm/unittests/IR/ManglerTest.cpp b/llvm/unittests/IR/ManglerTest.cpp
index 5ac784b7e89ac6..017d8303551244 100644
--- a/llvm/unittests/IR/ManglerTest.cpp
+++ b/llvm/unittests/IR/ManglerTest.cpp
@@ -172,4 +172,66 @@ TEST(ManglerTest, GOFF) {
             "L#foo");
 }
 
+TEST(ManglerTest, Arm64EC) {
+  constexpr std::string_view Arm64ECNames[] = {
+      // Basic C name.
+      "#Foo",
+
+      // Basic C++ name.
+      "?foo@@$$hYAHXZ",
+
+      // Regression test: https://github.com/llvm/llvm-project/issues/115231
+      "?GetValue@?$Wrapper at UA@@@@$$hQEBAHXZ",
+
+      // Symbols from:
+      // ```
+      // namespace A::B::C::D {
+      // struct Base {
+      //   virtual int f() { return 0; }
+      // };
+      // }
+      // struct Derived : public A::B::C::D::Base {
+      //   virtual int f() override { return 1; }
+      // };
+      // A::B::C::D::Base* MakeObj() { return new Derived(); }
+      // ```
+      // void * __cdecl operator new(unsigned __int64)
+      "??2@$$hYAPEAX_K at Z",
+      // public: virtual int __cdecl A::B::C::D::Base::f(void)
+      "?f at Base@D at C@B at A@@$$hUEAAHXZ",
+      // public: __cdecl A::B::C::D::Base::Base(void)
+      "??0Base at D@C at B@A@@$$hQEAA at XZ",
+      // public: virtual int __cdecl Derived::f(void)
+      "?f at Derived@@$$hUEAAHXZ",
+      // public: __cdecl Derived::Derived(void)
+      "??0Derived@@$$hQEAA at XZ",
+      // struct A::B::C::D::Base * __cdecl MakeObj(void)
+      "?MakeObj@@$$hYAPEAUBase at D@C at B@A@@XZ",
+  };
+
+  for (const auto &Arm64ECName : Arm64ECNames) {
+    // Check that this is a mangled name.
+    EXPECT_TRUE(isArm64ECMangledFunctionName(Arm64ECName))
+        << "Test case: " << Arm64ECName;
+    // Refuse to mangle it again.
+    EXPECT_FALSE(getArm64ECMangledFunctionName(Arm64ECName).has_value())
+        << "Test case: " << Arm64ECName;
+
+    // Demangle.
+    auto Arm64Name = getArm64ECDemangledFunctionName(Arm64ECName);
+    EXPECT_TRUE(Arm64Name.has_value()) << "Test case: " << Arm64ECName;
+    // Check that it is not mangled.
+    EXPECT_FALSE(isArm64ECMangledFunctionName(Arm64Name.value()))
+        << "Test case: " << Arm64ECName;
+    // Refuse to demangle it again.
+    EXPECT_FALSE(getArm64ECDemangledFunctionName(Arm64Name.value()).has_value())
+        << "Test case: " << Arm64ECName;
+
+    // Round-trip.
+    auto RoundTripArm64ECName =
+        getArm64ECMangledFunctionName(Arm64Name.value());
+    EXPECT_EQ(RoundTripArm64ECName, Arm64ECName);
+  }
+}
+
 } // end anonymous namespace



More information about the llvm-commits mailing list