[llvm] [BOLT][AArch64] Support for pointer authentication (PR #117578)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 09:01:09 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-adt

@llvm/pr-subscribers-bolt

Author: Gergely Bálint (bgergely0)

<details>
<summary>Changes</summary>

Currently BOLT does not support Aarch64 binaries with pointer authentication, because it cannot correctly handle `.cfi_negate_ra_state`. 

This has been raised in several issues:
- https://github.com/llvm/llvm-project/issues/74833
- https://github.com/llvm/llvm-project/issues/80992

## Where are OpNegateRAState CFIs needed?
These CFI instructions need to be placed in locations, where one BasicBlock has a different return address state (signed or not signed) than the previous one.

As BasicBlocks are moving around during optimization, we cannot know where these should be placed, until all optimizations are done.

For this reason, my implementation skips these CFIs during reading the binary, and adds a new pass `InsertNegateRAStatePass`, that can identify subsequent BBs with different RA States.

## InsertNegateRAStatePass
- the pass marks all BBs with three states: Unknown, Signed, and Unsigned. 
- it explores the CFG to find successors of blocks with LR signing, and marks them as Signed, until a BB with authenticating instruction is found. Successors of an authenticating BB, or BBs before signing are marked as Unsigned.
- after identifying the BB States, we can insert `OpNegateRAState` to locations, where the State is changing

## Dependencies
- to find out which BasicBlock has signed and unsigned state, we need a correct CFG. 
- I have found an issue with CFG generation, which is described here: 
https://github.com/llvm/llvm-project/issues/115154
- until I have a working PR to fix that (near future), I cherry-picked a commit from @<!-- -->kbeyls' `bolt-gadget-scanner` fork as a workaround for testing this change
- this commit allows the user to input suspected noreturn functions in a file using the `--noreturnfuncs-file=<file>` option.

## WIP
Not all pointer signing and authenticating instructions are covered yet, only `paciasp` and `autiasp`. This is enough to test on binaries built with `-mbranch-protection=pac-ret`.

---

Patch is 26.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117578.diff


16 Files Affected:

- (modified) bolt/include/bolt/Core/BinaryBasicBlock.h (+16) 
- (modified) bolt/include/bolt/Core/MCPlusBuilder.h (+30) 
- (added) bolt/include/bolt/Passes/InsertNegateRAStatePass.h (+33) 
- (modified) bolt/include/bolt/Utils/CommandLineOpts.h (+7) 
- (modified) bolt/lib/Core/BinaryBasicBlock.cpp (+5-1) 
- (modified) bolt/lib/Core/BinaryFunction.cpp (+7-3) 
- (modified) bolt/lib/Core/Exceptions.cpp (+4-2) 
- (modified) bolt/lib/Passes/BinaryPasses.cpp (+12-14) 
- (modified) bolt/lib/Passes/CMakeLists.txt (+1) 
- (added) bolt/lib/Passes/InsertNegateRAStatePass.cpp (+247) 
- (modified) bolt/lib/Rewrite/BinaryPassManager.cpp (+3) 
- (modified) bolt/lib/Rewrite/RewriteInstance.cpp (+8-12) 
- (modified) bolt/lib/Target/AArch64/AArch64MCPlusBuilder.cpp (+6) 
- (modified) bolt/lib/Utils/CommandLineOpts.cpp (+23-1) 
- (added) bolt/test/gadget-scanner/gs-pacret-noreturn.s (+31) 
- (modified) llvm/include/llvm/ADT/StringRef.h (+5) 


``````````diff
diff --git a/bolt/include/bolt/Core/BinaryBasicBlock.h b/bolt/include/bolt/Core/BinaryBasicBlock.h
index 25cccc4edecf68..fe48d3b4f740d6 100644
--- a/bolt/include/bolt/Core/BinaryBasicBlock.h
+++ b/bolt/include/bolt/Core/BinaryBasicBlock.h
@@ -37,6 +37,11 @@ class JumpTable;
 
 class BinaryBasicBlock {
 public:
+  enum class RAStateEnum : char {
+    Unknown, /// Not discovered yet
+    Signed,
+    Unsigned,
+  };
   /// Profile execution information for a given edge in CFG.
   ///
   /// If MispredictedCount equals COUNT_INFERRED, then we have a profile
@@ -350,6 +355,17 @@ class BinaryBasicBlock {
                                                       BranchInfo.end());
   }
 
+  RAStateEnum RAState{RAStateEnum::Unknown};
+  void setRASigned() { RAState = RAStateEnum::Signed; }
+  bool isRAStateUnknown() { return RAState == RAStateEnum::Unknown; }
+  bool isRAStateSigned() { return RAState == RAStateEnum::Signed; }
+  /// Unsigned should only overwrite Unknown state, and not Signed
+  void setRAUnsigned() {
+    if (RAState == RAStateEnum::Unknown) {
+      RAState = RAStateEnum::Unsigned;
+    }
+  }
+
   /// Get instruction at given index.
   MCInst &getInstructionAtIndex(unsigned Index) { return Instructions[Index]; }
 
diff --git a/bolt/include/bolt/Core/MCPlusBuilder.h b/bolt/include/bolt/Core/MCPlusBuilder.h
index 3634fed9757ceb..00b69d65a3b21e 100644
--- a/bolt/include/bolt/Core/MCPlusBuilder.h
+++ b/bolt/include/bolt/Core/MCPlusBuilder.h
@@ -16,6 +16,7 @@
 
 #include "bolt/Core/MCPlus.h"
 #include "bolt/Core/Relocation.h"
+#include "bolt/Utils/CommandLineOpts.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/StringMap.h"
@@ -27,6 +28,7 @@
 #include "llvm/MC/MCInstrAnalysis.h"
 #include "llvm/MC/MCInstrDesc.h"
 #include "llvm/MC/MCInstrInfo.h"
+#include "llvm/MC/MCSymbol.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -546,6 +548,27 @@ class MCPlusBuilder {
     return Analysis->isCall(Inst) || isTailCall(Inst);
   }
 
+  virtual std::optional<StringRef> getCalleeName(const MCInst &Inst) const {
+    assert(isCall(Inst));
+    if (MCPlus::getNumPrimeOperands(Inst) != 1 || !Inst.getOperand(0).isExpr())
+      return {};
+
+    const MCSymbol *CalleeSymbol = getTargetSymbol(Inst);
+    assert(CalleeSymbol != nullptr);
+    return CalleeSymbol->getName();
+  }
+
+  virtual bool isNoReturnCall(const MCInst& Inst) const {
+    if (!isCall(Inst))
+      return false;
+    auto calleeName = getCalleeName(Inst);
+    if (calleeName)
+      for (std::string &Name : opts::AssumeNoReturnFunctions)
+        if (calleeName->equals(Name))
+          return true;
+    return false;
+  }
+
   virtual bool isReturn(const MCInst &Inst) const {
     return Analysis->isReturn(Inst);
   }
@@ -648,6 +671,13 @@ class MCPlusBuilder {
     llvm_unreachable("not implemented");
     return false;
   }
+  virtual bool isPAuth(MCInst &Inst) const {
+    llvm_unreachable("not implemented");
+  }
+
+  virtual bool isPSign(MCInst &Inst) const {
+    llvm_unreachable("not implemented");
+  }
 
   virtual bool isCleanRegXOR(const MCInst &Inst) const {
     llvm_unreachable("not implemented");
diff --git a/bolt/include/bolt/Passes/InsertNegateRAStatePass.h b/bolt/include/bolt/Passes/InsertNegateRAStatePass.h
new file mode 100644
index 00000000000000..ac1a2f4cf7327c
--- /dev/null
+++ b/bolt/include/bolt/Passes/InsertNegateRAStatePass.h
@@ -0,0 +1,33 @@
+#ifndef BOLT_PASSES_INSERT_NEGATE_RA_STATE_PASS
+#define BOLT_PASSES_INSERT_NEGATE_RA_STATE_PASS
+
+#include "bolt/Passes/BinaryPasses.h"
+#include <stack>
+
+namespace llvm {
+namespace bolt {
+
+class InsertNegateRAState : public BinaryFunctionPass {
+public:
+  explicit InsertNegateRAState() : BinaryFunctionPass(false) {}
+
+  const char *getName() const override { return "insert-negate-ra-state-pass"; }
+
+  /// Pass entry point
+  Error runOnFunctions(BinaryContext &BC) override;
+  void runOnFunction(BinaryFunction &BF);
+  bool addNegateRAStateAfterPacOrAuth(BinaryFunction &BF);
+  bool BBhasAUTH(BinaryContext &BC, BinaryBasicBlock *BB);
+  bool BBhasSIGN(BinaryContext &BC, BinaryBasicBlock *BB);
+  void explore_call_graph(BinaryContext &BC, BinaryBasicBlock *BB);
+  void process_signed_BB(BinaryContext &BC, BinaryBasicBlock *BB,
+                         std::stack<BinaryBasicBlock *> *SignedStack,
+                         std::stack<BinaryBasicBlock *> *UnsignedStack);
+  void process_unsigned_BB(BinaryContext &BC, BinaryBasicBlock *BB,
+                           std::stack<BinaryBasicBlock *> *SignedStack,
+                           std::stack<BinaryBasicBlock *> *UnsignedStack);
+};
+
+} // namespace bolt
+} // namespace llvm
+#endif
diff --git a/bolt/include/bolt/Utils/CommandLineOpts.h b/bolt/include/bolt/Utils/CommandLineOpts.h
index 04bf7db5de9527..dcdcb0ed84cfc9 100644
--- a/bolt/include/bolt/Utils/CommandLineOpts.h
+++ b/bolt/include/bolt/Utils/CommandLineOpts.h
@@ -62,6 +62,13 @@ extern llvm::cl::opt<bool> TimeOpts;
 extern llvm::cl::opt<bool> UseOldText;
 extern llvm::cl::opt<bool> UpdateDebugSections;
 
+extern llvm::cl::list<std::string> AssumeNoReturnFunctions;
+extern llvm::cl::opt<std::string> AssumeNoReturnFunctionsFile;
+
+/// Reads names from FunctionNamesFile and adds them to FunctionNames.
+void populateFunctionNames(const llvm::cl::opt<std::string> &FunctionNamesFile,
+                           llvm::cl::list<std::string> &FunctionNames);
+
 // The default verbosity level (0) is pretty terse, level 1 is fairly
 // verbose and usually prints some informational message for every
 // function processed.  Level 2 is for the noisiest of messages and
diff --git a/bolt/lib/Core/BinaryBasicBlock.cpp b/bolt/lib/Core/BinaryBasicBlock.cpp
index 2a2192b79bb4bf..bc1d3f112f2ed2 100644
--- a/bolt/lib/Core/BinaryBasicBlock.cpp
+++ b/bolt/lib/Core/BinaryBasicBlock.cpp
@@ -201,7 +201,11 @@ int32_t BinaryBasicBlock::getCFIStateAtInstr(const MCInst *Instr) const {
       InstrSeen = (&Inst == Instr);
       continue;
     }
-    if (Function->getBinaryContext().MIB->isCFI(Inst)) {
+    // Fix: ignoring OpNegateRAState CFIs here, as they dont have a "State"
+    // number associated with them.
+    if (Function->getBinaryContext().MIB->isCFI(Inst) &&
+        (Function->getCFIFor(Inst)->getOperation() !=
+         MCCFIInstruction::OpNegateRAState)) {
       LastCFI = &Inst;
       break;
     }
diff --git a/bolt/lib/Core/BinaryFunction.cpp b/bolt/lib/Core/BinaryFunction.cpp
index 5da777411ba7a1..842a1cb16e14ac 100644
--- a/bolt/lib/Core/BinaryFunction.cpp
+++ b/bolt/lib/Core/BinaryFunction.cpp
@@ -2155,7 +2155,7 @@ Error BinaryFunction::buildCFG(MCPlusBuilder::AllocatorIdTy AllocatorId) {
       addCFIPlaceholders(Offset, InsertBB);
     }
 
-    const bool IsBlockEnd = MIB->isTerminator(Instr);
+    const bool IsBlockEnd = MIB->isTerminator(Instr) || MIB->isNoReturnCall(Instr);
     IsLastInstrNop = MIB->isNoop(Instr);
     if (!IsLastInstrNop)
       LastInstrOffset = Offset;
@@ -2242,8 +2242,11 @@ Error BinaryFunction::buildCFG(MCPlusBuilder::AllocatorIdTy AllocatorId) {
       //
       // Conditional tail call is a special case since we don't add a taken
       // branch successor for it.
-      IsPrevFT = !MIB->isTerminator(*LastInstr) ||
-                 MIB->getConditionalTailCall(*LastInstr);
+      if (MIB->isNoReturnCall(*LastInstr))
+        IsPrevFT = false;
+      else
+        IsPrevFT = !MIB->isTerminator(*LastInstr) ||
+                   MIB->getConditionalTailCall(*LastInstr);
     } else if (BB->succ_size() == 1) {
       IsPrevFT = MIB->isConditionalBranch(*LastInstr);
     } else {
@@ -2596,6 +2599,7 @@ struct CFISnapshot {
   void advanceTo(int32_t State) {
     for (int32_t I = CurState, E = State; I != E; ++I) {
       const MCCFIInstruction &Instr = FDE[I];
+      assert(Instr.getOperation() != MCCFIInstruction::OpNegateRAState);
       if (Instr.getOperation() != MCCFIInstruction::OpRestoreState) {
         update(Instr, I);
         continue;
diff --git a/bolt/lib/Core/Exceptions.cpp b/bolt/lib/Core/Exceptions.cpp
index 0b2e63b8ca6a79..f528db8449dbd2 100644
--- a/bolt/lib/Core/Exceptions.cpp
+++ b/bolt/lib/Core/Exceptions.cpp
@@ -632,8 +632,10 @@ bool CFIReaderWriter::fillCFIInfoFor(BinaryFunction &Function) const {
       // DW_CFA_GNU_window_save and DW_CFA_GNU_NegateRAState just use the same
       // id but mean different things. The latter is used in AArch64.
       if (Function.getBinaryContext().isAArch64()) {
-        Function.addCFIInstruction(
-            Offset, MCCFIInstruction::createNegateRAState(nullptr));
+        // Fix: not adding OpNegateRAState since the location they are needed
+        // depends on the order of BasicBlocks, which changes during
+        // optimizations. They are generated in InsertNegateRAStatePass after
+        // optimizations instead.
         break;
       }
       if (opts::Verbosity >= 1)
diff --git a/bolt/lib/Passes/BinaryPasses.cpp b/bolt/lib/Passes/BinaryPasses.cpp
index 03d3dd75a03368..0cce08f650ae74 100644
--- a/bolt/lib/Passes/BinaryPasses.cpp
+++ b/bolt/lib/Passes/BinaryPasses.cpp
@@ -1852,17 +1852,16 @@ Error InlineMemcpy::runOnFunctions(BinaryContext &BC) {
       for (auto II = BB.begin(); II != BB.end(); ++II) {
         MCInst &Inst = *II;
 
-        if (!BC.MIB->isCall(Inst) || MCPlus::getNumPrimeOperands(Inst) != 1 ||
-            !Inst.getOperand(0).isExpr())
+        if (!BC.MIB->isCall(Inst))
           continue;
-
-        const MCSymbol *CalleeSymbol = BC.MIB->getTargetSymbol(Inst);
-        if (CalleeSymbol->getName() != "memcpy" &&
-            CalleeSymbol->getName() != "memcpy at PLT" &&
-            CalleeSymbol->getName() != "_memcpy8")
+        std::optional<StringRef> CalleeName = BC.MIB->getCalleeName(Inst);
+        if (!CalleeName)
+          continue;
+        if (*CalleeName != "memcpy" && *CalleeName != "memcpy at PLT" &&
+            *CalleeName != "_memcpy8")
           continue;
 
-        const bool IsMemcpy8 = (CalleeSymbol->getName() == "_memcpy8");
+        const bool IsMemcpy8 = (*CalleeName == "_memcpy8");
         const bool IsTailCall = BC.MIB->isTailCall(Inst);
 
         const InstructionListType NewCode =
@@ -1951,13 +1950,12 @@ Error SpecializeMemcpy1::runOnFunctions(BinaryContext &BC) {
       for (auto II = CurBB->begin(); II != CurBB->end(); ++II) {
         MCInst &Inst = *II;
 
-        if (!BC.MIB->isCall(Inst) || MCPlus::getNumPrimeOperands(Inst) != 1 ||
-            !Inst.getOperand(0).isExpr())
+        if (!BC.MIB->isCall(Inst))
           continue;
-
-        const MCSymbol *CalleeSymbol = BC.MIB->getTargetSymbol(Inst);
-        if (CalleeSymbol->getName() != "memcpy" &&
-            CalleeSymbol->getName() != "memcpy at PLT")
+        std::optional<StringRef> CalleeName = BC.MIB->getCalleeName(Inst);
+        if (!CalleeName)
+          continue;
+        if (*CalleeName != "memcpy" && *CalleeName != "memcpy at PLT")
           continue;
 
         if (BC.MIB->isTailCall(Inst))
diff --git a/bolt/lib/Passes/CMakeLists.txt b/bolt/lib/Passes/CMakeLists.txt
index 1c1273b3d2420d..d7864e30305116 100644
--- a/bolt/lib/Passes/CMakeLists.txt
+++ b/bolt/lib/Passes/CMakeLists.txt
@@ -17,6 +17,7 @@ add_llvm_library(LLVMBOLTPasses
   IdenticalCodeFolding.cpp
   IndirectCallPromotion.cpp
   Inliner.cpp
+  InsertNegateRAStatePass.cpp
   Instrumentation.cpp
   JTFootprintReduction.cpp
   LongJmp.cpp
diff --git a/bolt/lib/Passes/InsertNegateRAStatePass.cpp b/bolt/lib/Passes/InsertNegateRAStatePass.cpp
new file mode 100644
index 00000000000000..c5db6df3a2b606
--- /dev/null
+++ b/bolt/lib/Passes/InsertNegateRAStatePass.cpp
@@ -0,0 +1,247 @@
+#include "bolt/Passes/InsertNegateRAStatePass.h"
+#include "bolt/Core/BinaryFunction.h"
+#include "bolt/Core/ParallelUtilities.h"
+#include "bolt/Utils/CommandLineOpts.h"
+#include <cstdlib>
+#include <fstream>
+#include <iterator>
+
+using namespace llvm;
+
+namespace llvm {
+namespace bolt {
+
+void InsertNegateRAState::runOnFunction(BinaryFunction &BF) {
+  BinaryContext &BC = BF.getBinaryContext();
+
+  if (BF.getState() == BinaryFunction::State::Empty) {
+    return;
+  }
+
+  if (BF.getState() != BinaryFunction::State::CFG &&
+      BF.getState() != BinaryFunction::State::CFG_Finalized) {
+    BC.errs() << "BOLT-WARNING: No CFG for " << BF.getPrintName()
+              << " in InsertNegateRAStatePass\n";
+    return;
+  }
+
+  if (BF.getState() == BinaryFunction::State::CFG_Finalized) {
+    BC.errs() << "BOLT-WARNING: CFG finalized for " << BF.getPrintName()
+              << " in InsertNegateRAStatePass\n";
+    return;
+  }
+
+  if (BF.isIgnored())
+    return;
+
+  if (!addNegateRAStateAfterPacOrAuth(BF)) {
+    // none inserted, function doesn't need more work
+    return;
+  }
+
+  auto FirstBB = BF.begin();
+  explore_call_graph(BC, &(*FirstBB));
+
+  // We have to do the walk again, starting from any undiscovered autiasp
+  // instructions, because some autiasp might not be reachable because of
+  // indirect branches but we know that autiasp block should have a Signed
+  // state, so we can work out other Unkown states starting from these nodes.
+  for (BinaryBasicBlock &BB : BF) {
+    if (BBhasAUTH(BC, &BB) && BB.isRAStateUnknown()) {
+      BB.setRASigned();
+      explore_call_graph(BC, &BB);
+    }
+  }
+
+  // insert negateRAState-s where there is a State boundary:
+  // that is: two consecutive BBs have different RA State
+  BinaryFunction::iterator PrevBB;
+  bool FirstIter = true;
+  for (auto BB = BF.begin(); BB != BF.end(); ++BB) {
+    if (!FirstIter) {
+      if ((PrevBB->RAState == BinaryBasicBlock::RAStateEnum::Signed &&
+           (*BB).RAState == BinaryBasicBlock::RAStateEnum::Unsigned &&
+           !BBhasAUTH(BC, &(*PrevBB))) ||
+          (PrevBB->RAState == BinaryBasicBlock::RAStateEnum::Signed &&
+           (*BB).RAState == BinaryBasicBlock::RAStateEnum::Signed &&
+           BBhasAUTH(BC, &(*PrevBB)))) {
+        auto InstRevIter = PrevBB->getLastNonPseudo();
+        MCInst LastNonPseudo = *InstRevIter;
+        auto InstIter = InstRevIter.base();
+        BF.addCFIInstruction(&(*PrevBB), InstIter,
+                             MCCFIInstruction::createNegateRAState(nullptr));
+      }
+    } else {
+      FirstIter = false;
+    }
+    PrevBB = BB;
+  }
+}
+
+void InsertNegateRAState::explore_call_graph(BinaryContext &BC,
+                                             BinaryBasicBlock *BB) {
+  std::stack<BinaryBasicBlock *> SignedStack;
+  std::stack<BinaryBasicBlock *> UnsignedStack;
+
+  // start according to the first BB
+  if (BBhasSIGN(BC, BB)) {
+    SignedStack.push(BB);
+    process_signed_BB(BC, BB, &SignedStack, &UnsignedStack);
+  } else {
+    UnsignedStack.push(BB);
+    process_unsigned_BB(BC, BB, &SignedStack, &UnsignedStack);
+  }
+
+  while (!(SignedStack.empty() && UnsignedStack.empty())) {
+    if (!SignedStack.empty()) {
+      BB = SignedStack.top();
+      SignedStack.pop();
+      process_signed_BB(BC, BB, &SignedStack, &UnsignedStack);
+    } else if (!UnsignedStack.empty()) {
+      BB = UnsignedStack.top();
+      UnsignedStack.pop();
+      process_unsigned_BB(BC, BB, &SignedStack, &UnsignedStack);
+    }
+  }
+}
+void InsertNegateRAState::process_signed_BB(
+    BinaryContext &BC, BinaryBasicBlock *BB,
+    std::stack<BinaryBasicBlock *> *SignedStack,
+    std::stack<BinaryBasicBlock *> *UnsignedStack) {
+
+  BB->setRASigned();
+
+  if (BBhasAUTH(BC, BB)) {
+    // successors of block with autiasp are stored in the Unsigned Stack
+    for (BinaryBasicBlock *Succ : BB->successors()) {
+      if (Succ->getFunction() == BB->getFunction() &&
+          Succ->isRAStateUnknown()) {
+        UnsignedStack->push(Succ);
+      }
+    }
+  } else {
+    for (BinaryBasicBlock *Succ : BB->successors()) {
+      if (Succ->getFunction() == BB->getFunction() &&
+          !Succ->isRAStateSigned()) {
+        SignedStack->push(Succ);
+      }
+    }
+  }
+  // process predecessors
+  if (BBhasSIGN(BC, BB)) {
+    for (BinaryBasicBlock *Pred : BB->predecessors()) {
+      if (Pred->getFunction() == BB->getFunction() &&
+          Pred->isRAStateUnknown()) {
+        UnsignedStack->push(Pred);
+      }
+    }
+  } else {
+    for (BinaryBasicBlock *Pred : BB->predecessors()) {
+      if (Pred->getFunction() == BB->getFunction() &&
+          !Pred->isRAStateSigned()) {
+        SignedStack->push(Pred);
+      }
+    }
+  }
+}
+
+void InsertNegateRAState::process_unsigned_BB(
+    BinaryContext &BC, BinaryBasicBlock *BB,
+    std::stack<BinaryBasicBlock *> *SignedStack,
+    std::stack<BinaryBasicBlock *> *UnsignedStack) {
+
+  BB->setRAUnsigned();
+
+  if (BBhasSIGN(BC, BB)) {
+    BB->setRASigned();
+    // successors of block with paciasp are stored in the Signed Stack
+    for (BinaryBasicBlock *Succ : BB->successors()) {
+      if (Succ->getFunction() == BB->getFunction() &&
+          !Succ->isRAStateSigned()) {
+        SignedStack->push(Succ);
+      }
+    }
+  } else {
+    for (BinaryBasicBlock *Succ : BB->successors()) {
+      if (Succ->getFunction() == BB->getFunction() &&
+          Succ->isRAStateUnknown()) {
+        UnsignedStack->push(Succ);
+      }
+    }
+  }
+
+  // process predecessors
+  if (BBhasAUTH(BC, BB)) {
+    BB->setRASigned();
+    for (BinaryBasicBlock *Pred : BB->predecessors()) {
+      if (Pred->getFunction() == BB->getFunction() &&
+          !Pred->isRAStateSigned()) {
+        SignedStack->push(Pred);
+      }
+    }
+  } else {
+    for (BinaryBasicBlock *Pred : BB->predecessors()) {
+      if (Pred->getFunction() == BB->getFunction() &&
+          Pred->isRAStateUnknown()) {
+        UnsignedStack->push(Pred);
+      }
+    }
+  }
+}
+
+bool InsertNegateRAState::BBhasAUTH(BinaryContext &BC, BinaryBasicBlock *BB) {
+  for (auto Iter = BB->begin(); Iter != BB->end(); ++Iter) {
+    MCInst Inst = *Iter;
+    if (BC.MIB->isPAuth(Inst)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+bool InsertNegateRAState::BBhasSIGN(BinaryContext &BC, BinaryBasicBlock *BB) {
+  for (auto Iter = BB->begin(); Iter != BB->end(); ++Iter) {
+    MCInst Inst = *Iter;
+    if (BC.MIB->isPSign(Inst)) {
+      return true;
+    }
+  }
+  return false;
+}
+
+bool InsertNegateRAState::addNegateRAStateAfterPacOrAuth(BinaryFunction &BF) {
+  BinaryContext &BC = BF.getBinaryContext();
+  bool FoundAny = false;
+  for (BinaryBasicBlock &BB : BF) {
+    for (auto Iter = BB.begin(); Iter != BB.end(); ++Iter) {
+      MCInst Inst = *Iter;
+      if (BC.MIB->isPSign(Inst)) {
+        Iter = BF.addCFIInstruction(
+            &BB, Iter + 1, MCCFIInstruction::createNegateRAState(nullptr));
+        FoundAny = true;
+      }
+
+      if (BC.MIB->isPAuth(Inst)) {
+        Iter = BF.addCFIInstruction(
+            &BB, Iter + 1, MCCFIInstruction::createNegateRAState(nullptr));
+        FoundAny = true;
+      }
+    }
+  }
+  return FoundAny;
+}
+
+Error InsertNegateRAState::runOnFunctions(BinaryContext &BC) {
+  ParallelUtilities::WorkFuncTy WorkFun = [&](BinaryFunction &BF) {
+    runOnFunction(BF);
+  };
+
+  ParallelUtilities::runOnEachFunction(
+      BC, ParallelUtilities::SchedulingPolicy::SP_TRIVIAL, WorkFun, nullptr,
+      "InsertNegateRAStatePass");
+
+  return Error::success();
+}
+
+} // end namespace bolt
+} // end namespace llvm
diff --git a/bolt/lib/Rewrite/BinaryPassManager.cpp b/bolt/lib/Rewrite/BinaryPassManager.cpp
index b0906041833484..d11321d8ef93ae 100644
--- a/bolt/lib/Rewrite/BinaryPassManager.cpp
+++ b/bolt/lib/Rewrite/BinaryPassManager.cpp
@@ -20,6 +20,7 @@
 #include "bolt/Passes/IdenticalCodeFolding.h"
 #include "bolt/Passes/IndirectCallPromotion.h"
 #include "bolt/Passes/Inliner.h"
+#include "bolt/Passes/InsertNegateRAStatePass.h"
 #include "bolt/Passes/Instrumentation.h"
 #include "bolt/Passes/JTFootprintReduction.h"
 #include "bolt/Passes/LongJmp.h"
@@ -499,6 +500,8 @@ Error BinaryFunctionPassManager::runAllPasses(BinaryContext &BC) {
     // targets. No extra instructions after this pass, otherwise we may have
     // relocations out of range and crash during linking.
     Manager.registerPass(std::make_unique<LongJmpPass>(PrintLongJmp));
+
+    Manager.registerPass(std::make_unique<InsertNegateRAState>());
   }
 
   // This pass should always run last.*
diff --git a/bolt/lib/Rewrite/RewriteInstance.cpp b/bolt/lib/Rewrite/RewriteInstance.cpp
index 7059a3dd231099..ce5dc7f10cdaa4 100644
--- a/bolt/lib/Rewrite/R...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/117578


More information about the llvm-commits mailing list