[llvm] [BOLT] Add support for Linux kernel static keys jump table (PR #86090)

Maksim Panchenko via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 20 23:02:02 PDT 2024


https://github.com/maksfb created https://github.com/llvm/llvm-project/pull/86090

None

>From b61ae32db1a1c182ff80c19f3bcad1c93b2f90e8 Mon Sep 17 00:00:00 2001
From: Maksim Panchenko <maks at fb.com>
Date: Wed, 28 Feb 2024 12:57:29 -0800
Subject: [PATCH] [BOLT] Add support for Linux kernel static keys jump table

---
 bolt/include/bolt/Core/MCPlus.h          |   1 +
 bolt/include/bolt/Core/MCPlusBuilder.h   |  25 ++
 bolt/lib/Core/BinaryContext.cpp          |   8 +-
 bolt/lib/Core/BinaryFunction.cpp         |  17 +-
 bolt/lib/Core/MCPlusBuilder.cpp          |  22 ++
 bolt/lib/Passes/BinaryPasses.cpp         |  20 +-
 bolt/lib/Rewrite/LinuxKernelRewriter.cpp | 385 +++++++++++++++++++++++
 bolt/lib/Target/X86/X86MCPlusBuilder.cpp |  13 +
 bolt/test/X86/linux-static-keys.s        |  67 ++++
 9 files changed, 555 insertions(+), 3 deletions(-)
 create mode 100644 bolt/test/X86/linux-static-keys.s

diff --git a/bolt/include/bolt/Core/MCPlus.h b/bolt/include/bolt/Core/MCPlus.h
index b6a9e73f2347e7..1d2360c180335f 100644
--- a/bolt/include/bolt/Core/MCPlus.h
+++ b/bolt/include/bolt/Core/MCPlus.h
@@ -73,6 +73,7 @@ class MCAnnotation {
     kOffset,              /// Offset in the function.
     kLabel,               /// MCSymbol pointing to this instruction.
     kSize,                /// Size of the instruction.
+    kDynamicBranch,       /// Jit instruction patched at runtime.
     kGeneric              /// First generic annotation.
   };
 
diff --git a/bolt/include/bolt/Core/MCPlusBuilder.h b/bolt/include/bolt/Core/MCPlusBuilder.h
index 96b58f54162344..0b6b3dcd769262 100644
--- a/bolt/include/bolt/Core/MCPlusBuilder.h
+++ b/bolt/include/bolt/Core/MCPlusBuilder.h
@@ -1199,6 +1199,16 @@ class MCPlusBuilder {
   /// Set instruction size.
   void setSize(MCInst &Inst, uint32_t Size) const;
 
+  /// Check if the branch instruction could be modified at runtime.
+  bool isDynamicBranch(const MCInst &Inst) const;
+
+  /// Return ID for runtime-modifiable instruction.
+  std::optional<uint32_t> getDynamicBranchID(const MCInst &Inst) const;
+
+  /// Mark instruction as a dynamic branch, i.e. a branch that can be
+  /// overwritten at runtime.
+  void setDynamicBranch(MCInst &Inst, uint32_t ID) const;
+
   /// Return MCSymbol that represents a target of this instruction at a given
   /// operand number \p OpNum. If there's no symbol associated with
   /// the operand - return nullptr.
@@ -1208,6 +1218,14 @@ class MCPlusBuilder {
     return nullptr;
   }
 
+  /// Return MCSymbol that represents a target of this instruction at a given
+  /// operand number \p OpNum. If there's no symbol associated with
+  /// the operand - return nullptr.
+  virtual MCSymbol *getTargetSymbol(MCInst &Inst, unsigned OpNum = 0) const {
+    return const_cast<MCSymbol *>(
+        getTargetSymbol(const_cast<const MCInst &>(Inst), OpNum));
+  }
+
   /// Return MCSymbol extracted from a target expression
   virtual const MCSymbol *getTargetSymbol(const MCExpr *Expr) const {
     return &cast<const MCSymbolRefExpr>(Expr)->getSymbol();
@@ -1688,6 +1706,13 @@ class MCPlusBuilder {
     llvm_unreachable("not implemented");
   }
 
+  /// Create long conditional branch with a target-specific conditional code
+  /// \p CC.
+  virtual void createLongCondBranch(MCInst &Inst, const MCSymbol *Target,
+                                    unsigned CC, MCContext *Ctx) const {
+    llvm_unreachable("not implemented");
+  }
+
   /// Reverses the branch condition in Inst and update its taken target to TBB.
   ///
   /// Returns true on success.
diff --git a/bolt/lib/Core/BinaryContext.cpp b/bolt/lib/Core/BinaryContext.cpp
index b29ebbbfa18c4b..267f43f65e206e 100644
--- a/bolt/lib/Core/BinaryContext.cpp
+++ b/bolt/lib/Core/BinaryContext.cpp
@@ -1939,7 +1939,13 @@ void BinaryContext::printInstruction(raw_ostream &OS, const MCInst &Instruction,
     OS << Endl;
     return;
   }
-  InstPrinter->printInst(&Instruction, 0, "", *STI, OS);
+  if (std::optional<uint32_t> DynamicID =
+          MIB->getDynamicBranchID(Instruction)) {
+    OS << "\tjit\t" << MIB->getTargetSymbol(Instruction)->getName()
+       << " # ID: " << DynamicID;
+  } else {
+    InstPrinter->printInst(&Instruction, 0, "", *STI, OS);
+  }
   if (MIB->isCall(Instruction)) {
     if (MIB->isTailCall(Instruction))
       OS << " # TAILCALL ";
diff --git a/bolt/lib/Core/BinaryFunction.cpp b/bolt/lib/Core/BinaryFunction.cpp
index ce4dd29f542b0d..fdadef9dcd3848 100644
--- a/bolt/lib/Core/BinaryFunction.cpp
+++ b/bolt/lib/Core/BinaryFunction.cpp
@@ -3350,6 +3350,16 @@ void BinaryFunction::fixBranches() {
 
       // Eliminate unnecessary conditional branch.
       if (TSuccessor == FSuccessor) {
+        // FIXME: at the moment, we cannot safely remove static key branches.
+        if (MIB->isDynamicBranch(*CondBranch)) {
+          if (opts::Verbosity) {
+            BC.outs()
+                << "BOLT-INFO: unable to remove redundant dynamic branch in "
+                << *this << '\n';
+          }
+          continue;
+        }
+
         BB->removeDuplicateConditionalSuccessor(CondBranch);
         if (TSuccessor != NextBB)
           BB->addBranchInstruction(TSuccessor);
@@ -3358,8 +3368,13 @@ void BinaryFunction::fixBranches() {
 
       // Reverse branch condition and swap successors.
       auto swapSuccessors = [&]() {
-        if (MIB->isUnsupportedBranch(*CondBranch))
+        if (MIB->isUnsupportedBranch(*CondBranch)) {
+          if (opts::Verbosity) {
+            BC.outs() << "BOLT-INFO: unable to swap successors in " << *this
+                      << '\n';
+          }
           return false;
+        }
         std::swap(TSuccessor, FSuccessor);
         BB->swapConditionalSuccessors();
         auto L = BC.scopeLock();
diff --git a/bolt/lib/Core/MCPlusBuilder.cpp b/bolt/lib/Core/MCPlusBuilder.cpp
index bd9bd0c45922a5..5b14ad5cdb880f 100644
--- a/bolt/lib/Core/MCPlusBuilder.cpp
+++ b/bolt/lib/Core/MCPlusBuilder.cpp
@@ -303,6 +303,28 @@ void MCPlusBuilder::setSize(MCInst &Inst, uint32_t Size) const {
   setAnnotationOpValue(Inst, MCAnnotation::kSize, Size);
 }
 
+bool MCPlusBuilder::isDynamicBranch(const MCInst &Inst) const {
+  if (!hasAnnotation(Inst, MCAnnotation::kDynamicBranch))
+    return false;
+  assert(isBranch(Inst) && "Branch expected.");
+  return true;
+}
+
+std::optional<uint32_t>
+MCPlusBuilder::getDynamicBranchID(const MCInst &Inst) const {
+  if (std::optional<int64_t> Value =
+          getAnnotationOpValue(Inst, MCAnnotation::kDynamicBranch)) {
+    assert(isBranch(Inst) && "Branch expected.");
+    return static_cast<uint32_t>(*Value);
+  }
+  return std::nullopt;
+}
+
+void MCPlusBuilder::setDynamicBranch(MCInst &Inst, uint32_t ID) const {
+  assert(isBranch(Inst) && "Branch expected.");
+  setAnnotationOpValue(Inst, MCAnnotation::kDynamicBranch, ID);
+}
+
 bool MCPlusBuilder::hasAnnotation(const MCInst &Inst, unsigned Index) const {
   return (bool)getAnnotationOpValue(Inst, Index);
 }
diff --git a/bolt/lib/Passes/BinaryPasses.cpp b/bolt/lib/Passes/BinaryPasses.cpp
index bf1c2ddd37dd24..c0ba73108f5778 100644
--- a/bolt/lib/Passes/BinaryPasses.cpp
+++ b/bolt/lib/Passes/BinaryPasses.cpp
@@ -107,6 +107,12 @@ static cl::opt<unsigned>
                   cl::desc("print statistics about basic block ordering"),
                   cl::init(0), cl::cat(BoltOptCategory));
 
+static cl::opt<bool> PrintLargeFunctions(
+    "print-large-functions",
+    cl::desc("print functions that could not be overwritten due to excessive "
+             "size"),
+    cl::init(false), cl::cat(BoltOptCategory));
+
 static cl::list<bolt::DynoStats::Category>
     PrintSortedBy("print-sorted-by", cl::CommaSeparated,
                   cl::desc("print functions sorted by order of dyno stats"),
@@ -570,8 +576,12 @@ Error CheckLargeFunctions::runOnFunctions(BinaryContext &BC) {
     uint64_t HotSize, ColdSize;
     std::tie(HotSize, ColdSize) =
         BC.calculateEmittedSize(BF, /*FixBranches=*/false);
-    if (HotSize > BF.getMaxSize())
+    if (HotSize > BF.getMaxSize()) {
+      if (opts::PrintLargeFunctions)
+        BC.outs() << "BOLT-INFO: " << BF << " size exceeds allocated space by "
+                  << (HotSize - BF.getMaxSize()) << " bytes\n";
       BF.setSimple(false);
+    }
   };
 
   ParallelUtilities::PredicateTy SkipFunc = [&](const BinaryFunction &BF) {
@@ -852,6 +862,10 @@ uint64_t SimplifyConditionalTailCalls::fixTailCalls(BinaryFunction &BF) {
       assert(Result && "internal error analyzing conditional branch");
       assert(CondBranch && "conditional branch expected");
 
+      // Skip dynamic branches for now.
+      if (BF.getBinaryContext().MIB->isDynamicBranch(*CondBranch))
+        continue;
+
       // It's possible that PredBB is also a successor to BB that may have
       // been processed by a previous iteration of the SCTC loop, in which
       // case it may have been marked invalid.  We should skip rewriting in
@@ -1012,6 +1026,10 @@ uint64_t ShortenInstructions::shortenInstructions(BinaryFunction &Function) {
   const BinaryContext &BC = Function.getBinaryContext();
   for (BinaryBasicBlock &BB : Function) {
     for (MCInst &Inst : BB) {
+      // Skip shortening instructions with Size annotation.
+      if (BC.MIB->getSize(Inst))
+        continue;
+
       MCInst OriginalInst;
       if (opts::Verbosity > 2)
         OriginalInst = Inst;
diff --git a/bolt/lib/Rewrite/LinuxKernelRewriter.cpp b/bolt/lib/Rewrite/LinuxKernelRewriter.cpp
index a2bfd45a64e304..faa71a97d909ee 100644
--- a/bolt/lib/Rewrite/LinuxKernelRewriter.cpp
+++ b/bolt/lib/Rewrite/LinuxKernelRewriter.cpp
@@ -14,7 +14,9 @@
 #include "bolt/Rewrite/MetadataRewriter.h"
 #include "bolt/Rewrite/MetadataRewriters.h"
 #include "bolt/Utils/CommandLineOpts.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/MC/MCDisassembler/MCDisassembler.h"
 #include "llvm/Support/BinaryStreamWriter.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -65,6 +67,16 @@ static cl::opt<bool> DumpStaticCalls("dump-static-calls",
                                      cl::init(false), cl::Hidden,
                                      cl::cat(BoltCategory));
 
+static cl::opt<bool>
+    DumpStaticKeys("dump-static-keys",
+                   cl::desc("dump Linux kernel static keys jump table"),
+                   cl::init(false), cl::Hidden, cl::cat(BoltCategory));
+
+static cl::opt<bool> LongJumpLabels(
+    "long-jump-labels",
+    cl::desc("always use long jumps/nops for Linux kernel static keys"),
+    cl::init(false), cl::Hidden, cl::cat(BoltCategory));
+
 static cl::opt<bool>
     PrintORC("print-orc",
              cl::desc("print ORC unwind information for instructions"),
@@ -151,6 +163,20 @@ class LinuxKernelRewriter final : public MetadataRewriter {
   /// Number of entries in the input file ORC sections.
   uint64_t NumORCEntries = 0;
 
+  /// Section containing static keys jump table.
+  ErrorOr<BinarySection &> StaticKeysJumpSection = std::errc::bad_address;
+  uint64_t StaticKeysJumpTableAddress = 0;
+  static constexpr size_t STATIC_KEYS_JUMP_ENTRY_SIZE = 8;
+
+  struct JumpInfoEntry {
+    bool Likely;
+    bool InitValue;
+  };
+  SmallVector<JumpInfoEntry, 16> JumpInfo;
+
+  /// Static key entries that need nop conversion.
+  DenseSet<uint32_t> NopIDs;
+
   /// Section containing static call table.
   ErrorOr<BinarySection &> StaticCallSection = std::errc::bad_address;
   uint64_t StaticCallTableAddress = 0;
@@ -235,6 +261,11 @@ class LinuxKernelRewriter final : public MetadataRewriter {
   /// Read .pci_fixup
   Error readPCIFixupTable();
 
+  /// Handle static keys jump table.
+  Error readStaticKeysJumpTable();
+  Error rewriteStaticKeysJumpTable();
+  Error updateStaticKeysJumpTablePostEmit();
+
   /// Mark instructions referenced by kernel metadata.
   Error markInstructions();
 
@@ -268,6 +299,9 @@ class LinuxKernelRewriter final : public MetadataRewriter {
     if (Error E = readPCIFixupTable())
       return E;
 
+    if (Error E = readStaticKeysJumpTable())
+      return E;
+
     return Error::success();
   }
 
@@ -290,12 +324,18 @@ class LinuxKernelRewriter final : public MetadataRewriter {
     if (Error E = rewriteStaticCalls())
       return E;
 
+    if (Error E = rewriteStaticKeysJumpTable())
+      return E;
+
     return Error::success();
   }
 
   Error postEmitFinalizer() override {
     updateLKMarkers();
 
+    if (Error E = updateStaticKeysJumpTablePostEmit())
+      return E;
+
     return Error::success();
   }
 };
@@ -1343,6 +1383,351 @@ Error LinuxKernelRewriter::readPCIFixupTable() {
   return Error::success();
 }
 
+/// Runtime code modification used by static keys is the most ubiquitous
+/// self-modifying feature of the Linux kernel. The idea is to to eliminate
+/// the condition check and associated conditional jump on a hot path if that
+/// condition (based on a boolean value of a static key) does not change often.
+/// Whenever they condition changes, the kernel runtime modifies all code
+/// paths associated with that key flipping the code between nop and
+/// (unconditional) jump. The information about the code is stored in a static
+/// key jump table and contains the list of entries of the following type from
+/// include/linux/jump_label.h:
+//
+///   struct jump_entry {
+///     s32 code;
+///     s32 target;
+///     long key; // key may be far away from the core kernel under KASLR
+///   };
+///
+/// The list does not have to be stored in any sorted way, but it is sorted at
+/// boot time (or module initialization time) first by "key" and then by "code".
+/// jump_label_sort_entries() is responsible for sorting the table.
+///
+/// The key in jump_entry structure uses lower two bits of the key address
+/// (which itself is aligned) to store extra information. We are interested in
+/// the lower bit which indicates if the key is likely to be set on the code
+/// path associated with this jump_entry.
+///
+/// static_key_{enable,disable}() functions modify the code based on key and
+/// jump table entries.
+///
+/// jump_label_update() updates all code entries for a given key. Batch mode is
+/// used for x86.
+///
+/// The actual patching happens in text_poke_bp_batch() that overrides the first
+/// byte of the sequence with int3 before proceeding with actual code
+/// replacement.
+Error LinuxKernelRewriter::readStaticKeysJumpTable() {
+  const BinaryData *StaticKeysJumpTable =
+      BC.getBinaryDataByName("__start___jump_table");
+  if (!StaticKeysJumpTable)
+    return Error::success();
+
+  StaticKeysJumpTableAddress = StaticKeysJumpTable->getAddress();
+
+  const BinaryData *Stop = BC.getBinaryDataByName("__stop___jump_table");
+  if (!Stop)
+    return createStringError(errc::executable_format_error,
+                             "missing __stop___jump_table symbol");
+
+  ErrorOr<BinarySection &> ErrorOrSection =
+      BC.getSectionForAddress(StaticKeysJumpTableAddress);
+  if (!ErrorOrSection)
+    return createStringError(errc::executable_format_error,
+                             "no section matching __start___jump_table");
+
+  StaticKeysJumpSection = *ErrorOrSection;
+  if (!StaticKeysJumpSection->containsAddress(Stop->getAddress() - 1))
+    return createStringError(errc::executable_format_error,
+                             "__stop___jump_table not in the same section "
+                             "as __start___jump_table");
+
+  if ((Stop->getAddress() - StaticKeysJumpTableAddress) %
+      STATIC_KEYS_JUMP_ENTRY_SIZE)
+    return createStringError(errc::executable_format_error,
+                             "static keys jump table size error");
+
+  const uint64_t SectionAddress = StaticKeysJumpSection->getAddress();
+  DataExtractor DE(StaticKeysJumpSection->getContents(),
+                   BC.AsmInfo->isLittleEndian(),
+                   BC.AsmInfo->getCodePointerSize());
+  DataExtractor::Cursor Cursor(StaticKeysJumpTableAddress - SectionAddress);
+  uint32_t EntryID = 0;
+  while (Cursor && Cursor.tell() < Stop->getAddress() - SectionAddress) {
+    const uint64_t JumpAddress =
+        SectionAddress + Cursor.tell() + (int32_t)DE.getU32(Cursor);
+    const uint64_t TargetAddress =
+        SectionAddress + Cursor.tell() + (int32_t)DE.getU32(Cursor);
+    const uint64_t KeyAddress =
+        SectionAddress + Cursor.tell() + (int64_t)DE.getU64(Cursor);
+
+    // Consume the status of the cursor.
+    if (!Cursor)
+      return createStringError(
+          errc::executable_format_error,
+          "out of bounds while reading static keys jump table: %s",
+          toString(Cursor.takeError()).c_str());
+
+    ++EntryID;
+
+    JumpInfo.push_back(JumpInfoEntry());
+    JumpInfoEntry &Info = JumpInfo.back();
+    Info.Likely = KeyAddress & 1;
+
+    if (opts::DumpStaticKeys) {
+      BC.outs() << "Static key jump entry: " << EntryID
+                << "\n\tJumpAddress:   0x" << Twine::utohexstr(JumpAddress)
+                << "\n\tTargetAddress: 0x" << Twine::utohexstr(TargetAddress)
+                << "\n\tKeyAddress:    0x" << Twine::utohexstr(KeyAddress)
+                << "\n\tIsLikely:      " << Info.Likely << '\n';
+    }
+
+    BinaryFunction *BF = BC.getBinaryFunctionContainingAddress(JumpAddress);
+    if (!BF && opts::Verbosity) {
+      BC.outs()
+          << "BOLT-INFO: no function matches address 0x"
+          << Twine::utohexstr(JumpAddress)
+          << " of jump instruction referenced from static keys jump table\n";
+    }
+
+    if (!BF || !BC.shouldEmit(*BF))
+      continue;
+
+    MCInst *Inst = BF->getInstructionAtOffset(JumpAddress - BF->getAddress());
+    if (!Inst)
+      return createStringError(
+          errc::executable_format_error,
+          "no instruction at static keys jump site address 0x%" PRIx64,
+          JumpAddress);
+
+    if (!BF->containsAddress(TargetAddress))
+      return createStringError(
+          errc::executable_format_error,
+          "invalid target of static keys jump at 0x%" PRIx64 " : 0x%" PRIx64,
+          JumpAddress, TargetAddress);
+
+    const bool IsBranch = BC.MIB->isBranch(*Inst);
+    if (!IsBranch && !BC.MIB->isNoop(*Inst))
+      return createStringError(errc::executable_format_error,
+                               "jump or nop expected at address 0x%" PRIx64,
+                               JumpAddress);
+
+    const uint64_t Size = BC.computeInstructionSize(*Inst);
+    if (Size != 2 && Size != 5) {
+      return createStringError(
+          errc::executable_format_error,
+          "unexpected static keys jump size at address 0x%" PRIx64,
+          JumpAddress);
+    }
+
+    MCSymbol *Target = BF->registerBranch(JumpAddress, TargetAddress);
+    MCInst StaticKeyBranch;
+
+    // Create a conditional branch instruction. The actual conditional code type
+    // should not matter as long as it's a valid code. The instruction should be
+    // treated as a conditional branch for control-flow purposes. Before we emit
+    // the code, it will be converted to a different instruction in
+    // rewriteStaticKeysJumpTable().
+    //
+    // NB: for older kernels, under LongJumpLabels option, we create long
+    //     conditional branch to guarantee that code size estimation takes
+    //     into account the extra bytes needed for long branch that will be used
+    //     by the kernel patching code. Newer kernels can work with both short
+    //     and long branches. The code for long conditional branch is larger
+    //     than unconditional one, so we are pessimistic in our estimations.
+    if (opts::LongJumpLabels)
+      BC.MIB->createLongCondBranch(StaticKeyBranch, Target, 0, BC.Ctx.get());
+    else
+      BC.MIB->createCondBranch(StaticKeyBranch, Target, 0, BC.Ctx.get());
+    BC.MIB->moveAnnotations(std::move(*Inst), StaticKeyBranch);
+    BC.MIB->setDynamicBranch(StaticKeyBranch, EntryID);
+    *Inst = StaticKeyBranch;
+
+    // IsBranch = InitialValue ^ LIKELY
+    //
+    //    0 0 0
+    //    1 0 1
+    //    1 1 0
+    //    0 1 1
+    //
+    // => InitialValue = IsBranch ^ LIKELY
+    Info.InitValue = IsBranch ^ Info.Likely;
+
+    // Add annotations to facilitate manual code analysis.
+    BC.MIB->addAnnotation(*Inst, "Likely", Info.Likely);
+    BC.MIB->addAnnotation(*Inst, "InitValue", Info.InitValue);
+    if (!BC.MIB->getSize(*Inst))
+      BC.MIB->setSize(*Inst, Size);
+
+    if (opts::LongJumpLabels)
+      BC.MIB->setSize(*Inst, 5);
+  }
+
+  BC.outs() << "BOLT-INFO: parsed " << EntryID << " static keys jump entries\n";
+
+  return Error::success();
+}
+
+// Pre-emit pass. Convert dynamic branch instructions into jumps that could be
+// relaxed. In post-emit pass we will convert those jumps into nops when
+// necessary. We do the unconditional conversion into jumps so that the jumps
+// can be relaxed and the optimal size of jump/nop instruction is selected.
+Error LinuxKernelRewriter::rewriteStaticKeysJumpTable() {
+  if (!StaticKeysJumpSection)
+    return Error::success();
+
+  uint64_t NumShort = 0;
+  uint64_t NumLong = 0;
+  for (BinaryFunction &BF : llvm::make_second_range(BC.getBinaryFunctions())) {
+    if (!BC.shouldEmit(BF))
+      continue;
+
+    for (BinaryBasicBlock &BB : BF) {
+      for (MCInst &Inst : BB) {
+        if (!BC.MIB->isDynamicBranch(Inst))
+          continue;
+
+        const uint32_t EntryID = *BC.MIB->getDynamicBranchID(Inst);
+        MCSymbol *Target =
+            const_cast<MCSymbol *>(BC.MIB->getTargetSymbol(Inst));
+        assert(Target && "Target symbol should be set.");
+
+        const JumpInfoEntry &Info = JumpInfo[EntryID - 1];
+        const bool IsBranch = Info.Likely ^ Info.InitValue;
+
+        uint32_t Size = *BC.MIB->getSize(Inst);
+        if (Size == 2)
+          ++NumShort;
+        else if (Size == 5)
+          ++NumLong;
+        else
+          llvm_unreachable("Wrong size for static keys jump instruction.");
+
+        MCInst NewInst;
+        // Replace the instruction with unconditional jump even if it needs to
+        // be nop in the binary.
+        if (opts::LongJumpLabels) {
+          BC.MIB->createLongUncondBranch(NewInst, Target, BC.Ctx.get());
+        } else {
+          // Newer kernels can handle short and long jumps for static keys.
+          // Optimistically, emit short jump and check if it gets relaxed into
+          // a long one during post-emit. Only then convert the jump to a nop.
+          BC.MIB->createUncondBranch(NewInst, Target, BC.Ctx.get());
+        }
+
+        BC.MIB->moveAnnotations(std::move(Inst), NewInst);
+        Inst = NewInst;
+
+        // Mark the instruction for nop conversion.
+        if (!IsBranch)
+          NopIDs.insert(EntryID);
+
+        MCSymbol *Label =
+            BC.MIB->getOrCreateInstLabel(Inst, "__SK_", BC.Ctx.get());
+
+        // Create a relocation against the label.
+        const uint64_t EntryOffset = StaticKeysJumpTableAddress -
+                                     StaticKeysJumpSection->getAddress() +
+                                     (EntryID - 1) * 16;
+        StaticKeysJumpSection->addRelocation(EntryOffset, Label,
+                                             ELF::R_X86_64_PC32,
+                                             /*Addend*/ 0);
+        StaticKeysJumpSection->addRelocation(EntryOffset + 4, Target,
+                                             ELF::R_X86_64_PC32, /*Addend*/ 0);
+      }
+    }
+  }
+
+  BC.outs() << "BOLT-INFO: the input contains " << NumShort << " short and "
+            << NumLong << " long static keys jumps in optimized functions\n";
+
+  return Error::success();
+}
+
+// Post-emit pass of static keys jump section. Convert jumps to nops.
+Error LinuxKernelRewriter::updateStaticKeysJumpTablePostEmit() {
+  if (!StaticKeysJumpSection || !StaticKeysJumpSection->isFinalized())
+    return Error::success();
+
+  const uint64_t SectionAddress = StaticKeysJumpSection->getAddress();
+  DataExtractor DE(StaticKeysJumpSection->getOutputContents(),
+                   BC.AsmInfo->isLittleEndian(),
+                   BC.AsmInfo->getCodePointerSize());
+  DataExtractor::Cursor Cursor(StaticKeysJumpTableAddress - SectionAddress);
+  const BinaryData *Stop = BC.getBinaryDataByName("__stop___jump_table");
+  uint32_t EntryID = 0;
+  uint64_t NumShort = 0;
+  uint64_t NumLong = 0;
+  while (Cursor && Cursor.tell() < Stop->getAddress() - SectionAddress) {
+    const uint64_t JumpAddress =
+        SectionAddress + Cursor.tell() + (int32_t)DE.getU32(Cursor);
+    const uint64_t TargetAddress =
+        SectionAddress + Cursor.tell() + (int32_t)DE.getU32(Cursor);
+    const uint64_t KeyAddress =
+        SectionAddress + Cursor.tell() + (int64_t)DE.getU64(Cursor);
+
+    // Consume the status of the cursor.
+    if (!Cursor)
+      return createStringError(errc::executable_format_error,
+                               "out of bounds while updating static keys: %s",
+                               toString(Cursor.takeError()).c_str());
+
+    ++EntryID;
+
+    LLVM_DEBUG({
+      dbgs() << "\n\tJumpAddress:   0x" << Twine::utohexstr(JumpAddress)
+             << "\n\tTargetAddress: 0x" << Twine::utohexstr(TargetAddress)
+             << "\n\tKeyAddress:    0x" << Twine::utohexstr(KeyAddress) << '\n';
+    });
+
+    BinaryFunction *BF =
+        BC.getBinaryFunctionContainingAddress(JumpAddress,
+                                              /*CheckPastEnd*/ false,
+                                              /*UseMaxSize*/ true);
+    assert(BF && "Cannot get function for modified static key.");
+
+    if (!BF->isEmitted())
+      continue;
+
+    // Disassemble instruction to collect stats even if nop-conversion is
+    // unnecessary.
+    MutableArrayRef<uint8_t> Contents = MutableArrayRef<uint8_t>(
+        reinterpret_cast<uint8_t *>(BF->getImageAddress()), BF->getImageSize());
+    assert(Contents.size() && "Non-empty function image expected.");
+
+    MCInst Inst;
+    uint64_t Size;
+    const uint64_t JumpOffset = JumpAddress - BF->getAddress();
+    if (!BC.DisAsm->getInstruction(Inst, Size, Contents.slice(JumpOffset), 0,
+                                   nulls())) {
+      llvm_unreachable("Unable to disassemble jump instruction.");
+    }
+    assert(BC.MIB->isBranch(Inst) && "Branch instruction expected.");
+
+    if (Size == 2)
+      ++NumShort;
+    else if (Size == 5)
+      ++NumLong;
+    else
+      llvm_unreachable("Unexpected size for static keys jump instruction.");
+
+    // Check if we need to convert jump instruction into a nop.
+    if (!NopIDs.contains(EntryID))
+      continue;
+
+    SmallString<15> NopCode;
+    raw_svector_ostream VecOS(NopCode);
+    BC.MAB->writeNopData(VecOS, Size, BC.STI.get());
+    for (uint64_t I = 0; I < Size; ++I)
+      Contents[JumpOffset + I] = NopCode[I];
+  }
+
+  BC.outs() << "BOLT-INFO: written " << NumShort << " short and " << NumLong
+            << " long static keys jumps in optimized functions\n";
+
+  return Error::success();
+}
+
 } // namespace
 
 std::unique_ptr<MetadataRewriter>
diff --git a/bolt/lib/Target/X86/X86MCPlusBuilder.cpp b/bolt/lib/Target/X86/X86MCPlusBuilder.cpp
index de55fbe51764dd..15f95f82177765 100644
--- a/bolt/lib/Target/X86/X86MCPlusBuilder.cpp
+++ b/bolt/lib/Target/X86/X86MCPlusBuilder.cpp
@@ -336,6 +336,9 @@ class X86MCPlusBuilder : public MCPlusBuilder {
   }
 
   bool isUnsupportedBranch(const MCInst &Inst) const override {
+    if (isDynamicBranch(Inst))
+      return true;
+
     switch (Inst.getOpcode()) {
     default:
       return false;
@@ -2728,6 +2731,7 @@ class X86MCPlusBuilder : public MCPlusBuilder {
 
   void createUncondBranch(MCInst &Inst, const MCSymbol *TBB,
                           MCContext *Ctx) const override {
+    Inst.clear();
     Inst.setOpcode(X86::JMP_1);
     Inst.clear();
     Inst.addOperand(MCOperand::createExpr(
@@ -2776,6 +2780,15 @@ class X86MCPlusBuilder : public MCPlusBuilder {
     Inst.addOperand(MCOperand::createImm(CC));
   }
 
+  void createLongCondBranch(MCInst &Inst, const MCSymbol *Target, unsigned CC,
+                            MCContext *Ctx) const override {
+    Inst.setOpcode(X86::JCC_4);
+    Inst.clear();
+    Inst.addOperand(MCOperand::createExpr(
+        MCSymbolRefExpr::create(Target, MCSymbolRefExpr::VK_None, *Ctx)));
+    Inst.addOperand(MCOperand::createImm(CC));
+  }
+
   bool reverseBranchCondition(MCInst &Inst, const MCSymbol *TBB,
                               MCContext *Ctx) const override {
     unsigned InvCC = getInvertedCondCode(getCondCode(Inst));
diff --git a/bolt/test/X86/linux-static-keys.s b/bolt/test/X86/linux-static-keys.s
new file mode 100644
index 00000000000000..08454bf9763193
--- /dev/null
+++ b/bolt/test/X86/linux-static-keys.s
@@ -0,0 +1,67 @@
+# REQUIRES: system-linux
+
+## Check that BOLT correctly updates the Linux kernel static keys jump table.
+
+# RUN: llvm-mc -filetype=obj -triple x86_64-unknown-unknown %s -o %t.o
+# RUN: %clang %cflags -nostdlib %t.o -o %t.exe \
+# RUN:   -Wl,--image-base=0xffffffff80000000,--no-dynamic-linker,--no-eh-frame-hdr
+
+## Verify static keys jump bindings to instructions.
+
+# RUN: llvm-bolt %t.exe --print-normalized -o %t.out --keep-nops=0 \
+# RUN:   --bolt-info=0 |& FileCheck %s
+
+## Verify the bindings again on the rewritten binary with nops removed.
+
+# RUN: llvm-bolt %t.out -o %t.out.1 --print-normalized |& FileCheck %s
+
+# CHECK:      BOLT-INFO: Linux kernel binary detected
+# CHECK:      BOLT-INFO: parsed 2 static keys jump entries
+
+  .text
+  .globl _start
+  .type _start, %function
+_start:
+# CHECK: Binary Function "_start"
+  nop
+.L0:
+  jmp .L1
+# CHECK:      jit
+# CHECK-SAME: # ID: 1 {{.*}} # Likely: 0 # InitValue: 1
+  nop
+.L1:
+  .nops 5
+# CHECK:      jit
+# CHECK-SAME: # ID: 2 {{.*}} # Likely: 1 # InitValue: 1
+.L2:
+  nop
+  .size _start, .-_start
+
+  .globl foo
+  .type foo, %function
+foo:
+  ret
+  .size foo, .-foo
+
+
+## Static keys jump table.
+  .rodata
+  .globl __start___jump_table
+  .type __start___jump_table, %object
+__start___jump_table:
+
+  .long .L0 - . # Jump address
+  .long .L1 - . # Target address
+  .quad 1       # Key address
+
+  .long .L1 - . # Jump address
+  .long .L2 - . # Target address
+  .quad 0       # Key address
+
+  .globl __stop___jump_table
+  .type __stop___jump_table, %object
+__stop___jump_table:
+
+## Fake Linux Kernel sections.
+  .section __ksymtab,"a", at progbits
+  .section __ksymtab_gpl,"a", at progbits



More information about the llvm-commits mailing list