[llvm] [BOLT] Gadget scanner: detect authentication oracles (PR #135663)

Kristof Beyls via llvm-commits llvm-commits at lists.llvm.org
Wed May 28 02:22:40 PDT 2025


================
@@ -717,6 +740,457 @@ SrcSafetyAnalysis::create(BinaryFunction &BF,
                                                        RegsToTrackInstsFor);
 }
 
+/// A state representing which registers are safe to be used as the destination
+/// operand of an authentication instruction.
+///
+/// Similar to SrcState, it is the responsibility of the analysis to take
+/// register aliasing into account.
+///
+/// Depending on the implementation, it may be possible that an authentication
+/// instruction returns an invalid pointer on failure instead of terminating
+/// the program immediately (assuming the program will crash as soon as that
+/// pointer is dereferenced). To prevent brute-forcing the correct signature,
+/// it should be impossible for an attacker to test if a pointer is correctly
+/// signed - either the program should be terminated on authentication failure
+/// or it should be impossible to tell whether authentication succeeded or not.
+///
+/// For that reason, a restricted set of operations is allowed on any register
+/// containing a value derived from the result of an authentication instruction
+/// until that register is either wiped or checked not to contain a result of a
+/// failed authentication.
+///
+/// Specifically, the safety property for a register is computed by iterating
+/// the instructions in backward order: the source register Xn of an instruction
+/// Inst is safe if at least one of the following is true:
+/// * Inst checks if Xn contains the result of a successful authentication and
+///   terminates the program on failure. Note that Inst can either naturally
+///   dereference Xn (load, branch, return, etc. instructions) or be the first
+///   instruction of an explicit checking sequence.
+/// * Inst performs safe address arithmetic AND both source and result
+///   registers, as well as any temporary registers, must be safe after
+///   execution of Inst (temporaries are not used on AArch64 and thus not
+///   currently supported/allowed).
+///   See MCPlusBuilder::analyzeAddressArithmeticsForPtrAuth for the details.
+/// * Inst fully overwrites Xn with an unrelated value.
+struct DstState {
+  /// The set of registers whose values cannot be inspected by an attacker in
+  /// a way usable as an authentication oracle. The results of authentication
+  /// instructions should only be written to such registers.
+  BitVector CannotEscapeUnchecked;
+
+  std::vector<SetOfRelatedInsts> FirstInstLeakingReg;
+
+  /// Constructs an empty state.
+  DstState() {}
+
+  DstState(unsigned NumRegs, unsigned NumRegsToTrack)
+      : CannotEscapeUnchecked(NumRegs), FirstInstLeakingReg(NumRegsToTrack) {}
+
+  DstState &merge(const DstState &StateIn) {
+    if (StateIn.empty())
+      return *this;
+    if (empty())
+      return (*this = StateIn);
+
+    CannotEscapeUnchecked &= StateIn.CannotEscapeUnchecked;
+    for (unsigned I = 0; I < FirstInstLeakingReg.size(); ++I)
+      for (const MCInst *J : StateIn.FirstInstLeakingReg[I])
+        FirstInstLeakingReg[I].insert(J);
+    return *this;
+  }
+
+  /// Returns true if this object does not store state of any registers -
+  /// neither safe, nor unsafe ones.
+  bool empty() const { return CannotEscapeUnchecked.empty(); }
+
+  bool operator==(const DstState &RHS) const {
+    return CannotEscapeUnchecked == RHS.CannotEscapeUnchecked &&
+           FirstInstLeakingReg == RHS.FirstInstLeakingReg;
+  }
+  bool operator!=(const DstState &RHS) const { return !((*this) == RHS); }
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const DstState &S) {
+  OS << "dst-state<";
+  if (S.empty()) {
+    OS << "empty";
+  } else {
+    OS << "CannotEscapeUnchecked: " << S.CannotEscapeUnchecked << ", ";
+    printInstsShort(OS, S.FirstInstLeakingReg);
+  }
+  OS << ">";
+  return OS;
+}
+
+class DstStatePrinter {
+public:
+  void print(raw_ostream &OS, const DstState &S) const;
+  explicit DstStatePrinter(const BinaryContext &BC) : BC(BC) {}
+
+private:
+  const BinaryContext &BC;
+};
+
+void DstStatePrinter::print(raw_ostream &OS, const DstState &S) const {
+  RegStatePrinter RegStatePrinter(BC);
+  OS << "dst-state<";
+  if (S.empty()) {
+    assert(S.CannotEscapeUnchecked.empty());
+    assert(S.FirstInstLeakingReg.empty());
+    OS << "empty";
+  } else {
+    OS << "CannotEscapeUnchecked: ";
+    RegStatePrinter.print(OS, S.CannotEscapeUnchecked);
+    OS << ", ";
+    printInstsShort(OS, S.FirstInstLeakingReg);
+  }
+  OS << ">";
+}
+
+/// Computes which registers are safe to be written to by auth instructions.
+///
+/// This is the base class for two implementations: a dataflow-based analysis
+/// which is intended to be used for most functions and a simplified CFG-unaware
+/// version for functions without reconstructed CFG.
+class DstSafetyAnalysis {
+public:
+  DstSafetyAnalysis(BinaryFunction &BF, ArrayRef<MCPhysReg> RegsToTrackInstsFor)
+      : BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
+        RegsToTrackInstsFor(RegsToTrackInstsFor) {}
+
+  virtual ~DstSafetyAnalysis() {}
+
+  static std::shared_ptr<DstSafetyAnalysis>
+  create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
+         ArrayRef<MCPhysReg> RegsToTrackInstsFor);
+
+  virtual void run() = 0;
+  virtual const DstState &getStateAfter(const MCInst &Inst) const = 0;
+
+protected:
+  BinaryContext &BC;
+  const unsigned NumRegs;
+
+  const TrackedRegisters RegsToTrackInstsFor;
+
+  /// Stores information about the detected instruction sequences emitted to
+  /// check an authenticated pointer. Specifically, if such sequence is detected
+  /// in a basic block, it maps the first instruction of that sequence to the
+  /// register being checked.
+  ///
+  /// As the detection of such sequences requires iterating over the adjacent
+  /// instructions, it should be done before calling computeNext(), which
+  /// operates on separate instructions.
+  DenseMap<const MCInst *, MCPhysReg> RegCheckedAt;
+
+  SetOfRelatedInsts &firstLeakingInsts(DstState &S, MCPhysReg Reg) const {
+    unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+    return S.FirstInstLeakingReg[Index];
+  }
+  const SetOfRelatedInsts &firstLeakingInsts(const DstState &S,
+                                             MCPhysReg Reg) const {
+    unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+    return S.FirstInstLeakingReg[Index];
+  }
+
+  /// Creates a state with all registers marked unsafe (not to be confused
+  /// with empty state).
+  DstState createUnsafeState() {
+    return DstState(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
+  }
+
+  /// Returns the set of registers that can be leaked by this instruction.
+  /// This is computed similar to the set of clobbered registers, but taking
+  /// input operands instead of outputs.
+  BitVector getLeakedRegs(const MCInst &Inst) const {
+    BitVector Leaked(NumRegs);
+
+    // Assume a call can read all registers.
+    if (BC.MIB->isCall(Inst)) {
+      Leaked.set();
+      return Leaked;
+    }
+
+    // Compute the set of registers overlapping with any register used by
+    // this instruction.
+
+    const MCInstrDesc &Desc = BC.MII->get(Inst.getOpcode());
+
+    for (MCPhysReg Reg : Desc.implicit_uses())
+      Leaked |= BC.MIB->getAliases(Reg, /*OnlySmaller=*/false);
+
+    for (const MCOperand &Op : BC.MIB->useOperands(Inst)) {
+      if (Op.isReg())
+        Leaked |= BC.MIB->getAliases(Op.getReg(), /*OnlySmaller=*/false);
+    }
+
+    return Leaked;
+  }
+
+  SmallVector<MCPhysReg> getRegsMadeProtected(const MCInst &Inst,
+                                              const BitVector &LeakedRegs,
+                                              const DstState &Cur) const {
+    SmallVector<MCPhysReg> Regs;
+
+    // A pointer can be checked, or
+    if (auto CheckedReg =
+            BC.MIB->getAuthCheckedReg(Inst, /*MayOverwrite=*/true))
+      Regs.push_back(*CheckedReg);
+    if (RegCheckedAt.contains(&Inst))
+      Regs.push_back(RegCheckedAt.at(&Inst));
+
+    // ... it can be used as a branch target, or
+    if (BC.MIB->isIndirectBranch(Inst) || BC.MIB->isIndirectCall(Inst)) {
+      bool IsAuthenticated;
+      MCPhysReg BranchDestReg =
+          BC.MIB->getRegUsedAsIndirectBranchDest(Inst, IsAuthenticated);
+      assert(BranchDestReg != BC.MIB->getNoRegister());
+      if (!IsAuthenticated)
+        Regs.push_back(BranchDestReg);
+    }
+
+    // ... it can be used as a return target, or
+    if (BC.MIB->isReturn(Inst)) {
+      bool IsAuthenticated = false;
+      std::optional<MCPhysReg> RetReg =
+          BC.MIB->getRegUsedAsRetDest(Inst, IsAuthenticated);
+      if (RetReg && !IsAuthenticated)
+        Regs.push_back(*RetReg);
+    }
+
+    // ... an address can be updated in a safe manner, or
+    if (auto DstAndSrc = BC.MIB->analyzeAddressArithmeticsForPtrAuth(Inst)) {
+      MCPhysReg DstReg, SrcReg;
+      std::tie(DstReg, SrcReg) = *DstAndSrc;
+      // Note that *all* registers containing the derived values must be safe,
+      // both source and destination ones. No temporaries are supported at now.
+      if (Cur.CannotEscapeUnchecked[SrcReg] &&
+          Cur.CannotEscapeUnchecked[DstReg])
+        Regs.push_back(SrcReg);
+    }
+
+    // ... the register can be overwritten in whole with an unrelated value -
+    // for that reason, ignore the registers that are both read and written:
+    //
+    //     movk x0, #42, lsl #16  // keeps some bits of x0
+    //     mul  x1, x1, #3        // not all information is actually lost
+    //
+    BitVector FullyOverwrittenRegs;
+    BC.MIB->getWrittenRegs(Inst, FullyOverwrittenRegs);
+    FullyOverwrittenRegs.reset(LeakedRegs);
+    for (MCPhysReg Reg : FullyOverwrittenRegs.set_bits())
+      Regs.push_back(Reg);
----------------
kbeyls wrote:

I couldn't easily map what the comment says to what the code does here.
I'm not fully sure exactly what I'm missing here. Could you explain this in a bit more detail?



https://github.com/llvm/llvm-project/pull/135663


More information about the llvm-commits mailing list