[llvm] [CodeGen] Utilizing register units based liveIns tracking in MBB (PR #129847)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 5 00:14:50 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-arm

Author: Vikash Gupta (vg0204)

<details>
<summary>Changes</summary>

Currently, the machine basicblock does not fully utilizes the laneBitmask associated with physReg liveIns to check for the precise liveness. Conservatively, it acts fully correct now, only if all liveIns check for MBB is in form it defines it for itself.

So, now with the use of register units tracking for MBB's liveIns , its possible to track & check liveness for all sorts of physRegs.

It is needed to handle #<!-- -->129848 

---
Full diff: https://github.com/llvm/llvm-project/pull/129847.diff


3 Files Affected:

- (modified) llvm/include/llvm/CodeGen/MachineBasicBlock.h (+15) 
- (modified) llvm/lib/CodeGen/MachineBasicBlock.cpp (+29-5) 
- (modified) llvm/test/CodeGen/ARM/aes-erratum-fix.ll (+4-4) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/MachineBasicBlock.h b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
index 2de96fa85b936..ab88177b63fed 100644
--- a/llvm/include/llvm/CodeGen/MachineBasicBlock.h
+++ b/llvm/include/llvm/CodeGen/MachineBasicBlock.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_CODEGEN_MACHINEBASICBLOCK_H
 #define LLVM_CODEGEN_MACHINEBASICBLOCK_H
 
+#include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/GraphTraits.h"
 #include "llvm/ADT/SparseBitVector.h"
@@ -158,6 +159,7 @@ class MachineBasicBlock
 
   MachineFunction *xParent;
   Instructions Insts;
+  const TargetRegisterInfo *TRI;
 
   /// Keep track of the predecessor / successor basic blocks.
   SmallVector<MachineBasicBlock *, 4> Predecessors;
@@ -177,6 +179,10 @@ class MachineBasicBlock
   using LiveInVector = std::vector<RegisterMaskPair>;
   LiveInVector LiveIns;
 
+  /// Keeps track of live register units for those physical registers which
+  /// are livein of the basicblock.
+  BitVector LiveInRegUnits;
+
   /// Alignment of the basic block. One if the basic block does not need to be
   /// aligned.
   Align Alignment;
@@ -458,11 +464,17 @@ class MachineBasicBlock
   void addLiveIn(MCRegister PhysReg,
                  LaneBitmask LaneMask = LaneBitmask::getAll()) {
     LiveIns.push_back(RegisterMaskPair(PhysReg, LaneMask));
+    addLiveInRegUnit(PhysReg, LaneMask);
   }
   void addLiveIn(const RegisterMaskPair &RegMaskPair) {
     LiveIns.push_back(RegMaskPair);
+    addLiveInRegUnit(RegMaskPair.PhysReg, RegMaskPair.LaneMask);
   }
 
+  // Sets the register units for Reg based on the LaneMask in the
+  // LiveInRegUnits.
+  void addLiveInRegUnit(MCRegister Reg, LaneBitmask LaneMask);
+
   /// Sorts and uniques the LiveIns vector. It can be significantly faster to do
   /// this than repeatedly calling isLiveIn before calling addLiveIn for every
   /// LiveIn insertion.
@@ -484,6 +496,9 @@ class MachineBasicBlock
   void removeLiveIn(MCRegister Reg,
                     LaneBitmask LaneMask = LaneBitmask::getAll());
 
+  /// Resets the register units from LiveInRegUnits for the specified regsiters.
+  void removeLiveInRegUnit(MCRegister Reg);
+
   /// Return true if the specified register is in the live in set.
   bool isLiveIn(MCRegister Reg,
                 LaneBitmask LaneMask = LaneBitmask::getAll()) const;
diff --git a/llvm/lib/CodeGen/MachineBasicBlock.cpp b/llvm/lib/CodeGen/MachineBasicBlock.cpp
index b3a71d1144726..0e5055d7bec2c 100644
--- a/llvm/lib/CodeGen/MachineBasicBlock.cpp
+++ b/llvm/lib/CodeGen/MachineBasicBlock.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/ModuleSlotTracker.h"
 #include "llvm/MC/MCAsmInfo.h"
 #include "llvm/MC/MCContext.h"
+#include "llvm/MC/MCRegisterInfo.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Target/TargetMachine.h"
@@ -51,10 +52,12 @@ static cl::opt<bool> PrintSlotIndexes(
     cl::init(true), cl::Hidden);
 
 MachineBasicBlock::MachineBasicBlock(MachineFunction &MF, const BasicBlock *B)
-    : BB(B), Number(-1), xParent(&MF) {
+    : BB(B), Number(-1), xParent(&MF),
+      TRI(MF.getSubtarget().getRegisterInfo()) {
   Insts.Parent = this;
   if (B)
     IrrLoopHeaderWeight = B->getIrrLoopHeaderWeight();
+  LiveInRegUnits.resize(TRI->getNumRegUnits());
 }
 
 MachineBasicBlock::~MachineBasicBlock() = default;
@@ -597,6 +600,14 @@ void MachineBasicBlock::printAsOperand(raw_ostream &OS,
   printName(OS, 0);
 }
 
+void MachineBasicBlock::addLiveInRegUnit(MCRegister Reg, LaneBitmask LaneMask) {
+  for (MCRegUnitMaskIterator Unit(Reg, TRI); Unit.isValid(); ++Unit) {
+    LaneBitmask UnitMask = (*Unit).second;
+    if ((UnitMask & LaneMask).any())
+      LiveInRegUnits.set((*Unit).first);
+  }
+}
+
 void MachineBasicBlock::removeLiveIn(MCRegister Reg, LaneBitmask LaneMask) {
   LiveInVector::iterator I = find_if(
       LiveIns, [Reg](const RegisterMaskPair &LI) { return LI.PhysReg == Reg; });
@@ -604,21 +615,32 @@ void MachineBasicBlock::removeLiveIn(MCRegister Reg, LaneBitmask LaneMask) {
     return;
 
   I->LaneMask &= ~LaneMask;
-  if (I->LaneMask.none())
+  if (I->LaneMask.none()) {
     LiveIns.erase(I);
+    removeLiveInRegUnit(I->PhysReg);
+  }
 }
 
 MachineBasicBlock::livein_iterator
 MachineBasicBlock::removeLiveIn(MachineBasicBlock::livein_iterator I) {
   // Get non-const version of iterator.
   LiveInVector::iterator LI = LiveIns.begin() + (I - LiveIns.begin());
+  removeLiveInRegUnit(LI->PhysReg);
   return LiveIns.erase(LI);
 }
 
+void MachineBasicBlock::removeLiveInRegUnit(MCRegister Reg) {
+  for (MCRegUnit Unit : TRI->regunits(Reg))
+    LiveInRegUnits.reset(Unit);
+}
+
 bool MachineBasicBlock::isLiveIn(MCRegister Reg, LaneBitmask LaneMask) const {
-  livein_iterator I = find_if(
-      LiveIns, [Reg](const RegisterMaskPair &LI) { return LI.PhysReg == Reg; });
-  return I != livein_end() && (I->LaneMask & LaneMask).any();
+  for (MCRegUnitMaskIterator Unit(Reg, TRI); Unit.isValid(); ++Unit) {
+    LaneBitmask UnitMask = (*Unit).second;
+    if ((UnitMask & LaneMask).any() && LiveInRegUnits.test((*Unit).first))
+      return true;
+  }
+  return false;
 }
 
 void MachineBasicBlock::sortUniqueLiveIns() {
@@ -1751,12 +1773,14 @@ MachineBasicBlock::getEndClobberMask(const TargetRegisterInfo *TRI) const {
 
 void MachineBasicBlock::clearLiveIns() {
   LiveIns.clear();
+  LiveInRegUnits.reset();
 }
 
 void MachineBasicBlock::clearLiveIns(
     std::vector<RegisterMaskPair> &OldLiveIns) {
   assert(OldLiveIns.empty() && "Vector must be empty");
   std::swap(LiveIns, OldLiveIns);
+  LiveInRegUnits.reset();
 }
 
 MachineBasicBlock::livein_iterator MachineBasicBlock::livein_begin() const {
diff --git a/llvm/test/CodeGen/ARM/aes-erratum-fix.ll b/llvm/test/CodeGen/ARM/aes-erratum-fix.ll
index 82f5bfd02a56e..e1361d6efa780 100644
--- a/llvm/test/CodeGen/ARM/aes-erratum-fix.ll
+++ b/llvm/test/CodeGen/ARM/aes-erratum-fix.ll
@@ -68,8 +68,8 @@ define arm_aapcs_vfpcc void @aese_via_call2(half %0, ptr %1) nounwind {
 ; CHECK-FIX-NEXT:    push {r4, lr}
 ; CHECK-FIX-NEXT:    mov r4, r0
 ; CHECK-FIX-NEXT:    bl get_inputf16
-; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    vld1.64 {d16, d17}, [r4]
+; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    aese.8 q8, q0
 ; CHECK-FIX-NEXT:    aesmc.8 q8, q8
 ; CHECK-FIX-NEXT:    vst1.64 {d16, d17}, [r4]
@@ -89,8 +89,8 @@ define arm_aapcs_vfpcc void @aese_via_call3(float %0, ptr %1) nounwind {
 ; CHECK-FIX-NEXT:    push {r4, lr}
 ; CHECK-FIX-NEXT:    mov r4, r0
 ; CHECK-FIX-NEXT:    bl get_inputf32
-; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    vld1.64 {d16, d17}, [r4]
+; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    aese.8 q8, q0
 ; CHECK-FIX-NEXT:    aesmc.8 q8, q8
 ; CHECK-FIX-NEXT:    vst1.64 {d16, d17}, [r4]
@@ -2222,8 +2222,8 @@ define arm_aapcs_vfpcc void @aesd_via_call2(half %0, ptr %1) nounwind {
 ; CHECK-FIX-NEXT:    push {r4, lr}
 ; CHECK-FIX-NEXT:    mov r4, r0
 ; CHECK-FIX-NEXT:    bl get_inputf16
-; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    vld1.64 {d16, d17}, [r4]
+; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    aesd.8 q8, q0
 ; CHECK-FIX-NEXT:    aesimc.8 q8, q8
 ; CHECK-FIX-NEXT:    vst1.64 {d16, d17}, [r4]
@@ -2243,8 +2243,8 @@ define arm_aapcs_vfpcc void @aesd_via_call3(float %0, ptr %1) nounwind {
 ; CHECK-FIX-NEXT:    push {r4, lr}
 ; CHECK-FIX-NEXT:    mov r4, r0
 ; CHECK-FIX-NEXT:    bl get_inputf32
-; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    vld1.64 {d16, d17}, [r4]
+; CHECK-FIX-NEXT:    vorr q0, q0, q0
 ; CHECK-FIX-NEXT:    aesd.8 q8, q0
 ; CHECK-FIX-NEXT:    aesimc.8 q8, q8
 ; CHECK-FIX-NEXT:    vst1.64 {d16, d17}, [r4]

``````````

</details>


https://github.com/llvm/llvm-project/pull/129847


More information about the llvm-commits mailing list