[llvm] [BOLT] Gadget scanner: detect authentication oracles (PR #135663)
Kristof Beyls via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 02:22:40 PDT 2025
================
@@ -717,6 +740,457 @@ SrcSafetyAnalysis::create(BinaryFunction &BF,
RegsToTrackInstsFor);
}
+/// A state representing which registers are safe to be used as the destination
+/// operand of an authentication instruction.
+///
+/// Similar to SrcState, it is the responsibility of the analysis to take
+/// register aliasing into account.
+///
+/// Depending on the implementation, it may be possible that an authentication
+/// instruction returns an invalid pointer on failure instead of terminating
+/// the program immediately (assuming the program will crash as soon as that
+/// pointer is dereferenced). To prevent brute-forcing the correct signature,
+/// it should be impossible for an attacker to test if a pointer is correctly
+/// signed - either the program should be terminated on authentication failure
+/// or it should be impossible to tell whether authentication succeeded or not.
+///
+/// For that reason, a restricted set of operations is allowed on any register
+/// containing a value derived from the result of an authentication instruction
+/// until that register is either wiped or checked not to contain a result of a
+/// failed authentication.
+///
+/// Specifically, the safety property for a register is computed by iterating
+/// the instructions in backward order: the source register Xn of an instruction
+/// Inst is safe if at least one of the following is true:
+/// * Inst checks if Xn contains the result of a successful authentication and
+/// terminates the program on failure. Note that Inst can either naturally
+/// dereference Xn (load, branch, return, etc. instructions) or be the first
+/// instruction of an explicit checking sequence.
+/// * Inst performs safe address arithmetic AND both source and result
+/// registers, as well as any temporary registers, must be safe after
+/// execution of Inst (temporaries are not used on AArch64 and thus not
+/// currently supported/allowed).
+/// See MCPlusBuilder::analyzeAddressArithmeticsForPtrAuth for the details.
+/// * Inst fully overwrites Xn with an unrelated value.
+struct DstState {
+ /// The set of registers whose values cannot be inspected by an attacker in
+ /// a way usable as an authentication oracle. The results of authentication
+ /// instructions should only be written to such registers.
+ BitVector CannotEscapeUnchecked;
+
+ std::vector<SetOfRelatedInsts> FirstInstLeakingReg;
+
+ /// Constructs an empty state.
+ DstState() {}
+
+ DstState(unsigned NumRegs, unsigned NumRegsToTrack)
+ : CannotEscapeUnchecked(NumRegs), FirstInstLeakingReg(NumRegsToTrack) {}
+
+ DstState &merge(const DstState &StateIn) {
+ if (StateIn.empty())
+ return *this;
+ if (empty())
+ return (*this = StateIn);
+
+ CannotEscapeUnchecked &= StateIn.CannotEscapeUnchecked;
+ for (unsigned I = 0; I < FirstInstLeakingReg.size(); ++I)
+ for (const MCInst *J : StateIn.FirstInstLeakingReg[I])
+ FirstInstLeakingReg[I].insert(J);
+ return *this;
+ }
+
+ /// Returns true if this object does not store state of any registers -
+ /// neither safe, nor unsafe ones.
+ bool empty() const { return CannotEscapeUnchecked.empty(); }
+
+ bool operator==(const DstState &RHS) const {
+ return CannotEscapeUnchecked == RHS.CannotEscapeUnchecked &&
+ FirstInstLeakingReg == RHS.FirstInstLeakingReg;
+ }
+ bool operator!=(const DstState &RHS) const { return !((*this) == RHS); }
+};
+
+static raw_ostream &operator<<(raw_ostream &OS, const DstState &S) {
+ OS << "dst-state<";
+ if (S.empty()) {
+ OS << "empty";
+ } else {
+ OS << "CannotEscapeUnchecked: " << S.CannotEscapeUnchecked << ", ";
+ printInstsShort(OS, S.FirstInstLeakingReg);
+ }
+ OS << ">";
+ return OS;
+}
+
+class DstStatePrinter {
+public:
+ void print(raw_ostream &OS, const DstState &S) const;
+ explicit DstStatePrinter(const BinaryContext &BC) : BC(BC) {}
+
+private:
+ const BinaryContext &BC;
+};
+
+void DstStatePrinter::print(raw_ostream &OS, const DstState &S) const {
+ RegStatePrinter RegStatePrinter(BC);
+ OS << "dst-state<";
+ if (S.empty()) {
+ assert(S.CannotEscapeUnchecked.empty());
+ assert(S.FirstInstLeakingReg.empty());
+ OS << "empty";
+ } else {
+ OS << "CannotEscapeUnchecked: ";
+ RegStatePrinter.print(OS, S.CannotEscapeUnchecked);
+ OS << ", ";
+ printInstsShort(OS, S.FirstInstLeakingReg);
+ }
+ OS << ">";
+}
+
+/// Computes which registers are safe to be written to by auth instructions.
+///
+/// This is the base class for two implementations: a dataflow-based analysis
+/// which is intended to be used for most functions and a simplified CFG-unaware
+/// version for functions without reconstructed CFG.
+class DstSafetyAnalysis {
+public:
+ DstSafetyAnalysis(BinaryFunction &BF, ArrayRef<MCPhysReg> RegsToTrackInstsFor)
+ : BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
+ RegsToTrackInstsFor(RegsToTrackInstsFor) {}
+
+ virtual ~DstSafetyAnalysis() {}
+
+ static std::shared_ptr<DstSafetyAnalysis>
+ create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
+ ArrayRef<MCPhysReg> RegsToTrackInstsFor);
+
+ virtual void run() = 0;
+ virtual const DstState &getStateAfter(const MCInst &Inst) const = 0;
+
+protected:
+ BinaryContext &BC;
+ const unsigned NumRegs;
+
+ const TrackedRegisters RegsToTrackInstsFor;
+
+ /// Stores information about the detected instruction sequences emitted to
+ /// check an authenticated pointer. Specifically, if such sequence is detected
+ /// in a basic block, it maps the first instruction of that sequence to the
+ /// register being checked.
+ ///
+ /// As the detection of such sequences requires iterating over the adjacent
+ /// instructions, it should be done before calling computeNext(), which
+ /// operates on separate instructions.
+ DenseMap<const MCInst *, MCPhysReg> RegCheckedAt;
+
+ SetOfRelatedInsts &firstLeakingInsts(DstState &S, MCPhysReg Reg) const {
+ unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+ return S.FirstInstLeakingReg[Index];
+ }
+ const SetOfRelatedInsts &firstLeakingInsts(const DstState &S,
+ MCPhysReg Reg) const {
+ unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+ return S.FirstInstLeakingReg[Index];
+ }
+
+ /// Creates a state with all registers marked unsafe (not to be confused
+ /// with empty state).
+ DstState createUnsafeState() {
+ return DstState(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
+ }
+
+ /// Returns the set of registers that can be leaked by this instruction.
+ /// This is computed similar to the set of clobbered registers, but taking
+ /// input operands instead of outputs.
+ BitVector getLeakedRegs(const MCInst &Inst) const {
+ BitVector Leaked(NumRegs);
+
+ // Assume a call can read all registers.
+ if (BC.MIB->isCall(Inst)) {
+ Leaked.set();
+ return Leaked;
+ }
+
+ // Compute the set of registers overlapping with any register used by
+ // this instruction.
+
+ const MCInstrDesc &Desc = BC.MII->get(Inst.getOpcode());
+
+ for (MCPhysReg Reg : Desc.implicit_uses())
+ Leaked |= BC.MIB->getAliases(Reg, /*OnlySmaller=*/false);
+
+ for (const MCOperand &Op : BC.MIB->useOperands(Inst)) {
+ if (Op.isReg())
+ Leaked |= BC.MIB->getAliases(Op.getReg(), /*OnlySmaller=*/false);
+ }
+
+ return Leaked;
+ }
+
+ SmallVector<MCPhysReg> getRegsMadeProtected(const MCInst &Inst,
+ const BitVector &LeakedRegs,
+ const DstState &Cur) const {
+ SmallVector<MCPhysReg> Regs;
+
+ // A pointer can be checked, or
+ if (auto CheckedReg =
+ BC.MIB->getAuthCheckedReg(Inst, /*MayOverwrite=*/true))
+ Regs.push_back(*CheckedReg);
+ if (RegCheckedAt.contains(&Inst))
+ Regs.push_back(RegCheckedAt.at(&Inst));
+
+ // ... it can be used as a branch target, or
+ if (BC.MIB->isIndirectBranch(Inst) || BC.MIB->isIndirectCall(Inst)) {
+ bool IsAuthenticated;
+ MCPhysReg BranchDestReg =
+ BC.MIB->getRegUsedAsIndirectBranchDest(Inst, IsAuthenticated);
+ assert(BranchDestReg != BC.MIB->getNoRegister());
+ if (!IsAuthenticated)
+ Regs.push_back(BranchDestReg);
+ }
+
+ // ... it can be used as a return target, or
+ if (BC.MIB->isReturn(Inst)) {
+ bool IsAuthenticated = false;
+ std::optional<MCPhysReg> RetReg =
+ BC.MIB->getRegUsedAsRetDest(Inst, IsAuthenticated);
+ if (RetReg && !IsAuthenticated)
+ Regs.push_back(*RetReg);
+ }
+
+ // ... an address can be updated in a safe manner, or
+ if (auto DstAndSrc = BC.MIB->analyzeAddressArithmeticsForPtrAuth(Inst)) {
+ MCPhysReg DstReg, SrcReg;
+ std::tie(DstReg, SrcReg) = *DstAndSrc;
+ // Note that *all* registers containing the derived values must be safe,
+ // both source and destination ones. No temporaries are supported at now.
+ if (Cur.CannotEscapeUnchecked[SrcReg] &&
+ Cur.CannotEscapeUnchecked[DstReg])
+ Regs.push_back(SrcReg);
+ }
+
+ // ... the register can be overwritten in whole with an unrelated value -
+ // for that reason, ignore the registers that are both read and written:
+ //
+ // movk x0, #42, lsl #16 // keeps some bits of x0
+ // mul x1, x1, #3 // not all information is actually lost
+ //
+ BitVector FullyOverwrittenRegs;
+ BC.MIB->getWrittenRegs(Inst, FullyOverwrittenRegs);
+ FullyOverwrittenRegs.reset(LeakedRegs);
+ for (MCPhysReg Reg : FullyOverwrittenRegs.set_bits())
+ Regs.push_back(Reg);
----------------
kbeyls wrote:
I couldn't easily map what the comment says to what the code does here.
I'm not fully sure exactly what I'm missing here. Could you explain this in a bit more detail?
https://github.com/llvm/llvm-project/pull/135663
More information about the llvm-commits
mailing list