[llvm-branch-commits] [llvm] [BOLT] Gadget scanner: use more appropriate types (NFC) (PR #135661)

Anatoly Trosinenko via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Apr 18 09:35:00 PDT 2025


https://github.com/atrosinenko updated https://github.com/llvm/llvm-project/pull/135661

>From d57dc48475cc025369a003abcb9412833a955d28 Mon Sep 17 00:00:00 2001
From: Anatoly Trosinenko <atrosinenko at accesssoftek.com>
Date: Mon, 14 Apr 2025 14:35:56 +0300
Subject: [PATCH] [BOLT] Gadget scanner: use more appropriate types (NFC)

* use more flexible `const ArrayRef<T>` and `StringRef` types instead of
  `const std::vector<T> &` and `const std::string &`, correspondingly,
  for function arguments
* return plain `const SrcState &` instead of `ErrorOr<const SrcState &>`
  from `SrcSafetyAnalysis::getStateBefore`, as absent state is not
  handled gracefully by any caller
---
 bolt/include/bolt/Passes/PAuthGadgetScanner.h |  8 +---
 bolt/lib/Passes/PAuthGadgetScanner.cpp        | 39 ++++++++-----------
 2 files changed, 19 insertions(+), 28 deletions(-)

diff --git a/bolt/include/bolt/Passes/PAuthGadgetScanner.h b/bolt/include/bolt/Passes/PAuthGadgetScanner.h
index 6765e2aff414f..3e39b64e59e0f 100644
--- a/bolt/include/bolt/Passes/PAuthGadgetScanner.h
+++ b/bolt/include/bolt/Passes/PAuthGadgetScanner.h
@@ -12,7 +12,6 @@
 #include "bolt/Core/BinaryContext.h"
 #include "bolt/Core/BinaryFunction.h"
 #include "bolt/Passes/BinaryPasses.h"
-#include "llvm/ADT/SmallSet.h"
 #include "llvm/Support/raw_ostream.h"
 #include <memory>
 
@@ -199,9 +198,6 @@ raw_ostream &operator<<(raw_ostream &OS, const MCInstReference &);
 
 namespace PAuthGadgetScanner {
 
-class SrcSafetyAnalysis;
-struct SrcState;
-
 /// Description of a gadget kind that can be detected. Intended to be
 /// statically allocated to be attached to reports by reference.
 class GadgetKind {
@@ -210,7 +206,7 @@ class GadgetKind {
 public:
   GadgetKind(const char *Description) : Description(Description) {}
 
-  const StringRef getDescription() const { return Description; }
+  StringRef getDescription() const { return Description; }
 };
 
 /// Base report located at some instruction, without any additional information.
@@ -261,7 +257,7 @@ struct GadgetReport : public Report {
 /// Report with a free-form message attached.
 struct GenericReport : public Report {
   std::string Text;
-  GenericReport(MCInstReference Location, const std::string &Text)
+  GenericReport(MCInstReference Location, StringRef Text)
       : Report(Location), Text(Text) {}
   virtual void generateReport(raw_ostream &OS,
                               const BinaryContext &BC) const override;
diff --git a/bolt/lib/Passes/PAuthGadgetScanner.cpp b/bolt/lib/Passes/PAuthGadgetScanner.cpp
index ad47bdff753c8..ed89471cbb8d3 100644
--- a/bolt/lib/Passes/PAuthGadgetScanner.cpp
+++ b/bolt/lib/Passes/PAuthGadgetScanner.cpp
@@ -91,14 +91,14 @@ class TrackedRegisters {
   const std::vector<MCPhysReg> Registers;
   std::vector<uint16_t> RegToIndexMapping;
 
-  static size_t getMappingSize(const std::vector<MCPhysReg> &RegsToTrack) {
+  static size_t getMappingSize(const ArrayRef<MCPhysReg> RegsToTrack) {
     if (RegsToTrack.empty())
       return 0;
     return 1 + *llvm::max_element(RegsToTrack);
   }
 
 public:
-  TrackedRegisters(const std::vector<MCPhysReg> &RegsToTrack)
+  TrackedRegisters(const ArrayRef<MCPhysReg> RegsToTrack)
       : Registers(RegsToTrack),
         RegToIndexMapping(getMappingSize(RegsToTrack), NoIndex) {
     for (unsigned I = 0; I < RegsToTrack.size(); ++I)
@@ -234,7 +234,7 @@ struct SrcState {
 
 static void printLastInsts(
     raw_ostream &OS,
-    const std::vector<SmallPtrSet<const MCInst *, 4>> &LastInstWritingReg) {
+    const ArrayRef<SmallPtrSet<const MCInst *, 4>> LastInstWritingReg) {
   OS << "Insts: ";
   for (unsigned I = 0; I < LastInstWritingReg.size(); ++I) {
     auto &Set = LastInstWritingReg[I];
@@ -295,7 +295,7 @@ void SrcStatePrinter::print(raw_ostream &OS, const SrcState &S) const {
 class SrcSafetyAnalysis {
 public:
   SrcSafetyAnalysis(BinaryFunction &BF,
-                    const std::vector<MCPhysReg> &RegsToTrackInstsFor)
+                    const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
       : BC(BF.getBinaryContext()), NumRegs(BC.MRI->getNumRegs()),
         RegsToTrackInstsFor(RegsToTrackInstsFor) {}
 
@@ -303,11 +303,10 @@ class SrcSafetyAnalysis {
 
   static std::shared_ptr<SrcSafetyAnalysis>
   create(BinaryFunction &BF, MCPlusBuilder::AllocatorIdTy AllocId,
-         const std::vector<MCPhysReg> &RegsToTrackInstsFor);
+         const ArrayRef<MCPhysReg> RegsToTrackInstsFor);
 
   virtual void run() = 0;
-  virtual ErrorOr<const SrcState &>
-  getStateBefore(const MCInst &Inst) const = 0;
+  virtual const SrcState &getStateBefore(const MCInst &Inst) const = 0;
 
 protected:
   BinaryContext &BC;
@@ -348,7 +347,7 @@ class SrcSafetyAnalysis {
   }
 
   BitVector getClobberedRegs(const MCInst &Point) const {
-    BitVector Clobbered(NumRegs, false);
+    BitVector Clobbered(NumRegs);
     // Assume a call can clobber all registers, including callee-saved
     // registers. There's a good chance that callee-saved registers will be
     // saved on the stack at some point during execution of the callee.
@@ -409,8 +408,7 @@ class SrcSafetyAnalysis {
 
       // FirstCheckerInst should belong to the same basic block, meaning
       // it was deterministically processed a few steps before this instruction.
-      const SrcState &StateBeforeChecker =
-          getStateBefore(*FirstCheckerInst).get();
+      const SrcState &StateBeforeChecker = getStateBefore(*FirstCheckerInst);
       if (StateBeforeChecker.SafeToDerefRegs[CheckedReg])
         Regs.push_back(CheckedReg);
     }
@@ -523,10 +521,7 @@ class SrcSafetyAnalysis {
                          const ArrayRef<MCPhysReg> UsedDirtyRegs) const {
     if (RegsToTrackInstsFor.empty())
       return {};
-    auto MaybeState = getStateBefore(Inst);
-    if (!MaybeState)
-      llvm_unreachable("Expected state to be present");
-    const SrcState &S = *MaybeState;
+    const SrcState &S = getStateBefore(Inst);
     // Due to aliasing registers, multiple registers may have been tracked.
     std::set<const MCInst *> LastWritingInsts;
     for (MCPhysReg TrackedReg : UsedDirtyRegs) {
@@ -537,7 +532,7 @@ class SrcSafetyAnalysis {
     for (const MCInst *Inst : LastWritingInsts) {
       MCInstReference Ref = MCInstReference::get(Inst, BF);
       assert(Ref && "Expected Inst to be found");
-      Result.push_back(MCInstReference(Ref));
+      Result.push_back(Ref);
     }
     return Result;
   }
@@ -557,11 +552,11 @@ class DataflowSrcSafetyAnalysis
 public:
   DataflowSrcSafetyAnalysis(BinaryFunction &BF,
                             MCPlusBuilder::AllocatorIdTy AllocId,
-                            const std::vector<MCPhysReg> &RegsToTrackInstsFor)
+                            const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
       : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), DFParent(BF, AllocId) {}
 
-  ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
-    return DFParent::getStateBefore(Inst);
+  const SrcState &getStateBefore(const MCInst &Inst) const override {
+    return DFParent::getStateBefore(Inst).get();
   }
 
   void run() override {
@@ -670,7 +665,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
 public:
   CFGUnawareSrcSafetyAnalysis(BinaryFunction &BF,
                               MCPlusBuilder::AllocatorIdTy AllocId,
-                              const std::vector<MCPhysReg> &RegsToTrackInstsFor)
+                              const ArrayRef<MCPhysReg> RegsToTrackInstsFor)
       : SrcSafetyAnalysis(BF, RegsToTrackInstsFor), BF(BF), AllocId(AllocId) {
     StateAnnotationIndex =
         BC.MIB->getOrCreateAnnotationIndex("CFGUnawareSrcSafetyAnalysis");
@@ -704,7 +699,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
     }
   }
 
-  ErrorOr<const SrcState &> getStateBefore(const MCInst &Inst) const override {
+  const SrcState &getStateBefore(const MCInst &Inst) const override {
     return BC.MIB->getAnnotationAs<SrcState>(Inst, StateAnnotationIndex);
   }
 
@@ -714,7 +709,7 @@ class CFGUnawareSrcSafetyAnalysis : public SrcSafetyAnalysis {
 std::shared_ptr<SrcSafetyAnalysis>
 SrcSafetyAnalysis::create(BinaryFunction &BF,
                           MCPlusBuilder::AllocatorIdTy AllocId,
-                          const std::vector<MCPhysReg> &RegsToTrackInstsFor) {
+                          const ArrayRef<MCPhysReg> RegsToTrackInstsFor) {
   if (BF.hasCFG())
     return std::make_shared<DataflowSrcSafetyAnalysis>(BF, AllocId,
                                                        RegsToTrackInstsFor);
@@ -821,7 +816,7 @@ Analysis::findGadgets(BinaryFunction &BF,
 
   BinaryContext &BC = BF.getBinaryContext();
   iterateOverInstrs(BF, [&](MCInstReference Inst) {
-    const SrcState &S = *Analysis->getStateBefore(Inst);
+    const SrcState &S = Analysis->getStateBefore(Inst);
 
     // If non-empty state was never propagated from the entry basic block
     // to Inst, assume it to be unreachable and report a warning.



More information about the llvm-branch-commits mailing list