[llvm] [CodeGen] [EarlyIfConversion] Prioritize conversion of hard to predict branches (PR #174457)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 5 10:24:51 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Jonathan Cohen (jcohen-apple)

<details>
<summary>Changes</summary>

This PR is the first in a series which expands the if conversion heuristics to be more aggressive under certain conditions. Here we distinguish between conditional branches which we can prove that depend on values which are loaded from memory. Since branch history does not help guess predict the branch direction, we can increase the critical path extension we are willing to suffer, because we are more likely to hit the mispredict penalty on these kinds of branches.

This PR (and the following one) both lead to noticeable runtime improvements on internal benchmarks with negligible change to compile time (measured on `llvm-test-suite`), using an M4 Pro processor.

---
Full diff: https://github.com/llvm/llvm-project/pull/174457.diff


6 Files Affected:

- (modified) llvm/include/llvm/CodeGen/TargetRegisterInfo.h (+4) 
- (modified) llvm/lib/CodeGen/EarlyIfConversion.cpp (+169-2) 
- (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp (+4) 
- (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.h (+2) 
- (added) llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-easy-to-predict.mir (+45) 
- (added) llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-hard-to-predict.mir (+49) 


``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
index 35b14e8b8fd30..921a1e64b2ccb 100644
--- a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
+++ b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h
@@ -662,6 +662,10 @@ class LLVM_ABI TargetRegisterInfo : public MCRegisterInfo {
     return false;
   }
 
+  /// Returns the physical register that holds condition flags (status flags)
+  /// for conditional branches on this target.
+  virtual MCRegister getStatusFlagsRegister() const { return MCRegister(); }
+
   /// Prior to adding the live-out mask to a stackmap or patchpoint
   /// instruction, provide the target the opportunity to adjust it (mainly to
   /// remove pseudo-registers that should be ignored).
diff --git a/llvm/lib/CodeGen/EarlyIfConversion.cpp b/llvm/lib/CodeGen/EarlyIfConversion.cpp
index 28993c47c094d..3c728943df7df 100644
--- a/llvm/lib/CodeGen/EarlyIfConversion.cpp
+++ b/llvm/lib/CodeGen/EarlyIfConversion.cpp
@@ -55,10 +55,17 @@ BlockInstrLimit("early-ifcvt-limit", cl::init(30), cl::Hidden,
 static cl::opt<bool> Stress("stress-early-ifcvt", cl::Hidden,
   cl::desc("Turn all knobs to 11"));
 
+// Enable analysis of hard-to-predict branches (conditions derived from loads).
+static cl::opt<bool> EnableHardToPredictAnalysis(
+    "enable-early-ifcvt-hard-to-predict", cl::Hidden, cl::init(true),
+    cl::desc("Enable hard-to-predict branch analysis for if-conversion"));
+
 STATISTIC(NumDiamondsSeen,  "Number of diamonds");
 STATISTIC(NumDiamondsConv,  "Number of diamonds converted");
 STATISTIC(NumTrianglesSeen, "Number of triangles");
 STATISTIC(NumTrianglesConv, "Number of triangles converted");
+STATISTIC(NumDataDependant,
+          "Number of data dependent conditional branches encountered");
 
 //===----------------------------------------------------------------------===//
 //                                 SSAIfConv
@@ -811,6 +818,9 @@ class EarlyIfConverter {
   MachineTraceMetrics::Ensemble *MinInstr = nullptr;
   SSAIfConv IfConv;
 
+  // Cache for operand memory analysis.
+  DenseMap<Register, bool> OperandMemoryCache;
+
 public:
   EarlyIfConverter(MachineDominatorTree &DT, MachineLoopInfo &LI,
                    MachineTraceMetrics &MTM)
@@ -823,6 +833,8 @@ class EarlyIfConverter {
   bool tryConvertIf(MachineBasicBlock *);
   void invalidateTraces();
   bool shouldConvertIf();
+  bool isConditionDataDependent();
+  bool doOperandsComeFromMemory(Register Reg);
 };
 
 class EarlyIfConverterLegacy : public MachineFunctionPass {
@@ -896,6 +908,142 @@ void EarlyIfConverter::invalidateTraces() {
   Traces->verifyAnalysis();
 }
 
+/// Check if a register's value comes from a memory load by walking the
+/// def-use chain. Results are cached for virtual registers only.
+bool EarlyIfConverter::doOperandsComeFromMemory(Register Reg) {
+  if (!Reg.isVirtual())
+    return false;
+
+  // Check cache first.
+  auto It = OperandMemoryCache.find(Reg);
+  if (It != OperandMemoryCache.end())
+    return It->second;
+
+  // Walk the def-use chain.
+  SmallPtrSet<const MachineInstr *, 16> Visited;
+  SmallVector<const MachineInstr *, 16> Worklist;
+  SmallVector<Register, 16> VisitedRegs;
+
+  MachineInstr *DefMI = MRI->getVRegDef(Reg);
+  // The operand is defined outside of the function - it does not
+  // come from memory access.
+  if (!DefMI) {
+    OperandMemoryCache[Reg] = false;
+    return false;
+  }
+
+  Worklist.push_back(DefMI);
+  VisitedRegs.push_back(Reg);
+
+  while (!Worklist.empty()) {
+    const MachineInstr *MI = Worklist.pop_back_val();
+    if (!Visited.insert(MI).second)
+      continue;
+
+    // Check if this instruction is a load.
+    if (MI->mayLoad()) {
+      LLVM_DEBUG(dbgs() << "Operand comes from load: " << *MI);
+      // Cache all visited virtual registers as coming from memory.
+      for (Register Reg : VisitedRegs)
+        OperandMemoryCache[Reg] = true;
+      return true;
+    }
+
+    // Walk through all register use operands and find their definitions.
+    for (const MachineOperand &MO : MI->operands()) {
+      if (!MO.isReg() || !MO.isUse())
+        continue;
+      Register UseReg = MO.getReg();
+      if (!UseReg.isVirtual())
+        continue;
+
+      // Check cache for this operand.
+      auto CacheIt = OperandMemoryCache.find(UseReg);
+      if (CacheIt != OperandMemoryCache.end()) {
+        if (CacheIt->second) {
+          for (Register Reg : VisitedRegs)
+            OperandMemoryCache[Reg] = true;
+          return true;
+        }
+        continue;
+      }
+
+      MachineInstr *UseDef = MRI->getVRegDef(UseReg);
+      if (UseDef && !Visited.count(UseDef)) {
+        Worklist.push_back(UseDef);
+        VisitedRegs.push_back(UseReg);
+      }
+    }
+  }
+
+  // No load found, cache all visited virtual registers as not from memory.
+  for (Register Reg : VisitedRegs)
+    OperandMemoryCache[Reg] = false;
+  return false;
+}
+
+/// Check if the branch condition is data-dependent (comes from memory loads).
+bool EarlyIfConverter::isConditionDataDependent() {
+  MachineBasicBlock::iterator BranchIt = IfConv.Head->getFirstTerminator();
+  if (BranchIt == IfConv.Head->end())
+    return false;
+
+  // Get the status flags register from the target. If the target hasn't
+  // implemented this hook, we skip physical register analysis entirely
+  // and fallback to the previous behavior.
+  MCRegister StatusFlagsReg = TRI->getStatusFlagsRegister();
+
+  if (!StatusFlagsReg.isValid())
+    return false;
+
+  for (const MachineOperand &MO : BranchIt->operands()) {
+    if (!MO.isReg() || !MO.isUse())
+      continue;
+    Register Reg = MO.getReg();
+
+    if (Reg.isVirtual()) {
+      if (doOperandsComeFromMemory(Reg)) {
+        LLVM_DEBUG(dbgs() << "Condition is data-dependent (from memory)\n");
+        ++NumDataDependant;
+        return true;
+      }
+    } else if (Reg.isPhysical() && TRI->regsOverlap(Reg, StatusFlagsReg)) {
+      // For the status flags register (e.g., NZCV on AArch64),
+      // search backwards to find the defining instruction.
+      for (auto I = std::prev(BranchIt);; --I) {
+        // Check if this instruction defines the status flags register.
+        bool FoundDef = false;
+        for (const MachineOperand &DefMO : I->operands()) {
+          if (DefMO.isReg() && DefMO.isDef() &&
+              TRI->regsOverlap(DefMO.getReg(), StatusFlagsReg)) {
+            FoundDef = true;
+            break;
+          }
+        }
+
+        if (FoundDef) {
+          for (const MachineOperand &DefMO : I->operands()) {
+            if (!DefMO.isReg() || !DefMO.isUse())
+              continue;
+            Register DefReg = DefMO.getReg();
+            if (DefReg.isVirtual() && doOperandsComeFromMemory(DefReg)) {
+              LLVM_DEBUG(dbgs()
+                         << "Condition is data-dependent (from memory)\n");
+              ++NumDataDependant;
+              return true;
+            }
+          }
+          break;
+        }
+
+        if (I == IfConv.Head->begin())
+          break;
+      }
+    }
+  }
+  return false;
+}
+
 // Adjust cycles with downward saturation.
 static unsigned adjCycles(unsigned Cyc, int Delta) {
   if (Delta < 0 && Cyc + Delta > Cyc)
@@ -963,12 +1111,31 @@ bool EarlyIfConverter::shouldConvertIf() {
   unsigned MinCrit = std::min(TBBTrace.getCriticalPath(),
                               FBBTrace.getCriticalPath());
 
-  // Set a somewhat arbitrary limit on the critical path extension we accept.
-  unsigned CritLimit = SchedModel.MispredictPenalty/2;
+  // Set a limit on the critical path extension we accept.
+  // When hard-to-predict analysis is enabled, use full MispredictPenalty for
+  // hard-to-predict branches, half for others. Otherwise use half for all.
+  bool DataDependent = false;
+  if (EnableHardToPredictAnalysis) {
+    DataDependent = isConditionDataDependent();
+  }
+  unsigned CritLimit = DataDependent ? SchedModel.MispredictPenalty
+                                     : SchedModel.MispredictPenalty / 2;
 
   MachineBasicBlock &MBB = *IfConv.Head;
   MachineOptimizationRemarkEmitter MORE(*MBB.getParent(), nullptr);
 
+  // Emit analysis remark about data-dependent condition.
+  if (DataDependent) {
+    MORE.emit([&]() {
+      return MachineOptimizationRemarkAnalysis(DEBUG_TYPE,
+                                               "DataDependentCondition",
+                                               MBB.back().getDebugLoc(), &MBB)
+             << "branch condition is data-dependent (from memory load), "
+             << "using higher CritLimit of " << ore::NV("CritLimit", CritLimit)
+             << " cycles";
+    });
+  }
+
   // If-conversion only makes sense when there is unexploited ILP. Compute the
   // maximum-ILP resource length of the trace after if-conversion. Compare it
   // to the shortest critical path.
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index ee43448d5baec..b43c833b096e5 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -1435,3 +1435,7 @@ bool AArch64RegisterInfo::isIgnoredCVReg(MCRegister LLVMReg) const {
   return (LLVMReg >= AArch64::Z0 && LLVMReg <= AArch64::Z31) ||
          (LLVMReg >= AArch64::P0 && LLVMReg <= AArch64::P15);
 }
+
+MCRegister AArch64RegisterInfo::getStatusFlagsRegister() const {
+  return AArch64::NZCV;
+}
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
index 89d1802ab98d5..b40b34b6dd0ee 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.h
@@ -154,6 +154,8 @@ class AArch64RegisterInfo final : public AArch64GenRegisterInfo {
   bool shouldAnalyzePhysregInMachineLoopInfo(MCRegister R) const override;
 
   bool isIgnoredCVReg(MCRegister LLVMReg) const override;
+
+  MCRegister getStatusFlagsRegister() const override;
 };
 
 } // end namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-easy-to-predict.mir b/llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-easy-to-predict.mir
new file mode 100644
index 0000000000000..2560c4ade5c73
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-easy-to-predict.mir
@@ -0,0 +1,45 @@
+# RUN: llc -mtriple=aarch64 -mcpu=apple-m1 -run-pass=early-ifcvt -o - %s | FileCheck %s
+
+# CHECK: bb.0:
+# CHECK: SUBSWri
+# CHECK: Bcc
+# CHECK: bb.1:
+# CHECK: bb.2:
+# CHECK: PHI
+
+--- |
+  define i32 @arg_to_cond_branch(i32 %val, i32 %x, i32 %y) {
+    ret i32 0
+  }
+...
+---
+name:            arg_to_cond_branch
+alignment:       4
+tracksRegLiveness: true
+liveins:
+  - { reg: '$w0', virtual-reg: '%0' }
+  - { reg: '$w1', virtual-reg: '%1' }
+  - { reg: '$w2', virtual-reg: '%2' }
+body:             |
+  bb.1:
+    successors: %bb.2(0x40000000), %bb.3(0x40000000)
+    liveins: $w0, $w1, $w2
+
+    %0:gpr32common = COPY $w0
+    %1:gpr32 = COPY $w1
+    %2:gpr32 = COPY $w2
+    %11:gpr32 = SUBSWri %0, 100, 0, implicit-def $nzcv
+    Bcc 12, %bb.3, implicit $nzcv
+    B %bb.2
+
+  bb.2:
+    successors: %bb.3(0x80000000)
+
+    %20:gpr32 = MADDWrrr %1, %2, $wzr
+    %21:gpr32 = MADDWrrr %20, %1, $wzr
+
+  bb.3:
+    %100:gpr32 = PHI %1, %bb.1, %21, %bb.2
+    $w0 = COPY %100
+    RET_ReallyLR implicit $w0
+...
diff --git a/llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-hard-to-predict.mir b/llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-hard-to-predict.mir
new file mode 100644
index 0000000000000..0cffedf4533fc
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/early-ifcvt-cascade-tests/early-ifcvt-hard-to-predict.mir
@@ -0,0 +1,49 @@
+# RUN: llc -mtriple=aarch64 -mcpu=apple-m1 -run-pass=early-ifcvt -o - %s | FileCheck %s
+
+# CHECK-LABEL: name: load_to_cond_branch
+# CHECK: bb.0:
+# CHECK: LDRWui
+# CHECK: SUBSWri
+# CHECK: MADDWrrr
+# CHECK: MADDWrrr
+# CHECK: CSELWr
+# CHECK-NOT: Bcc
+# CHECK-NOT: PHI
+
+--- |
+  define i32 @load_to_cond_branch(ptr %p, i32 %x, i32 %y) {
+    ret i32 0
+  }
+...
+---
+name:            load_to_cond_branch
+alignment:       4
+tracksRegLiveness: true
+liveins:
+  - { reg: '$x0', virtual-reg: '%0' }
+  - { reg: '$w1', virtual-reg: '%1' }
+  - { reg: '$w2', virtual-reg: '%2' }
+body:             |
+  bb.1:
+    successors: %bb.2(0x40000000), %bb.3(0x40000000)
+    liveins: $x0, $w1, $w2
+
+    %0:gpr64common = COPY $x0
+    %1:gpr32 = COPY $w1
+    %2:gpr32 = COPY $w2
+    %10:gpr32common = LDRWui %0, 0 :: (load (s32))
+    %11:gpr32 = SUBSWri %10, 100, 0, implicit-def $nzcv
+    Bcc 12, %bb.3, implicit $nzcv
+    B %bb.2
+
+  bb.2:
+    successors: %bb.3(0x80000000)
+
+    %20:gpr32 = MADDWrrr %1, %2, $wzr
+    %21:gpr32 = MADDWrrr %20, %1, $wzr
+
+  bb.3:
+    %100:gpr32 = PHI %1, %bb.1, %21, %bb.2
+    $w0 = COPY %100
+    RET_ReallyLR implicit $w0
+...

``````````

</details>


https://github.com/llvm/llvm-project/pull/174457


More information about the llvm-commits mailing list