[llvm] [EarlyIfConversion] Determine if branch is predictable using new APIs. (PR #95877)

Mikhail Gudim via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 17 20:34:33 PDT 2024


https://github.com/mgudim updated https://github.com/llvm/llvm-project/pull/95877

>From 3d4d09edc78de4f8253604aec6a58492a250eef2 Mon Sep 17 00:00:00 2001
From: Mikhail Gudim <mgudim at gmail.com>
Date: Mon, 17 Jun 2024 23:17:09 -0400
Subject: [PATCH] [EarlyIfConversion] Determine if branch is predictable using
 new APIs.

---
 llvm/include/llvm/CodeGen/MachineLoopInfo.h  |  6 ++-
 llvm/include/llvm/CodeGen/TargetInstrInfo.h  | 12 ++++++
 llvm/lib/CodeGen/EarlyIfConversion.cpp       | 40 +++++---------------
 llvm/lib/CodeGen/MachineLoopInfo.cpp         | 32 ++++++++++------
 llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 24 +++++++++++-
 llvm/lib/Target/AArch64/AArch64InstrInfo.h   |  7 ++++
 6 files changed, 75 insertions(+), 46 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/MachineLoopInfo.h b/llvm/include/llvm/CodeGen/MachineLoopInfo.h
index 967c4a70ca469..e6a163cfd388e 100644
--- a/llvm/include/llvm/CodeGen/MachineLoopInfo.h
+++ b/llvm/include/llvm/CodeGen/MachineLoopInfo.h
@@ -82,7 +82,8 @@ class MachineLoop : public LoopBase<MachineBasicBlock, MachineLoop> {
   /// ExcludeReg can be used to exclude the given register from the check
   /// i.e. when we're considering hoisting it's definition but not hoisted it
   /// yet
-  bool isLoopInvariant(MachineInstr &I, const Register ExcludeReg = 0) const;
+  bool isLoopInvariant(const MachineInstr &I, const Register ExcludeReg = 0,
+                       unsigned RecursionDepth = 1) const;
 
   void dump() const;
 
@@ -90,7 +91,8 @@ class MachineLoop : public LoopBase<MachineBasicBlock, MachineLoop> {
   friend class LoopInfoBase<MachineBasicBlock, MachineLoop>;
 
   /// Returns true if the given physreg has no defs inside the loop.
-  bool isLoopInvariantImplicitPhysReg(Register Reg) const;
+  bool isLoopInvariantImplicitPhysReg(Register Reg, Register ExcludeReg,
+                                      unsigned RecursionDepth = 1) const;
 
   explicit MachineLoop(MachineBasicBlock *MBB)
     : LoopBase<MachineBasicBlock, MachineLoop>(MBB) {}
diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
index d5b1df2114e9e..adf1a43b10d6f 100644
--- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h
@@ -45,6 +45,7 @@ class InstrItineraryData;
 class LiveIntervals;
 class LiveVariables;
 class MachineLoop;
+class MachineLoopInfo;
 class MachineMemOperand;
 class MachineRegisterInfo;
 class MCAsmInfo;
@@ -654,6 +655,17 @@ class TargetInstrInfo : public MCInstrInfo {
     return true;
   }
 
+  // Same as above but also if IsPredictable is non-null set IsPredictable to
+  // "true" if target considers this branch to be predictable and to false
+  // otherwise.
+  virtual bool analyzeBranch(MachineBasicBlock &MBB, MachineBasicBlock *&TBB,
+                             MachineBasicBlock *&FBB,
+                             SmallVectorImpl<MachineOperand> &Cond,
+                             bool *IsPredictable, const MachineLoopInfo *MLI,
+                             bool AllowModify = false) const {
+    return analyzeBranch(MBB, TBB, FBB, Cond, AllowModify);
+  }
+
   /// Represents a predicate at the MachineFunction level.  The control flow a
   /// MachineBranchPredicate represents is:
   ///
diff --git a/llvm/lib/CodeGen/EarlyIfConversion.cpp b/llvm/lib/CodeGen/EarlyIfConversion.cpp
index 2a7bee1618deb..b668e6f427f6b 100644
--- a/llvm/lib/CodeGen/EarlyIfConversion.cpp
+++ b/llvm/lib/CodeGen/EarlyIfConversion.cpp
@@ -83,6 +83,7 @@ class SSAIfConv {
   const TargetInstrInfo *TII;
   const TargetRegisterInfo *TRI;
   MachineRegisterInfo *MRI;
+  const MachineLoopInfo *MLI;
 
 public:
   /// The block containing the conditional branch.
@@ -121,6 +122,8 @@ class SSAIfConv {
 
   /// The branch condition determined by analyzeBranch.
   SmallVector<MachineOperand, 4> Cond;
+  /// Is branch predictable as determined by analyzeBranch.
+  bool IsPredictableBranch = false;
 
 private:
   /// Instructions in Head that define values used by the conditional blocks.
@@ -164,7 +167,7 @@ class SSAIfConv {
 
 public:
   /// runOnMachineFunction - Initialize per-function data structures.
-  void runOnMachineFunction(MachineFunction &MF) {
+  void runOnMachineFunction(MachineFunction &MF, const MachineLoopInfo *MLI) {
     TII = MF.getSubtarget().getInstrInfo();
     TRI = MF.getSubtarget().getRegisterInfo();
     MRI = &MF.getRegInfo();
@@ -172,6 +175,7 @@ class SSAIfConv {
     LiveRegUnits.setUniverse(TRI->getNumRegUnits());
     ClobberedRegUnits.clear();
     ClobberedRegUnits.resize(TRI->getNumRegUnits());
+    this->MLI = MLI;
   }
 
   /// canConvertIf - If the sub-CFG headed by MBB can be if-converted,
@@ -485,7 +489,7 @@ bool SSAIfConv::canConvertIf(MachineBasicBlock *MBB, bool Predicate) {
 
   // The branch we're looking to eliminate must be analyzable.
   Cond.clear();
-  if (TII->analyzeBranch(*Head, TBB, FBB, Cond)) {
+  if (TII->analyzeBranch(*Head, TBB, FBB, Cond, &IsPredictableBranch, MLI)) {
     LLVM_DEBUG(dbgs() << "Branch not analyzable.\n");
     return false;
   }
@@ -874,33 +878,7 @@ bool EarlyIfConverter::shouldConvertIf() {
   // Do not try to if-convert if the condition has a high chance of being
   // predictable.
   MachineLoop *CurrentLoop = Loops->getLoopFor(IfConv.Head);
-  // If the condition is in a loop, consider it predictable if the condition
-  // itself or all its operands are loop-invariant. E.g. this considers a load
-  // from a loop-invariant address predictable; we were unable to prove that it
-  // doesn't alias any of the memory-writes in the loop, but it is likely to
-  // read to same value multiple times.
-  if (CurrentLoop && any_of(IfConv.Cond, [&](MachineOperand &MO) {
-        if (!MO.isReg() || !MO.isUse())
-          return false;
-        Register Reg = MO.getReg();
-        if (Register::isPhysicalRegister(Reg))
-          return false;
-
-        MachineInstr *Def = MRI->getVRegDef(Reg);
-        return CurrentLoop->isLoopInvariant(*Def) ||
-               all_of(Def->operands(), [&](MachineOperand &Op) {
-                 if (Op.isImm())
-                   return true;
-                 if (!MO.isReg() || !MO.isUse())
-                   return false;
-                 Register Reg = MO.getReg();
-                 if (Register::isPhysicalRegister(Reg))
-                   return false;
-
-                 MachineInstr *Def = MRI->getVRegDef(Reg);
-                 return CurrentLoop->isLoopInvariant(*Def);
-               });
-      }))
+  if (CurrentLoop && IfConv.IsPredictableBranch)
     return false;
 
   if (!MinInstr)
@@ -1095,7 +1073,7 @@ bool EarlyIfConverter::runOnMachineFunction(MachineFunction &MF) {
   MinInstr = nullptr;
 
   bool Changed = false;
-  IfConv.runOnMachineFunction(MF);
+  IfConv.runOnMachineFunction(MF, Loops);
 
   // Visit blocks in dominator tree post-order. The post-order enables nested
   // if-conversion in a single pass. The tryConvertIf() function may erase
@@ -1228,7 +1206,7 @@ bool EarlyIfPredicator::runOnMachineFunction(MachineFunction &MF) {
   MBPI = &getAnalysis<MachineBranchProbabilityInfo>();
 
   bool Changed = false;
-  IfConv.runOnMachineFunction(MF);
+  IfConv.runOnMachineFunction(MF, Loops);
 
   // Visit blocks in dominator tree post-order. The post-order enables nested
   // if-conversion in a single pass. The tryConvertIf() function may erase
diff --git a/llvm/lib/CodeGen/MachineLoopInfo.cpp b/llvm/lib/CodeGen/MachineLoopInfo.cpp
index 1019c53e57c6f..4f8245a35c9ab 100644
--- a/llvm/lib/CodeGen/MachineLoopInfo.cpp
+++ b/llvm/lib/CodeGen/MachineLoopInfo.cpp
@@ -198,7 +198,8 @@ MDNode *MachineLoop::getLoopID() const {
   return LoopID;
 }
 
-bool MachineLoop::isLoopInvariantImplicitPhysReg(Register Reg) const {
+bool MachineLoop::isLoopInvariantImplicitPhysReg(
+    Register Reg, Register ExcludeReg, unsigned RecursionDepth) const {
   MachineFunction *MF = getHeader()->getParent();
   MachineRegisterInfo *MRI = &MF->getRegInfo();
 
@@ -210,15 +211,20 @@ bool MachineLoop::isLoopInvariantImplicitPhysReg(Register Reg) const {
            ->shouldAnalyzePhysregInMachineLoopInfo(Reg))
     return false;
 
-  return !llvm::any_of(
-      MRI->def_instructions(Reg),
-      [this](const MachineInstr &MI) { return this->contains(&MI); });
+  return !llvm::any_of(MRI->def_instructions(Reg), [=](const MachineInstr &MI) {
+    return (this->contains(&MI) &&
+            !isLoopInvariant(MI, ExcludeReg, RecursionDepth - 1));
+  });
 }
 
-bool MachineLoop::isLoopInvariant(MachineInstr &I,
-                                  const Register ExcludeReg) const {
-  MachineFunction *MF = I.getParent()->getParent();
-  MachineRegisterInfo *MRI = &MF->getRegInfo();
+bool MachineLoop::isLoopInvariant(const MachineInstr &I,
+                                  const Register ExcludeReg,
+                                  unsigned RecursionDepth) const {
+  if (RecursionDepth == 0)
+    return false;
+
+  const MachineFunction *MF = I.getParent()->getParent();
+  const MachineRegisterInfo *MRI = &MF->getRegInfo();
   const TargetSubtargetInfo &ST = MF->getSubtarget();
   const TargetRegisterInfo *TRI = ST.getRegisterInfo();
   const TargetInstrInfo *TII = ST.getInstrInfo();
@@ -243,7 +249,7 @@ bool MachineLoop::isLoopInvariant(MachineInstr &I,
         // it could get allocated to something with a def during allocation.
         // However, if the physreg is known to always be caller saved/restored
         // then this use is safe to hoist.
-        if (!isLoopInvariantImplicitPhysReg(Reg) &&
+        if (!isLoopInvariantImplicitPhysReg(Reg, ExcludeReg, RecursionDepth) &&
             !(TRI->isCallerPreservedPhysReg(Reg.asMCReg(), *I.getMF())) &&
             !TII->isIgnorableUse(MO))
           return false;
@@ -265,9 +271,11 @@ bool MachineLoop::isLoopInvariant(MachineInstr &I,
     assert(MRI->getVRegDef(Reg) &&
            "Machine instr not mapped for this vreg?!");
 
-    // If the loop contains the definition of an operand, then the instruction
-    // isn't loop invariant.
-    if (contains(MRI->getVRegDef(Reg)))
+    // If the loop contains the definition of an operand, then it must be loop
+    // invariant
+    MachineInstr *VRegDefMI = MRI->getVRegDef(Reg);
+    if (contains(VRegDefMI) &&
+        !isLoopInvariant(*VRegDefMI, ExcludeReg, RecursionDepth - 1))
       return false;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index aa0b7c93f8661..0b1fb4d20801b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -28,6 +28,7 @@
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
+#include "llvm/CodeGen/MachineLoopInfo.h"
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/CodeGen/MachineOperand.h"
@@ -327,12 +328,27 @@ void AArch64InstrInfo::insertIndirectBranch(MachineBasicBlock &MBB,
       .addImm(16);
 }
 
-// Branch analysis.
+bool AArch64InstrInfo::isCondBranchPredictable(
+    const MachineInstr &CondBr, const MachineLoopInfo &MLI) const {
+  MachineLoop *Loop = MLI.getLoopFor(CondBr.getParent());
+  if (!Loop)
+    return false;
+  return Loop->isLoopInvariant(CondBr, /*ExcludeReg=*/0, /*RecursionDepth=*/2);
+}
+
 bool AArch64InstrInfo::analyzeBranch(MachineBasicBlock &MBB,
                                      MachineBasicBlock *&TBB,
                                      MachineBasicBlock *&FBB,
                                      SmallVectorImpl<MachineOperand> &Cond,
                                      bool AllowModify) const {
+  return analyzeBranch(MBB, TBB, FBB, Cond, nullptr, nullptr, AllowModify);
+}
+
+// Branch analysis.
+bool AArch64InstrInfo::analyzeBranch(
+    MachineBasicBlock &MBB, MachineBasicBlock *&TBB, MachineBasicBlock *&FBB,
+    SmallVectorImpl<MachineOperand> &Cond, bool *IsPredictable,
+    const MachineLoopInfo *MLI, bool AllowModify) const {
   // If the block has no terminators, it just falls into the block after it.
   MachineBasicBlock::iterator I = MBB.getLastNonDebugInstr();
   if (I == MBB.end())
@@ -360,6 +376,8 @@ bool AArch64InstrInfo::analyzeBranch(MachineBasicBlock &MBB,
     if (isCondBranchOpcode(LastOpc)) {
       // Block ends with fall-through condbranch.
       parseCondBranch(LastInst, TBB, Cond);
+      if (IsPredictable && MLI)
+        *IsPredictable = isCondBranchPredictable(*LastInst, *MLI);
       return false;
     }
     return true; // Can't handle indirect branch.
@@ -402,6 +420,8 @@ bool AArch64InstrInfo::analyzeBranch(MachineBasicBlock &MBB,
       if (isCondBranchOpcode(LastOpc)) {
         // Block ends with fall-through condbranch.
         parseCondBranch(LastInst, TBB, Cond);
+        if (IsPredictable && MLI)
+          *IsPredictable = isCondBranchPredictable(*LastInst, *MLI);
         return false;
       }
       return true; // Can't handle indirect branch.
@@ -418,6 +438,8 @@ bool AArch64InstrInfo::analyzeBranch(MachineBasicBlock &MBB,
   if (isCondBranchOpcode(SecondLastOpc) && isUncondBranchOpcode(LastOpc)) {
     parseCondBranch(SecondLastInst, TBB, Cond);
     FBB = LastInst->getOperand(0).getMBB();
+    if (IsPredictable && MLI)
+      *IsPredictable = isCondBranchPredictable(*LastInst, *MLI);
     return false;
   }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.h b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
index f434799c3982b..1d68a7e25ed3b 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.h
@@ -374,10 +374,17 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
                             MachineBasicBlock &RestoreBB, const DebugLoc &DL,
                             int64_t BrOffset, RegScavenger *RS) const override;
 
+  bool isCondBranchPredictable(const MachineInstr &CondBr,
+                               const MachineLoopInfo &MLI) const;
   bool analyzeBranch(MachineBasicBlock &MBB, MachineBasicBlock *&TBB,
                      MachineBasicBlock *&FBB,
                      SmallVectorImpl<MachineOperand> &Cond,
                      bool AllowModify = false) const override;
+  bool analyzeBranch(MachineBasicBlock &MBB, MachineBasicBlock *&TBB,
+                     MachineBasicBlock *&FBB,
+                     SmallVectorImpl<MachineOperand> &Cond, bool *IsPredictable,
+                     const MachineLoopInfo *MLI,
+                     bool AllowModify = false) const override;
   bool analyzeBranchPredicate(MachineBasicBlock &MBB,
                               MachineBranchPredicate &MBP,
                               bool AllowModify) const override;



More information about the llvm-commits mailing list