[llvm] [CodeGen][ARM64EC] Add support for hybrid_patchable attribute. (PR #92965)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 21 14:09:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Jacek Caban (cjacek)

<details>
<summary>Changes</summary>

This PR implements LLVM part of hybrid_patchable support. A prototype of clang part is here: https://github.com/cjacek/llvm-project/commit/e4b1ffc1d76721af47a7185ace89ec9cc91a280e. The attribute is mentioned in the official documentation: https://learn.microsoft.com/en-us/windows/arm/arm64ec-abi. I described more details about how they work: https://wiki.winehq.org/ARM64ECToolchain#Patchable_functions.

---
Full diff: https://github.com/llvm/llvm-project/pull/92965.diff


8 Files Affected:

- (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1) 
- (modified) llvm/include/llvm/IR/Attributes.td (+3) 
- (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+2) 
- (modified) llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (+115-2) 
- (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+12-4) 
- (modified) llvm/lib/Target/AArch64/AArch64CallingConvention.td (+1-1) 
- (modified) llvm/lib/Transforms/Utils/CodeExtractor.cpp (+1) 
- (added) llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll (+77) 


``````````diff
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a..1e6a1cbc856a7 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -744,6 +744,7 @@ enum AttributeKindCodes {
   ATTR_KIND_CORO_ONLY_DESTROY_WHEN_COMPLETE = 90,
   ATTR_KIND_DEAD_ON_UNWIND = 91,
   ATTR_KIND_RANGE = 92,
+  ATTR_KIND_HYBRID_PATCHABLE = 93,
 };
 
 enum ComdatSelectionKindCodes {
diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td
index cef8b17769f0d..3eebd6d018730 100644
--- a/llvm/include/llvm/IR/Attributes.td
+++ b/llvm/include/llvm/IR/Attributes.td
@@ -109,6 +109,9 @@ def ElementType : TypeAttr<"elementtype", [ParamAttr]>;
 /// symbol.
 def FnRetThunkExtern : EnumAttr<"fn_ret_thunk_extern", [FnAttr]>;
 
+/// Function has a hybrid patchable thunk.
+def HybridPatchable : EnumAttr<"hybrid_patchable", [FnAttr]>;
+
 /// Pass structure in an alloca.
 def InAlloca : TypeAttr<"inalloca", [ParamAttr]>;
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index c4cea3d6eef2d..0816278ad040d 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -707,6 +707,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
     return bitc::ATTR_KIND_HOT;
   case Attribute::ElementType:
     return bitc::ATTR_KIND_ELEMENTTYPE;
+  case Attribute::HybridPatchable:
+    return bitc::ATTR_KIND_HYBRID_PATCHABLE;
   case Attribute::InlineHint:
     return bitc::ATTR_KIND_INLINE_HINT;
   case Attribute::InReg:
diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
index 0ec15d34cd4a9..a8dd376e7416d 100644
--- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
@@ -57,6 +57,7 @@ class AArch64Arm64ECCallLowering : public ModulePass {
   Function *buildEntryThunk(Function *F);
   void lowerCall(CallBase *CB);
   Function *buildGuestExitThunk(Function *F);
+  Function *buildPatchableThunk(Function *F);
   bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
   bool runOnModule(Module &M) override;
 
@@ -64,8 +65,11 @@ class AArch64Arm64ECCallLowering : public ModulePass {
   int cfguard_module_flag = 0;
   FunctionType *GuardFnType = nullptr;
   PointerType *GuardFnPtrType = nullptr;
+  FunctionType *DispatchFnType = nullptr;
+  PointerType *DispatchFnPtrType = nullptr;
   Constant *GuardFnCFGlobal = nullptr;
   Constant *GuardFnGlobal = nullptr;
+  Constant *DispatchFnGlobal = nullptr;
   Module *M = nullptr;
 
   Type *PtrTy;
@@ -615,6 +619,78 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
   return GuestExit;
 }
 
+Function *AArch64Arm64ECCallLowering::buildPatchableThunk(Function *F) {
+  llvm::raw_null_ostream NullThunkName;
+  FunctionType *Arm64Ty, *X64Ty;
+  getThunkType(F->getFunctionType(), F->getAttributes(),
+               Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
+  auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
+  assert(MangledName && "Can't guest exit to function that's already native");
+  std::string ThunkName = *MangledName;
+  if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
+    ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
+  } else {
+    ThunkName.append("$hybpatch_thunk");
+  }
+
+  Function *GuestExit =
+      Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
+  GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
+  GuestExit->setSection(".wowthk$aa");
+  GuestExit->setMetadata(
+      "arm64ec_unmangled_name",
+      MDNode::get(M->getContext(),
+                  MDString::get(M->getContext(), F->getName())));
+  GuestExit->setMetadata(
+      "arm64ec_ecmangled_name",
+      MDNode::get(M->getContext(),
+                  MDString::get(M->getContext(), *MangledName)));
+  GuestExit->setMetadata(
+      "arm64ec_exp_name",
+      MDNode::get(M->getContext(),
+                  MDString::get(M->getContext(), "EXP+" + *MangledName)));
+  F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
+  BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
+  IRBuilder<> B(BB);
+
+  // Load the global symbol as a pointer to the check function.
+  LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
+  Value *TargetFn =
+      M->getOrInsertFunction(*MangledName + "$hp_target", F->getFunctionType())
+          .getCallee();
+
+  // Create new dispatch call instruction.
+  Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
+  CallInst *Dispatch = B.CreateCall(DispatchFnType, DispatchLoad,
+                                    {B.CreateBitCast(F, B.getPtrTy()),
+                                     B.CreateBitCast(Thunk, B.getPtrTy()),
+                                     B.CreateBitCast(TargetFn, B.getPtrTy())});
+
+  // Ensure that the first arguments are passed in the correct registers.
+  Dispatch->setCallingConv(CallingConv::CFGuard_Check);
+
+  Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
+  SmallVector<Value *> Args;
+  for (Argument &Arg : GuestExit->args())
+    Args.push_back(&Arg);
+  CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
+  Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
+
+  if (Call->getType()->isVoidTy())
+    B.CreateRetVoid();
+  else
+    B.CreateRet(Call);
+
+  auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
+  auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
+  if (SRetAttr.isValid() && !InRegAttr.isValid()) {
+    GuestExit->addParamAttr(0, SRetAttr);
+    Call->addParamAttr(0, SRetAttr);
+  }
+
+  return GuestExit;
+}
+
 // Lower an indirect call with inline code.
 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
@@ -670,10 +746,40 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
 
   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
   GuardFnPtrType = PointerType::get(GuardFnType, 0);
+  DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
+  DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
   GuardFnCFGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
   GuardFnGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
+  DispatchFnGlobal =
+      M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
+
+  // Rename hybrid patchable functions and change callers to use an external
+  // linkage function call instead.
+  SetVector<Function *> PatchableFns;
+  for (Function &F : Mod) {
+    if (!F.hasFnAttribute(Attribute::HybridPatchable) ||
+        F.getName().ends_with("$hp_target"))
+      continue;
+
+    if (F.isDeclaration() || F.hasLocalLinkage()) {
+      F.removeFnAttr(Attribute::HybridPatchable);
+      continue;
+    }
+
+    if (std::optional<std::string> MangledName =
+            getArm64ECMangledFunctionName(F.getName().str())) {
+      std::string OrigName(F.getName());
+      F.setName(MangledName.value() + "$hp_target");
+
+      Function *EF = Function::Create(
+          F.getFunctionType(), GlobalValue::ExternalLinkage, 0, OrigName, M);
+      EF->copyAttributesFrom(&F);
+      F.replaceAllUsesWith(EF);
+      PatchableFns.insert(EF);
+    }
+  }
 
   SetVector<Function *> DirectCalledFns;
   for (Function &F : Mod)
@@ -702,10 +808,15 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
     ThunkMapping.push_back(
         {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
          Arm64ECThunkType::Exit});
+    assert(!F->hasFnAttribute(Attribute::HybridPatchable));
     if (!F->hasDLLImportStorageClass())
       ThunkMapping.push_back(
           {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
   }
+  for (Function *F : PatchableFns) {
+    Function *Thunk = buildPatchableThunk(F);
+    ThunkMapping.push_back({Thunk, F, Arm64ECThunkType::GuestExit});
+  }
 
   if (!ThunkMapping.empty()) {
     SmallVector<Constant *> ThunkMappingArrayElems;
@@ -738,7 +849,8 @@ bool AArch64Arm64ECCallLowering::processFunction(
   // name (emitting the definition) can grab it from the metadata.
   //
   // FIXME: Handle functions with weak linkage?
-  if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
+  if ((!F.hasLocalLinkage() || F.hasAddressTaken()) &&
+      !F.hasFnAttribute(Attribute::HybridPatchable)) {
     if (std::optional<std::string> MangledName =
             getArm64ECMangledFunctionName(F.getName().str())) {
       F.setMetadata("arm64ec_unmangled_name",
@@ -773,7 +885,8 @@ bool AArch64Arm64ECCallLowering::processFunction(
       // unprototyped functions in C)
       if (Function *F = CB->getCalledFunction()) {
         if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
-            F->isIntrinsic() || !F->isDeclaration())
+            F->isIntrinsic() || !F->isDeclaration() ||
+            F->hasFnAttribute(Attribute::HybridPatchable))
           continue;
 
         DirectCalledFns.insert(F);
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index bdc3fc630a4e3..d4b88a35a8286 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -1167,8 +1167,9 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
       !MF->getFunction().hasLocalLinkage()) {
     // For ARM64EC targets, a function definition's name is mangled differently
     // from the normal symbol, emit required aliases here.
-    auto emitFunctionAlias = [&](MCSymbol *Src, MCSymbol *Dst) {
-      OutStreamer->emitSymbolAttribute(Src, MCSA_WeakAntiDep);
+    auto emitFunctionAlias = [&](MCSymbol *Src, MCSymbol *Dst,
+                                 MCSymbolAttr Attr = MCSA_WeakAntiDep) {
+      OutStreamer->emitSymbolAttribute(Src, Attr);
       OutStreamer->emitAssignment(
           Src, MCSymbolRefExpr::create(Dst, MCSymbolRefExpr::VK_None,
                                        MMI->getContext()));
@@ -1186,8 +1187,15 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
     if (MCSymbol *UnmangledSym =
             getSymbolFromMetadata("arm64ec_unmangled_name")) {
       MCSymbol *ECMangledSym = getSymbolFromMetadata("arm64ec_ecmangled_name");
-
-      if (ECMangledSym) {
+      MCSymbol *ExpSym = getSymbolFromMetadata("arm64ec_exp_name");
+
+      if (ExpSym) {
+        // A hybrid patchable function, emit the alias from the unmangled
+        // symbol to x64 thunk and and the alias from the mangled symbol to
+        // patchable guest exit thunk.
+        emitFunctionAlias(ECMangledSym, CurrentFnSym, MCSA_Weak);
+        emitFunctionAlias(UnmangledSym, ExpSym, MCSA_Weak);
+      } else if (ECMangledSym) {
         // An external function, emit the alias from the unmangled symbol to
         // mangled symbol name and the alias from the mangled symbol to guest
         // exit thunk.
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index 8e67f0f5c8815..5061605364c21 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -333,7 +333,7 @@ def CC_AArch64_Win64_CFGuard_Check : CallingConv<[
 
 let Entry = 1 in
 def CC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[
-  CCIfType<[i64], CCAssignToReg<[X11, X10]>>
+  CCIfType<[i64], CCAssignToReg<[X11, X10, X9]>>
 ]>;
 
 let Entry = 1 in
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index f2672b8e9118f..e49309014ca71 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -932,6 +932,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
       case Attribute::DisableSanitizerInstrumentation:
       case Attribute::FnRetThunkExtern:
       case Attribute::Hot:
+      case Attribute::HybridPatchable:
       case Attribute::NoRecurse:
       case Attribute::InlineHint:
       case Attribute::MinSize:
diff --git a/llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll b/llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll
new file mode 100644
index 0000000000000..5b1cac173a82a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll
@@ -0,0 +1,77 @@
+; RUN: llc -mtriple=arm64ec-pc-windows-msvc < %s | FileCheck %s
+
+define dso_local i32 @func() hybrid_patchable nounwind {
+; CHECK-LABEL:     .def    "#func$hp_target";
+; CHECK:           .section        .text,"xr",discard,"#func$hp_target"
+; CHECK-NEXT:      .globl  "#func$hp_target"               // -- Begin function #func$hp_target
+; CHECK-NEXT:      .p2align        2
+; CHECK-NEXT:  "#func$hp_target":                      // @"#func$hp_target"
+; CHECK-NEXT:      // %bb.0:
+; CHECK-NEXT:      mov     w0, #1                          // =0x1
+; CHECK-NEXT:      ret
+  ret i32 1
+}
+
+; hybrid_patchable attribute is ignored on internal functions
+define internal i32 @static_func() hybrid_patchable nounwind {
+; CHECK-LABEL:     .def    static_func;
+; CHECK:       static_func:                            // @static_func
+; CHECK-NEXT:      // %bb.0:
+; CHECK-NEXT:      mov     w0, #2                          // =0x2
+; CHECK-NEXT:      ret
+  ret i32 2
+}
+
+define dso_local void @caller() nounwind {
+; CHECK-LABEL:     .def    "#caller";
+; CHECK:           .section        .text,"xr",discard,"#caller"
+; CHECK-NEXT:      .globl  "#caller"                       // -- Begin function #caller
+; CHECK-NEXT:      .p2align        2
+; CHECK-NEXT:  "#caller":                              // @"#caller"
+; CHECK-NEXT:      .weak_anti_dep  caller
+; CHECK-NEXT:  .set caller, "#caller"{{$}}
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:      str     x30, [sp, #-16]!                // 8-byte Folded Spill
+; CHECK-NEXT:      bl      "#func"
+; CHECK-NEXT:      bl      static_func
+; CHECK-NEXT:      ldr     x30, [sp], #16                  // 8-byte Folded Reload
+; CHECK-NEXT:      ret
+  %1 = call i32 @func()
+  %2 = call i32 @static_func()
+  ret void
+}
+
+; CHECK: .def    $ientry_thunk$cdecl$i8$v;
+; CHECK: .def    $ientry_thunk$cdecl$v$v;
+; CHECK: .def    $iexit_thunk$cdecl$i8$v;
+
+; CHECK-LABEL:       def    "#func$hybpatch_thunk";
+; CHECK:            .section        .wowthk$aa,"xr",discard,"#func$hybpatch_thunk"
+; CHECK-NEXT:       .globl  "#func$hybpatch_thunk"          // -- Begin function #func$hybpatch_thunk
+; CHECK-NEXT:       .p2align        2
+; CHECK-NEXT:   "#func$hybpatch_thunk":                 // @"#func$hybpatch_thunk"
+; CHECK-NEXT:       .weak  "#func"
+; CHECK-NEXT:   .set "#func", "#func$hybpatch_thunk"{{$}}
+; CHECK-NEXT:       .weak  func
+; CHECK-NEXT:   .set func, "EXP+#func"{{$}}
+; CHECK-NEXT:   .seh_proc "#func$hybpatch_thunk"
+; CHECK-NEXT:   // %bb.0:
+; CHECK-NEXT:       str     x30, [sp, #-16]!                // 8-byte Folded Spill
+; CHECK-NEXT:       .seh_save_reg_x x30, 16
+; CHECK-NEXT:       .seh_endprologue
+; CHECK-NEXT:       adrp    x8, __os_arm64x_dispatch_call
+; CHECK-NEXT:       adrp    x11, func
+; CHECK-NEXT:       add     x11, x11, :lo12:func
+; CHECK-NEXT:       ldr     x8, [x8, :lo12:__os_arm64x_dispatch_call]
+; CHECK-NEXT:       adrp    x10, ($iexit_thunk$cdecl$i8$v)
+; CHECK-NEXT:       add     x10, x10, :lo12:($iexit_thunk$cdecl$i8$v)
+; CHECK-NEXT:       adrp    x9, "#func$hp_target"
+; CHECK-NEXT:       add     x9, x9, :lo12:"#func$hp_target"
+; CHECK-NEXT:       blr     x8
+; CHECK-NEXT:       .seh_startepilogue
+; CHECK-NEXT:       ldr     x30, [sp], #16                  // 8-byte Folded Reload
+; CHECK-NEXT:       .seh_save_reg_x x30, 16
+; CHECK-NEXT:       .seh_endepilogue
+; CHECK-NEXT:       br      x11
+; CHECK-NEXT:       .seh_endfunclet
+; CHECK-NEXT:       .seh_endproc

``````````

</details>


https://github.com/llvm/llvm-project/pull/92965


More information about the llvm-commits mailing list