[llvm] [ReachingDefAnalysis] Turn MBBReachingDefsInfo into a proper class (NFC) (PR #110432)

Kazu Hirata via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 29 11:35:59 PDT 2024


https://github.com/kazutakahirata created https://github.com/llvm/llvm-project/pull/110432

I'm trying to speed up the reaching def analysis by changing the
underlying data structure.  Turning MBBReachingDefsInfo into a proper
class decouples the data structure and its users.  This patch does not
change the existing three-dimensional vector structure.


>From a27f7df64aa84de516abbb095d8ae68d51c1f417 Mon Sep 17 00:00:00 2001
From: Kazu Hirata <kazu at google.com>
Date: Sat, 28 Sep 2024 23:22:27 -0700
Subject: [PATCH] [ReachingDefAnalysis] Turn MBBReachingDefsInfo into a proper
 class (NFC)

I'm trying to speed up the reaching def analysis by changing the
underlying data structure.  Turning MBBReachingDefsInfo into a proper
class decouples the data structure and its users.  This patch does not
change the existing three-dimensional vector structure.
---
 .../llvm/CodeGen/ReachingDefAnalysis.h        | 50 ++++++++++++++++---
 llvm/lib/CodeGen/ReachingDefAnalysis.cpp      | 40 ++++++++-------
 2 files changed, 65 insertions(+), 25 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
index ec652f448f0f65..52fdf1961385d8 100644
--- a/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
+++ b/llvm/include/llvm/CodeGen/ReachingDefAnalysis.h
@@ -65,6 +65,50 @@ struct PointerLikeTypeTraits<ReachingDef> {
   }
 };
 
+// The storage for all reaching definitions.
+class MBBReachingDefsInfo {
+public:
+  void init(unsigned NumBlockIDs) { AllReachingDefs.resize(NumBlockIDs); }
+
+  unsigned numBlockIDs() const { return AllReachingDefs.size(); }
+
+  void startBasicBlock(unsigned MBBNumber, unsigned NumRegUnits) {
+    AllReachingDefs[MBBNumber].resize(NumRegUnits);
+  }
+
+  void append(unsigned MBBNumber, unsigned Unit, int Def) {
+    AllReachingDefs[MBBNumber][Unit].push_back(Def);
+  }
+
+  void prepend(unsigned MBBNumber, unsigned Unit, int Def) {
+    auto &Defs = AllReachingDefs[MBBNumber][Unit];
+    Defs.insert(Defs.begin(), Def);
+  }
+
+  void replaceFront(unsigned MBBNumber, unsigned Unit, int Def) {
+    assert(!AllReachingDefs[MBBNumber][Unit].empty());
+    *AllReachingDefs[MBBNumber][Unit].begin() = Def;
+  }
+
+  void clear() { AllReachingDefs.clear(); }
+
+  ArrayRef<ReachingDef> defs(unsigned MBBNumber, unsigned Unit) const {
+    if (AllReachingDefs[MBBNumber].empty())
+      // Block IDs are not necessarily dense.
+      return ArrayRef<ReachingDef>();
+    return AllReachingDefs[MBBNumber][Unit];
+  }
+
+private:
+  /// All reaching defs of a given RegUnit for a given MBB.
+  using MBBRegUnitDefs = TinyPtrVector<ReachingDef>;
+  /// All reaching defs of all reg units for a given MBB
+  using MBBDefsInfo = std::vector<MBBRegUnitDefs>;
+
+  /// All reaching defs of all reg units for a all MBBs
+  SmallVector<MBBDefsInfo, 4> AllReachingDefs;
+};
+
 /// This class provides the reaching def analysis.
 class ReachingDefAnalysis : public MachineFunctionPass {
 private:
@@ -93,12 +137,6 @@ class ReachingDefAnalysis : public MachineFunctionPass {
   /// their basic blocks.
   DenseMap<MachineInstr *, int> InstIds;
 
-  /// All reaching defs of a given RegUnit for a given MBB.
-  using MBBRegUnitDefs = TinyPtrVector<ReachingDef>;
-  /// All reaching defs of all reg units for a given MBB
-  using MBBDefsInfo = std::vector<MBBRegUnitDefs>;
-  /// All reaching defs of all reg units for a all MBBs
-  using MBBReachingDefsInfo = SmallVector<MBBDefsInfo, 4>;
   MBBReachingDefsInfo MBBReachingDefs;
 
   /// Default values are 'nothing happened a long time ago'.
diff --git a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
index 07fa92889d8853..0e8220ec6251cb 100644
--- a/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
+++ b/llvm/lib/CodeGen/ReachingDefAnalysis.cpp
@@ -50,9 +50,9 @@ static bool isValidRegDefOf(const MachineOperand &MO, MCRegister PhysReg,
 
 void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
   unsigned MBBNumber = MBB->getNumber();
-  assert(MBBNumber < MBBReachingDefs.size() &&
+  assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
          "Unexpected basic block number.");
-  MBBReachingDefs[MBBNumber].resize(NumRegUnits);
+  MBBReachingDefs.startBasicBlock(MBBNumber, NumRegUnits);
 
   // Reset instruction counter in each basic block.
   CurInstr = 0;
@@ -71,7 +71,7 @@ void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
         // before the call.
         if (LiveRegs[Unit] != -1) {
           LiveRegs[Unit] = -1;
-          MBBReachingDefs[MBBNumber][Unit].push_back(-1);
+          MBBReachingDefs.append(MBBNumber, Unit, -1);
         }
       }
     }
@@ -97,7 +97,7 @@ void ReachingDefAnalysis::enterBasicBlock(MachineBasicBlock *MBB) {
   // Insert the most recent reaching definition we found.
   for (unsigned Unit = 0; Unit != NumRegUnits; ++Unit)
     if (LiveRegs[Unit] != ReachingDefDefaultVal)
-      MBBReachingDefs[MBBNumber][Unit].push_back(LiveRegs[Unit]);
+      MBBReachingDefs.append(MBBNumber, Unit, LiveRegs[Unit]);
 }
 
 void ReachingDefAnalysis::leaveBasicBlock(MachineBasicBlock *MBB) {
@@ -122,7 +122,7 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
   assert(!MI->isDebugInstr() && "Won't process debug instructions");
 
   unsigned MBBNumber = MI->getParent()->getNumber();
-  assert(MBBNumber < MBBReachingDefs.size() &&
+  assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
          "Unexpected basic block number.");
 
   for (auto &MO : MI->operands()) {
@@ -136,7 +136,7 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
       // How many instructions since this reg unit was last written?
       if (LiveRegs[Unit] != CurInstr) {
         LiveRegs[Unit] = CurInstr;
-        MBBReachingDefs[MBBNumber][Unit].push_back(CurInstr);
+        MBBReachingDefs.append(MBBNumber, Unit, CurInstr);
       }
     }
   }
@@ -146,7 +146,7 @@ void ReachingDefAnalysis::processDefs(MachineInstr *MI) {
 
 void ReachingDefAnalysis::reprocessBasicBlock(MachineBasicBlock *MBB) {
   unsigned MBBNumber = MBB->getNumber();
-  assert(MBBNumber < MBBReachingDefs.size() &&
+  assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
          "Unexpected basic block number.");
 
   // Count number of non-debug instructions for end of block adjustment.
@@ -169,16 +169,16 @@ void ReachingDefAnalysis::reprocessBasicBlock(MachineBasicBlock *MBB) {
       if (Def == ReachingDefDefaultVal)
         continue;
 
-      auto Start = MBBReachingDefs[MBBNumber][Unit].begin();
-      if (Start != MBBReachingDefs[MBBNumber][Unit].end() && *Start < 0) {
-        if (*Start >= Def)
+      auto Defs = MBBReachingDefs.defs(MBBNumber, Unit);
+      if (!Defs.empty() && Defs.front() < 0) {
+        if (Defs.front() >= Def)
           continue;
 
         // Update existing reaching def from predecessor to a more recent one.
-        *Start = Def;
+        MBBReachingDefs.replaceFront(MBBNumber, Unit, Def);
       } else {
         // Insert new reaching def from predecessor.
-        MBBReachingDefs[MBBNumber][Unit].insert(Start, Def);
+        MBBReachingDefs.prepend(MBBNumber, Unit, Def);
       }
 
       // Update reaching def at end of BB. Keep in mind that these are
@@ -234,7 +234,7 @@ void ReachingDefAnalysis::reset() {
 
 void ReachingDefAnalysis::init() {
   NumRegUnits = TRI->getNumRegUnits();
-  MBBReachingDefs.resize(MF->getNumBlockIDs());
+  MBBReachingDefs.init(MF->getNumBlockIDs());
   // Initialize the MBBOutRegsInfos
   MBBOutRegsInfos.resize(MF->getNumBlockIDs());
   LoopTraversal Traversal;
@@ -247,10 +247,11 @@ void ReachingDefAnalysis::traverse() {
     processBasicBlock(TraversedMBB);
 #ifndef NDEBUG
   // Make sure reaching defs are sorted and unique.
-  for (MBBDefsInfo &MBBDefs : MBBReachingDefs) {
-    for (MBBRegUnitDefs &RegUnitDefs : MBBDefs) {
+  for (unsigned MBBNumber = 0, NumBlockIDs = MF->getNumBlockIDs();
+       MBBNumber != NumBlockIDs; ++MBBNumber) {
+    for (unsigned Unit = 0; Unit != NumRegUnits; ++Unit) {
       int LastDef = ReachingDefDefaultVal;
-      for (int Def : RegUnitDefs) {
+      for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) {
         assert(Def > LastDef && "Defs must be sorted and unique");
         LastDef = Def;
       }
@@ -265,11 +266,11 @@ int ReachingDefAnalysis::getReachingDef(MachineInstr *MI,
   int InstId = InstIds.lookup(MI);
   int DefRes = ReachingDefDefaultVal;
   unsigned MBBNumber = MI->getParent()->getNumber();
-  assert(MBBNumber < MBBReachingDefs.size() &&
+  assert(MBBNumber < MBBReachingDefs.numBlockIDs() &&
          "Unexpected basic block number.");
   int LatestDef = ReachingDefDefaultVal;
   for (MCRegUnit Unit : TRI->regunits(PhysReg)) {
-    for (int Def : MBBReachingDefs[MBBNumber][Unit]) {
+    for (int Def : MBBReachingDefs.defs(MBBNumber, Unit)) {
       if (Def >= InstId)
         break;
       DefRes = Def;
@@ -299,7 +300,8 @@ bool ReachingDefAnalysis::hasSameReachingDef(MachineInstr *A, MachineInstr *B,
 
 MachineInstr *ReachingDefAnalysis::getInstFromId(MachineBasicBlock *MBB,
                                                  int InstId) const {
-  assert(static_cast<size_t>(MBB->getNumber()) < MBBReachingDefs.size() &&
+  assert(static_cast<size_t>(MBB->getNumber()) <
+             MBBReachingDefs.numBlockIDs() &&
          "Unexpected basic block number.");
   assert(InstId < static_cast<int>(MBB->size()) &&
          "Unexpected instruction id.");



More information about the llvm-commits mailing list