[llvm] [BOLT] Gadget scanner: factor out utility code (PR #131895)

Anatoly Trosinenko via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 18 12:25:16 PDT 2025


https://github.com/atrosinenko created https://github.com/llvm/llvm-project/pull/131895

Factor out the code for mapping from physical registers to consecutive
array indexes.

Introduce helper functions to print instructions and registers to
prevent mixing of analysis logic and implementation details of debug
output.

Removed the debug printing from `Gadget::generateReport`, as it doesn't
seem to add important information to what was already printed in the
report itself.

>From 9761e5d53a0dc620889ca63d5e90d0110afbda7a Mon Sep 17 00:00:00 2001
From: Anatoly Trosinenko <atrosinenko at accesssoftek.com>
Date: Fri, 14 Mar 2025 18:29:05 +0300
Subject: [PATCH] [BOLT] Gadget scanner: factor out utility code

Factor out the code for mapping from physical registers to consecutive
array indexes.

Introduce helper functions to print instructions and registers to
prevent mixing of analysis logic and implementation details of debug
output.

Removed the debug printing from `Gadget::generateReport`, as it doesn't
seem to add important information to what was already printed in the
report itself.
---
 .../lib/Passes/NonPacProtectedRetAnalysis.cpp | 136 +++++++++++-------
 1 file changed, 84 insertions(+), 52 deletions(-)

diff --git a/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp b/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp
index dc7cb275f5664..77a16379a14b9 100644
--- a/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp
+++ b/bolt/lib/Passes/NonPacProtectedRetAnalysis.cpp
@@ -14,6 +14,7 @@
 #include "bolt/Passes/NonPacProtectedRetAnalysis.h"
 #include "bolt/Core/ParallelUtilities.h"
 #include "bolt/Passes/DataflowAnalysis.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/MC/MCInst.h"
 #include "llvm/Support/Format.h"
@@ -58,6 +59,71 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &Ref) {
 
 namespace NonPacProtectedRetAnalysis {
 
+static void traceInst(const BinaryContext &BC, StringRef Label,
+                      const MCInst &MI) {
+  dbgs() << "  " << Label << ": ";
+  BC.printInstruction(dbgs(), MI);
+}
+
+static void traceReg(const BinaryContext &BC, StringRef Label,
+                     ErrorOr<MCPhysReg> Reg) {
+  dbgs() << "    " << Label << ": ";
+  if (Reg.getError())
+    dbgs() << "(error)";
+  else if (*Reg == BC.MIB->getNoRegister())
+    dbgs() << "(none)";
+  else
+    dbgs() << BC.MRI->getName(*Reg);
+  dbgs() << "\n";
+}
+
+static void traceRegMask(const BinaryContext &BC, StringRef Label,
+                         BitVector Mask) {
+  dbgs() << "    " << Label << ": ";
+  RegStatePrinter(BC).print(dbgs(), Mask);
+  dbgs() << "\n";
+}
+
+// This class represents mapping from arbitrary physical registers to
+// consecutive array indexes.
+class TrackedRegisters {
+  static const uint16_t NoIndex = -1;
+  const std::vector<MCPhysReg> Registers;
+  std::vector<uint16_t> RegToIndexMapping;
+
+  static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
+    if (RegsToTrack.empty())
+      return 0;
+    return 1 + *llvm::max_element(RegsToTrack);
+  }
+
+public:
+  TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
+      : Registers(RegsToTrack),
+        RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
+    for (unsigned I = 0; I < RegsToTrack.size(); ++I)
+      RegToIndexMapping[RegsToTrack[I]] = I;
+  }
+
+  const ArrayRef<MCPhysReg> getRegisters() const { return Registers; }
+
+  size_t getNumTrackedRegisters() const { return Registers.size(); }
+
+  bool empty() const { return Registers.empty(); }
+
+  bool isTracked(MCPhysReg Reg) const {
+    bool IsTracked = (unsigned)Reg < RegToIndexMapping.size() &&
+                     RegToIndexMapping[Reg] != NoIndex;
+    assert(IsTracked == llvm::is_contained(Registers, Reg));
+    return IsTracked;
+  }
+
+  unsigned getIndex(MCPhysReg Reg) const {
+    assert(isTracked(Reg) && "Register is not tracked");
+    return RegToIndexMapping[Reg];
+  }
+};
+
 // The security property that is checked is:
 // When a register is used as the address to jump to in a return instruction,
 // that register must either:
@@ -169,52 +235,34 @@ class PacRetAnalysis
   PacRetAnalysis(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
                  const std::vector<MCPhysReg> &RegsToTrackInstsFor)
       : Parent(BF, AllocId), NumRegs(BF.getBinaryContext().MRI->getNumRegs()),
-        RegsToTrackInstsFor(RegsToTrackInstsFor),
-        TrackingLastInsts(!RegsToTrackInstsFor.empty()),
-        Reg2StateIdx(RegsToTrackInstsFor.empty()
-                         ? 0
-                         : *llvm::max_element(RegsToTrackInstsFor) + 1,
-                     -1) {
-    for (unsigned I = 0; I < RegsToTrackInstsFor.size(); ++I)
-      Reg2StateIdx[RegsToTrackInstsFor[I]] = I;
-  }
+        RegsToTrackInstsFor(RegsToTrackInstsFor) {}
   virtual ~PacRetAnalysis() {}
 
 protected:
   const unsigned NumRegs;
   /// RegToTrackInstsFor is the set of registers for which the dataflow analysis
   /// must compute which the last set of instructions writing to it are.
-  const std::vector<MCPhysReg> RegsToTrackInstsFor;
-  const bool TrackingLastInsts;
-  /// Reg2StateIdx maps Register to the index in the vector used in State to
-  /// track which instructions last wrote to this register.
-  std::vector<uint16_t> Reg2StateIdx;
+  const TrackedRegisters RegsToTrackInstsFor;
 
   SmallPtrSet<const MCInst *, 4> &lastWritingInsts(State &S,
                                                    MCPhysReg Reg) const {
-    assert(Reg < Reg2StateIdx.size());
-    assert(isTrackingReg(Reg));
-    return S.LastInstWritingReg[Reg2StateIdx[Reg]];
+    unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+    return S.LastInstWritingReg[Index];
   }
   const SmallPtrSet<const MCInst *, 4> &lastWritingInsts(const State &S,
                                                          MCPhysReg Reg) const {
-    assert(Reg < Reg2StateIdx.size());
-    assert(isTrackingReg(Reg));
-    return S.LastInstWritingReg[Reg2StateIdx[Reg]];
-  }
-
-  bool isTrackingReg(MCPhysReg Reg) const {
-    return llvm::is_contained(RegsToTrackInstsFor, Reg);
+    unsigned Index = RegsToTrackInstsFor.getIndex(Reg);
+    return S.LastInstWritingReg[Index];
   }
 
   void preflight() {}
 
   State getStartingStateAtBB(const BinaryBasicBlock &BB) {
-    return State(NumRegs, RegsToTrackInstsFor.size());
+    return State(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
   }
 
   State getStartingStateAtPoint(const MCInst &Point) {
-    return State(NumRegs, RegsToTrackInstsFor.size());
+    return State(NumRegs, RegsToTrackInstsFor.getNumTrackedRegisters());
   }
 
   void doConfluence(State &StateOut, const State &StateIn) {
@@ -275,7 +323,7 @@ class PacRetAnalysis
     Next.NonAutClobRegs |= Written;
     // Keep track of this instruction if it writes to any of the registers we
     // need to track that for:
-    for (MCPhysReg Reg : RegsToTrackInstsFor)
+    for (MCPhysReg Reg : RegsToTrackInstsFor.getRegisters())
       if (Written[Reg])
         lastWritingInsts(Next, Reg) = {&Point};
 
@@ -287,7 +335,7 @@ class PacRetAnalysis
       // https://github.com/llvm/llvm-project/pull/122304#discussion_r1939515516
       Next.NonAutClobRegs.reset(
           BC.MIB->getAliases(*AutReg, /*OnlySmaller=*/true));
-      if (TrackingLastInsts && isTrackingReg(*AutReg))
+      if (RegsToTrackInstsFor.isTracked(*AutReg))
         lastWritingInsts(Next, *AutReg).clear();
     }
 
@@ -306,7 +354,7 @@ class PacRetAnalysis
   std::vector<MCInstReference>
   getLastClobberingInsts(const MCInst Ret, BinaryFunction &BF,
                          const BitVector &UsedDirtyRegs) const {
-    if (!TrackingLastInsts)
+    if (RegsToTrackInstsFor.empty())
       return {};
     auto MaybeState = getStateAt(Ret);
     if (!MaybeState)
@@ -355,28 +403,18 @@ Analysis::computeDfState(PacRetAnalysis &PRA, BinaryFunction &BF,
         }
         MCPhysReg RetReg = *MaybeRetReg;
         LLVM_DEBUG({
-          dbgs() << "  Found RET inst: ";
-          BC.printInstruction(dbgs(), Inst);
-          dbgs() << "    RetReg: " << BC.MRI->getName(RetReg)
-                 << "; authenticatesReg: "
-                 << BC.MIB->isAuthenticationOfReg(Inst, RetReg) << "\n";
+          traceInst(BC, "Found RET inst", Inst);
+          traceReg(BC, "RetReg", RetReg);
+          traceReg(BC, "Authenticated reg", BC.MIB->getAuthenticatedReg(Inst));
         });
         if (BC.MIB->isAuthenticationOfReg(Inst, RetReg))
           break;
         BitVector UsedDirtyRegs = PRA.getStateAt(Inst)->NonAutClobRegs;
-        LLVM_DEBUG({
-          dbgs() << "  NonAutClobRegs at Ret: ";
-          RegStatePrinter RSP(BC);
-          RSP.print(dbgs(), UsedDirtyRegs);
-          dbgs() << "\n";
-        });
+        LLVM_DEBUG(
+            { traceRegMask(BC, "NonAutClobRegs at Ret", UsedDirtyRegs); });
         UsedDirtyRegs &= BC.MIB->getAliases(RetReg, /*OnlySmaller=*/true);
-        LLVM_DEBUG({
-          dbgs() << "  Intersection with RetReg: ";
-          RegStatePrinter RSP(BC);
-          RSP.print(dbgs(), UsedDirtyRegs);
-          dbgs() << "\n";
-        });
+        LLVM_DEBUG(
+            { traceRegMask(BC, "Intersection with RetReg", UsedDirtyRegs); });
         if (UsedDirtyRegs.any()) {
           // This return instruction needs to be reported
           Result.Diagnostics.push_back(std::make_shared<Gadget>(
@@ -472,12 +510,6 @@ void Gadget::generateReport(raw_ostream &OS, const BinaryContext &BC) const {
     OS << "  " << (I + 1) << ". ";
     BC.printInstruction(OS, InstRef, InstRef.getAddress(), BF);
   };
-  LLVM_DEBUG({
-    dbgs() << "  .. OverWritingRetRegInst:\n";
-    for (MCInstReference Ref : OverwritingRetRegInst) {
-      dbgs() << "    " << Ref << "\n";
-    }
-  });
   if (OverwritingRetRegInst.size() == 1) {
     const MCInstReference OverwInst = OverwritingRetRegInst[0];
     assert(OverwInst.ParentKind == MCInstReference::BasicBlockParent);



More information about the llvm-commits mailing list