[llvm] [llvm][aarch64] Fix Arm64EC name mangling algorithm (PR #115567)

Daniel Paoliello via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 13 12:24:40 PST 2024


https://github.com/dpaoliello updated https://github.com/llvm/llvm-project/pull/115567

>From e0bcc78b521672f01465a9e0abeb2c5c8e52063a Mon Sep 17 00:00:00 2001
From: "Daniel Paoliello (HE/HIM)" <danpao at microsoft.com>
Date: Fri, 8 Nov 2024 15:40:11 -0800
Subject: [PATCH] Fix Arm64EC name mangling algorithm

---
 llvm/include/llvm/Demangle/Demangle.h         |  4 +
 .../include/llvm/Demangle/MicrosoftDemangle.h |  4 +
 llvm/include/llvm/IR/Mangler.h                |  2 +-
 llvm/lib/Demangle/MicrosoftDemangle.cpp       | 19 +++++
 llvm/lib/IR/Mangler.cpp                       | 19 ++---
 llvm/unittests/IR/ManglerTest.cpp             | 77 +++++++++++++++++++
 6 files changed, 113 insertions(+), 12 deletions(-)

diff --git a/llvm/include/llvm/Demangle/Demangle.h b/llvm/include/llvm/Demangle/Demangle.h
index fe129603c0785d..132e5088b55148 100644
--- a/llvm/include/llvm/Demangle/Demangle.h
+++ b/llvm/include/llvm/Demangle/Demangle.h
@@ -10,6 +10,7 @@
 #define LLVM_DEMANGLE_DEMANGLE_H
 
 #include <cstddef>
+#include <optional>
 #include <string>
 #include <string_view>
 
@@ -54,6 +55,9 @@ enum MSDemangleFlags {
 char *microsoftDemangle(std::string_view mangled_name, size_t *n_read,
                         int *status, MSDemangleFlags Flags = MSDF_None);
 
+std::optional<size_t>
+getArm64ECInsertionPointInMangledName(std::string_view MangledName);
+
 // Demangles a Rust v0 mangled symbol.
 char *rustDemangle(std::string_view MangledName);
 
diff --git a/llvm/include/llvm/Demangle/MicrosoftDemangle.h b/llvm/include/llvm/Demangle/MicrosoftDemangle.h
index 6891185a28e57f..276efa7603690e 100644
--- a/llvm/include/llvm/Demangle/MicrosoftDemangle.h
+++ b/llvm/include/llvm/Demangle/MicrosoftDemangle.h
@@ -9,6 +9,7 @@
 #ifndef LLVM_DEMANGLE_MICROSOFTDEMANGLE_H
 #define LLVM_DEMANGLE_MICROSOFTDEMANGLE_H
 
+#include "llvm/Demangle/Demangle.h"
 #include "llvm/Demangle/MicrosoftDemangleNodes.h"
 
 #include <cassert>
@@ -141,6 +142,9 @@ enum class FunctionIdentifierCodeGroup { Basic, Under, DoubleUnder };
 // It has a set of functions to parse mangled symbols into Type instances.
 // It also has a set of functions to convert Type instances to strings.
 class Demangler {
+  friend std::optional<size_t>
+  llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName);
+
 public:
   Demangler() = default;
   virtual ~Demangler() = default;
diff --git a/llvm/include/llvm/IR/Mangler.h b/llvm/include/llvm/IR/Mangler.h
index 3c3f0c6dce80fa..6c8ebf5f072f28 100644
--- a/llvm/include/llvm/IR/Mangler.h
+++ b/llvm/include/llvm/IR/Mangler.h
@@ -64,7 +64,7 @@ std::optional<std::string> getArm64ECDemangledFunctionName(StringRef Name);
 /// Check if an ARM64EC function name is mangled.
 bool inline isArm64ECMangledFunctionName(StringRef Name) {
   return Name[0] == '#' ||
-         (Name[0] == '?' && Name.find("$$h") != StringRef::npos);
+         (Name[0] == '?' && Name.find("@$$h") != StringRef::npos);
 }
 
 } // End llvm namespace
diff --git a/llvm/lib/Demangle/MicrosoftDemangle.cpp b/llvm/lib/Demangle/MicrosoftDemangle.cpp
index aa65f3be29da77..6be8b0fe739967 100644
--- a/llvm/lib/Demangle/MicrosoftDemangle.cpp
+++ b/llvm/lib/Demangle/MicrosoftDemangle.cpp
@@ -24,6 +24,7 @@
 #include <array>
 #include <cctype>
 #include <cstdio>
+#include <optional>
 #include <string_view>
 #include <tuple>
 
@@ -2428,6 +2429,24 @@ void Demangler::dumpBackReferences() {
     std::printf("\n");
 }
 
+std::optional<size_t>
+llvm::getArm64ECInsertionPointInMangledName(std::string_view MangledName) {
+  std::string_view ProcessedName{MangledName};
+
+  // We only support this for MSVC-style C++ symbols.
+  if (!consumeFront(ProcessedName, '?'))
+    return std::nullopt;
+
+  // The insertion point is just after the name of the symbol, so parse that to
+  // remove it from the processed name.
+  Demangler D;
+  D.demangleFullyQualifiedSymbolName(ProcessedName);
+  if (D.Error)
+    return std::nullopt;
+
+  return MangledName.length() - ProcessedName.length();
+}
+
 char *llvm::microsoftDemangle(std::string_view MangledName, size_t *NMangled,
                               int *Status, MSDemangleFlags Flags) {
   Demangler D;
diff --git a/llvm/lib/IR/Mangler.cpp b/llvm/lib/IR/Mangler.cpp
index 15a4debf191a5b..3b9c00cf993f38 100644
--- a/llvm/lib/IR/Mangler.cpp
+++ b/llvm/lib/IR/Mangler.cpp
@@ -14,6 +14,7 @@
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Demangle/Demangle.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
@@ -299,21 +300,17 @@ std::optional<std::string> llvm::getArm64ECMangledFunctionName(StringRef Name) {
     return std::optional<std::string>(("#" + Name).str());
   }
 
-  // Insert the ARM64EC "$$h" tag after the mangled function name.
+  // If the name contains $$h, then it is already mangled.
   if (Name.contains("$$h"))
     return std::nullopt;
-  size_t InsertIdx = Name.find("@@");
-  size_t ThreeAtSignsIdx = Name.find("@@@");
-  if (InsertIdx != std::string::npos && InsertIdx != ThreeAtSignsIdx) {
-    InsertIdx += 2;
-  } else {
-    InsertIdx = Name.find("@");
-    if (InsertIdx != std::string::npos)
-      InsertIdx++;
-  }
+
+  // Ask the demangler where we should insert "$$h".
+  auto InsertIdx = getArm64ECInsertionPointInMangledName(Name);
+  if (!InsertIdx)
+    return std::nullopt;
 
   return std::optional<std::string>(
-      (Name.substr(0, InsertIdx) + "$$h" + Name.substr(InsertIdx)).str());
+      (Name.substr(0, *InsertIdx) + "$$h" + Name.substr(*InsertIdx)).str());
 }
 
 std::optional<std::string>
diff --git a/llvm/unittests/IR/ManglerTest.cpp b/llvm/unittests/IR/ManglerTest.cpp
index 5ac784b7e89ac6..a2b4e81690310b 100644
--- a/llvm/unittests/IR/ManglerTest.cpp
+++ b/llvm/unittests/IR/ManglerTest.cpp
@@ -172,4 +172,81 @@ TEST(ManglerTest, GOFF) {
             "L#foo");
 }
 
+TEST(ManglerTest, Arm64EC) {
+  constexpr std::string_view Arm64ECNames[] = {
+      // Basic C name.
+      "#Foo",
+
+      // Basic C++ name.
+      "?foo@@$$hYAHXZ",
+
+      // Regression test: https://github.com/llvm/llvm-project/issues/115231
+      "?GetValue@?$Wrapper at UA@@@@$$hQEBAHXZ",
+
+      // Symbols from:
+      // ```
+      // namespace A::B::C::D {
+      // struct Base {
+      //   virtual int f() { return 0; }
+      // };
+      // }
+      // struct Derived : public A::B::C::D::Base {
+      //   virtual int f() override { return 1; }
+      // };
+      // A::B::C::D::Base* MakeObj() { return new Derived(); }
+      // ```
+      // void * __cdecl operator new(unsigned __int64)
+      "??2@$$hYAPEAX_K at Z",
+      // public: virtual int __cdecl A::B::C::D::Base::f(void)
+      "?f at Base@D at C@B at A@@$$hUEAAHXZ",
+      // public: __cdecl A::B::C::D::Base::Base(void)
+      "??0Base at D@C at B@A@@$$hQEAA at XZ",
+      // public: virtual int __cdecl Derived::f(void)
+      "?f at Derived@@$$hUEAAHXZ",
+      // public: __cdecl Derived::Derived(void)
+      "??0Derived@@$$hQEAA at XZ",
+      // struct A::B::C::D::Base * __cdecl MakeObj(void)
+      "?MakeObj@@$$hYAPEAUBase at D@C at B@A@@XZ",
+
+      // Symbols from:
+      // ```
+      // template <typename T> struct WW { struct Z{}; };
+      // template <typename X> struct Wrapper {
+      //   int GetValue(typename WW<X>::Z) const;
+      // };
+      // struct A { };
+      // template <typename X> int Wrapper<X>::GetValue(typename WW<X>::Z) const
+      // { return 3; }
+      // template class Wrapper<A>;
+      // ```
+      // public: int __cdecl Wrapper<struct A>::GetValue(struct WW<struct
+      // A>::Z)const
+      "?GetValue@?$Wrapper at UA@@@@$$hQEBAHUZ@?$WW at UA@@@@@Z",
+  };
+
+  for (const auto &Arm64ECName : Arm64ECNames) {
+    // Check that this is a mangled name.
+    EXPECT_TRUE(isArm64ECMangledFunctionName(Arm64ECName))
+        << "Test case: " << Arm64ECName;
+    // Refuse to mangle it again.
+    EXPECT_FALSE(getArm64ECMangledFunctionName(Arm64ECName).has_value())
+        << "Test case: " << Arm64ECName;
+
+    // Demangle.
+    auto Arm64Name = getArm64ECDemangledFunctionName(Arm64ECName);
+    EXPECT_TRUE(Arm64Name.has_value()) << "Test case: " << Arm64ECName;
+    // Check that it is not mangled.
+    EXPECT_FALSE(isArm64ECMangledFunctionName(Arm64Name.value()))
+        << "Test case: " << Arm64ECName;
+    // Refuse to demangle it again.
+    EXPECT_FALSE(getArm64ECDemangledFunctionName(Arm64Name.value()).has_value())
+        << "Test case: " << Arm64ECName;
+
+    // Round-trip.
+    auto RoundTripArm64ECName =
+        getArm64ECMangledFunctionName(Arm64Name.value());
+    EXPECT_EQ(RoundTripArm64ECName, Arm64ECName);
+  }
+}
+
 } // end anonymous namespace



More information about the llvm-commits mailing list