[llvm] [VPlan] Replace BranchOnCount with Compare + BranchOnCond (NFC). (PR #172181)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 13 14:20:13 PST 2025


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/172181

Expand BranchOnCount to BranchOnCond + ICmp in convertToConcreteRecipes to simplify codegen.

>From 5fd4626d323317e921cb6f6ee2eddbc1785ea243 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 13 Dec 2025 22:15:11 +0000
Subject: [PATCH] [VPlan] Replace BranchOnCount with Compare + BranchOnCond
 (NFC).

Expand BranchOnCount to BranchOnCond + ICmp in convertToConcreteRecipes
to simplify codegen.
---
 llvm/lib/Transforms/Vectorize/VPlan.h         |  2 +
 .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 39 ++++++-------------
 .../Transforms/Vectorize/VPlanTransforms.cpp  | 36 ++++++++++-------
 .../LoopVectorize/AArch64/vplan-printing.ll   |  3 +-
 .../LoopVectorize/vplan-iv-transforms.ll      |  3 +-
 .../LoopVectorize/vplan-predicate-switch.ll   |  3 +-
 .../vplan-printing-reductions.ll              |  6 ++-
 7 files changed, 46 insertions(+), 46 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 0f8f6abab7b0a..42a5baaa5750e 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -1054,6 +1054,8 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
     CalculateTripCountMinusVF,
     // Increment the canonical IV separately for each unrolled part.
     CanonicalIVIncrementForPart,
+    // Abstract instruction that compares two values and branches. This is
+    // lowered to ICmp + BranchOnCond during VPlan to VPlan transformation.
     BranchOnCount,
     BranchOnCond,
     Broadcast,
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 24951be232b1f..fab21b34ecfe2 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -510,26 +510,6 @@ bool VPInstruction::canGenerateScalarForFirstLane() const {
   }
 }
 
-/// Create a conditional branch using \p Cond branching to the successors of \p
-/// VPBB. Note that the first successor is always forward (i.e. not created yet)
-/// while the second successor may already have been created (if it is a header
-/// block and VPBB is a latch).
-static BranchInst *createCondBranch(Value *Cond, VPBasicBlock *VPBB,
-                                    VPTransformState &State) {
-  // Replace the temporary unreachable terminator with a new conditional
-  // branch, hooking it up to backward destination (header) for latch blocks
-  // now, and to forward destination(s) later when they are created.
-  // Second successor may be backwards - iff it is already in VPBB2IRBB.
-  VPBasicBlock *SecondVPSucc = cast<VPBasicBlock>(VPBB->getSuccessors()[1]);
-  BasicBlock *SecondIRSucc = State.CFG.VPBB2IRBB.lookup(SecondVPSucc);
-  BasicBlock *IRBB = State.CFG.VPBB2IRBB[VPBB];
-  BranchInst *CondBr = State.Builder.CreateCondBr(Cond, IRBB, SecondIRSucc);
-  // First successor is always forward, reset it to nullptr
-  CondBr->setSuccessor(0, nullptr);
-  IRBB->getTerminator()->eraseFromParent();
-  return CondBr;
-}
-
 Value *VPInstruction::generate(VPTransformState &State) {
   IRBuilderBase &Builder = State.Builder;
 
@@ -659,17 +639,20 @@ Value *VPInstruction::generate(VPTransformState &State) {
   }
   case VPInstruction::BranchOnCond: {
     Value *Cond = State.get(getOperand(0), VPLane(0));
-    auto *Br = createCondBranch(Cond, getParent(), State);
+    // Replace the temporary unreachable terminator with a new conditional
+    // branch, hooking it up to backward destination for latch blocks now, and
+    // to forward destination(s) later when they are created.
+    VPBasicBlock *SecondVPSucc =
+        cast<VPBasicBlock>(getParent()->getSuccessors()[1]);
+    BasicBlock *SecondIRSucc = State.CFG.VPBB2IRBB.lookup(SecondVPSucc);
+    BasicBlock *IRBB = State.CFG.VPBB2IRBB[getParent()];
+    auto *Br = Builder.CreateCondBr(Cond, IRBB, SecondIRSucc);
+    // First successor is always forward, reset it to nullptr.
+    Br->setSuccessor(0, nullptr);
+    IRBB->getTerminator()->eraseFromParent();
     applyMetadata(*Br);
     return Br;
   }
-  case VPInstruction::BranchOnCount: {
-    // First create the compare.
-    Value *IV = State.get(getOperand(0), /*IsScalar*/ true);
-    Value *TC = State.get(getOperand(1), /*IsScalar*/ true);
-    Value *Cond = Builder.CreateICmpEQ(IV, TC);
-    return createCondBranch(Cond, getParent(), State);
-  }
   case VPInstruction::Broadcast: {
     return Builder.CreateVectorSplat(
         State.VF, State.get(getOperand(0), /*IsScalar*/ true), "broadcast");
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 6627133878fdb..f1424df1ce924 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -3201,27 +3201,25 @@ void VPlanTransforms::canonicalizeEVLLoops(VPlan &Plan) {
   CanonicalIV->eraseFromParent();
 
   // Replace the use of VectorTripCount in the latch-exiting block.
-  // Before: (branch-on-count EVLIVInc, VectorTripCount)
-  // After: (branch-on-cond eq AVLNext, 0)
-
+  // Before: (branch-on-cond (icmp eq EVLIVInc, VectorTripCount))
+  // After: (branch-on-cond icmp eq AVLNext, 0)
   VPBasicBlock *LatchExiting =
       HeaderVPBB->getPredecessors()[1]->getEntryBasicBlock();
   auto *LatchExitingBr = cast<VPInstruction>(LatchExiting->getTerminator());
-  // Skip single-iteration loop region
   if (match(LatchExitingBr, m_BranchOnCond(m_True())))
     return;
-  assert(LatchExitingBr &&
-         match(LatchExitingBr,
-               m_BranchOnCount(m_VPValue(EVLIncrement),
-                               m_Specific(&Plan.getVectorTripCount()))) &&
-         "Unexpected terminator in EVL loop");
+
+  assert(match(LatchExitingBr, m_BranchOnCond(m_SpecificCmp(
+                                   CmpInst::ICMP_EQ, m_VPValue(EVLIncrement),
+                                   m_Specific(&Plan.getVectorTripCount())))) &&
+         "Expected BranchOnCond with ICmp comparing EVL increment with vector "
+         "trip count");
 
   Type *AVLTy = VPTypeAnalysis(Plan).inferScalarType(AVLNext);
   VPBuilder Builder(LatchExitingBr);
-  VPValue *Cmp = Builder.createICmp(CmpInst::ICMP_EQ, AVLNext,
-                                    Plan.getConstantInt(AVLTy, 0));
-  Builder.createNaryOp(VPInstruction::BranchOnCond, Cmp);
-  LatchExitingBr->eraseFromParent();
+  LatchExitingBr->setOperand(0,
+                             Builder.createICmp(CmpInst::ICMP_EQ, AVLNext,
+                                                Plan.getConstantInt(AVLTy, 0)));
 }
 
 void VPlanTransforms::replaceSymbolicStrides(
@@ -3740,6 +3738,17 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
         continue;
       }
 
+      // Lower BranchOnCount to ICmp + BranchOnCond.
+      VPValue *IV, *TC;
+      if (match(&R, m_BranchOnCount(m_VPValue(IV), m_VPValue(TC)))) {
+        auto *BranchOnCountInst = cast<VPInstruction>(&R);
+        DebugLoc DL = BranchOnCountInst->getDebugLoc();
+        VPValue *Cond = Builder.createICmp(CmpInst::ICMP_EQ, IV, TC, DL);
+        Builder.createNaryOp(VPInstruction::BranchOnCond, {Cond}, DL);
+        ToRemove.push_back(BranchOnCountInst);
+        continue;
+      }
+
       VPValue *VectorStep;
       VPValue *ScalarStep;
       if (!match(&R, m_VPInstruction<VPInstruction::WideIVStep>(
@@ -3853,6 +3862,7 @@ void VPlanTransforms::handleUncountableEarlyExit(VPBasicBlock *EarlyExitingVPBB,
   // with one exiting if either the original condition of the vector latch is
   // true or the early exit has been taken.
   auto *LatchExitingBranch = cast<VPInstruction>(LatchVPBB->getTerminator());
+  // Skip single-iteration loop region
   assert(LatchExitingBranch->getOpcode() == VPInstruction::BranchOnCount &&
          "Unexpected terminator");
   auto *IsLatchExitTaken =
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
index a1d03c4a7fbc6..32ee9a0142a7b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll
@@ -91,7 +91,8 @@ define i32 @print_partial_reduction(ptr %a, ptr %b) "target-features"="+neon,+do
 ; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a>
 ; CHECK-NEXT:   PARTIAL-REDUCE ir<[[RDX_NEXT]]> = ir<[[RDX]]> + reduce.add (ir<%mul>)
 ; CHECK-NEXT:   EMIT vp<[[EP_IV_NEXT:%.+]]> = add nuw vp<[[EP_IV]]>, ir<16>
-; CHECK-NEXT:   EMIT branch-on-count vp<[[EP_IV_NEXT]]>, ir<1024>
+; CHECK-NEXT:   EMIT vp<{{%.+}}> = icmp eq vp<%index.next>, ir<1024>
+; CHECK-NEXT:   EMIT branch-on-cond vp<{{%.+}}>
 ; CHECK-NEXT: Successor(s): middle.block, vector.body
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-iv-transforms.ll b/llvm/test/Transforms/LoopVectorize/vplan-iv-transforms.ll
index 2c260ef9f4963..8f365d2d4dd5e 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-iv-transforms.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-iv-transforms.ll
@@ -115,7 +115,8 @@ define void @iv_expand(ptr %p, i64 %n) {
 ; CHECK-NEXT:   WIDEN store ir<%q>, ir<%y>
 ; CHECK-NEXT:   EMIT vp<%index.next> = add nuw vp<[[SCALAR_PHI]]>, ir<8>
 ; CHECK-NEXT:   EMIT vp<%vec.ind.next> = add ir<%iv>, vp<[[BROADCAST_INC]]>
-; CHECK-NEXT:   EMIT branch-on-count vp<%index.next>, vp<%n.vec>
+; CHECK-NEXT:   EMIT vp<{{%.+}}> = icmp eq vp<%index.next>, vp<%n.vec>
+; CHECK-NEXT:   EMIT branch-on-cond vp<{{%.+}}>
 ; CHECK-NEXT: Successor(s): middle.block, vector.body
 entry:
   br label %loop
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-predicate-switch.ll b/llvm/test/Transforms/LoopVectorize/vplan-predicate-switch.ll
index cf85f26992c2f..917869bf96f02 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-predicate-switch.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-predicate-switch.ll
@@ -84,7 +84,8 @@ define void @switch4_default_common_dest_with_case(ptr %start, ptr %end) {
 ; CHECK-EMPTY:
 ; CHECK-NEXT: default.2:
 ; CHECK-NEXT:   EMIT vp<[[CAN_IV_NEXT]]> = add nuw vp<[[CAN_IV]]>, ir<2>
-; CHECK-NEXT:   EMIT branch-on-count vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
+; CHECK-NEXT:   EMIT vp<{{%.+}}> = icmp eq vp<[[CAN_IV_NEXT]]>, vp<[[VTC]]>
+; CHECK-NEXT:   EMIT branch-on-cond vp<{{%.+}}>
 ; CHECK-NEXT: Successor(s): middle.block, vector.body
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
index 0fde76477b9f9..8e160e57b7703 100644
--- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
+++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll
@@ -547,7 +547,8 @@ define i32 @print_mulacc_sub(ptr %a, ptr %b) {
 ; CHECK-NEXT:   WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a>
 ; CHECK-NEXT:   REDUCE ir<%add> = ir<%accum> + reduce.sub (ir<%mul>)
 ; CHECK-NEXT:   EMIT vp<%index.next> = add nuw vp<%index>, ir<4>
-; CHECK-NEXT:   EMIT branch-on-count vp<%index.next>, ir<1024>
+; CHECK-NEXT:   EMIT vp<{{%.+}}> = icmp eq vp<%index.next>, ir<1024>
+; CHECK-NEXT:   EMIT branch-on-cond vp<{{%.+}}>
 ; CHECK-NEXT: Successor(s): middle.block, vector.body
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:
@@ -667,7 +668,8 @@ define i32 @print_mulacc_negated(ptr %a, ptr %b) {
 ; CHECK-NEXT:   WIDEN ir<%sub> = sub ir<0>, ir<%mul>
 ; CHECK-NEXT:   REDUCE ir<%add> = ir<%accum> + reduce.add (ir<%sub>)
 ; CHECK-NEXT:   EMIT vp<%index.next> = add nuw vp<%index>, ir<4>
-; CHECK-NEXT:   EMIT branch-on-count vp<%index.next>, ir<1024>
+; CHECK-NEXT:   EMIT vp<{{%.+}}> = icmp eq vp<%index.next>, ir<1024>
+; CHECK-NEXT:   EMIT branch-on-cond vp<{{%.+}}>
 ; CHECK-NEXT: Successor(s): middle.block, vector.body
 ; CHECK-EMPTY:
 ; CHECK-NEXT: middle.block:



More information about the llvm-commits mailing list