[llvm] [x64][win] Add compiler support for x64 import call optimization (equivalent to MSVC /d2guardretpoline) (PR #126631)

Daniel Paoliello via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 11 12:38:49 PST 2025


https://github.com/dpaoliello updated https://github.com/llvm/llvm-project/pull/126631

>From b4ab8af255bbd15b45618a164bd32e875534468b Mon Sep 17 00:00:00 2001
From: Daniel Paoliello <danpao at microsoft.com>
Date: Thu, 23 Jan 2025 10:45:02 -0800
Subject: [PATCH] Implement Import Call Optimization for x64

---
 llvm/include/llvm/Transforms/CFGuard.h        |   3 +
 llvm/lib/MC/MCObjectFileInfo.cpp              |   5 +
 llvm/lib/Target/X86/X86AsmPrinter.cpp         |  32 ++++
 llvm/lib/Target/X86/X86AsmPrinter.h           |  33 +++-
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  18 +-
 llvm/lib/Target/X86/X86ISelLowering.h         |   8 +-
 llvm/lib/Target/X86/X86ISelLoweringCall.cpp   |   7 +-
 llvm/lib/Target/X86/X86InstrCompiler.td       |   2 +
 llvm/lib/Target/X86/X86InstrFragments.td      |   3 +
 llvm/lib/Target/X86/X86MCInstLower.cpp        | 169 ++++++++++++++++--
 llvm/lib/Transforms/CFGuard/CFGuard.cpp       |  15 +-
 .../win-import-call-optimization-cfguard.ll   |  34 ++++
 .../win-import-call-optimization-jumptable.ll |  83 +++++++++
 .../win-import-call-optimization-nocalls.ll   |  21 +++
 .../X86/win-import-call-optimization.ll       |  65 +++++++
 .../MC/X86/win-import-call-optimization.s     |  69 +++++++
 16 files changed, 540 insertions(+), 27 deletions(-)
 create mode 100644 llvm/test/CodeGen/X86/win-import-call-optimization-cfguard.ll
 create mode 100644 llvm/test/CodeGen/X86/win-import-call-optimization-jumptable.ll
 create mode 100644 llvm/test/CodeGen/X86/win-import-call-optimization-nocalls.ll
 create mode 100644 llvm/test/CodeGen/X86/win-import-call-optimization.ll
 create mode 100644 llvm/test/MC/X86/win-import-call-optimization.s

diff --git a/llvm/include/llvm/Transforms/CFGuard.h b/llvm/include/llvm/Transforms/CFGuard.h
index caf822a2ec9fb..b81db8f487965 100644
--- a/llvm/include/llvm/Transforms/CFGuard.h
+++ b/llvm/include/llvm/Transforms/CFGuard.h
@@ -16,6 +16,7 @@
 namespace llvm {
 
 class FunctionPass;
+class GlobalValue;
 
 class CFGuardPass : public PassInfoMixin<CFGuardPass> {
 public:
@@ -34,6 +35,8 @@ FunctionPass *createCFGuardCheckPass();
 /// Insert Control FLow Guard dispatches on indirect function calls.
 FunctionPass *createCFGuardDispatchPass();
 
+bool isCFGuardFunction(const GlobalValue *GV);
+
 } // namespace llvm
 
 #endif
diff --git a/llvm/lib/MC/MCObjectFileInfo.cpp b/llvm/lib/MC/MCObjectFileInfo.cpp
index 150e38a94db6a..334673c4dba79 100644
--- a/llvm/lib/MC/MCObjectFileInfo.cpp
+++ b/llvm/lib/MC/MCObjectFileInfo.cpp
@@ -599,6 +599,11 @@ void MCObjectFileInfo::initCOFFMCObjectFileInfo(const Triple &T) {
   if (T.getArch() == Triple::aarch64) {
     ImportCallSection =
         Ctx->getCOFFSection(".impcall", COFF::IMAGE_SCN_LNK_INFO);
+  } else if (T.getArch() == Triple::x86_64) {
+    // Import Call Optimization on x64 leverages the same metadata as the
+    // retpoline mitigation, hence the unusual section name.
+    ImportCallSection =
+        Ctx->getCOFFSection(".retplne", COFF::IMAGE_SCN_LNK_INFO);
   }
 
   // Debug info.
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.cpp b/llvm/lib/Target/X86/X86AsmPrinter.cpp
index f01e47b41cf5e..52f8280b25965 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.cpp
+++ b/llvm/lib/Target/X86/X86AsmPrinter.cpp
@@ -920,6 +920,9 @@ void X86AsmPrinter::emitStartOfAsmFile(Module &M) {
     OutStreamer->emitSymbolAttribute(S, MCSA_Global);
     OutStreamer->emitAssignment(
         S, MCConstantExpr::create(Feat00Value, MMI->getContext()));
+
+    if (M.getModuleFlag("import-call-optimization"))
+      EnableImportCallOptimization = true;
   }
   OutStreamer->emitSyntaxDirective();
 
@@ -1021,6 +1024,35 @@ void X86AsmPrinter::emitEndOfAsmFile(Module &M) {
     // safe to set.
     OutStreamer->emitAssemblerFlag(MCAF_SubsectionsViaSymbols);
   } else if (TT.isOSBinFormatCOFF()) {
+    // If import call optimization is enabled, emit the appropriate section.
+    // We do this whether or not we recorded any items.
+    if (EnableImportCallOptimization) {
+      OutStreamer->switchSection(getObjFileLowering().getImportCallSection());
+
+      // Section always starts with some magic.
+      constexpr char ImpCallMagic[12] = "RetpolineV1";
+      OutStreamer->emitBytes(StringRef{ImpCallMagic, sizeof(ImpCallMagic)});
+
+      // Layout of this section is:
+      // Per section that contains an item to record:
+      //  uint32_t SectionSize: Size in bytes for information in this section.
+      //  uint32_t Section Number
+      //  Per call to imported function in section:
+      //    uint32_t Kind: the kind of item.
+      //    uint32_t InstOffset: the offset of the instr in its parent section.
+      for (auto &[Section, CallsToImportedFuncs] :
+           SectionToImportedFunctionCalls) {
+        unsigned SectionSize =
+            sizeof(uint32_t) * (2 + 2 * CallsToImportedFuncs.size());
+        OutStreamer->emitInt32(SectionSize);
+        OutStreamer->emitCOFFSecNumber(Section->getBeginSymbol());
+        for (auto &[CallsiteSymbol, Kind] : CallsToImportedFuncs) {
+          OutStreamer->emitInt32(Kind);
+          OutStreamer->emitCOFFSecOffset(CallsiteSymbol);
+        }
+      }
+    }
+
     if (usesMSVCFloatingPoint(TT, M)) {
       // In Windows' libcmt.lib, there is a file which is linked in only if the
       // symbol _fltused is referenced. Linking this in causes some
diff --git a/llvm/lib/Target/X86/X86AsmPrinter.h b/llvm/lib/Target/X86/X86AsmPrinter.h
index 693021eca3295..47e82c4dfcea5 100644
--- a/llvm/lib/Target/X86/X86AsmPrinter.h
+++ b/llvm/lib/Target/X86/X86AsmPrinter.h
@@ -31,6 +31,26 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   bool EmitFPOData = false;
   bool ShouldEmitWeakSwiftAsyncExtendedFramePointerFlags = false;
   bool IndCSPrefix = false;
+  bool EnableImportCallOptimization = false;
+
+  enum ImportCallKind : unsigned {
+    IMAGE_RETPOLINE_AMD64_IMPORT_BR = 0x02,
+    IMAGE_RETPOLINE_AMD64_IMPORT_CALL = 0x03,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR = 0x04,
+    IMAGE_RETPOLINE_AMD64_INDIR_CALL = 0x05,
+    IMAGE_RETPOLINE_AMD64_INDIR_BR_REX = 0x06,
+    IMAGE_RETPOLINE_AMD64_CFG_BR = 0x08,
+    IMAGE_RETPOLINE_AMD64_CFG_CALL = 0x09,
+    IMAGE_RETPOLINE_AMD64_CFG_BR_REX = 0x0A,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST = 0x010,
+    IMAGE_RETPOLINE_AMD64_SWITCHTABLE_LAST = 0x01F,
+  };
+  struct ImportCallInfo {
+    MCSymbol *CalleeSymbol;
+    ImportCallKind Kind;
+  };
+  DenseMap<MCSection *, std::vector<ImportCallInfo>>
+      SectionToImportedFunctionCalls;
 
   // This utility class tracks the length of a stackmap instruction's 'shadow'.
   // It is used by the X86AsmPrinter to ensure that the stackmap shadow
@@ -45,7 +65,7 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
     void startFunction(MachineFunction &MF) {
       this->MF = &MF;
     }
-    void count(MCInst &Inst, const MCSubtargetInfo &STI,
+    void count(const MCInst &Inst, const MCSubtargetInfo &STI,
                MCCodeEmitter *CodeEmitter);
 
     // Called to signal the start of a shadow of RequiredSize bytes.
@@ -126,6 +146,17 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
   void emitMachOIFuncStubHelperBody(Module &M, const GlobalIFunc &GI,
                                     MCSymbol *LazyPointer) override;
 
+  void emitCallInstruction(const llvm::MCInst &MCI);
+
+  // Emits a label to mark the next instruction as being relevant to Import Call
+  // Optimization.
+  void emitLabelAndRecordForImportCallOptimization(ImportCallKind Kind);
+
+  // Ensure that rax is used as the operand for the given instruction.
+  //
+  // NOTE: This assumes that it is safe to clobber rax.
+  void ensureRaxUsedForOperand(MCInst &TmpInst);
+
 public:
   X86AsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer);
 
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6cf6061deba70..30a98a5a13ebf 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -18922,7 +18922,7 @@ SDValue X86TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {
 
 SDValue X86TargetLowering::LowerExternalSymbol(SDValue Op,
                                                SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 SDValue
@@ -18950,7 +18950,8 @@ X86TargetLowering::LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const {
 /// Creates target global address or external symbol nodes for calls or
 /// other uses.
 SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                                 bool ForCall) const {
+                                                 bool ForCall,
+                                                 bool *IsImpCall) const {
   // Unpack the global address or external symbol.
   SDLoc dl(Op);
   const GlobalValue *GV = nullptr;
@@ -19000,6 +19001,16 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
   if (ForCall && !NeedsLoad && !HasPICReg && Offset == 0)
     return Result;
 
+  // If Import Call Optimization is enabled and this is an imported function
+  // then make a note of it and return the global address without wrapping.
+  if (IsImpCall && (OpFlags == X86II::MO_DLLIMPORT) &&
+      Mod.getModuleFlag("import-call-optimization")) {
+    assert(ForCall && "Should only enable import call optimization if we are "
+                      "lowering a call");
+    *IsImpCall = true;
+    return Result;
+  }
+
   Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result);
 
   // With PIC, the address is actually $g + Offset.
@@ -19025,7 +19036,7 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
 
 SDValue
 X86TargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
-  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
+  return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
 }
 
 static SDValue GetTLSADDR(SelectionDAG &DAG, GlobalAddressSDNode *GA,
@@ -34562,6 +34573,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FST)
   NODE_NAME_CASE(CALL)
   NODE_NAME_CASE(CALL_RVMARKER)
+  NODE_NAME_CASE(IMP_CALL)
   NODE_NAME_CASE(BT)
   NODE_NAME_CASE(CMP)
   NODE_NAME_CASE(FCMP)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index fe79fefeed631..6324cc65398a0 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -81,6 +81,10 @@ namespace llvm {
     // marker instruction.
     CALL_RVMARKER,
 
+    // Pseudo for a call to an imported function to ensure the correct machine
+    // instruction is emitted for Import Call Optimization.
+    IMP_CALL,
+
     /// X86 compare and logical compare instructions.
     CMP,
     FCMP,
@@ -1733,8 +1737,8 @@ namespace llvm {
 
     /// Creates target global address or external symbol nodes for calls or
     /// other uses.
-    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
-                                  bool ForCall) const;
+    SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG, bool ForCall,
+                                  bool *IsImpCall) const;
 
     SDValue LowerSINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerUINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
index 6835c7e336a5c..cbbdf37a3fb75 100644
--- a/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
+++ b/llvm/lib/Target/X86/X86ISelLoweringCall.cpp
@@ -2402,6 +2402,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     InGlue = Chain.getValue(1);
   }
 
+  bool IsImpCall = false;
   if (DAG.getTarget().getCodeModel() == CodeModel::Large) {
     assert(Is64Bit && "Large code model is only legal in 64-bit mode.");
     // In the 64-bit large code model, we have to make all calls
@@ -2414,7 +2415,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     // ForCall to true here has the effect of removing WrapperRIP when possible
     // to allow direct calls to be selected without first materializing the
     // address into a register.
-    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true);
+    Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true, &IsImpCall);
   } else if (Subtarget.isTarget64BitILP32() &&
              Callee.getValueType() == MVT::i32) {
     // Zero-extend the 32-bit Callee address into a 64-bit according to x32 ABI
@@ -2536,7 +2537,9 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   // Returns a chain & a glue for retval copy to use.
   SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
-  if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
+  if (IsImpCall) {
+    Chain = DAG.getNode(X86ISD::IMP_CALL, dl, NodeTys, Ops);
+  } else if (HasNoCfCheck && IsCFProtectionSupported && IsIndirectCall) {
     Chain = DAG.getNode(X86ISD::NT_CALL, dl, NodeTys, Ops);
   } else if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
     // Calls with a "clang.arc.attachedcall" bundle are special. They should be
diff --git a/llvm/lib/Target/X86/X86InstrCompiler.td b/llvm/lib/Target/X86/X86InstrCompiler.td
index 9687ae29f1c78..5f603de695906 100644
--- a/llvm/lib/Target/X86/X86InstrCompiler.td
+++ b/llvm/lib/Target/X86/X86InstrCompiler.td
@@ -1309,6 +1309,8 @@ def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 texternalsym:$dst)),
 def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 tglobaladdr:$dst)),
           (CALL64pcrel32_RVMARKER tglobaladdr:$rvfunc, tglobaladdr:$dst)>;
 
+def : Pat<(X86imp_call (i64 tglobaladdr:$dst)),
+          (CALL64pcrel32 tglobaladdr:$dst)>;
 
 // Tailcall stuff. The TCRETURN instructions execute after the epilog, so they
 // can never use callee-saved registers. That is the purpose of the GR64_TC
diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td
index ddbc7c55a6113..3ab820de78efc 100644
--- a/llvm/lib/Target/X86/X86InstrFragments.td
+++ b/llvm/lib/Target/X86/X86InstrFragments.td
@@ -210,6 +210,9 @@ def X86call_rvmarker  : SDNode<"X86ISD::CALL_RVMARKER",     SDT_X86Call,
                         [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
                          SDNPVariadic]>;
 
+def X86imp_call  : SDNode<"X86ISD::IMP_CALL",     SDT_X86Call,
+                        [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
+                         SDNPVariadic]>;
 
 def X86NoTrackCall : SDNode<"X86ISD::NT_CALL", SDT_X86Call,
                             [SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
diff --git a/llvm/lib/Target/X86/X86MCInstLower.cpp b/llvm/lib/Target/X86/X86MCInstLower.cpp
index 0f8fbf5be1c95..f265093a60d12 100644
--- a/llvm/lib/Target/X86/X86MCInstLower.cpp
+++ b/llvm/lib/Target/X86/X86MCInstLower.cpp
@@ -47,6 +47,7 @@
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/CFGuard.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizer.h"
 #include "llvm/Transforms/Instrumentation/AddressSanitizerCommon.h"
 #include <string>
@@ -112,7 +113,7 @@ struct NoAutoPaddingScope {
 static void emitX86Nops(MCStreamer &OS, unsigned NumBytes,
                         const X86Subtarget *Subtarget);
 
-void X86AsmPrinter::StackMapShadowTracker::count(MCInst &Inst,
+void X86AsmPrinter::StackMapShadowTracker::count(const MCInst &Inst,
                                                  const MCSubtargetInfo &STI,
                                                  MCCodeEmitter *CodeEmitter) {
   if (InShadow) {
@@ -2193,6 +2194,27 @@ static void addConstantComments(const MachineInstr *MI,
   }
 }
 
+bool isImportedFunction(const MachineOperand &MO) {
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_DLLIMPORT);
+}
+
+bool isCallToCFGuardFunction(const MachineInstr *MI) {
+  assert(MI->getOpcode() == X86::TAILJMPm64_REX ||
+         MI->getOpcode() == X86::CALL64m);
+  const MachineOperand &MO = MI->getOperand(3);
+  return MO.isGlobal() && (MO.getTargetFlags() == X86II::MO_NO_FLAG) &&
+         isCFGuardFunction(MO.getGlobal());
+}
+
+bool hasJumpTableInfoInBlock(const llvm::MachineInstr *MI) {
+  const MachineBasicBlock &MBB = *MI->getParent();
+  for (auto I = MBB.instr_rbegin(), E = MBB.instr_rend(); I != E; ++I)
+    if (I->isJumpTableDebugInfo())
+      return true;
+
+  return false;
+}
+
 void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   // FIXME: Enable feature predicate checks once all the test pass.
   // X86_MC::verifyInstructionPredicates(MI->getOpcode(),
@@ -2271,20 +2293,64 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::TAILJMPd64:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
-    [[fallthrough]];
-  case X86::TAILJMPr:
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
+  case X86::TAILJMPm64_REX:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_BR_REX);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    break;
   case X86::TAILJMPm:
   case X86::TAILJMPd:
   case X86::TAILJMPd_CC:
-  case X86::TAILJMPr64:
   case X86::TAILJMPm64:
   case X86::TAILJMPd64_CC:
-  case X86::TAILJMPr64_REX:
-  case X86::TAILJMPm64_REX:
     // Lower these as normal, but add some comments.
     OutStreamer->AddComment("TAILCALL");
     break;
 
+  case X86::TAILJMPr:
+  case X86::TAILJMPr64:
+  case X86::TAILJMPr64_REX: {
+    MCInst TmpInst;
+    MCInstLowering.Lower(MI, TmpInst);
+
+    if (EnableImportCallOptimization) {
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_BR);
+    }
+
+    // Lower these as normal, but add some comments.
+    OutStreamer->AddComment("TAILCALL");
+    EmitAndCountInstruction(TmpInst);
+    return;
+  }
+
+  case X86::JMP64r:
+  case X86::JMP64m:
+    if (EnableImportCallOptimization && hasJumpTableInfoInBlock(MI)) {
+      uint16_t EncodedReg =
+          this->getSubtarget().getRegisterInfo()->getEncodingValue(
+              MI->getOperand(0).getReg().asMCReg());
+      emitLabelAndRecordForImportCallOptimization(
+          (ImportCallKind)(IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST +
+                           EncodedReg));
+    }
+    break;
+
   case X86::TLS_addr32:
   case X86::TLS_addr64:
   case X86::TLS_addrX32:
@@ -2469,7 +2535,49 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   case X86::CALL64pcrel32:
     if (IndCSPrefix && MI->hasRegisterImplicitUseOperand(X86::R11))
       EmitAndCountInstruction(MCInstBuilder(X86::CS_PREFIX));
+
+    if (EnableImportCallOptimization && isImportedFunction(MI->getOperand(0))) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_IMPORT_CALL);
+
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // For Import Call Optimization to work, we need a the call instruction
+      // with a rex prefix, and a 5-byte nop after the call instruction.
+      EmitAndCountInstruction(MCInstBuilder(X86::REX64_PREFIX));
+      emitCallInstruction(TmpInst);
+      emitNop(*OutStreamer, 5, Subtarget);
+      return;
+    }
+
     break;
+  case X86::CALL64r:
+    if (EnableImportCallOptimization) {
+      MCInst TmpInst;
+      MCInstLowering.Lower(MI, TmpInst);
+
+      // Import call optimization requires all indirect calls go via RAX.
+      ensureRaxUsedForOperand(TmpInst);
+
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_INDIR_CALL);
+      emitCallInstruction(TmpInst);
+
+      // For Import Call Optimization to work, we a 3-byte nop after the call
+      // instruction.
+      emitNop(*OutStreamer, 3, Subtarget);
+      return;
+    }
+
+    break;
+  case X86::CALL64m:
+    if (EnableImportCallOptimization && isCallToCFGuardFunction(MI)) {
+      emitLabelAndRecordForImportCallOptimization(
+          IMAGE_RETPOLINE_AMD64_CFG_CALL);
+    }
+    break;
+
   case X86::JCC_1:
     // Two instruction prefixes (2EH for branch not-taken and 3EH for branch
     // taken) are used as branch hints. Here we add branch taken prefix for
@@ -2490,20 +2598,47 @@ void X86AsmPrinter::emitInstruction(const MachineInstr *MI) {
   MCInst TmpInst;
   MCInstLowering.Lower(MI, TmpInst);
 
-  // Stackmap shadows cannot include branch targets, so we can count the bytes
-  // in a call towards the shadow, but must ensure that the no thread returns
-  // in to the stackmap shadow.  The only way to achieve this is if the call
-  // is at the end of the shadow.
   if (MI->isCall()) {
-    // Count then size of the call towards the shadow
-    SMShadowTracker.count(TmpInst, getSubtargetInfo(), CodeEmitter.get());
-    // Then flush the shadow so that we fill with nops before the call, not
-    // after it.
-    SMShadowTracker.emitShadowPadding(*OutStreamer, getSubtargetInfo());
-    // Then emit the call
-    OutStreamer->emitInstruction(TmpInst, getSubtargetInfo());
+    emitCallInstruction(TmpInst);
     return;
   }
 
   EmitAndCountInstruction(TmpInst);
 }
+
+void X86AsmPrinter::emitCallInstruction(const llvm::MCInst &MCI) {
+  // Stackmap shadows cannot include branch targets, so we can count the bytes
+  // in a call towards the shadow, but must ensure that the no thread returns
+  // in to the stackmap shadow.  The only way to achieve this is if the call
+  // is at the end of the shadow.
+
+  // Count then size of the call towards the shadow
+  SMShadowTracker.count(MCI, getSubtargetInfo(), CodeEmitter.get());
+  // Then flush the shadow so that we fill with nops before the call, not
+  // after it.
+  SMShadowTracker.emitShadowPadding(*OutStreamer, getSubtargetInfo());
+  // Then emit the call
+  OutStreamer->emitInstruction(MCI, getSubtargetInfo());
+}
+
+void X86AsmPrinter::emitLabelAndRecordForImportCallOptimization(
+    ImportCallKind Kind) {
+  assert(EnableImportCallOptimization);
+
+  MCSymbol *CallSiteSymbol = MMI->getContext().createNamedTempSymbol("impcall");
+  OutStreamer->emitLabel(CallSiteSymbol);
+
+  SectionToImportedFunctionCalls[OutStreamer->getCurrentSectionOnly()]
+      .push_back({CallSiteSymbol, Kind});
+}
+
+void X86AsmPrinter::ensureRaxUsedForOperand(MCInst &TmpInst) {
+  assert(TmpInst.getNumOperands() == 1);
+
+  MCOperand &Op = TmpInst.getOperand(0);
+  if (Op.isReg() && Op.getReg() != X86::RAX) {
+    EmitAndCountInstruction(
+        MCInstBuilder(X86::MOV64rr).addReg(X86::RAX).addReg(Op.getReg()));
+    Op.setReg(X86::RAX);
+  }
+}
diff --git a/llvm/lib/Transforms/CFGuard/CFGuard.cpp b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
index 41d68b62eb8d7..cadc538021103 100644
--- a/llvm/lib/Transforms/CFGuard/CFGuard.cpp
+++ b/llvm/lib/Transforms/CFGuard/CFGuard.cpp
@@ -31,6 +31,9 @@ using OperandBundleDef = OperandBundleDefT<Value *>;
 
 STATISTIC(CFGuardCounter, "Number of Control Flow Guard checks added");
 
+constexpr StringRef GuardCheckFunctionName = "__guard_check_icall_fptr";
+constexpr StringRef GuardDispatchFunctionName = "__guard_dispatch_icall_fptr";
+
 namespace {
 
 /// Adds Control Flow Guard (CFG) checks on indirect function calls/invokes.
@@ -45,10 +48,10 @@ class CFGuardImpl {
     // Get or insert the guard check or dispatch global symbols.
     switch (GuardMechanism) {
     case Mechanism::Check:
-      GuardFnName = "__guard_check_icall_fptr";
+      GuardFnName = GuardCheckFunctionName;
       break;
     case Mechanism::Dispatch:
-      GuardFnName = "__guard_dispatch_icall_fptr";
+      GuardFnName = GuardDispatchFunctionName;
       break;
     }
   }
@@ -318,3 +321,11 @@ FunctionPass *llvm::createCFGuardCheckPass() {
 FunctionPass *llvm::createCFGuardDispatchPass() {
   return new CFGuard(CFGuardPass::Mechanism::Dispatch);
 }
+
+bool llvm::isCFGuardFunction(const GlobalValue *GV) {
+  if (GV->getLinkage() != GlobalValue::ExternalLinkage)
+    return false;
+
+  StringRef Name = GV->getName();
+  return Name == GuardCheckFunctionName || Name == GuardDispatchFunctionName;
+}
diff --git a/llvm/test/CodeGen/X86/win-import-call-optimization-cfguard.ll b/llvm/test/CodeGen/X86/win-import-call-optimization-cfguard.ll
new file mode 100644
index 0000000000000..12be910d68ee9
--- /dev/null
+++ b/llvm/test/CodeGen/X86/win-import-call-optimization-cfguard.ll
@@ -0,0 +1,34 @@
+; RUN: llc -mtriple=x86_64-pc-windows-msvc < %s | FileCheck %s --check-prefix=CHECK
+
+define dso_local void @normal_call(ptr noundef readonly %func_ptr) local_unnamed_addr section "nc_sect" {
+entry:
+  call void %func_ptr()
+  ret void
+}
+; CHECK-LABEL:  normal_call:
+; CHECK:        .Limpcall0:
+; CHECK-NEXT:     callq   *__guard_dispatch_icall_fptr(%rip)
+
+define dso_local void @tail_call_fp(ptr noundef readonly %func_ptr) local_unnamed_addr section "tc_sect" {
+entry:
+  tail call void %func_ptr()
+  ret void
+}
+; CHECK-LABEL:  tail_call_fp:
+; CHECK:        .Limpcall1:
+; CHECK-NEXT:     rex64 jmpq      *__guard_dispatch_icall_fptr(%rip)
+
+; CHECK-LABEL  .section   .retplne,"yi"
+; CHECK-NEXT   .asciz  "RetpolineV1"
+; CHECK-NEXT   .long   16
+; CHECK-NEXT   .secnum tc_sect
+; CHECK-NEXT   .long   10
+; CHECK-NEXT   .secoffset      .Limpcall1
+; CHECK-NEXT   .long   16
+; CHECK-NEXT   .secnum nc_sect
+; CHECK-NEXT   .long   9
+; CHECK-NEXT   .secoffset      .Limpcall0
+
+!llvm.module.flags = !{!0, !1}
+!0 = !{i32 1, !"import-call-optimization", i32 1}
+!1 = !{i32 2, !"cfguard", i32 2}
diff --git a/llvm/test/CodeGen/X86/win-import-call-optimization-jumptable.ll b/llvm/test/CodeGen/X86/win-import-call-optimization-jumptable.ll
new file mode 100644
index 0000000000000..fe22b251685e6
--- /dev/null
+++ b/llvm/test/CodeGen/X86/win-import-call-optimization-jumptable.ll
@@ -0,0 +1,83 @@
+; RUN: llc -mtriple=x86_64-pc-windows-msvc < %s | FileCheck %s
+
+; CHECK-LABEL:  uses_rax:
+; CHECK:        .Limpcall0:
+; CHECK-NEXT:     jmpq    *%rax
+
+define void @uses_rax(i32 %x) {
+entry:
+  switch i32 %x, label %sw.epilog [
+    i32 0, label %sw.bb
+    i32 1, label %sw.bb1
+    i32 2, label %sw.bb2
+    i32 3, label %sw.bb3
+  ]
+
+sw.bb:
+  tail call void @g(i32 0) #2
+  br label %sw.epilog
+
+sw.bb1:
+  tail call void @g(i32 1) #2
+  br label %sw.epilog
+
+sw.bb2:
+  tail call void @g(i32 2) #2
+  br label %sw.epilog
+
+sw.bb3:
+  tail call void @g(i32 3) #2
+  br label %sw.epilog
+
+sw.epilog:
+  tail call void @g(i32 10) #2
+  ret void
+}
+
+; CHECK-LABEL:  uses_rcx:
+; CHECK:        .Limpcall1:
+; CHECK-NEXT:     jmpq    *%rcx
+
+define void @uses_rcx(i32 %x) {
+entry:
+  switch i32 %x, label %sw.epilog [
+    i32 10, label %sw.bb
+    i32 11, label %sw.bb1
+    i32 12, label %sw.bb2
+    i32 13, label %sw.bb3
+  ]
+
+sw.bb:
+  tail call void @g(i32 0) #2
+  br label %sw.epilog
+
+sw.bb1:
+  tail call void @g(i32 1) #2
+  br label %sw.epilog
+
+sw.bb2:
+  tail call void @g(i32 2) #2
+  br label %sw.epilog
+
+sw.bb3:
+  tail call void @g(i32 3) #2
+  br label %sw.epilog
+
+sw.epilog:
+  tail call void @g(i32 10) #2
+  ret void
+}
+
+declare void @g(i32)
+
+; CHECK-LABEL:  .section        .retplne,"yi"
+; CHECK-NEXT:   .asciz  "RetpolineV1"
+; CHECK-NEXT:   .long   24
+; CHECK-NEXT:   .secnum .text
+; CHECK-NEXT:   .long   16
+; CHECK-NEXT:   .secoffset      .Limpcall0
+; CHECK-NEXT:   .long   17
+; CHECK-NEXT:   .secoffset      .Limpcall1
+
+!llvm.module.flags = !{!0}
+!0 = !{i32 1, !"import-call-optimization", i32 1}
diff --git a/llvm/test/CodeGen/X86/win-import-call-optimization-nocalls.ll b/llvm/test/CodeGen/X86/win-import-call-optimization-nocalls.ll
new file mode 100644
index 0000000000000..4ca7b85282f2e
--- /dev/null
+++ b/llvm/test/CodeGen/X86/win-import-call-optimization-nocalls.ll
@@ -0,0 +1,21 @@
+; RUN: llc -mtriple=x86_64-pc-windows-msvc < %s | FileCheck %s
+
+define dso_local void @normal_call() local_unnamed_addr {
+entry:
+  call void @a()
+  ret void
+}
+; CHECK-LABEL:  normal_call:
+; CHECK:        callq   a
+
+declare void @a() local_unnamed_addr
+
+; Even if there are no calls to imported functions, we still need to emit the
+; .impcall section.
+
+; CHECK-LABEL  .section   .retplne,"yi"
+; CHECK-NEXT   .asciz  "RetpolineV1"
+; CHECK-NOT    .secnum
+
+!llvm.module.flags = !{!0}
+!0 = !{i32 1, !"import-call-optimization", i32 1}
diff --git a/llvm/test/CodeGen/X86/win-import-call-optimization.ll b/llvm/test/CodeGen/X86/win-import-call-optimization.ll
new file mode 100644
index 0000000000000..e669984ac37a4
--- /dev/null
+++ b/llvm/test/CodeGen/X86/win-import-call-optimization.ll
@@ -0,0 +1,65 @@
+; RUN: llc -mtriple=x86_64-pc-windows-msvc < %s | FileCheck %s --check-prefix=CHECK
+
+define dso_local void @normal_call(ptr noundef readonly %func_ptr) local_unnamed_addr section "nc_sect" {
+entry:
+  call void @a()
+  call void @a()
+  call void %func_ptr()
+  ret void
+}
+; CHECK-LABEL:  normal_call:
+; CHECK:        .Limpcall0:
+; CHECK-NEXT:     rex64
+; CHECK-NEXT:     callq   __imp_a
+; CHECK-NEXT:     nopl    8(%rax,%rax)
+; CHECK-NEXT:   .Limpcall1:
+; CHECK-NEXT:     rex64
+; CHECK-NEXT:     callq   __imp_a
+; CHECK-NEXT:     nopl    8(%rax,%rax)
+; CHECK-NEXT:     movq    %rsi, %rax
+; CHECK-NEXT:   .Limpcall2:
+; CHECK-NEXT:     callq   *%rax
+; CHECK-NEXT:     nopl    (%rax)
+; CHECK-NEXT:     nop
+
+define dso_local void @tail_call() local_unnamed_addr section "tc_sect" {
+entry:
+  tail call void @b()
+  ret void
+}
+; CHECK-LABEL:  tail_call:
+; CHECK:        .Limpcall3:
+; CHECK-NEXT:     jmp __imp_b
+
+define dso_local void @tail_call_fp(ptr noundef readonly %func_ptr) local_unnamed_addr section "tc_sect" {
+entry:
+  tail call void %func_ptr()
+  ret void
+}
+; CHECK-LABEL:  tail_call_fp:
+; CHECK:          movq    %rcx, %rax
+; CHECK-NEXT:   .Limpcall4:
+; CHECK-NEXT:     rex64 jmpq      *%rax
+
+declare dllimport void @a() local_unnamed_addr
+declare dllimport void @b() local_unnamed_addr
+
+; CHECK-LABEL  .section   .retplne,"yi"
+; CHECK-NEXT   .asciz  "RetpolineV1"
+; CHECK-NEXT   .long   24
+; CHECK-NEXT   .secnum tc_sect
+; CHECK-NEXT   .long   3
+; CHECK-NEXT   .secoffset      .Limpcall3
+; CHECK-NEXT   .long   5
+; CHECK-NEXT   .secoffset      .Limpcall4
+; CHECK-NEXT   .long   32
+; CHECK-NEXT   .secnum nc_sect
+; CHECK-NEXT   .long   3
+; CHECK-NEXT   .secoffset      .Limpcall0
+; CHECK-NEXT   .long   3
+; CHECK-NEXT   .secoffset      .Limpcall1
+; CHECK-NEXT   .long   5
+; CHECK-NEXT   .secoffset      .Limpcall2
+
+!llvm.module.flags = !{!0}
+!0 = !{i32 1, !"import-call-optimization", i32 1}
diff --git a/llvm/test/MC/X86/win-import-call-optimization.s b/llvm/test/MC/X86/win-import-call-optimization.s
new file mode 100644
index 0000000000000..4f839a2bc6011
--- /dev/null
+++ b/llvm/test/MC/X86/win-import-call-optimization.s
@@ -0,0 +1,69 @@
+// RUN: llvm-mc -triple x86_64-windows-msvc -filetype obj -o %t.obj %s
+// RUN: llvm-readobj --sections --sd --relocs %t.obj | FileCheck %s
+
+.section        nc_sect,"xr"
+normal_call:
+.seh_proc normal_call
+# %bb.0:                                # %entry
+        subq    $40, %rsp
+        .seh_stackalloc 40
+        .seh_endprologue
+.Limpcall0:
+        rex64
+        callq   *__imp_a(%rip)
+        nopl    8(%rax,%rax)
+        nop
+        addq    $40, %rsp
+        retq
+        .seh_endproc
+
+.section        tc_sect,"xr"
+tail_call:
+.Limpcall1:
+        rex64
+        jmp     *__imp_b(%rip)
+
+.section        .retplne,"yi"
+.asciz  "RetpolineV1"
+.long   16
+.secnum tc_sect
+.long   2
+.secoffset .Limpcall1
+.long   16
+.secnum nc_sect
+.long   3
+.secoffset .Limpcall0
+
+// CHECK-LABEL: Name: .retplne (2E 72 65 74 70 6C 6E 65)
+// CHECK-NEXT:  VirtualSize: 0x0
+// CHECK-NEXT:  VirtualAddress: 0x0
+// CHECK-NEXT:  RawDataSize: 44
+// CHECK-NEXT:  PointerToRawData:
+// CHECK-NEXT:  PointerToRelocations:
+// CHECK-NEXT:  PointerToLineNumbers:
+// CHECK-NEXT:  RelocationCount: 0
+// CHECK-NEXT:  LineNumberCount: 0
+// CHECK-NEXT:  Characteristics [
+// CHECK-NEXT:    IMAGE_SCN_ALIGN_1BYTES
+// CHECK-NEXT:    IMAGE_SCN_LNK_INFO
+// CHECK-NEXT:  ]
+// CHECK-NEXT:  SectionData (
+// CHECK-NEXT:    52657470 6F6C696E 65563100 10000000  |RetpolineV1.....|
+// CHECK-NEXT:    0010:
+// CHECK-SAME:    [[#%.2X,TCSECT:]]000000
+// CHECK-SAME:    02000000
+// CHECK-SAME:    [[#%.2X,TCOFFSET:]]000000
+// CHECK-SAME:    10000000
+// CHECK-NEXT:    0020:
+// CHECK-SAME:    [[#%.2X,NCSECT:]]000000
+// CHECK-SAME:    03000000
+// CHECK-SAME:    [[#%.2X,NCOFFSET:]]000000
+// CHECK-NEXT:  )
+
+// CHECK-LABEL: Relocations [
+// CHECK-NEXT:     Section ([[#%u,NCSECT]]) nc_sect {
+// CHECK-NEXT:       0x[[#%x,NCOFFSET + 3]] IMAGE_REL_AMD64_REL32 __imp_a
+// CHECK-NEXT:     }
+// CHECK-NEXT:     Section ([[#%u,TCSECT]]) tc_sect {
+// CHECK-NEXT:       0x[[#%x,TCOFFSET + 3]] IMAGE_REL_AMD64_REL32 __imp_b
+// CHECK-NEXT:     }



More information about the llvm-commits mailing list